"""
Tools for working with Q (Forbes) polynomials.
code adapted in its majority from the prysm package - (https://github.com/brandondube/prysm)
Manuel Fragata Mendes, 2025
Copyright notice:
Copyright (c) 2017 Brandon Dube
"""
from __future__ import annotations
from collections import defaultdict
from functools import cache
from scipy import special
import optiland.backend as be
def kronecker(i: int, j: int) -> int:
"""The Kronecker delta function."""
return 1 if i == j else 0
def _trim_trailing_zeros(coefs):
"""Drop trailing exact-zero entries from a coefficient sequence.
Only *trailing* exact-zero entries are removed;
interior zeros are preserved. Returns an empty list when all entries
are zero or the input is empty/``None``.
Args:
coefs: A sequence of scalar coefficients.
Returns:
list: A list of coefficients with trailing zeros removed.
"""
if coefs is None:
return []
try:
n = len(coefs)
except TypeError:
# Scalar (or 0-d tensor) — wrap in a single-element list, trimming
# if it is zero.
try:
return [] if bool(coefs == 0) else [coefs]
except Exception:
return [coefs]
while n > 0:
v = coefs[n - 1]
try:
is_zero = bool(v == 0)
except Exception:
is_zero = False
if not is_zero:
break
n -= 1
if n == 0:
return []
if isinstance(coefs, list | tuple):
return list(coefs[:n])
# NumPy array or Torch tensor — keep entries by reference so that any
# autograd graph attached to surviving entries is preserved.
return [coefs[i] for i in range(n)]
@cache
def gamma_func(n: int, m: int) -> float:
"""Recursive gamma function for Q2D polynomials."""
if n == 1 and m == 2:
return 3 / 8
if n == 1 and m > 2:
mm1 = m - 1
numerator = 2 * mm1 + 1
denominator = 2 * (mm1 - 1)
return (numerator / denominator) * gamma_func(1, mm1)
nm1 = n - 1
num = (nm1 + 1) * (2 * m + 2 * nm1 - 1)
den = (m + nm1 - 2) * (2 * nm1 + 1)
return (num / den) * gamma_func(nm1, m)
# -----------------------------------------------------------------------------
# Forbes slope-orthogonal (Q^bfs) polynomial basis functions
# -----------------------------------------------------------------------------
# NOTE: The "qbfs" suffix in function names below is a historical identifier
# from Forbes' 2007 paper, where these polynomials were called Q^bfs for
# "best-fit sphere." However, the modern formulation (Forbes 2011) uses these
# same polynomial basis functions with a general conic reference surface
# (conic constant k may be nonzero). The "qbfs" naming is retained here for
# code stability and to match the original literature, but users should NOT
# infer that a spherical reference is required or used.
# -----------------------------------------------------------------------------
@cache
def g_qbfs(n_minus_1: int) -> float:
"""Recurrence coefficient g for Q-BFS polynomials."""
if n_minus_1 == 0:
return -1 / 2
n_minus_2 = n_minus_1 - 1
return -(1 + g_qbfs(n_minus_2) * h_qbfs(n_minus_2)) / f_qbfs(n_minus_1)
@cache
def h_qbfs(n_minus_2: int) -> float:
"""Recurrence coefficient h for Q-BFS polynomials."""
n = n_minus_2 + 2
return -n * (n - 1) / (2 * f_qbfs(n_minus_2))
@cache
def f_qbfs(n: int) -> float:
"""Recurrence coefficient f for Q-BFS polynomials."""
if n == 0:
return 2.0
if n == 1:
return 19**0.5 / 2
term1 = float(n * (n + 1) + 3)
term2 = g_qbfs(n - 1) ** 2
term3 = h_qbfs(n - 2) ** 2
return (term1 - term2 - term3) ** 0.5
def change_basis_qbfs_to_pn(cs: list[float], _no_trim: bool = False) -> be.array:
"""
Changes the basis of Q-BFS coefficients to orthonormal Pn coefficients.
Trailing exact-zero entries are removed before basis conversion;
they would not contribute to the polynomial sum, and keeping them
drives the Clenshaw recurrence to needlessly high order.
Args:
cs: Q-BFS coefficient sequence.
_no_trim: When True, ``cs`` is assumed already trimmed (trailing
zeros removed). This avoids the ``bool(v == 0)`` element checks,
which force a device-to-host synchronization on CUDA tensors.
"""
if not _no_trim:
cs = _trim_trailing_zeros(cs)
m = len(cs) - 1
if m < 0:
return be.array(cs)
bs_list = [0.0] * (m + 1)
f_m = f_qbfs(m)
if not isinstance(f_m, (int | float)):
cs = be.stack(cs)
bs_list[m] = cs[m] / f_m
if m == 0:
return be.array(bs_list) if be.get_backend() != "torch" else be.stack(bs_list)
g = g_qbfs(m - 1)
f = f_qbfs(m - 1)
bs_list[m - 1] = (cs[m - 1] - g * bs_list[m]) / f
for i in range(m - 2, -1, -1):
g = g_qbfs(i)
h = h_qbfs(i)
f = f_qbfs(i)
bs_list[i] = (cs[i] - g * bs_list[i + 1] - h * bs_list[i + 2]) / f
return be.array(bs_list) if be.get_backend() != "torch" else be.stack(bs_list)
def _initialize_alphas_q(cs, x, alphas, j=0):
"""Initializes the alpha array for Clenshaw's algorithm."""
if alphas is not None:
return alphas
n_modes = max(len(cs), 2)
shape = (n_modes, *be.shape(x)) if hasattr(x, "shape") else (n_modes,)
if j != 0:
shape = (j + 1, *shape)
zeros = be.zeros(shape)
if be.get_backend() == "torch":
zeros.requires_grad = False
return zeros
def _clenshaw_qbfs_recurrence(bs, usq, alphas):
"""Backend-agnostic Clenshaw recurrence calculation for Q-BFS."""
m = len(bs) - 1
if m < 0:
return alphas
prefix = 2 - 4 * usq
alphas[m] = bs[m]
if m > 0:
alphas[m - 1] = bs[m - 1] + prefix * alphas[m]
for i in range(m - 2, -1, -1):
alphas[i] = bs[i] + prefix * alphas[i + 1] - alphas[i + 2]
return alphas
def clenshaw_qbfs(
cs: list[float], usq: be.array, alphas: be.array = None, _no_trim: bool = False
):
"""Computes the sum of Q-BFS polynomials using Clenshaw's algorithm.
Trailing exact-zero coefficients are trimmed before evaluation unless
``_no_trim`` is set (the caller has already prepared a trimmed sequence).
"""
if not _no_trim:
cs = _trim_trailing_zeros(cs)
bs = change_basis_qbfs_to_pn(cs, _no_trim=True)
m = len(bs) - 1
if m < 0:
return be.zeros_like(usq) if hasattr(usq, "shape") else 0.0
if be.get_backend() == "torch":
s, _, _ = _clenshaw_qbfs_functional(bs, usq)
if alphas is not None:
alphas_res = _clenshaw_qbfs_recurrence(bs, usq, be.empty_like(alphas))
alphas[...] = alphas_res
return s
alphas = _initialize_alphas_q(cs, usq, alphas)
alphas = _clenshaw_qbfs_recurrence(bs, usq, alphas)
return 2 * (alphas[0] + alphas[1]) if m > 0 else 2 * alphas[0]
def _clenshaw_qbfs_functional(bs, usq):
"""Pure-functional Clenshaw that returns (S, alpha0, alpha1)."""
m = len(bs) - 1
if m < 0:
zeros = be.zeros_like(usq)
return zeros, zeros, zeros
prefix = 2 - 4 * usq
b_curr = bs[m] + usq * 0
b_next = be.zeros_like(b_curr)
for n in range(m - 1, -1, -1):
b_new = bs[n] + prefix * b_curr - b_next
b_next, b_curr = b_curr, b_new
alpha0, alpha1 = b_curr, b_next
s = 2 * (alpha0 + alpha1) if m > 0 else 2 * alpha0
return s, alpha0, alpha1
def clenshaw_qbfs_der(cs, usq, j=1, alphas=None, _no_trim: bool = False):
"""Computes derivatives of Q-BFS polynomials using Clenshaw's method.
Trailing exact-zero coefficients are trimmed before evaluation unless
``_no_trim`` is set. When the derivative order ``j`` exceeds the trimmed
polynomial degree the higher-order alpha tables are returned as zero.
"""
if not _no_trim:
cs = _trim_trailing_zeros(cs)
if be.get_backend() == "torch":
return _clenshaw_qbfs_der_functional(cs, usq, j)
m = len(cs) - 1
alphas = _initialize_alphas_q(cs, usq, alphas, j=j)
if m < 0:
return alphas
clenshaw_qbfs(cs, usq, alphas=alphas[0])
prefix = 2 - 4 * usq
for jj in range(1, j + 1):
if m - jj < 0:
continue
alphas[jj][m - jj] = -4 * jj * alphas[jj - 1][m - jj + 1]
if m - jj - 1 >= 0:
alphas[jj][m - jj - 1] = (
prefix * alphas[jj][m - jj] - 4 * jj * alphas[jj - 1][m - jj]
)
for n in range(m - jj - 2, -1, -1):
alphas[jj][n] = (
prefix * alphas[jj][n + 1]
- alphas[jj][n + 2]
- 4 * jj * alphas[jj - 1][n + 1]
)
return alphas
def _clenshaw_qbfs_der_functional(cs, usq, j=1):
"""Pure-functional Clenshaw for Q-BFS derivatives (PyTorch backend)."""
m = len(cs) - 1
if m < 0:
shape = (
(j + 1, len(cs), *be.shape(usq))
if hasattr(usq, "shape")
else (j + 1, len(cs))
)
return be.zeros(shape)
bs = change_basis_qbfs_to_pn(cs, _no_trim=True)
prefix = 2 - 4 * usq
# functional implementation of the base case (j=0)
alphas_j0_list = [be.zeros_like(usq) for _ in range(m + 1)]
if m >= 0:
# first scalar coefficient is broadcast to the full
# tensor size
alphas_j0_list[m] = bs[m] + be.zeros_like(usq)
if m >= 1:
alphas_j0_list[m - 1] = bs[m - 1] + prefix * alphas_j0_list[m]
for i in range(m - 2, -1, -1):
alphas_j0_list[i] = (
bs[i] + prefix * alphas_j0_list[i + 1] - alphas_j0_list[i + 2]
)
all_alphas_tensors = [be.stack(alphas_j0_list)]
prev_alphas_j_list = alphas_j0_list
for jj in range(1, j + 1):
alphas_jj_list = [be.zeros_like(usq) for _ in range(m + 1)]
if m - jj >= 0:
alphas_jj_list[m - jj] = -4 * jj * prev_alphas_j_list[m - jj + 1]
if m - jj - 1 >= 0:
alphas_jj_list[m - jj - 1] = (
prefix * alphas_jj_list[m - jj] - 4 * jj * prev_alphas_j_list[m - jj]
)
for n in range(m - jj - 2, -1, -1):
alphas_jj_list[n] = (
prefix * alphas_jj_list[n + 1]
- alphas_jj_list[n + 2]
- 4 * jj * prev_alphas_j_list[n + 1]
)
all_alphas_tensors.append(be.stack(alphas_jj_list))
prev_alphas_j_list = alphas_jj_list
return be.stack(all_alphas_tensors)
def compute_z_qbfs(
coefs: list[float], usq: be.array, _no_trim: bool = False
) -> be.array:
"""Sag-only Q-BFS polynomial sum (no derivative table built).
Equivalent to the first return value of :func:`compute_z_zprime_qbfs`
but skips the j=1 Clenshaw pass entirely. Use this from ``sag()``
code paths where the derivative is not needed.
Args:
coefs: Q-BFS coefficient sequence (trailing zeros are trimmed).
usq: Squared normalized radius ``u**2``.
_no_trim: When True, ``coefs`` is assumed already trimmed.
Returns:
be.array: The raw Q-BFS polynomial sum at each ``usq`` sample.
"""
if not _no_trim:
coefs = _trim_trailing_zeros(coefs)
if len(coefs) == 0:
return be.zeros_like(usq) if hasattr(usq, "shape") else be.array(0.0)
return clenshaw_qbfs(coefs, usq, _no_trim=True)
def compute_z_zprime_qbfs(
coefs: list[float], u: be.array, usq: be.array, _no_trim: bool = False
) -> tuple[be.array, be.array]:
"""Computes the raw Q-BFS polynomial sum and its derivative w.r.t. u."""
if not _no_trim:
coefs = _trim_trailing_zeros(coefs)
if len(coefs) == 0:
zeros = be.zeros_like(u)
return zeros, zeros
alphas = clenshaw_qbfs_der(coefs, usq, j=1, _no_trim=True)
if len(coefs) > 1:
s = 2 * (alphas[0, 0] + alphas[0, 1])
ds_dusq = 2 * (alphas[1, 0] + alphas[1, 1])
else:
s = 2 * alphas[0, 0]
ds_dusq = 2 * alphas[1, 0]
ds_du = ds_dusq * 2 * u
return s, ds_du
# q2d polynomials logic
@cache
def _g_q2d_raw(n: int, m: int) -> float:
"""Raw G coefficient for Q2D polynomials."""
if n == 0:
num = special.factorial2(2 * m - 1)
den = 2 ** (m + 1) * special.factorial(m - 1)
return num / den
if n > 0 and m == 1:
t1num = (2 * n**2 - 1) * (n**2 - 1)
t1den = 8 * (4 * n**2 - 1)
term1 = -t1num / t1den
term2 = 1 / 24 * kronecker(n, 1)
return term1 - term2
nt1 = 2 * n * (m + n - 1) - m
nt2 = (n + 1) * (2 * m + 2 * n - 1)
num = nt1 * nt2
dt1 = (m + 2 * n - 2) * (m + 2 * n - 1)
dt2 = (m + 2 * n) * (2 * n + 1)
den = dt1 * dt2
term1 = -num / den
return term1 * gamma_func(n, m)
@cache
def _f_q2d_raw(n: int, m: int) -> float:
"""Raw F coefficient for Q2D polynomials."""
if n == 0 and m == 1:
return 0.25
if n == 0:
num = m**2 * special.factorial2(2 * m - 3)
den = 2 ** (m + 1) * special.factorial(m - 1)
return num / den
if n > 0 and m == 1:
t1num = 4 * (n - 1) ** 2 * n**2 + 1
t1den = 8 * (2 * n - 1) ** 2
term1 = t1num / t1den
term2 = 11 / 32 * kronecker(n, 1)
return term1 + term2
chi = m + n - 2
nt1 = 2 * n * chi * (3 - 5 * m + 4 * n * chi)
nt2 = m**2 * (3 - m + 4 * n * chi)
num = nt1 + nt2
dt1 = (m + 2 * n - 3) * (m + 2 * n - 2)
dt2 = (m + 2 * n - 1) * (2 * n - 1)
den = dt1 * dt2
term1 = num / den
return term1 * gamma_func(n, m)
@cache
def g_q2d(n: int, m: int) -> float:
"""Recurrence coefficient g for Q2D polynomials."""
return _g_q2d_raw(n, m) / f_q2d(n, m)
@cache
def f_q2d(n: int, m: int) -> float:
"""Recurrence coefficient f for Q2D polynomials."""
if n == 0:
return _f_q2d_raw(n=0, m=m) ** 0.5
return (_f_q2d_raw(n, m) - g_q2d(n - 1, m) ** 2) ** 0.5
def change_basis_q2d_to_pnm(
cns: list[float], m: int, _no_trim: bool = False
) -> be.array:
"""
Changes the basis of Q2D coefficients to orthonormal Pnm coefficients.
"""
if not _no_trim:
cns = _trim_trailing_zeros(cns)
m = abs(m)
n_max = len(cns) - 1
if n_max < 0:
return be.array(cns)
ds_list = [be.array(0.0)] * (n_max + 1)
ds_list[n_max] = cns[n_max] / f_q2d(n_max, m)
for n in range(n_max - 1, -1, -1):
ds_list[n] = (cns[n] - g_q2d(n, m) * ds_list[n + 1]) / f_q2d(n, m)
return be.stack(ds_list)
_ABC_Q2D_SPECIAL_CASES = {
(1, 0): (2, -1, 0),
(1, 1): (-4 / 3, -8 / 3, -11 / 3),
(1, 2): (9 / 5, -24 / 5, 0),
(2, 0): (3, -2, 0),
(3, 0): (5, -4, 0),
}
@cache
def abc_q2d(n: int, m: int) -> tuple[float, float, float]:
"""Recurrence coefficients A, B, C for Q2D Clenshaw algorithm."""
d = (4 * n**2 - 1) * (m + n - 2) * (m + 2 * n - 3)
if d == 0:
d = 1e-99
term1 = (2 * n - 1) * (m + 2 * n - 2)
term2 = 4 * n * (m + n - 2) + (m - 3) * (2 * m - 1)
a = (term1 * term2) / d
num_b = -2 * (2 * n - 1) * (m + 2 * n - 3) * (m + 2 * n - 2) * (m + 2 * n - 1)
b = num_b / d
num_c = n * (2 * n - 3) * (m + 2 * n - 1) * (2 * m + 2 * n - 3)
c = num_c / d
return a, b, c
def abc_q2d_clenshaw(n: int, m: int) -> tuple[float, float, float]:
"""Provides A, B, C coefficients for Clenshaw, handling special cases."""
return _ABC_Q2D_SPECIAL_CASES.get((m, n), abc_q2d(n, m))
def q2d_sum_from_alphas(alphas: be.array, m: int, num_coeffs: int) -> be.array:
"""
Computes the final sum from the alpha coefficients returned by Clenshaw's
method, applying the special summation rule for m=1.
The m==1 correction reads ``alphas[3]``; the read is also guarded by
the actual alpha-table length, so a caller passing a stale
``num_coeffs`` does not over-index.
"""
s = 0.5 * alphas[0]
# special case for m=1, as in Forbes' papers
if m == 1 and num_coeffs - 1 > 2 and be.shape(alphas)[0] > 3:
s -= 2 / 5 * alphas[3]
return s
def _get_s_and_s_prime(alphas, m, num_coeffs):
"""Helper to compute S and S' from alpha derivatives for Q2D."""
s = q2d_sum_from_alphas(alphas[0], m, num_coeffs)
s_prime = q2d_sum_from_alphas(alphas[1], m, num_coeffs)
return s, s_prime
def _compute_m_gt0_components(ams, bms, u, t, usq, _no_trim: bool = False):
"""Computes the sum and derivatives for all m>0 components.
Per-m terms are accumulated incrementally rather than collected into
lists and ``be.stack``-ed, which avoids allocating a ``(n_modes, *shape)``
intermediate per accumulator (a measurable kernel-launch/allocation cost
on CUDA). The running ``a + b`` adds are out-of-place, so the PyTorch
autograd graph is preserved.
"""
poly_sum = dr_sum = dt_sum = None
for m_idx, (a_coef, b_coef) in enumerate(zip(ams, bms, strict=False)):
m = m_idx + 1
# Trim trailing zeros independently per family so the m==1 special
# case in q2d_sum_from_alphas reads a correctly sized alpha table
# and does not over index when the user supplied vector ends in
# zeros. Skipped when the caller passes already-prepared families.
if not _no_trim:
a_coef = _trim_trailing_zeros(a_coef)
b_coef = _trim_trailing_zeros(b_coef)
s_a, s_b, s_prime_a, s_prime_b = 0, 0, 0, 0
if a_coef:
alphas_a = clenshaw_q2d_der(a_coef, m, usq, j=1, _no_trim=True)
s_a, s_prime_a = _get_s_and_s_prime(alphas_a, m, len(a_coef))
if b_coef:
alphas_b = clenshaw_q2d_der(b_coef, m, usq, j=1, _no_trim=True)
s_b, s_prime_b = _get_s_and_s_prime(alphas_b, m, len(b_coef))
um = u**m
cost = be.cos(m * t)
sint = be.sin(m * t)
poly_term = um * (cost * s_a + sint * s_b)
umm1 = u ** (m - 1) if m > 0 else be.ones_like(u)
two_usq = 2 * usq
aterm = cost * (two_usq * s_prime_a + m * s_a)
bterm = sint * (two_usq * s_prime_b + m * s_b)
dr_term = umm1 * (aterm + bterm)
dt_term = m * um * (-s_a * sint + s_b * cost)
poly_sum = poly_term if poly_sum is None else poly_sum + poly_term
dr_sum = dr_term if dr_sum is None else dr_sum + dr_term
dt_sum = dt_term if dt_sum is None else dt_sum + dt_term
zeros = be.zeros_like(u)
return (
poly_sum if poly_sum is not None else zeros,
dr_sum if dr_sum is not None else zeros,
dt_sum if dt_sum is not None else zeros,
)
def _harmonic_powers(X, Y, m_max):
"""Compute Re/Im parts of (X + iY)**k for k = 0 .. m_max.
Iteratively applies the complex-multiplication recurrence
H_c[k+1] = X * H_c[k] - Y * H_s[k]
H_s[k+1] = X * H_s[k] + Y * H_c[k]
Backend-agnostic and singularity-free at the origin (no division
by ``r``); autograd-safe (purely functional list construction).
Args:
X: Normalized x-coordinate (be.array).
Y: Normalized y-coordinate (be.array).
m_max: Highest azimuthal order needed (inclusive).
Returns:
tuple[list[be.array], list[be.array]]: ``(H_c, H_s)`` each of
length ``m_max + 1``.
"""
ones = be.ones_like(X) if hasattr(X, "shape") else be.array(1.0)
zeros = be.zeros_like(X) if hasattr(X, "shape") else be.array(0.0)
H_c = [ones]
H_s = [zeros]
for _ in range(m_max):
c_prev, s_prev = H_c[-1], H_s[-1]
H_c.append(X * c_prev - Y * s_prev)
H_s.append(X * s_prev + Y * c_prev)
return H_c, H_s
def _q2d_cartesian_eval(X, Y, cm0, ams, bms, _no_trim: bool = False):
"""Evaluate the Q2D polynomial sum P(X, Y) and its Cartesian derivatives.
P is the dimensionless polynomial part of the Forbes Q2D departure,
P(X, Y) = u**2 * (1 - u**2) * S_cm0(u**2)
+ sum_{m>=1} [ Re((X + iY)**m) * S_a_m(u**2)
+ Im((X + iY)**m) * S_b_m(u**2) ],
with ``u**2 = X**2 + Y**2``. Derivatives are computed in normalized
Cartesian coordinates via harmonic powers, so the result is regular
at ``X = Y = 0`` (no polar ``1/r`` artifact). The caller is
responsible for applying the conic-correction factor, the base-sag
derivative, and the chain rule from ``X = x / R_n`` to physical
coordinates.
Args:
X: Normalized x-coordinate (be.array).
Y: Normalized y-coordinate (be.array).
cm0: m=0 (Qbfs-style) coefficient sequence; trimmed internally.
ams: Per-m cosine coefficient families (index ``i`` is m == i+1).
bms: Per-m sine coefficient families, same layout as ``ams``.
Returns:
tuple: ``(P, dP_dX, dP_dY)`` as be.arrays, broadcast to
``X`` / ``Y`` shape.
"""
usq = X * X + Y * Y
# m == 0 envelope: P_m0 = u^2 (1 - u^2) * S_cm0(u^2)
if not _no_trim:
cm0 = _trim_trailing_zeros(cm0)
if cm0:
# Reuse derivative Clenshaw to get S and dS/du^2 in one sweep.
alphas_m0 = clenshaw_qbfs_der(cm0, usq, j=1, _no_trim=True)
if len(cm0) > 1:
s_cm0 = 2 * (alphas_m0[0, 0] + alphas_m0[0, 1])
dsdu2_cm0 = 2 * (alphas_m0[1, 0] + alphas_m0[1, 1])
else:
s_cm0 = 2 * alphas_m0[0, 0]
dsdu2_cm0 = 2 * alphas_m0[1, 0]
env = usq * (1 - usq)
P_m0 = env * s_cm0
# d/dX [ u^2(1-u^2) * S ] = 2X * [ (1 - 2u^2) S + u^2(1-u^2) dS/du^2 ]
radial_chain = (1 - 2 * usq) * s_cm0 + env * dsdu2_cm0
dP_m0_dX = 2 * X * radial_chain
dP_m0_dY = 2 * Y * radial_chain
else:
zeros = be.zeros_like(usq)
P_m0 = zeros
dP_m0_dX = zeros
dP_m0_dY = zeros
# m >= 1: harmonic powers + per-m radial Clenshaw.
m_max = max(len(ams), len(bms))
if m_max == 0:
return P_m0, dP_m0_dX, dP_m0_dY
H_c, H_s = _harmonic_powers(X, Y, m_max)
# Accumulate the m>0 contributions incrementally to avoid building three
# term lists and the corresponding ``be.stack`` / ``be.sum`` intermediates
# (heavy allocation on CUDA float32 for dense freeforms). Out-of-place adds
# keep the autograd graph intact.
P_mgt0 = dPx_mgt0 = dPy_mgt0 = None
for m_idx in range(m_max):
m = m_idx + 1
a_coef = ams[m_idx] if m_idx < len(ams) else []
b_coef = bms[m_idx] if m_idx < len(bms) else []
if not _no_trim:
a_coef = _trim_trailing_zeros(a_coef)
b_coef = _trim_trailing_zeros(b_coef)
if not a_coef and not b_coef:
continue
s_a = s_b = dsdu2_a = dsdu2_b = 0.0
if a_coef:
alphas_a = clenshaw_q2d_der(a_coef, m, usq, j=1, _no_trim=True)
s_a = q2d_sum_from_alphas(alphas_a[0], m, len(a_coef))
dsdu2_a = q2d_sum_from_alphas(alphas_a[1], m, len(a_coef))
if b_coef:
alphas_b = clenshaw_q2d_der(b_coef, m, usq, j=1, _no_trim=True)
s_b = q2d_sum_from_alphas(alphas_b[0], m, len(b_coef))
dsdu2_b = q2d_sum_from_alphas(alphas_b[1], m, len(b_coef))
Hc_m = H_c[m]
Hs_m = H_s[m]
Hc_mm1 = H_c[m - 1]
Hs_mm1 = H_s[m - 1]
P_term = Hc_m * s_a + Hs_m * s_b
# d/dX [ H_c[m] S_a + H_s[m] S_b ]
# = m*H_c[m-1]*S_a + H_c[m]*2X*dS_a + m*H_s[m-1]*S_b + H_s[m]*2X*dS_b
dPx_term = m * (Hc_mm1 * s_a + Hs_mm1 * s_b) + 2 * X * (
Hc_m * dsdu2_a + Hs_m * dsdu2_b
)
# d/dY [ H_c[m] S_a + H_s[m] S_b ]
# = -m*H_s[m-1]*S_a + H_c[m]*2Y*dS_a + m*H_c[m-1]*S_b + H_s[m]*2Y*dS_b
dPy_term = m * (-Hs_mm1 * s_a + Hc_mm1 * s_b) + 2 * Y * (
Hc_m * dsdu2_a + Hs_m * dsdu2_b
)
P_mgt0 = P_term if P_mgt0 is None else P_mgt0 + P_term
dPx_mgt0 = dPx_term if dPx_mgt0 is None else dPx_mgt0 + dPx_term
dPy_mgt0 = dPy_term if dPy_mgt0 is None else dPy_mgt0 + dPy_term
if P_mgt0 is None:
P_mgt0 = be.zeros_like(usq)
dPx_mgt0 = be.zeros_like(usq)
dPy_mgt0 = be.zeros_like(usq)
return P_m0 + P_mgt0, dP_m0_dX + dPx_mgt0, dP_m0_dY + dPy_mgt0
def _compute_m_gt0_sag_only(ams, bms, u, t, usq, _no_trim: bool = False):
"""Sag-only counterpart of :func:`_compute_m_gt0_components`.
Skips the derivative Clenshaw pass and the radial / azimuthal
derivative accumulators; only the m>0 polynomial sum is returned.
Terms are accumulated incrementally (see
:func:`_compute_m_gt0_components`).
"""
poly_sum = None
for m_idx, (a_coef, b_coef) in enumerate(zip(ams, bms, strict=False)):
m = m_idx + 1
if not _no_trim:
a_coef = _trim_trailing_zeros(a_coef)
b_coef = _trim_trailing_zeros(b_coef)
s_a, s_b = 0, 0
if a_coef:
alphas_a = clenshaw_q2d(a_coef, m, usq, _no_trim=True)
s_a = q2d_sum_from_alphas(alphas_a, m, len(a_coef))
if b_coef:
alphas_b = clenshaw_q2d(b_coef, m, usq, _no_trim=True)
s_b = q2d_sum_from_alphas(alphas_b, m, len(b_coef))
um = u**m
cost = be.cos(m * t)
sint = be.sin(m * t)
poly_term = um * (cost * s_a + sint * s_b)
poly_sum = poly_term if poly_sum is None else poly_sum + poly_term
return poly_sum if poly_sum is not None else be.zeros_like(u)
def compute_z_q2d(cm0, ams, bms, u, t, _no_trim: bool = False):
"""Sag-only Q2D polynomial sum (no derivative table built).
Returns the pair ``(poly_sum_m0, poly_sum_m_gt0)`` — the same first
and third entries as :func:`compute_z_zprime_q2d` would return, but
without the j=1 Clenshaw pass over each per-m family. Use this from
``sag()`` code paths where the derivative is not needed.
Args:
cm0: m==0 Qbfs-style coefficient sequence.
ams: Per-m cosine coefficient families (index ``i`` is m == i+1).
bms: Per-m sine coefficient families, same layout as ``ams``.
u: Normalized radius.
t: Azimuth in radians.
Returns:
tuple: ``(poly_sum_m0, poly_sum_m_gt0)``.
"""
usq = u * u
zeros = be.zeros_like(u)
if not _no_trim:
cm0 = _trim_trailing_zeros(cm0)
poly_sum_m0 = zeros if not cm0 else compute_z_qbfs(cm0, usq, _no_trim=True)
poly_sum_m_gt0 = _compute_m_gt0_sag_only(ams, bms, u, t, usq, _no_trim=_no_trim)
return poly_sum_m0, poly_sum_m_gt0
[docs]
def compute_z_zprime_q2d(cm0, ams, bms, u, t, _no_trim: bool = False):
"""Computes the polynomial sum components for a Q2D surface."""
usq = u * u
zeros = be.zeros_like(u)
if not _no_trim:
cm0 = _trim_trailing_zeros(cm0)
poly_sum_m0, d_poly_sum_m0_du = zeros, zeros
if cm0:
poly_sum_m0, d_poly_sum_m0_du = compute_z_zprime_qbfs(
cm0, u, usq, _no_trim=True
)
poly_sum_m_gt0, dr_m_gt0, dt_m_gt0 = _compute_m_gt0_components(
ams, bms, u, t, usq, _no_trim=_no_trim
)
return poly_sum_m0, d_poly_sum_m0_du, poly_sum_m_gt0, dr_m_gt0, dt_m_gt0
[docs]
def q2d_nm_coeffs_to_ams_bms(nms: list[tuple[int, int]], coefs: list[float]):
"""Converts a list of (n, m) indexed coefficients to grouped a_m and b_m lists."""
cms = []
ac = defaultdict(list)
bc = defaultdict(list)
for (n, m), c in zip(nms, coefs, strict=False):
if m == 0:
if n >= len(cms):
cms.extend([0.0] * (n - len(cms) + 1))
cms[n] = c
continue
target_dict = ac if m > 0 else bc
m_abs = abs(m)
if n >= len(target_dict[m_abs]):
target_dict[m_abs].extend([0.0] * (n - len(target_dict[m_abs]) + 1))
target_dict[m_abs][n] = c
max_m = 0
if ac:
max_m = max(max_m, max(ac.keys()))
if bc:
max_m = max(max_m, max(bc.keys()))
ams_ret = [ac.get(i, []) for i in range(1, max_m + 1)]
bms_ret = [bc.get(i, []) for i in range(1, max_m + 1)]
return cms, ams_ret, bms_ret
def clenshaw_q2d(cns, m, usq, alphas=None, _no_trim: bool = False):
"""Evaluates the Q2D Clenshaw alpha table for azimuthal order ``m``."""
if not _no_trim:
cns = _trim_trailing_zeros(cns)
if be.get_backend() == "torch":
ds = change_basis_q2d_to_pnm(cns, m, _no_trim=True)
all_alphas_list = _clenshaw_q2d_functional(ds, m, usq)
if not all_alphas_list:
return _initialize_alphas_q(cns, usq, alphas)
result_tensor = be.stack(all_alphas_list)
if alphas is not None:
alphas[...] = result_tensor
return alphas
return result_tensor
ds = change_basis_q2d_to_pnm(cns, m, _no_trim=True)
alphas = _initialize_alphas_q(ds, usq, alphas)
n_max = len(ds) - 1
if n_max < 0:
return alphas
alphas[n_max] = ds[n_max]
if n_max > 0:
a, b, _ = abc_q2d_clenshaw(n_max - 1, m)
alphas[n_max - 1] = ds[n_max - 1] + (a + b * usq) * alphas[n_max]
for n in range(n_max - 2, -1, -1):
a, b, _ = abc_q2d_clenshaw(n, m)
_, _, c = abc_q2d_clenshaw(n + 1, m)
alphas[n] = ds[n] + (a + b * usq) * alphas[n + 1] - c * alphas[n + 2]
return alphas
def _clenshaw_q2d_functional(ds, m, usq):
"""Pure-functional Clenshaw for Q2D polynomials."""
n_max = len(ds) - 1
if n_max < 0:
return []
all_alphas = [be.zeros_like(usq) for _ in range(n_max + 1)]
if n_max >= 0:
all_alphas[n_max] = ds[n_max] + usq * 0
if n_max >= 1:
a, b, _ = abc_q2d_clenshaw(n_max - 1, m)
all_alphas[n_max - 1] = ds[n_max - 1] + (a + b * usq) * all_alphas[n_max]
for n in range(n_max - 2, -1, -1):
a, b, _ = abc_q2d_clenshaw(n, m)
_, _, c = abc_q2d_clenshaw(n + 1, m)
all_alphas[n] = (
ds[n] + (a + b * usq) * all_alphas[n + 1] - c * all_alphas[n + 2]
)
return all_alphas
def clenshaw_q2d_der(cns, m, usq, j=1, alphas=None, _no_trim: bool = False):
"""Computes derivatives of Q-2D polynomials using Clenshaw's method."""
if not _no_trim:
cns = _trim_trailing_zeros(cns)
if be.get_backend() == "torch":
return _clenshaw_q2d_der_functional(cns, m, usq, j)
n_max = len(cns) - 1
alphas = _initialize_alphas_q(cns, usq, alphas, j=j)
if n_max < 0:
return alphas
clenshaw_q2d(cns, m, usq, alphas[0])
for jj in range(1, j + 1):
if n_max - jj < 0:
continue
_, b, _ = abc_q2d_clenshaw(n_max - jj, m)
alphas[jj][n_max - jj] = jj * b * alphas[jj - 1][n_max - jj + 1]
for n in range(n_max - jj - 1, -1, -1):
a, b, _ = abc_q2d_clenshaw(n, m)
_, _, c = abc_q2d_clenshaw(n + 1, m)
alphas[jj][n] = (
jj * b * alphas[jj - 1][n + 1]
+ (a + b * usq) * alphas[jj][n + 1]
- c * alphas[jj][n + 2]
)
return alphas
def _clenshaw_q2d_der_functional(cns, m, usq, j=1):
"""Pure-functional Clenshaw for Q-2D derivatives (PyTorch backend)."""
n_max = len(cns) - 1
if n_max < 0:
shape = (
(j + 1, len(cns), *be.shape(usq))
if hasattr(usq, "shape")
else (j + 1, len(cns))
)
return be.zeros(shape)
ds = change_basis_q2d_to_pnm(cns, m, _no_trim=True)
alphas_j0_list = _clenshaw_q2d_functional(ds, m, usq)
all_alphas_tensors = [be.stack(alphas_j0_list)]
prev_alphas_j_list = alphas_j0_list
for jj in range(1, j + 1):
alphas_jj_list = [be.zeros_like(usq) for _ in range(n_max + 1)]
if n_max - jj >= 0:
_, b, _ = abc_q2d_clenshaw(n_max - jj, m)
alphas_jj_list[n_max - jj] = jj * b * prev_alphas_j_list[n_max - jj + 1]
for n in range(n_max - jj - 1, -1, -1):
a, b, _ = abc_q2d_clenshaw(n, m)
_, _, c = abc_q2d_clenshaw(n + 1, m)
alphas_jj_list[n] = (
jj * b * prev_alphas_j_list[n + 1]
+ (a + b * usq) * alphas_jj_list[n + 1]
- c * alphas_jj_list[n + 2]
)
all_alphas_tensors.append(be.stack(alphas_jj_list))
prev_alphas_j_list = alphas_jj_list
return be.stack(all_alphas_tensors)