Source code for mtf.geometric

"""Geometric Modulation Transfer Function (MTF) Module.

This module provides the GeometricMTF class for computing the MTF
of an optical system based on spot diagram data.

Kramer Harrison, 2025
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt

import optiland.backend as be
from optiland.analysis import SpotDiagram
from optiland.utils import resolve_wavelength

if TYPE_CHECKING:
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure

    from optiland._types import BEArray, DistributionType, ScalarOrArray
    from optiland.optic import Optic


[docs] class GeometricMTF(SpotDiagram): """Smith, Modern Optical Engineering 3rd edition, Section 11.9 This class represents the Geometric MTF (Modulation Transfer Function) of an optical system. It inherits from the SpotDiagram class. Args: optic (Optic): The optical system for which to calculate the MTF. fields (str or list, optional): The field points at which to calculate the MTF. Defaults to 'all'. wavelength (str or float, optional): The wavelength at which to calculate the MTF. Defaults to 'primary'. num_rays (int, optional): The number of rays to trace for each field point. Defaults to 100. distribution (str, optional): The distribution of rays within each field point. Defaults to 'uniform'. num_points (int, optional): The number of points to sample in the MTF curve. Defaults to 256. max_freq (str or float, optional): The maximum frequency to consider in the MTF curve. Defaults to 'cutoff'. scale (bool, optional): Whether to scale the MTF curve using the diffraction-limited curve. Defaults to True. Attributes: num_points (int): The number of points to sample in the MTF curve. scale (bool): Whether to scale the MTF curve. max_freq (float): The maximum frequency to consider in the MTF curve. freq (be.ndarray): The frequency values for the MTF curve. mtf (list): The MTF data for each field point. Each element is a list containing tangential and sagittal MTF data (`be.ndarray`) for a field. diff_limited_mtf (be.ndarray): The diffraction-limited MTF curve. Methods: view(figsize=(12, 4), add_reference=False): Plots the MTF curve. _generate_mtf_data(): Generates the MTF data for each field point. _compute_field_data(xi, v, scale_factor): Computes the MTF data for a given field point. _plot_field(ax, mtf_data, field, color): Plots the MTF data for a given field point. """ def __init__( self, optic: Optic, fields: str | list = "all", wavelength: str | float = "primary", num_rays=100, distribution: DistributionType = "uniform", num_points=256, max_freq="cutoff", scale=True, ): self.num_points = num_points self.scale = scale resolved_wavelength = resolve_wavelength(optic, wavelength) # wavelength must be converted to mm for frequency units cycles/mm self.cutoff_freq = 1 / (resolved_wavelength * 1e-3 * optic.paraxial.FNO()) if max_freq == "cutoff": self.max_freq = self.cutoff_freq else: # If a specific max_freq is provided, use it directly self.max_freq = max_freq super().__init__(optic, fields, [resolved_wavelength], num_rays, distribution) self.freq = be.linspace(0, self.max_freq, num_points) self.mtf, self.diff_limited_mtf = self._generate_mtf_data()
[docs] def view( self, fig_to_plot_on: Figure | None = None, figsize: tuple[float, float] = (12, 4), add_reference: bool = False, ) -> tuple[Figure, Axes]: """Plots the MTF curve. Args: fig_to_plot_on (plt.Figure, optional): The figure to plot on. If provided, the existing figure is cleared and reused. Defaults to None, which creates a new figure. figsize (tuple, optional): The size of the figure. Defaults to (12, 4). add_reference (bool, optional): Whether to add the diffraction limit reference curve. Defaults to False. Returns: tuple: A tuple containing the matplotlib figure and axes objects. If `fig_to_plot_on` is provided, the existing figure is cleared and reused; otherwise, a new figure is created. """ is_gui_embedding = fig_to_plot_on is not None if is_gui_embedding: current_fig = fig_to_plot_on current_fig.clear() ax = current_fig.add_subplot(111) else: current_fig, ax = plt.subplots(figsize=figsize) for k, data in enumerate(self.mtf): self._plot_field(ax, data, self.fields[k], color=f"C{k}") if add_reference: ax.plot( be.to_numpy(self.freq), be.to_numpy(self.diff_limited_mtf), "k--", label="Diffraction Limit", ) ax.legend(bbox_to_anchor=(1.05, 0.5), loc="center left") ax.set_xlim([0, be.to_numpy(self.max_freq)]) ax.set_ylim([0, 1]) ax.set_xlabel("Frequency (cycles/mm)", labelpad=10) ax.set_ylabel("Modulation", labelpad=10) current_fig.tight_layout() ax.grid(alpha=0.25) if is_gui_embedding and hasattr(current_fig, "canvas"): current_fig.canvas.draw_idle() return current_fig, ax
[docs] def _generate_mtf_data(self): """Generates the MTF data for each field point. Returns: tuple: A tuple containing the MTF data for each field point and the scale factor. """ if self.scale: ratio = be.clip(self.freq / self.cutoff_freq, 0.0, 1.0) phi = be.arccos(ratio) scale_factor = 2 / be.pi * (phi - be.cos(phi) * be.sin(phi)) else: scale_factor = 1 mtf = [] # TODO: add option for polychromatic MTF for field_data in self.data: spot_data_item = field_data[0] xi, yi = spot_data_item.x, spot_data_item.y mtf.append( [ self._compute_field_data(yi, self.freq, scale_factor), self._compute_field_data(xi, self.freq, scale_factor), ], ) return mtf, scale_factor
[docs] def _compute_field_data( self, xi: BEArray, v: BEArray, scale_factor: ScalarOrArray ) -> BEArray: """Computes the MTF data for a given field point. Args: xi (be.ndarray): The coordinate values (x or y) of the field point. v (be.ndarray): The frequency values for the MTF curve. scale_factor (float or be.ndarray): The scale factor for the MTF curve. Returns: be.ndarray: The MTF data for the field point. """ A, edges = be.histogram(xi, bins=self.num_points + 1) x = (edges[1:] + edges[:-1]) / 2 dx = x[1] - x[0] mtf = be.copy(be.zeros_like(v)) # copy required to maintain gradient for k in range(len(v)): Ac = be.sum(A * be.cos(2 * be.pi * v[k] * x) * dx) / be.sum(A * dx) As = be.sum(A * be.sin(2 * be.pi * v[k] * x) * dx) / be.sum(A * dx) mtf[k] = be.sqrt(Ac**2 + As**2) return mtf * scale_factor
[docs] def _plot_field( self, ax: Axes, mtf_data: list[BEArray], field: tuple[float, float], color: str ): """Plots the MTF data for a given field point. Args: ax (matplotlib.axes.Axes): The matplotlib axes object. mtf_data (list[be.ndarray]): The MTF data for the field point, containing tangential and sagittal MTF arrays. field (tuple[float, float]): The field point coordinates (Hx, Hy). color (str): The color of the plotted lines. """ ax.plot( be.to_numpy(self.freq), be.to_numpy(mtf_data[0]), label=f"Hx: {field.coord[0]:.1f}, Hy: {field.coord[1]:.1f}, Tangential", color=color, linestyle="-", ) ax.plot( be.to_numpy(self.freq), be.to_numpy(mtf_data[1]), label=f"Hx: {field.coord[0]:.1f}, Hy: {field.coord[1]:.1f}, Sagittal", color=color, linestyle="--", )