"""Through Focus Spot Diagram Analysis
This module provides a class for performing through-focus spot diagram
analysis, calculating the spot diagram at various focal planes.
Kramer Harrison, 2025
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Literal
import matplotlib.pyplot as plt
import numpy as np
import optiland.backend as be
from optiland.analysis.spot_diagram import SpotDiagram
from optiland.analysis.through_focus import ThroughFocusAnalysis
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from numpy.typing import NDArray
from optiland._types import DistributionType
from optiland.optic import Optic
[docs]
class ThroughFocusSpotDiagram(ThroughFocusAnalysis):
"""Performs spot diagram analysis over a range of focal planes.
This class extends `ThroughFocusAnalysis` to specifically calculate and
report RMS spot radii from spot diagrams at various focal positions.
It utilizes the `SpotDiagram` class for the core calculations at each
focal plane.
Attributes:
optic (optiland.optic.Optic): The optical system being analyzed.
delta_focus (float): The focal shift increment in mm.
num_steps (int): Number of focal planes analyzed before and after
the nominal focus.
fields (list): Resolved list of field coordinates for analysis.
wavelengths (list): Resolved list of wavelengths for analysis.
num_rings (int): Number of rings for pupil sampling in the
`SpotDiagram` calculation.
distribution (str): Pupil sampling distribution type (e.g.,
'hexapolar', 'random') for `SpotDiagram`.
coordinates (Literal["global", "local"]): Coordinate system used for
spot data generation within `SpotDiagram`.
results (list[dict[float, list[float]]]): A list where each item is a
dictionary. Each dictionary corresponds to a single focal plane
and maps the delta focus (float, in mm) to a list of RMS spot
radii (list of floats, in mm). Each RMS spot radius in the list
corresponds to a field defined in `self.fields`, calculated at the
primary wavelength.
"""
def __init__(
self,
optic: Optic,
delta_focus: float = 0.1,
num_steps: int = 5,
fields="all",
wavelengths="all",
num_rings: int = 6,
distribution: DistributionType = "hexapolar",
coordinates: Literal["global", "local"] = "local",
):
"""Initializes the ThroughFocusSpotDiagram analysis.
Args:
optic (optiland.optic.Optic): The optical system to analyze.
delta_focus (float, optional): The increment of focal shift in mm.
Defaults to 0.1.
num_steps (int, optional): The number of focal planes to analyze
on either side of the nominal focus. Defaults to 5. Must be in
range [1, 7].
fields (list[tuple[float,float]] | str, optional): Fields for
analysis. If "all", uses all fields from `optic.fields`.
Otherwise, expects a list of field coordinates.
Defaults to "all".
wavelengths (list[float] | str, optional): Wavelengths for
analysis. If "all", uses all wavelengths from
`optic.wavelengths`. Otherwise, expects a list of
wavelength values. Defaults to "all".
num_rings (int, optional): Number of rings for pupil sampling in
the `SpotDiagram` calculation. Defaults to 6.
distribution (str, optional): Pupil sampling distribution type for
`SpotDiagram` (e.g., 'hexapolar', 'random').
Defaults to "hexapolar".
coordinates (Literal["global", "local"], optional): Coordinate
system for spot data generation in `SpotDiagram`.
Defaults to "local".
"""
self.num_rings = num_rings
self.distribution: DistributionType = distribution
if coordinates not in ["global", "local"]:
raise ValueError("Coordinates must be 'global' or 'local'.")
self.coordinates = coordinates
super().__init__(
optic,
delta_focus=delta_focus,
num_steps=num_steps,
fields=fields,
wavelengths=wavelengths,
)
def _perform_analysis_at_focus(self):
"""Calculates RMS spot radii at the current focal plane.
This method is called by the base class for each focal step. It
instantiates a `SpotDiagram` object for the optic's current focal
state, calculates the RMS spot radius for each specified field at the
primary wavelength, and returns this data.
Note:
This implementation re-instantiates `SpotDiagram` for each focal
step, which involves recalculating ray data. For high-performance
needs, optimizing this by directly accessing or reusing ray tracing
functionality might be considered.
Returns:
list: a list of spot diagram data, including intersection points and
intensity
"""
# Extract raw coords and wavelength values so SpotDiagram can consume them
fields_raw = [fp.coord for fp in self.fields]
wavelengths_raw = [wp.value for wp in self.wavelengths]
spot_diagram_at_focus = SpotDiagram(
self.optic,
fields=fields_raw,
wavelengths=wavelengths_raw,
num_rings=self.num_rings,
distribution=self.distribution,
coordinates=self.coordinates,
)
return spot_diagram_at_focus.data
[docs]
def view(
self,
fig_to_plot_on: Figure | None = None,
figsize_per_plot: tuple[float, float] = (3, 3),
buffer: float = 1.05,
*,
show: bool = True,
) -> tuple[Figure, list[Axes]] | None:
"""
Visualizes the through-focus spot diagrams, either in a new window or on a
provided GUI figure.
Args:
fig_to_plot_on: A matplotlib figure to plot on.
If None, a new figure will be created.
figsize_per_plot: Size of each subplot
in inches
(width, height). Defaults to (3, 3).
buffer: Scaling buffer applied to the maximum radius
for axis limits. Defaults to 1.05.
show (bool): If True (default), calls plt.show(). Set False for
headless use.
Returns:
A tuple containing the figure and a list of axes used for plotting. Or None
if updating the GUI.
"""
is_gui_embedding = fig_to_plot_on is not None
if not self._validate_view_prerequisites():
if is_gui_embedding:
fig_to_plot_on.text(
0.5, 0.5, "No data to display.", ha="center", va="center"
)
if hasattr(fig_to_plot_on, "canvas"):
fig_to_plot_on.canvas.draw_idle()
return
num_fields = len(self.fields)
num_steps = self.num_steps
if is_gui_embedding:
current_fig = fig_to_plot_on
current_fig.clear()
else:
current_fig = plt.figure(
figsize=(
num_steps * figsize_per_plot[0],
num_fields * figsize_per_plot[1],
)
)
axs = current_fig.subplots(
num_fields, num_steps, sharex=True, sharey=True, squeeze=False
)
global_axis_limit = self._compute_global_axis_limit(buffer)
x_label, y_label = self._get_plot_axis_labels()
legend_handles, legend_labels = [], []
for i, fp in enumerate(self.fields):
field_coord = fp.coord
for j, position in enumerate(self.positions):
ax = axs[i, j]
data = self.results[j][i]
defocus = float(position) - be.to_numpy(self.nominal_focus).item()
centroid_x, centroid_y = self._get_spot_centroid(data)
self._plot_wavelengths(
ax,
data,
centroid_x,
centroid_y,
i,
j,
legend_handles,
legend_labels,
)
self._configure_subplot(
ax,
field_coord,
defocus,
i,
j,
num_fields,
x_label,
y_label,
global_axis_limit,
)
self._add_legend(
current_fig, legend_handles, legend_labels, num_fields, figsize_per_plot
)
current_fig.tight_layout(rect=(0, 0.03, 1, 0.97))
if is_gui_embedding and hasattr(current_fig, "canvas"):
current_fig.canvas.draw_idle()
if show and not is_gui_embedding:
plt.show()
return current_fig, current_fig.get_axes()
def _validate_view_prerequisites(self) -> bool:
"""Validates prerequisites before plotting.
Checks whether results, fields, and wavelengths are present
and non-empty.
Returns:
True if plotting can proceed, False otherwise.
"""
if not self.results:
print("No data to display. Run analysis first.")
return False
if not self.fields or not self.wavelengths or self.num_steps == 0:
print("No fields, defocus steps, or wavelengths to plot.")
return False
return True
def _create_subplot_grid(
self, num_fields: int, num_steps: int, figsize_per_plot: tuple[float, float]
) -> tuple[Figure, NDArray[np.object_]]:
"""Creates a 2D grid of subplots.
Args:
num_fields: Number of rows (fields).
num_steps: Number of columns (defocus steps).
figsize_per_plot: Size per subplot in inches.
Returns:
tuple: (matplotlib.figure.Figure, ndarray of Axes).
"""
fig, axs = plt.subplots(
num_fields,
num_steps,
figsize=(num_steps * figsize_per_plot[0], num_fields * figsize_per_plot[1]),
sharex=True,
sharey=True,
squeeze=False,
)
return fig, axs
def _get_plot_axis_labels(self) -> tuple[str, str]:
"""Determines axis labels based on image surface orientation.
Returns:
tuple[str, str]: Labels for the X and Y axes.
"""
cs = self.optic.image_surface.geometry.cs
orientation = np.abs(be.to_numpy(cs.get_effective_rotation_euler()))
tol = 0.01
if orientation[0] > tol or orientation[1] > tol:
return "U (mm)", "V (mm)"
return "X (mm)", "Y (mm)"
def _compute_global_axis_limit(self, buffer: float) -> float:
"""Computes a global axis limit for consistent plot scaling.
Considers the maximum geometric radius of spot positions
(centered by centroid) across all defocus steps and fields.
Args:
buffer (float): Scaling buffer applied to max radius.
Returns:
float: Global axis limit after applying buffer.
"""
max_r_sq = 0.0
for data_at_step in self.results:
for field_data in data_at_step:
centroid_x, centroid_y = self._get_spot_centroid(field_data)
for spot_data in field_data:
valid = spot_data.intensity != 0
if be.any(valid):
dx = spot_data.x - centroid_x
dy = spot_data.y - centroid_y
r_sq = dx[valid] ** 2 + dy[valid] ** 2
max_r_sq = max(max_r_sq, be.to_numpy(be.max(r_sq)).item())
return np.sqrt(max_r_sq) * buffer if max_r_sq > 0 else 0.01
def _get_spot_centroid(self, field_data: list) -> tuple[float, float]:
"""Computes the centroid of spot data for the primary wavelength.
Uses intensity-weighted centroid unless all rays have zero intensity,
in which case returns (0.0, 0.0).
Args:
field_data (list): List of spot data items across wavelengths.
Returns:
tuple[float, float]: (x, y) centroid in mm.
"""
idx = self.optic.wavelengths.primary_index
idx = min(idx, len(field_data) - 1)
spot = field_data[idx]
nonzero = spot.intensity != 0
if be.any(nonzero):
cx = be.to_numpy(be.mean(spot.x[nonzero])).item()
cy = be.to_numpy(be.mean(spot.y[nonzero])).item()
else:
cx = cy = 0.0
return cx, cy
def _plot_wavelengths(
self,
ax: Axes,
field_data: list,
cx: float,
cy: float,
i: int,
j: int,
handles: list,
labels: list,
):
"""Plots rays for all wavelengths, centered at the primary centroid.
Args:
ax (matplotlib.axes.Axes): Axis object to draw on.
field_data (list): List of spot data for one field at one defocus step.
cx (float): Centroid x-coordinate.
cy (float): Centroid y-coordinate.
i (int): Field index (row).
j (int): Defocus step index (column).
handles (list): List to store legend handle objects.
labels (list): List to store corresponding legend labels.
"""
markers = ["o", "s", "^"]
for k, spot in enumerate(field_data):
x = be.to_numpy(spot.x - cx)
y = be.to_numpy(spot.y - cy)
i_mask = be.to_numpy(spot.intensity) != 0
if np.any(i_mask):
scatter = ax.scatter(
x[i_mask],
y[i_mask],
s=10,
marker=markers[k % len(markers)],
alpha=0.7,
)
if i == 0 and j == 0:
wl = self.wavelengths[k].value
handles.append(scatter)
labels.append(f"{wl:.4f} µm")
def _configure_subplot(
self,
ax: Axes,
field: tuple,
defocus: float,
i: int,
j: int,
num_fields: int,
x_label: str,
y_label: str,
limit: float,
):
"""Applies titles, labels, and axis limits to a subplot.
Args:
ax (matplotlib.axes.Axes): Axis to configure.
field (tuple): Field coordinates (x, y).
defocus (float): Defocus amount in mm.
i (int): Field index.
j (int): Defocus step index.
num_fields (int): Total number of fields.
x_label (str): Label for x-axis.
y_label (str): Label for y-axis.
limit (float): Axis limit for both x and y.
"""
ax.axis("square")
ax.grid(alpha=0.25)
title = f"Field: ({field[0]:.2f},{field[1]:.2f})"
if i == 0:
title = f"Defocus: {defocus:+.3f} mm\n{title}"
ax.set_title(title, fontsize=10)
if i == num_fields - 1:
ax.set_xlabel(x_label)
if j == 0:
ax.set_ylabel(y_label)
ax.set_xlim(-limit, limit)
ax.set_ylim(-limit, limit)
def _add_legend(
self,
fig: Figure,
handles: list,
labels: list,
num_fields: int,
figsize_per_plot: tuple[float, float],
):
"""Adds a wavelength legend below the plot grid.
Args:
fig (matplotlib.figure.Figure): Figure object.
handles (list): Legend handles for plotted wavelengths.
labels (list): Corresponding labels.
num_fields (int): Number of fields (rows).
figsize_per_plot (tuple): Subplot size in inches.
"""
if handles:
fig.legend(
handles,
labels,
loc="lower center",
ncol=min(5, len(labels)),
bbox_to_anchor=(0.5, -0.02 / (figsize_per_plot[1] * num_fields / 4)),
)