"""Base PSF Module
This module provides a base class for Point Spread Function (PSF) calculations.
Kramer Harrison, 2025
"""
from __future__ import annotations
from abc import abstractmethod
from typing import TYPE_CHECKING
from warnings import warn
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
from matplotlib.colors import LogNorm
from scipy.ndimage import zoom
import optiland.backend as be
from optiland.utils import get_working_FNO, resolve_wavelength
from optiland.wavefront import Wavefront
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from mpl_toolkits.mplot3d import Axes3D
from optiland.fields import Field
from optiland.optic import Optic
[docs]
def replace_nonpositive(image, min_value=1e-9):
"""
Replace values <= 0 in the image with the smallest positive value in the image.
If no positive value exists, use min_value.
Args:
image: Array (backend or numpy) to process.
min_value: Value to use if no positive values exist (default: 1e-9).
Returns:
Array with non-positive values replaced.
"""
if be.any(image <= 0):
min_positive = be.min(image[image > 0]) if be.any(image > 0) else min_value
return be.where(image <= 0, min_positive, image)
return image
[docs]
class BasePSF(Wavefront):
"""Base class for Point Spread Function (PSF) calculations.
Args:
optic (Optic): The optical system.
field (tuple): The field as (x, y) at which to compute the PSF.
wavelength (str | float): The wavelength of light. Can be 'primary' or a
float value.
num_rays (int, optional): The number of rays used for wavefront
computation. Defaults to 128.
strategy (str): The calculation strategy to use. Supported options are
"chief_ray", "centroid_sphere", and "best_fit_sphere".
Defaults to "chief_ray".
remove_tilt (bool): If True, removes tilt and piston from the OPD data.
Defaults to True.
**kwargs: Additional keyword arguments passed to the strategy.
Attributes:
psf (ndarray): The computed PSF. This should be set by subclasses.
Methods:
view(projection='2d', log=False, figsize=(7, 5.5), threshold=0.05,
num_points=128): Visualizes the PSF.
"""
def __init__(
self,
optic: Optic,
field: Field,
wavelength: str | float,
num_rays=128,
strategy="chief_ray",
remove_tilt=True,
**kwargs,
):
resolved_wavelength = resolve_wavelength(optic, wavelength)
super().__init__(
optic=optic,
fields=[field],
wavelengths=[resolved_wavelength],
num_rays=num_rays,
distribution="uniform",
strategy=strategy,
remove_tilt=remove_tilt,
**kwargs,
)
self.psf = None # Subclasses must compute and set this
[docs]
def view(
self,
fig_to_plot_on: Figure | None = None,
projection: str = "2d",
log: bool = False,
figsize: tuple = (7, 5.5),
threshold: float = 0.05,
num_points: int = 128,
) -> tuple[Figure, Axes]:
"""Visualizes the PSF.
Args:
projection (str, optional): The projection type. Can be '2d' or
'3d'. Defaults to '2d'.
log (bool, optional): Whether to use a logarithmic scale for the
intensity. Defaults to False.
figsize (tuple, optional): The figure size. Defaults to (7, 5.5).
threshold (float, optional): The threshold for determining the
bounds of the PSF for zoomed view. Defaults to 0.05.
num_points (int, optional): The number of points used for
interpolating the PSF for smoother visualization. Defaults to 128.
Returns:
tuple: A tuple containing the figure and axes objects.
Raises:
RunentimeError: If the PSF has not been computed.
ValueError: If the projection is not '2d' or '3d'.
RuntimeError: If the PSF has not been computed by the subclass.
"""
if self.psf is None:
raise RuntimeError(
"PSF has not been computed. Call _compute_psf in subclass."
)
is_gui_embedding = fig_to_plot_on is not None
if is_gui_embedding:
current_fig = fig_to_plot_on
current_fig.clear()
ax = (
current_fig.add_subplot(111)
if projection == "2d"
else current_fig.add_subplot(111, projection="3d")
)
else:
current_fig, ax = (
plt.subplots(figsize=figsize)
if projection == "2d"
else plt.subplots(subplot_kw={"projection": "3d"}, figsize=figsize)
)
psf_np = be.to_numpy(self.psf)
min_x, min_y, max_x, max_y = self._find_bounds(psf_np, threshold)
psf_zoomed = psf_np[min_x:max_x, min_y:max_y]
oversampling_factor = num_points / psf_zoomed.shape[0]
if oversampling_factor > 3:
message = (
f"The PSF view has a high oversampling factor "
f"({oversampling_factor:.2f}). Results may be inaccurate."
)
warn(message, stacklevel=2)
# Subclasses should implement _get_psf_units if they want physical units
# otherwise, pixel units are used.
if hasattr(self, "_get_psf_units"):
x_extent, y_extent = self._get_psf_units(psf_zoomed)
x_label, y_label = "X (µm)", "Y (µm)"
else:
# Default to pixel units if not implemented by subclass
x_extent = psf_zoomed.shape[1]
y_extent = psf_zoomed.shape[0]
x_label, y_label = "X (pixels)", "Y (pixels)"
psf_smooth = self._interpolate_psf(psf_zoomed, num_points)
if projection == "2d":
self._plot_2d(
current_fig,
ax,
psf_smooth,
log,
x_extent,
y_extent,
figsize,
x_label,
y_label,
psf_zoomed.shape,
)
return current_fig, ax
elif projection == "3d":
self._plot_3d(
current_fig,
ax,
psf_smooth,
log,
x_extent,
y_extent,
figsize,
x_label,
y_label,
psf_zoomed.shape,
)
return current_fig, ax
# Raise error if projection is not recognized
else:
raise ValueError('Projection must be "2d" or "3d".')
if is_gui_embedding and hasattr(current_fig, "canvas"):
current_fig.canvas.draw_idle()
return current_fig, ax
def _plot_2d(
self,
fig: Figure,
ax: Axes,
image: np.ndarray,
log: bool,
x_extent: float,
y_extent: float,
figsize: tuple,
x_label: str,
y_label: str,
original_size: tuple,
) -> None:
"""Plots the PSF in 2D.
Args:
image (numpy.ndarray): The 2D image of the PSF to plot.
log (bool): If True, apply logarithmic normalization to the image.
x_extent (float): The extent of the x-axis.
y_extent (float): The extent of the y-axis.
figsize (tuple): The size of the figure.
x_label (str): Label for the x-axis.
y_label (str): Label for the y-axis.
original_size (tuple): The original size of the PSF image before
interpolation.
"""
norm = LogNorm() if log else None
# Replace values <= 0 with smallest non-zero value in image for log scale
if log and be.any(image <= 0):
image = replace_nonpositive(image)
extent = [-x_extent / 2, x_extent / 2, -y_extent / 2, y_extent / 2]
im = ax.imshow(be.to_numpy(image), norm=norm, extent=extent, origin="lower")
self._annotate_original_size(fig, original_size)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_title(f"{self.__class__.__name__.replace('PSF', ' PSF')}")
cbar = plt.colorbar(im)
cbar.ax.get_yaxis().labelpad = 15
cbar.ax.set_ylabel("Relative Intensity (%)", rotation=270)
def _plot_3d(
self,
fig: Figure,
ax: Axes3D,
image: np.ndarray,
log: bool,
x_extent: float,
y_extent: float,
figsize: tuple,
x_label: str,
y_label: str,
original_size: tuple,
) -> None:
"""Plots the PSF in 3D.
Args:
image (numpy.ndarray): The PSF image data.
log (bool): Whether to apply logarithmic scaling to the image.
x_extent (float): The extent of the x-axis.
y_extent (float): The extent of the y-axis.
figsize (tuple): The size of the figure.
x_label (str): Label for the x-axis.
y_label (str): Label for the y-axis.
original_size (tuple): The original size of the PSF image before
interpolation.
"""
x_np = be.to_numpy(be.linspace(-x_extent / 2, x_extent / 2, image.shape[1]))
y_np = be.to_numpy(be.linspace(-y_extent / 2, y_extent / 2, image.shape[0]))
X_np, Y_np = np.meshgrid(x_np, y_np)
# Replace values <= 0 with smallest non-zero value in image for log scale
if log and be.any(image <= 0):
image = replace_nonpositive(image)
image_np = be.to_numpy(image)
log_formatter = None
if log:
image_plot = np.log10(image_np)
formatter = mticker.FuncFormatter(self._log_tick_formatter)
ax.zaxis.set_major_formatter(formatter)
ax.zaxis.set_major_locator(mticker.MaxNLocator(integer=True))
log_formatter = self._log_colorbar_formatter
else:
image_plot = image_np
surf = ax.plot_surface(
X_np,
Y_np,
image_plot,
rstride=1,
cstride=1,
cmap="viridis",
linewidth=0,
antialiased=False,
)
self._annotate_original_size(fig, original_size)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_zlabel("Relative Intensity (%)")
ax.set_title(f"{self.__class__.__name__.replace('PSF', ' PSF')}")
fig.colorbar(surf, ax=ax, shrink=0.5, aspect=10, pad=0.15, format=log_formatter)
fig.tight_layout()
def _log_tick_formatter(self, value, pos=None):
"""Formats tick labels for a logarithmic scale (Z-axis in 3D plot)."""
return f"$10^{{{int(value)}}}$"
def _log_colorbar_formatter(self, value, pos=None):
"""Formats tick labels for a logarithmic colorbar."""
linear_value = 10**value
return f"{linear_value:.1e}"
def _annotate_original_size(self, fig: Figure, original_size):
"""Annotates the original size of the zoomed PSF in the bottom right corner."""
text = f"Original Size: {original_size[0]}×{original_size[1]}"
fig.text(
0.99,
0.01,
text,
transform=fig.transFigure,
fontsize=10,
verticalalignment="bottom",
horizontalalignment="right",
bbox=dict(facecolor="white", alpha=0.8, edgecolor="none"),
)
def _interpolate_psf(self, image, n=128):
"""Interpolates the PSF for visualization.
Uses scipy.ndimage.zoom for interpolation. Converts to NumPy
as zoom requires NumPy array.
Args:
image (ndarray): The input image (can be backend array).
n (int, optional): The number of points in the interpolated
grid. Defaults to 128.
Returns:
ndarray: The interpolated PSF grid (backend array).
"""
image_np = be.to_numpy(image)
zoom_factor = n / image_np.shape[0]
if zoom_factor == 1:
return image # Return original backend array if no zoom
interpolated_np = zoom(image_np, zoom_factor, order=3)
return be.array(interpolated_np)
@staticmethod
def _find_bounds(psf, threshold=0.25):
"""Finds the bounding box coordinates for the non-zero elements in the
PSF matrix.
Args:
psf (numpy.ndarray): The PSF matrix.
threshold (float): The threshold value for determining non-zero
elements in the PSF matrix. Default is 0.25.
Returns:
tuple: A tuple containing the minimum and maximum x and y
coordinates of the bounding box.
"""
thresholded_psf = psf > threshold
non_zero_indices = np.argwhere(thresholded_psf)
try:
min_x, min_y = np.min(non_zero_indices, axis=0)
max_x, max_y = np.max(non_zero_indices, axis=0)
except ValueError:
min_x, min_y = 0, 0
max_x, max_y = psf.shape
size = max(max_x - min_x, max_y - min_y)
peak_x, peak_y = psf.shape[0] // 2, psf.shape[1] // 2
min_x = peak_x - size / 2
max_x = peak_x + size / 2
min_y = peak_y - size / 2
max_y = peak_y + size / 2
min_x = max(0, min_x)
min_y = max(0, min_y)
max_x = min(psf.shape[0], max_x)
max_y = min(psf.shape[1], max_y)
return int(min_x), int(min_y), int(max_x), int(max_y)
@abstractmethod
def _compute_psf(self):
"""Computes the PSF.
This method must be implemented by subclasses.
It should calculate the PSF and store it in self.psf.
"""
raise NotImplementedError("Subclasses must implement _compute_psf.")
[docs]
def strehl_ratio(self):
"""Computes the Strehl ratio of the PSF.
The Strehl ratio is the ratio of the peak intensity of the aberrated
PSF to the peak intensity of the diffraction-limited PSF.
Assumes self.psf is normalized such that its peak would be 1.0 (or 100%)
for a diffraction-limited system.
Returns:
float: The Strehl ratio.
Raises:
RuntimeError: If the PSF has not been computed.
"""
if self.psf is None:
raise RuntimeError("PSF has not been computed.")
center_x = self.psf.shape[0] // 2
center_y = self.psf.shape[1] // 2
strehl = self.psf[center_x, center_y] / 100
return float(be.to_numpy(strehl).item())
def _get_working_FNO(self):
"""Calculates the working F-number of the optical system for the
single defined field point and given wavelength.
Algorithm:
1. Retrieve the defined given wavelength and field coordinates.
2. Determine the image-space refractive index 'n' at the given wavelength.
3. Trace four marginal rays (top, bottom, left, right) at the pupil edges,
as well as the chief ray.
4. Compute the angle between each marginal ray and the chief ray.
4. Calculate the average of the squared numerical apertures from all traced
marginal rays.
5. Compute the working F-number as 1 / (2 * sqrt(average_NA_squared)).
6. Cap the calculated F/# at 10,000 if it exceeds this value.
Returns:
float: The working F-number.
"""
return get_working_FNO(
self.optic, self.fields[0].coord, self.wavelengths[0].value
)