"""Zernike Fit Module
This module contains the ZernikeFit class, which can be used to fit Zernike
polynomial to a set of points. This is commonly used for wavefront calculations,
but the class can be used for any fitting operation requiring Zernike polynomials.
Kramer Harrison, 2025
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Literal
import matplotlib.pyplot as plt
import numpy as np
from optiland import backend as be
from optiland.zernike import ZernikeFringe, ZernikeNoll, ZernikeStandard
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from optiland._types import BEArray, ScalarOrArray
ZERNIKE_CLASSES: dict[str, type[ZernikeFringe | ZernikeStandard | ZernikeNoll]] = {
"fringe": ZernikeFringe,
"standard": ZernikeStandard,
"noll": ZernikeNoll,
}
[docs]
class ZernikeFit:
"""
Fit Zernike polynomials to wavefront or arbitrary data points.
This class constructs a linear design matrix of Zernike basis functions and solves
for the coefficients via least squares.
Args:
x (array-like): X-coordinates of data points.
y (array-like): Y-coordinates of data points.
z (array-like): Values at (x, y) to fit.
zernike_type (str): Type of Zernike basis: 'fringe', 'standard', or 'noll'.
num_terms (int): Number of Zernike terms to include in the fit.
Attributes:
x (array-like): Flattened x-coordinates.
y (array-like): Flattened y-coordinates.
z (array-like): Flattened target values.
radius (array-like): Radial coordinate of each point.
phi (array-like): Azimuthal coordinate of each point.
num_pts (int): Number of data points.
zernike (BaseZernike): Zernike basis instance with fitted coefficients.
"""
def __init__(
self,
x: ScalarOrArray,
y: ScalarOrArray,
z: ScalarOrArray,
zernike_type: Literal["fringe", "standard", "noll"] = "fringe",
num_terms: int = 36,
):
# Convert inputs to backend tensors and flatten
self.x = be.asarray(x).reshape(-1)
self.y = be.asarray(y).reshape(-1)
self.z = be.asarray(z).reshape(-1)
if self.x.shape != self.y.shape or self.x.shape != self.z.shape:
raise ValueError("`x`, `y`, and `z` must have the same number of elements.")
self.num_terms = num_terms
self.num_pts = int(be.size(self.x))
# Compute polar coordinates
self.radius = be.sqrt(self.x**2 + self.y**2)
self.phi = be.arctan2(self.y, self.x)
# Validate Zernike type and instantiate basis
if zernike_type not in ZERNIKE_CLASSES:
raise ValueError(
f"Invalid Zernike type '{zernike_type}'. "
f"Choose from: {list(ZERNIKE_CLASSES)}"
)
self.zernike_type = zernike_type
self.zernike: ZernikeFringe | ZernikeStandard | ZernikeNoll = ZERNIKE_CLASSES[
zernike_type
](be.ones([num_terms]))
# Fit coefficients
self._fit()
@property
def coeffs(self) -> BEArray:
"""
Tensor: The fitted Zernike coefficients.
"""
return self.zernike.coeffs
def _fit(self):
"""
Build design matrix of Zernike basis functions and solve linear least squares.
"""
# Build design matrix A
A = be.stack(self.zernike.terms(self.radius, self.phi), axis=1)
# Solve linear least squares A c = z
try:
solution = be.linalg.lstsq(A, self.z, rcond=None)
coeffs = solution[0]
except (AttributeError, TypeError):
# Fallback via pseudoinverse
pinv = be.linalg.pinv(A)
coeffs = be.matmul(pinv, self.z)
# Assign coefficients to the main zernike instance
self.zernike.coeffs = coeffs
[docs]
def view(
self,
fig_to_plot_on: Figure | None = None,
projection: str = "2d",
num_points: int = 128,
figsize: tuple[float, float] = (7, 5.5),
z_label: str = "OPD (waves)",
) -> tuple[Figure, Axes]:
"""
Visualize the fitted Zernike surface.
Args:
fig_to_plot_on (plt.Figure, optional): Figure to plot on.
If None, a new figure is created.
projection (str): '2d' for image plot, '3d' for surface plot.
num_points (int): Grid resolution for display.
figsize (tuple): Figure size in inches.
defaults to (7, 5.5).
z_label (str): Label for the z-axis or colorbar.
defaults to 'OPD (waves)'.
Returns:
tuple: A tuple containing the figure and axes objects.
Raises:
ValueError: If `projection` is not '2d' or '3d'.
"""
is_gui_embedding = fig_to_plot_on is not None
if is_gui_embedding:
current_fig = fig_to_plot_on
current_fig.clear()
ax = (
current_fig.add_subplot(111)
if projection == "2d"
else current_fig.add_subplot(111, projection="3d")
)
else:
current_fig, ax = (
plt.subplots(figsize=figsize)
if projection == "2d"
else plt.subplots(subplot_kw={"projection": "3d"}, figsize=figsize)
)
# Create grid in unit circle
grid_x, grid_y = be.meshgrid(
be.linspace(-1.0, 1.0, num_points),
be.linspace(-1.0, 1.0, num_points),
)
grid_r = be.sqrt(grid_x**2 + grid_y**2)
grid_phi = be.arctan2(grid_y, grid_x)
grid_z = self.zernike.poly(grid_r, grid_phi)
# Mask outside unit circle
grid_z = be.where(grid_r > 1.0, be.nan, grid_z)
# Convert to NumPy for plotting
x_np = be.to_numpy(grid_x)
y_np = be.to_numpy(grid_y)
z_np = be.to_numpy(grid_z)
if projection == "2d":
self._plot_2d(current_fig, ax, z_np, z_label=z_label)
elif projection == "3d":
self._plot_3d(current_fig, ax, x_np, y_np, z_np, z_label=z_label)
else:
raise ValueError("`projection` must be '2d' or '3d'.")
if is_gui_embedding and hasattr(current_fig, "canvas"):
current_fig.canvas.draw_idle()
return current_fig, ax
[docs]
def view_residual(
self,
fig_to_plot_on: Figure | None = None,
figsize: tuple[float, float] = (7, 5.5),
z_label: str = "Residual (waves)",
):
"""
Scatter plot of residuals between fitted surface and original data.
Args:
fig_to_plot_on (plt.Figure, optional): Figure to plot on.
If None, a new figure is created.
figsize (tuple): Figure size in inches.
Defaults to (7, 5.5).
z_label (str): Label for the colorbar.
Defaults to 'Residual (waves)'.
Returns:
tuple: A tuple containing the figure and axes objects.
"""
# Compute fitted values and residuals
fitted = self.zernike.poly(self.radius, self.phi)
residuals = fitted - self.z
rms = be.sqrt(be.mean(residuals**2))
is_gui_embedding = fig_to_plot_on is not None
if is_gui_embedding:
current_fig = fig_to_plot_on
current_fig.clear()
ax = current_fig.add_subplot(111)
else:
current_fig, ax = plt.subplots(figsize=figsize)
sc = ax.scatter(
be.to_numpy(self.x),
be.to_numpy(self.y),
c=be.to_numpy(residuals),
marker="o",
edgecolors="none",
)
ax.set_xlabel("Pupil X")
ax.set_ylabel("Pupil Y")
ax.set_title(f"Residuals (RMS={rms:.3f})")
cbar = plt.colorbar(sc)
cbar.ax.set_ylabel(z_label, rotation=270, labelpad=15)
if is_gui_embedding and hasattr(current_fig, "canvas"):
current_fig.canvas.draw_idle()
return current_fig, ax
def _plot_2d(self, fig: Figure, ax: Axes, z: np.ndarray, z_label: str) -> None:
"""Plot a 2D representation of the given data.
Args:
z (numpy.ndarray): The data to be plotted.
figsize (tuple, optional): The size of the figure
(default is (7, 5.5)).
z_label (str, optional): The label for the colorbar
(default is 'OPD (waves)').
"""
im = ax.imshow(np.flipud(z), extent=[-1, 1, -1, 1])
ax.set_xlabel("Pupil X")
ax.set_ylabel("Pupil Y")
ax.set_title(f"Zernike {self.zernike_type.capitalize()} Fit")
cbar = plt.colorbar(im)
cbar.ax.set_ylabel(z_label, rotation=270, labelpad=15)
def _plot_3d(
self,
fig: Figure,
ax: Axes,
x: np.ndarray,
y: np.ndarray,
z: np.ndarray,
z_label: str,
) -> None:
"""Plot a 3D surface plot of the given data.
Args:
fig (Figure): The figure to plot on.
ax (Axes): The axes to plot on.
x (numpy.ndarray): Array of x-coordinates.
y (numpy.ndarray): Array of y-coordinates.
z (numpy.ndarray): Array of z-coordinates.
z_label (str, optional): Label for the z-axis.
"""
surf = ax.plot_surface(
x,
y,
z,
rstride=1,
cstride=1,
cmap="viridis",
linewidth=0,
antialiased=False,
)
ax.set_xlabel("Pupil X")
ax.set_ylabel("Pupil Y")
ax.set_zlabel(z_label)
ax.set_title(f"Zernike {self.zernike_type.capitalize()} Fit")
fig.colorbar(surf, ax=ax, shrink=0.5, aspect=10, pad=0.15)
fig.tight_layout()