Source code for optiland.wavefront.opd

"""
This module defines the OPD class.

Kramer Harrison, 2024
"""

from __future__ import annotations

from typing import TYPE_CHECKING, TypedDict, cast

import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import griddata

import optiland.backend as be
from optiland.utils import resolve_wavelength

from .wavefront import Wavefront

if TYPE_CHECKING:
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure
    from mpl_toolkits.mplot3d import Axes3D
    from numpy.typing import NDArray

    from optiland._types import DistributionType, PlotProjection
    from optiland.optic.optic import Optic
    from optiland.wavefront.strategy import WavefrontStrategyType


[docs] class OPDData(TypedDict): x: NDArray y: NDArray z: NDArray
[docs] class OPD(Wavefront): """Represents an Optical Path Difference (OPD) wavefront. Args: optic (Optic): The optic object. field (tuple): The field at which to calculate the OPD. wavelength (str | float): The wavelength of the wavefront. Can be 'primary' or a float value. num_rings (int, optional): The number of rings for ray tracing. Defaults to 15. strategy (str): The calculation strategy to use. Supported options are "chief_ray", "centroid", and "best_fit". Defaults to "chief_ray". afocal (bool): If True, uses a planar reference geometry for afocal systems. If False, uses a spherical reference geometry. Defaults to False. remove_tilt (bool): If True, removes tilt and piston from the OPD data. Defaults to False. **kwargs: Additional keyword arguments passed to the strategy. Attributes: optic (Optic): The optic object. field (tuple[float, float]): The field coordinates (Hx, Hy). wavelength (float): The wavelength of the wavefront in micrometers. num_rays (int): The number of rays (or rings for hexapolar distribution) to use for pupil sampling. distribution (BaseDistribution): The pupil sampling distribution instance. data (dict): A dictionary mapping (field, wavelength) tuples to `WavefrontData` objects. Inherited from `Wavefront`. Methods: view(projection='2d', num_points=256, figsize=(7, 5.5)): Visualizes the OPD wavefront. rms(): Calculates the root mean square (RMS) of the OPD wavefront. """ def __init__( self, optic: Optic, field: tuple[float, float], wavelength: str | float, num_rays: int = 15, distribution: DistributionType = "hexapolar", strategy: WavefrontStrategyType = "chief_ray", remove_tilt: bool = False, **kwargs, ) -> None: resolved_wavelength = resolve_wavelength(optic, wavelength) super().__init__( optic, fields=[field], wavelengths=[resolved_wavelength], num_rays=num_rays, distribution=distribution, strategy=strategy, remove_tilt=remove_tilt, **kwargs, )
[docs] def view( self, fig_to_plot_on: Figure | None = None, projection: PlotProjection = "2d", num_points: int = 256, figsize: tuple[float, float] = (7, 5.5), ) -> tuple[Figure, Axes]: """Visualizes the OPD wavefront. Args: fig_to_plot_on (plt.Figure, optional): The figure to plot on. If None, a new figure is created. projection (str, optional): The projection type. Defaults to '2d'. num_points (int, optional): The number of points for interpolation. Defaults to 256. figsize (tuple, optional): The figure size. Defaults to (7, 5.5). Returns: tuple: A tuple containing the figure and axes objects. Raises: ValueError: If the projection is not '2d' or '3d'. """ is_gui_embedding = fig_to_plot_on is not None if is_gui_embedding: current_fig = cast("Figure", fig_to_plot_on) current_fig.clear() ax = ( current_fig.add_subplot(111) if projection == "2d" else current_fig.add_subplot(111, projection="3d") ) else: current_fig, ax = ( plt.subplots(figsize=figsize) if projection == "2d" else plt.subplots(figsize=figsize, subplot_kw={"projection": "3d"}) ) opd_map = self.generate_opd_map(num_points) if projection == "2d": self._plot_2d(data=opd_map, ax=ax) elif projection == "3d": self._plot_3d(fig=current_fig, ax=ax, data=opd_map) else: raise ValueError('OPD projection must be "2d" or "3d".') if is_gui_embedding and hasattr(current_fig, "canvas"): current_fig.canvas.draw_idle() return current_fig, ax
[docs] def rms(self) -> be.ndarray: """Calculates the root mean square (RMS) of the OPD wavefront. Returns: float: The RMS value. """ data = self.get_data(self.fields[0], self.wavelengths[0]) mask = data.intensity > 0 if not be.any(mask): raise ValueError( "No valid rays with non-zero intensity for RMS calculation." ) opd = data.opd[mask] return be.sqrt(be.mean(opd**2))
def _plot_2d(self, ax: Axes, data: dict[str, NDArray]) -> None: """Plots the 2D visualization of the OPD wavefront. Args: data (dict[str, np.ndarray]): The OPD map data, where keys are 'x', 'y', 'z' and values are NumPy arrays suitable for plotting. figsize (tuple, optional): The figure size. Defaults to (7, 5.5). """ im = ax.imshow( np.flipud(data["z"]), extent=(-1, 1, -1, 1) ) # np.flipud is fine here as data['z'] is already numpy ax.set_xlabel("Pupil X") ax.set_ylabel("Pupil Y") ax.set_title(f"OPD Map: RMS={self.rms():.3f} waves") cbar = plt.colorbar(im) cbar.ax.get_yaxis().labelpad = 15 cbar.ax.set_ylabel("OPD (waves)", rotation=270) def _plot_3d(self, fig: Figure, ax: Axes3D, data: dict[str, NDArray]) -> None: """Plots the 3D visualization of the OPD wavefront. Args: data (dict[str, np.ndarray]): The OPD map data, where keys are 'x', 'y', 'z' and values are NumPy arrays suitable for plotting. figsize (tuple, optional): The figure size. Defaults to (7, 5.5). """ surf = ax.plot_surface( data["x"], data["y"], data["z"], rstride=1, cstride=1, cmap="viridis", linewidth=0, antialiased=False, ) ax.set_xlabel("Pupil X") ax.set_ylabel("Pupil Y") ax.set_zlabel("OPD (waves)") ax.set_title(f"OPD Map: RMS={self.rms():.3f} waves") fig.colorbar(surf, ax=ax, shrink=0.5, aspect=10, pad=0.15) fig.tight_layout()
[docs] def generate_opd_map(self, num_points: int = 256) -> OPDData: """Generates the OPD map data. Args: num_points (int, optional): The number of points for interpolation along each axis of the grid. Defaults to 256. Returns: dict[str, np.ndarray]: A dictionary containing the interpolated OPD map, with keys 'x', 'y', and 'z'. The values are NumPy arrays. """ data = self.get_data(self.fields[0], self.wavelengths[0]) x = be.to_numpy(self.distribution.x) y = be.to_numpy(self.distribution.y) z = be.to_numpy(data.opd) intensity = be.to_numpy(data.intensity) # Ignore zero intensity points mask = intensity > 0 x = x[mask] y = y[mask] z = z[mask] intensity = intensity[mask] x_interp, y_interp = np.meshgrid( np.linspace(-1, 1, num_points), np.linspace(-1, 1, num_points), ) points = np.column_stack((x.flatten(), y.flatten())) values = z.flatten() * intensity.flatten() z_interp = griddata(points, values, (x_interp, y_interp), method="cubic") data = OPDData(x=x_interp, y=y_interp, z=z_interp) return data