Source code for analysis.mtf_vs_field

"""MTF versus Field Analysis

This module enables the calculation of the Modulation Transfer Function (MTF)
versus field coordinate of an optical system.

Kramer Harrison, 2026
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt

import optiland.backend as be
from optiland.analysis.base import BaseAnalysis
from optiland.mtf.sampled import SampledMTF

if TYPE_CHECKING:
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure

    from optiland.optic import Optic


[docs] class MTFvsField(BaseAnalysis): """MTF versus Field Coordinate. This class is used to analyze the Modulation Transfer Function (MTF) versus the field coordinate of an optical system for specified spatial frequencies. Args: optic (Optic): the optical system. frequencies (list[float]): the spatial frequencies (in cycles/mm) to analyze. num_fields (int): the number of fields in the Y direction. Default is 32. wavelengths (str or list): the wavelengths to be analyzed. Default is 'all'. num_rays (int): the number of rays across the pupil in 1D for the SampledMTF calculation. Default is 128. override_limits (bool): If True, bypasses the limit on the number of frequencies and wavelengths to prevent cluttered plots. Default is False. """ MAX_FREQUENCIES = 5 MAX_WAVELENGTHS = 3 def __init__( self, optic: Optic, frequencies: list[float], num_fields: int = 32, wavelengths: str | list[float] = "all", num_rays: int = 128, override_limits: bool = False, ): self.frequencies = frequencies self.num_fields = num_fields self.num_rays = num_rays self._check_limits(override_limits, wavelengths, optic) # Base Analysis will set self.wavelengths super().__init__(optic, wavelengths) def _check_limits(self, override_limits: bool, wavelengths, optic): """Check to ensure inputs won't produce an overly cluttered plot.""" if override_limits: return if len(self.frequencies) > self.MAX_FREQUENCIES: raise ValueError( f"Number of frequencies ({len(self.frequencies)}) exceeds the " f"recommended limit of {self.MAX_FREQUENCIES} for clean plots. " "Set override_limits=True to bypass this check." ) from optiland.utils import resolve_wavelengths as _resolve_wls resolved_wls = _resolve_wls(optic, wavelengths) num_wl = len(resolved_wls) if num_wl > self.MAX_WAVELENGTHS: raise ValueError( f"Number of wavelengths ({num_wl}) exceeds the recommended " f"limit of {self.MAX_WAVELENGTHS} for clean plots. " "Set override_limits=True to bypass this check." ) def _generate_data(self): """Generate the MTF data across fields, wavelengths, and frequencies.""" fields = [(0.0, float(Hy)) for Hy in be.linspace(0.0, 1.0, self.num_fields)] self._field_coords = be.array(fields) # Pre-build list of frequencies to calculate at once freqs_to_calc = [] for freq in self.frequencies: freqs_to_calc.append((freq, 0.0)) freqs_to_calc.append((0.0, freq)) results = [] for wp in self.wavelengths: wl = wp.value wl_results = [{"tangential": [], "sagittal": []} for _ in self.frequencies] for field in fields: sampled_mtf = SampledMTF( optic=self.optic, field=field, wavelength=wl, num_rays=self.num_rays, distribution="uniform", zernike_terms=37, zernike_type="fringe", ) mtfs = sampled_mtf.calculate_mtf(freqs_to_calc) for i_freq in range(len(self.frequencies)): wl_results[i_freq]["tangential"].append(mtfs[2 * i_freq]) wl_results[i_freq]["sagittal"].append(mtfs[2 * i_freq + 1]) for i_freq in range(len(self.frequencies)): wl_results[i_freq]["tangential"] = be.array( wl_results[i_freq]["tangential"] ) wl_results[i_freq]["sagittal"] = be.array( wl_results[i_freq]["sagittal"] ) results.append(wl_results) return results
[docs] def view( self, fig_to_plot_on: Figure | None = None, figsize: tuple[float, float] = (8, 5), ) -> tuple[Figure, Axes]: """ Plots the MTF versus the field coordinate for each frequency and wavelength. Args: fig_to_plot_on (Figure, optional): An existing matplotlib Figure to plot on. If provided, the plot will be embedded in this figure. If None (default), a new figure will be created. figsize (tuple[float, float], optional): Size of the figure to create if `fig_to_plot_on` is None. Defaults to (8, 5). Returns: tuple[Figure, Axes]: The matplotlib Figure and Axes objects containing the plot. """ is_gui_embedding = fig_to_plot_on is not None if is_gui_embedding: current_fig = fig_to_plot_on current_fig.clear() ax = current_fig.add_subplot(111) else: current_fig, ax = plt.subplots(figsize=figsize) max_field = float(self.optic.fields.max_field) y_coords_normalized = be.to_numpy(self._field_coords[:, 1]) x_plot = y_coords_normalized * max_field # Determine X-axis label field_def = self.optic.fields.field_definition x_label = "Field Coordinate" if field_def is not None: field_name = field_def.__class__.__name__ if "Angle" in field_name: x_label = "Angle (deg)" elif "Height" in field_name: x_label = "Height (mm)" else: # Fallback if no specific type is set but fields exist x_label = "Field Coordinate" axes_color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"] for i_wl, wp in enumerate(self.wavelengths): for i_freq, freq in enumerate(self.frequencies): color_idx = (i_wl * len(self.frequencies) + i_freq) % len( axes_color_cycle ) color = axes_color_cycle[color_idx] tan_data = be.to_numpy(self.data[i_wl][i_freq]["tangential"]) sag_data = be.to_numpy(self.data[i_wl][i_freq]["sagittal"]) label_prefix = f"{freq} cyc/mm" if len(self.wavelengths) > 1: label_prefix += f", {wp.value:.4f} µm" ax.plot( x_plot, tan_data, linestyle="-", color=color, label=f"{label_prefix} (Tan)", ) ax.plot( x_plot, sag_data, linestyle="--", color=color, label=f"{label_prefix} (Sag)", ) ax.set_xlabel(x_label) ax.set_ylabel("Modulus of the OTF") ax.legend(bbox_to_anchor=(1.05, 0.5), loc="center left") if max_field > 0: ax.set_xlim(0, max_field) ax.set_ylim(0, 1.05) ax.grid(True, linestyle=":", alpha=0.5) current_fig.tight_layout() if is_gui_embedding and hasattr(current_fig, "canvas"): current_fig.canvas.draw_idle() return current_fig, ax