"""Spot Diagram Analysis
This module provides the core spot diagram analysis for optical systems,
including data generation, centering, and metrics calculation.
Kramer Harrison, 2024
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal
import optiland.backend as be
from optiland.utils import resolve_fields
from optiland.visualization.system.utils import transform
from ..base import BaseAnalysis
from .plotting import (
calculate_axis_limits,
finalize_plot,
handle_no_fields,
plot_field,
setup_plot_layout,
)
from .reference import SpotReferenceType, create_reference_strategy
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from optiland._types import BEArray, DistributionType
[docs]
@dataclass
class SpotData:
"""Stores the x, y coordinates and intensity of a spot.
Attributes:
x: Array of x-coordinates.
y: Array of y-coordinates.
intensity: Array of intensity values.
"""
x: be.array
y: be.array
intensity: be.array
[docs]
class SpotDiagram(BaseAnalysis):
"""Generates and plots real ray intersection data on the image surface.
This class creates spot diagrams, which are purely geometric plots that give
an indication of the blur produced by aberrations in an optical system.
Attributes:
optic: Instance of the optic object to be assessed.
fields: Fields at which data is generated.
wavelengths: Wavelengths at which data is generated.
num_rings: Number of rings in the pupil distribution for ray tracing.
distribution: The pupil distribution type for ray tracing.
data: Contains spot data in a nested list, ordered by field, then
wavelength.
coordinates: The coordinate system ('global' or 'local') for data and
plotting.
reference: The reference point type used for centering spots.
"""
def __init__(
self,
optic,
fields: str | list = "all",
wavelengths: str | list = "all",
num_rings: int = 6,
distribution: DistributionType = "hexapolar",
coordinates: Literal["global", "local"] = "local",
reference: str | SpotReferenceType = SpotReferenceType.CHIEF_RAY,
):
"""Initializes the SpotDiagram analysis.
Note:
The constructor generates all data that is later used for plotting.
Args:
optic: An instance of the optic object to be assessed.
fields: Fields at which to generate data. If 'all', all defined
field points are used. Defaults to "all".
wavelengths: Wavelengths at which to generate data. If 'all', all
defined wavelengths are used. Defaults to "all".
num_rings: Number of rings in the pupil distribution for ray
tracing. Defaults to 6.
distribution: Pupil distribution type for ray tracing.
Defaults to "hexapolar".
coordinates: Coordinate system for data generation and plotting.
Defaults to "local".
reference: Reference point type for centering spots. Can be
"chief_ray" or "centroid". Defaults to "chief_ray".
Raises:
ValueError: If `coordinates` is not 'global' or 'local'.
ValueError: If `reference` is not a valid SpotReferenceType.
"""
self.fields = resolve_fields(optic, fields) # list[FieldPoint]
if coordinates not in ["global", "local"]:
raise ValueError("Coordinates must be 'global' or 'local'.")
self.coordinates = coordinates
self.num_rings = num_rings
self.distribution: DistributionType = distribution
self._reference_strategy = create_reference_strategy(reference)
super().__init__(optic, wavelengths)
primary_wl_value = self.optic.primary_wavelength
wl_values = [wp.value for wp in self.wavelengths]
if primary_wl_value in wl_values:
self._analysis_ref_wavelength_index = wl_values.index(primary_wl_value)
else:
self._analysis_ref_wavelength_index = 0
[docs]
def view(
self,
fig_to_plot_on: Figure | None = None,
figsize: tuple[float, float] = (12, 4),
add_airy_disk: bool = False,
) -> tuple[Figure, list[Axes]]:
"""Displays the spot diagram plot.
Args:
fig_to_plot_on: An existing Matplotlib figure to plot on. If None,
a new figure is created. Defaults to None.
figsize: The figure size for the output window, applied per row.
Defaults to (12, 4).
add_airy_disk: If True, adds the Airy disk visualization to the
plots. Defaults to False.
Returns:
A tuple containing the Matplotlib figure and a list of its axes.
"""
if not self.fields:
return handle_no_fields(fig_to_plot_on)
centered_data = self._center_spots(self.data)
airy_disk_data = self._prepare_airy_disk_data() if add_airy_disk else None
fig, axs = setup_plot_layout(len(self.fields), fig_to_plot_on, figsize)
axis_lim = calculate_axis_limits(centered_data, self.fields, airy_disk_data)
for i, field_data in enumerate(centered_data):
if i >= len(axs):
break
plot_field(
axs[i],
field_data,
self.wavelengths,
self.fields[i].coord,
axis_lim,
i,
self.optic.image_surface,
airy_disk_data,
)
finalize_plot(fig, axs, len(self.fields), self.wavelengths)
return fig, fig.get_axes()
# --- Calculation Methods ---
[docs]
def angle_from_cosine(self, a: BEArray, b: BEArray) -> float:
"""Calculates the angle in radians between two direction cosine vectors.
Args:
a: The first direction cosine vector.
b: The second direction cosine vector.
Returns:
The angle between the vectors in radians.
"""
a = a / be.linalg.norm(a)
b = b / be.linalg.norm(b)
return be.arccos(be.clip(be.dot(a, b), -1, 1))
[docs]
def f_number(self, n: float, theta: float) -> float:
"""Calculates the physical F-number.
Args:
n: The refractive index of the medium.
theta: The half-angle of the cone of light in radians.
Returns:
The calculated physical F-number.
"""
return 1 / (2 * n * be.sin(theta))
[docs]
def airy_radius(self, n_w: float, wavelength: float) -> float:
"""Calculates the Airy disk radius.
Args:
n_w: The physical F-number.
wavelength: The wavelength of light in micrometers.
Returns:
The Airy disk radius.
"""
return 1.22 * n_w * wavelength
[docs]
def generate_marginal_rays(
self, H_x: float, H_y: float, wavelength: float
) -> tuple:
"""Generates marginal rays at the four cardinal points of the pupil.
Args:
H_x: The x-field coordinate.
H_y: The y-field coordinate.
wavelength: The wavelength for the rays.
Returns:
A tuple containing the traced rays for north, south, east, and west
pupil points.
"""
ray_north = self.optic.trace_generic(
Hx=H_x, Hy=H_y, Px=0, Py=1, wavelength=wavelength
)
ray_south = self.optic.trace_generic(
Hx=H_x, Hy=H_y, Px=0, Py=-1, wavelength=wavelength
)
ray_east = self.optic.trace_generic(
Hx=H_x, Hy=H_y, Px=1, Py=0, wavelength=wavelength
)
ray_west = self.optic.trace_generic(
Hx=H_x, Hy=H_y, Px=-1, Py=0, wavelength=wavelength
)
return ray_north, ray_south, ray_east, ray_west
[docs]
def generate_marginal_rays_cosines(
self, H_x: float, H_y: float, wavelength: float
) -> tuple:
"""Generates direction cosines for each marginal ray of a given field.
Args:
H_x: The x-field coordinate.
H_y: The y-field coordinate.
wavelength: The wavelength for the rays.
Returns:
A tuple of direction cosine vectors for north, south, east, and
west rays.
"""
rays = self.generate_marginal_rays(H_x, H_y, wavelength)
return tuple(be.array([ray.L, ray.M, ray.N]).ravel() for ray in rays)
[docs]
def generate_chief_rays_cosines(self, wavelength: float) -> BEArray:
"""Generates direction cosines for the chief ray of each field.
Args:
wavelength: The wavelength for the rays.
Returns:
An array of shape (num_fields, 3) containing the direction cosines.
"""
cosines = [
be.array([ray.L, ray.M, ray.N]).ravel()
for fp in self.fields
for ray in [
self.optic.trace_generic(
Hx=fp.coord[0], Hy=fp.coord[1], Px=0, Py=0, wavelength=wavelength
)
]
]
return be.stack(cosines, axis=0)
[docs]
def generate_chief_rays_centers(self, wavelength: float) -> BEArray:
"""Generates the (x, y) intersection points for the chief ray of each
field.
Args:
wavelength: The wavelength for the rays.
Returns:
An array of shape (num_fields, 2) with (x, y) coordinates.
"""
centers = [
[ray.x.item(), ray.y.item()]
for fp in self.fields
for ray in [
self.optic.trace_generic(
Hx=fp.coord[0], Hy=fp.coord[1], Px=0, Py=0, wavelength=wavelength
)
]
]
return be.stack(centers, axis=0)
[docs]
def airy_disc_x_y(self, wavelength: float) -> tuple[list[float], list[float]]:
"""Generates the Airy disk radii for the x and y axes for each field.
Args:
wavelength: The wavelength for the calculation.
Returns:
A tuple of two lists: x-axis radii and y-axis radii per field.
"""
chief_cosines = self.generate_chief_rays_cosines(wavelength)
airy_rad_x_list, airy_rad_y_list = [], []
for i, fp in enumerate(self.fields):
H_x, H_y = fp.coord
north, south, east, west = self.generate_marginal_rays_cosines(
H_x, H_y, wavelength
)
chief = chief_cosines[i]
angle_x = (
self.angle_from_cosine(chief, north)
+ self.angle_from_cosine(chief, south)
) / 2
angle_y = (
self.angle_from_cosine(chief, east)
+ self.angle_from_cosine(chief, west)
) / 2
f_num_x = self.f_number(n=1, theta=angle_x)
f_num_y = self.f_number(n=1, theta=angle_y)
# Convert radius from µm to mm
airy_rad_x_list.append(self.airy_radius(f_num_x, wavelength) * 1e-3)
airy_rad_y_list.append(self.airy_radius(f_num_y, wavelength) * 1e-3)
return airy_rad_x_list, airy_rad_y_list
[docs]
def centroid(self) -> list[tuple[BEArray, BEArray]]:
"""Calculates the geometric centroid of each spot for the reference
wavelength.
Returns:
A list of (x, y) centroid coordinates for each field.
"""
ref_idx = self._analysis_ref_wavelength_index
return [
(be.mean(field_data[ref_idx].x), be.mean(field_data[ref_idx].y))
for field_data in self.data
]
[docs]
def geometric_spot_radius(self) -> list[list[BEArray]]:
"""Calculates the maximum geometric spot radius for each spot.
Returns:
A nested list of maximum radii for each field and wavelength.
"""
centered_data = self._center_spots(self.data)
return [
[
be.max(be.sqrt(wave_data.x**2 + wave_data.y**2))
for wave_data in field_data
]
for field_data in centered_data
]
[docs]
def rms_spot_radius(self) -> list[list[BEArray]]:
"""Calculates the root-mean-square (RMS) spot radius for each spot.
Returns:
A nested list of RMS radii for each field and wavelength.
"""
centered_data = self._center_spots(self.data)
return [
[
be.sqrt(be.mean(wave_data.x**2 + wave_data.y**2))
for wave_data in field_data
]
for field_data in centered_data
]
# --- Internal Data Generation Helpers ---
def _get_reference_centers(
self, data: list[list[SpotData]]
) -> list[tuple[BEArray, BEArray]]:
"""Computes the reference centers using the configured strategy.
Args:
data: The spot data to compute centers for.
Returns:
A list of (x, y) center tuples, one per field.
"""
ref_wl = self.wavelengths[self._analysis_ref_wavelength_index].value
fields_coords = [fp.coord for fp in self.fields]
return self._reference_strategy.get_centers(
data,
self._analysis_ref_wavelength_index,
self.optic,
fields_coords,
ref_wl,
self.coordinates,
)
def _center_spots(self, data: list[list[SpotData]]) -> list[list[SpotData]]:
"""Centers spot data around the configured reference point.
Args:
data: The original, uncentered spot data.
Returns:
A deep copy of the data, centered around the reference points.
"""
centers = self._get_reference_centers(data)
centered_data = []
for i, field_list in enumerate(data):
cx, cy = centers[i]
centered_field = [
SpotData(
x=sd.x - cx,
y=sd.y - cy,
intensity=be.copy(sd.intensity),
)
for sd in field_list
]
centered_data.append(centered_field)
return centered_data
def _generate_data(self) -> list[list[SpotData]]:
"""Generates spot data for all configured fields and wavelengths.
Returns:
A nested list of spot intersection data.
"""
return [
[
self._generate_field_data(
fp.coord,
wp.value,
self.num_rings,
self.distribution,
self.coordinates,
)
for wp in self.wavelengths
]
for fp in self.fields
]
def _generate_field_data(
self,
field: tuple[float, float],
wavelength: float,
num_rays: int,
distribution: DistributionType,
coordinates: str,
) -> SpotData:
"""Generates spot data for a single field and wavelength.
Args:
field: The (Hx, Hy) field coordinates.
wavelength: The wavelength for tracing.
num_rays: The number of rays to generate, or number of rings if
distribution is hexapolar.
distribution: The ray distribution pattern.
coordinates: The coordinate system ('local' or 'global').
Returns:
A SpotData object with the traced ray intersection data.
"""
self.optic.trace(*field, wavelength, num_rays, distribution)
surf_group = self.optic.surfaces
x_g, y_g, z_g, i_g = (
surf_group.x[-1, :],
surf_group.y[-1, :],
surf_group.z[-1, :],
surf_group.intensity[-1, :],
)
# Ignore rays with zero intensity
mask = i_g > 0
x_g, y_g, z_g, i_g = x_g[mask], y_g[mask], z_g[mask], i_g[mask]
if coordinates == "local":
x_plot, y_plot, _ = transform(
x_g, y_g, z_g, self.optic.image_surface, is_global=True
)
else:
x_plot, y_plot = x_g, y_g
return SpotData(x=x_plot, y=y_plot, intensity=i_g)
def _prepare_airy_disk_data(self) -> dict:
"""Prepares all necessary data for plotting the Airy disk.
The Airy disk position is determined by the configured reference
strategy: when using chief ray centering, the Airy disk sits at (0,0);
when using centroid centering, it is offset by the difference between
the chief ray and centroid positions.
Returns:
A dictionary containing Airy disk radii and center coordinates
relative to the reference point.
"""
primary_wl_obj = self.optic.wavelengths.primary_wavelength
wl_val = primary_wl_obj.value if primary_wl_obj else self.wavelengths[0]
airy_rad_x, airy_rad_y = self.airy_disc_x_y(wavelength=wl_val)
chief_centers = self.generate_chief_rays_centers(wavelength=wl_val)
reference_centers = self._get_reference_centers(self.data)
# Airy disk is physically at the chief ray position. Compute its
# offset relative to whichever reference was used for centering.
airy_centers = be.to_numpy(chief_centers) - be.to_numpy(
be.stack(reference_centers)
)
return {
"radii_x": be.to_numpy(airy_rad_x),
"radii_y": be.to_numpy(airy_rad_y),
"airy_centers": airy_centers,
}