Source code for mtf.fft

"""Fast Fourier Transform Modulation Transfer Function (FFTMTF) Module.

This module provides the ScalarFFTMTF, VectorialFFTMTF (via factory), and
FFTMTF factory class for computing the MTF of an optical system using FFT
techniques.

Kramer Harrison, 2025
"""

from __future__ import annotations

import optiland.backend as be
from optiland.psf.fft import ScalarFFTPSF, calculate_grid_size
from optiland.utils import get_working_FNO

from .base import BaseMTF


[docs] class ScalarFFTMTF(BaseMTF): """Scalar Fast Fourier Transform Modulation Transfer Function class. This class calculates and visualizes the Modulation Transfer Function (MTF) of an optic using the scalar FFT method. It is intended for use with unpolarized optical systems. Use the `FFTMTF` factory to automatically select between scalar and vectorial implementations based on the optic's polarization state. Args: optic (Optic): The optic for which to calculate the MTF. fields (str or list, optional): The field coordinates for which to calculate the MTF. Defaults to 'all'. wavelength (str or float, optional): The wavelength of light to use for the MTF calculation. Defaults to 'primary'. num_rays (int, optional): The number of rays to use for the PSF calculation, which is then used for MTF. Defaults to 128. grid_size (int or None, optional): The size of the grid used for the PSF/MTF calculation. If `None`, the grid size will be calculated from `num_rays` as documented in `optiland.psf.fft.ScalarFFTPSF`. Defaults to `None`. max_freq (str or float, optional): The maximum frequency for the MTF calculation. Defaults to 'cutoff'. strategy (str): The calculation strategy to use. Supported options are "chief_ray", "centroid_sphere", and "best_fit_sphere". Defaults to "chief_ray". remove_tilt (bool): If True, removes tilt and piston from the OPD data. Defaults to False. **kwargs: Additional keyword arguments passed to the strategy. Attributes: num_rays (int): The number of rays used for the MTF calculation. grid_size (int): The size of the grid used for the MTF calculation. max_freq (float): The maximum frequency for the MTF calculation. FNO (float): The F-number of the optic. psf (list): List of PSF data for each field. mtf (list): List of MTF data ([tangential, sagittal]) for each field. freq (be.ndarray): Array of frequency points for the MTF curve. """ def __init__( self, optic, fields: str | list = "all", wavelength: str | float = "primary", num_rays=128, grid_size=None, max_freq="cutoff", strategy="chief_ray", remove_tilt=False, **kwargs, ): if grid_size is None: self.num_rays, self.grid_size = calculate_grid_size(num_rays) else: self.num_rays = num_rays self.grid_size = grid_size super().__init__(optic, fields, wavelength, strategy, remove_tilt, **kwargs) self.FNO = [ get_working_FNO(self.optic, field, self.resolved_wavelength) for field in self.resolved_fields ] if max_freq == "cutoff": on_axis_fno = self._get_fno() self.max_freq = 1 / (self.resolved_wavelength * 1e-3 * on_axis_fno) else: self.max_freq = max_freq n_fields = len(self.resolved_fields) self.freq_tang = [ be.arange(self.grid_size // 2) * self._get_mtf_units_tang(k) for k in range(n_fields) ] self.freq_sag = [ be.arange(self.grid_size // 2) * self._get_mtf_units_sag(k) for k in range(n_fields) ] # Backward-compatible alias (tangential is the primary reference). self.freq = self.freq_tang def _calculate_psf(self): """Calculates and stores the Point Spread Function (PSF). This method uses the resolved field points and wavelength from BaseMTF, and explicitly uses the scalar FFT PSF implementation. """ self.psf = [ ScalarFFTPSF( self.optic, field, self.resolved_wavelength, self.num_rays, self.grid_size, self.strategy, self.remove_tilt, **self.strategy_kwargs, ).psf for field in self.resolved_fields ] def _plot_field_mtf(self, ax, field_index, mtf_field_data, color): """Plots the MTF data for a single field for ScalarFFTMTF. Args: ax (matplotlib.axes.Axes): The matplotlib axes object. field_index (int): The index of the current field in self.resolved_fields. mtf_field_data (list): A list containing tangential and sagittal MTF data (be.ndarray) for the field. color (str): The color to use for plotting this field. """ current_field_label_info = self.resolved_fields[field_index] # Plot tangential MTF (uses the image-plane-corrected frequency axis) ax.plot( be.to_numpy(self.freq_tang[field_index]), be.to_numpy(mtf_field_data[0]), label=( f"Hx: {current_field_label_info[0]:.1f}, " f"Hy: {current_field_label_info[1]:.1f}, Tangential" ), color=color, linestyle="-", ) # Plot sagittal MTF (no tilt in the sagittal plane — use per-field axis) ax.plot( be.to_numpy(self.freq_sag[field_index]), be.to_numpy(mtf_field_data[1]), label=( f"Hx: {current_field_label_info[0]:.1f}, " f"Hy: {current_field_label_info[1]:.1f}, Sagittal" ), color=color, linestyle="--", ) def _generate_mtf_data(self): """Generates the MTF data for each field. The OTF is computed as the 2D FFT of the PSF. The DC component (zero spatial frequency) is located at index ``(grid_size // 2, grid_size // 2)`` after ``fftshift``. The tangential and sagittal MTF slices are extracted from that center outward and normalized by the DC value so that MTF(0) = 1, consistent with the incoherent imaging convention used by OpticStudio. Returns: list: A list of MTF data for each field. Each element is a list ``[tangential_mtf, sagittal_mtf]`` where each is a 1-D array of length ``grid_size // 2`` with values in ``[0, 1]``. """ mtf_data = [be.abs(be.fft.fftshift(be.fft.fft2(psf))) for psf in self.psf] mtf = [] center = self.grid_size // 2 for data in mtf_data: # Extract 1-D slices from the DC bin outward, clipped to grid_size // 2 tangential = data[center:, center][:center] sagittal = data[center, center:][:center] # Normalize by the DC value (OTF at zero frequency = total PSF power). # Physical MTF must satisfy MTF(0) = 1; the DC bin is always the maximum # for a well-behaved incoherent system. dc_value = data[center, center] if dc_value == 0: norm_tangential = be.zeros_like(tangential) norm_sagittal = be.zeros_like(sagittal) else: norm_tangential = tangential / dc_value norm_sagittal = sagittal / dc_value # Guard against floating-point overshoot beyond the physical [0, 1] range. norm_tangential = be.clip(norm_tangential, 0.0, 1.0) norm_sagittal = be.clip(norm_sagittal, 0.0, 1.0) mtf.append([norm_tangential, norm_sagittal]) return mtf def _get_mtf_units_tang(self, k): """Tangential frequency step (cycles/mm) with image-plane correction. The chief ray tilts in the tangential plane. Converting the per-field working F/# (measured in the chief-ray frame) to the flat image plane introduces a cos(θ_chief) ≈ FNO_on/FNO_off compression: df_tang = df_chief * (FNO_on / FNO_off) For on-axis fields FNO_on == FNO_off and the correction is unity. Args: k (int): Field index. Returns: float: Tangential frequency step in cycles/mm. """ on_axis_fno = self._get_fno() off_axis_fno = self.FNO[k] df_chief = 1 / ( (self.num_rays - 1) * self.resolved_wavelength * 1e-3 * off_axis_fno ) return df_chief * (on_axis_fno / off_axis_fno) def _get_mtf_units_sag(self, k): """Sagittal frequency step (cycles/mm). There is no chief-ray tilt in the sagittal plane, so the per-field working F/# (chief-ray frame) is used directly. Args: k (int): Field index. Returns: float: Sagittal frequency step in cycles/mm. """ off_axis_fno = self.FNO[k] return 1 / ( (self.num_rays - 1) * self.resolved_wavelength * 1e-3 * off_axis_fno )
[docs] class FFTMTF: """Factory class for generating either a Vectorial or Scalar FFT MTF. This class inspects the optical system's polarization state to determine which FFT MTF implementation to instantiate. If polarization is enabled, it returns a `VectorialFFTMTF`. Otherwise, it returns a `ScalarFFTMTF`. Args: optic (Optic): The optical system object. fields (str or list, optional): The field coordinates for which to calculate the MTF. Defaults to 'all'. wavelength (str or float, optional): The wavelength of light to use. Defaults to 'primary'. num_rays (int, optional): The number of rays to use for the PSF calculation. Defaults to 128. grid_size (int or None, optional): The FFT grid size. Defaults to None. max_freq (str or float, optional): The maximum frequency for the MTF calculation. Defaults to 'cutoff'. strategy (str): The wavefront calculation strategy. Defaults to "chief_ray". remove_tilt (bool): If True, removes tilt from OPD. Defaults to False. **kwargs: Additional keyword arguments passed to the strategy. """ def __new__( cls, optic, fields: str | list = "all", wavelength: str | float = "primary", num_rays=128, grid_size=None, max_freq="cutoff", strategy="chief_ray", remove_tilt=False, **kwargs, ): if optic.polarization_state is not None: from optiland.mtf.vectorial_fft import VectorialFFTMTF return VectorialFFTMTF( optic=optic, fields=fields, wavelength=wavelength, num_rays=num_rays, grid_size=grid_size, max_freq=max_freq, strategy=strategy, remove_tilt=remove_tilt, **kwargs, ) else: return ScalarFFTMTF( optic=optic, fields=fields, wavelength=wavelength, num_rays=num_rays, grid_size=grid_size, max_freq=max_freq, strategy=strategy, remove_tilt=remove_tilt, **kwargs, )