Source code for optiland.analysis.field_curvature

"""Field Curvature Analysis

This module provides a field curvature analysis for optical systems.

Kramer Harrison, 2024
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np

import optiland.backend as be

from .base import BaseAnalysis

if TYPE_CHECKING:
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure


[docs] class FieldCurvature(BaseAnalysis): """Represents a class for analyzing field curvature of an optic. Args: optic (Optic): The optic object to analyze. wavelengths (str or list, optional): The wavelengths to analyze. Defaults to 'all'. num_points (int, optional): The number of points to generate for the analysis. Defaults to 128. Attributes: optic (Optic): The optic object being analyzed. wavelengths (list): The wavelengths being analyzed. num_points (int): The number of points generated for the analysis. data (list): The generated data for the analysis. Methods: view(figsize=(8, 5.5)): Displays a plot of the field curvature analysis. """ def __init__(self, optic, wavelengths="all", num_points=128): self.num_points = num_points super().__init__(optic, wavelengths)
[docs] def view( self, fig_to_plot_on: Figure | None = None, figsize: tuple[float, float] = (8, 5.5), *, show: bool = True, ) -> tuple[Figure, Axes]: """Displays a plot of the field curvature analysis. Args: fig_to_plot_on (plt.Figure, optional): The figure to plot on. If None, a new figure will be created. Defaults to None. figsize (tuple[float, float], optional): The size of the figure. Defaults to (8, 5.5). show (bool): If True (default), calls plt.show(). Set False for headless use. Returns: tuple: The current figure and its axes. """ is_gui_embedding = fig_to_plot_on is not None if is_gui_embedding: current_fig = fig_to_plot_on current_fig.clear() ax = current_fig.add_subplot(111) else: current_fig, ax = plt.subplots(figsize=figsize) field = be.linspace(0, self.optic.fields.max_field, self.num_points) field_np = be.to_numpy(field) for k, wp in enumerate(self.wavelengths): dk_np_tan = be.to_numpy(self.data[k][0]) ax.plot( dk_np_tan, field_np, f"C{k}", zorder=10, label=f"{wp.value:.4f} µm, Tangential", ) dk_np_sag = be.to_numpy(self.data[k][1]) ax.plot( dk_np_sag, field_np, f"C{k}--", zorder=10, label=f"{wp.value:.4f} µm, Sagittal", ) ax.set_xlabel("Image Plane Delta (mm)") ax.set_ylabel("Field") ax.set_ylim([0, self.optic.fields.max_field]) current_xlim = ax.get_xlim() max_abs_lim = max(np.abs(current_xlim)) if max_abs_lim > 1e-9: ax.set_xlim([-max_abs_lim, max_abs_lim]) ax.set_title("Field Curvature") ax.axvline(x=0, color="k", linewidth=0.5) ax.legend(bbox_to_anchor=(1.05, 0.5), loc="center left") ax.grid(True) current_fig.tight_layout() if is_gui_embedding and hasattr(current_fig, "canvas"): current_fig.canvas.draw_idle() if show and not is_gui_embedding: plt.show() return current_fig, ax
def _generate_data(self): """Generates field curvature data for each wavelength by calculating the tangential and sagittal intersections. Returns: list: A list of np.ndarry containing the tangential and sagittal intersection points for each wavelength. """ data = [] for wp in self.wavelengths: wavelength = wp.value tangential = self._intersection_parabasal_tangential(wavelength) sagittal = self._intersection_parabasal_sagittal(wavelength) data.append([tangential, sagittal]) return data def _intersection_parabasal_tangential(self, wavelength, delta=1e-5): """Calculate the intersection of parabasal rays in tangential plane. Args: wavelength (float): The wavelength of the light. delta (float, optional): The delta value in normalized pupil y coordinates for pairs of parabasal rays. Defaults to 1e-5. Returns: numpy.ndarray: The calculated intersection values. """ Hx = be.zeros(2 * self.num_points) Hy = be.repeat(be.linspace(0, 1, self.num_points), 2) Px = be.zeros(2 * self.num_points) Py = be.tile(be.array([-delta, delta]), self.num_points) self.optic.trace_generic(Hx, Hy, Px, Py, wavelength=wavelength) M1 = self.optic.surfaces.M[-1, ::2] N1 = self.optic.surfaces.N[-1, ::2] M2 = self.optic.surfaces.M[-1, 1::2] N2 = self.optic.surfaces.N[-1, 1::2] y01 = self.optic.surfaces.y[-1, ::2] z01 = self.optic.surfaces.z[-1, ::2] y02 = self.optic.surfaces.y[-1, 1::2] z02 = self.optic.surfaces.z[-1, 1::2] t1 = (M2 * z01 - M2 * z02 - N2 * y01 + N2 * y02) / (M1 * N2 - M2 * N1) return t1 * N1 def _intersection_parabasal_sagittal(self, wavelength, delta=1e-5): """Calculate the intersection of parabasal rays in sagittal plane. Args: wavelength (float): The wavelength of the light. delta (float, optional): The delta value in normalized pupil y coordinates for pairs of parabasal rays. Defaults to 1e-5. Returns: numpy.ndarray: The calculated intersection values. """ Hx = be.zeros(2 * self.num_points) Hy = be.repeat(be.linspace(0, 1, self.num_points), 2) Px = be.tile(be.array([-delta, delta]), self.num_points) Py = be.zeros(2 * self.num_points) self.optic.trace_generic(Hx, Hy, Px, Py, wavelength=wavelength) L1 = self.optic.surfaces.L[-1, ::2] N1 = self.optic.surfaces.N[-1, ::2] L2 = self.optic.surfaces.L[-1, 1::2] N2 = self.optic.surfaces.N[-1, 1::2] x01 = self.optic.surfaces.x[-1, ::2] z01 = self.optic.surfaces.z[-1, ::2] x02 = self.optic.surfaces.x[-1, 1::2] z02 = self.optic.surfaces.z[-1, 1::2] t2 = (L2 * z01 - L2 * z02 - N2 * x01 + N2 * x02) / (L1 * N2 - L2 * N1) return t2 * N1