"""Base Modulation Transfer Function (FFTMTF) Module.
This module contains the abstract base class for MTF calculations
based on the PSF. This includes, e.g., the FFT-based method
and the Huygen-Fresnel-based method.
Kramer Harrison, 2025
"""
from __future__ import annotations
import abc
from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
import optiland.backend as be
from optiland.utils import get_working_FNO, resolve_fields, resolve_wavelength
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
[docs]
class BaseMTF(abc.ABC):
"""Base class for MTF computations based on a PSF calculation.
Attributes:
optic: The optical system.
fields: Original field point specification (e.g., "all" or list).
wavelength: Original wavelength specification (e.g., "primary" or value).
resolved_fields: List of actual field coordinates (Hx, Hy) to be used.
resolved_wavelength: Actual wavelength value (in µm) to be used.
"""
def __init__(
self,
optic,
fields: str | list,
wavelength: str | float,
strategy="chief_ray",
remove_tilt=False,
**kwargs,
):
"""Initializes BaseMTF and resolves field/wavelength values.
Args:
optic: The optical system.
fields: The field points for MTF calculation. Can be "all" to
use all fields from the optic, or a list of field coordinates.
wavelength: The wavelength for MTF calculation. Can be "primary"
to use the optic's primary wavelength, or a specific
wavelength value (typically in µm).
strategy (str): The calculation strategy to use. Supported options are
"chief_ray", "centroid_sphere", and "best_fit_sphere".
Defaults to "chief_ray".
remove_tilt (bool): If True, removes tilt and piston from the OPD data.
Defaults to False.
**kwargs: Additional keyword arguments passed to the strategy.
"""
self.optic = optic
self.fields = fields
self.wavelength = wavelength
self.strategy = strategy
self.remove_tilt = remove_tilt
self.strategy_kwargs = kwargs
resolved = resolve_fields(optic, fields)
# extract plain (x, y) coords; BaseMTF uses single wavelength (no weighting)
self.resolved_fields = [fp.coord for fp in resolved]
self.resolved_wavelength = resolve_wavelength(optic, wavelength)
self._calculate_psf()
self.mtf = self._generate_mtf_data()
@abc.abstractmethod
def _generate_mtf_data(self):
"""Generates and returns MTF data."""
pass
@abc.abstractmethod
def _plot_field_mtf(self, ax, field_index, mtf_field_data, color):
"""Plots the MTF data for a single field on the given axes.
Args:
ax (matplotlib.axes.Axes): The matplotlib axes object.
field_index (int): The index of the current field.
mtf_field_data (any): The MTF data for this specific field.
Subclasses will define its structure.
color (str): The color to use for plotting this field.
"""
pass
@abc.abstractmethod
def _calculate_psf(self):
"""Calculates and potentially stores the Point Spread Function."""
pass
[docs]
def view(
self,
fig_to_plot_on: Figure | None = None,
figsize: tuple[float, float] = (12, 4),
add_reference: bool = False,
) -> tuple[Figure, Axes]:
"""Visualizes the Modulation Transfer Function (MTF).
This method sets up the plot and iterates through field data,
calling `_plot_field_mtf` for each field's specific plotting.
Subclasses must ensure `self.mtf`, `self.freq`, and `self.max_freq`
are populated before calling this method. `self.resolved_fields`
(from __init__) is also used.
Args:
fig_to_plot_on (plt.Figure, optional): The figure to plot on.
If None, a new figure will be created. Defaults to None.
figsize (tuple, optional): The size of the figure.
Defaults to (12, 4).
add_reference (bool, optional): Whether to overlay the theoretical
diffraction-limited MTF curve for a clear circular aperture.
The reference is computed using the *on-axis* working F/# and
the resolved wavelength. Defaults to False.
Returns:
tuple: A tuple containing the figure and axes objects.
"""
is_gui_embedding = fig_to_plot_on is not None
if is_gui_embedding:
current_fig = fig_to_plot_on
current_fig.clear()
ax = current_fig.add_subplot(111)
else:
current_fig, ax = plt.subplots(figsize=figsize)
for k, field_mtf_item in enumerate(self.mtf):
self._plot_field_mtf(ax, k, field_mtf_item, color=f"C{k}")
if add_reference:
# The reference curve shows the theoretical diffraction-limited OTF for
# a clear circular aperture evaluated at the *on-axis* working F/# and
# the resolved wavelength. For off-axis fields the working F/# may
# differ; the on-axis reference provides a consistent, field-independent
# benchmark.
#
# Formula (incoherent OTF for a circular aperture):
# MTF(u) = (2/π)(arccos(u) − u √(1 − u²)) for u ∈ [0, 1]
# where u = f / f_c and f_c = 1 / (λ · F/#_on-axis).
on_axis_fno = self._get_fno()
cutoff_freq = 1 / (self.resolved_wavelength * 1e-3 * on_axis_fno)
ref_freq = be.linspace(0, cutoff_freq, 256)
ratio = be.clip(ref_freq / cutoff_freq, 0.0, 1.0)
phi = be.arccos(ratio)
diff_limited_mtf = (2 / be.pi) * (phi - be.cos(phi) * be.sin(phi))
ax.plot(
be.to_numpy(ref_freq),
be.to_numpy(diff_limited_mtf),
"k--",
label="Diffraction Limit (on-axis)",
)
ax.legend(bbox_to_anchor=(1.05, 0.5), loc="center left")
ax.set_xlim([0, be.to_numpy(self.max_freq)])
ax.set_ylim([0, 1])
ax.set_xlabel("Frequency (cycles/mm)", labelpad=10)
ax.set_ylabel("Modulation", labelpad=10)
current_fig.tight_layout()
ax.grid(alpha=0.25)
if is_gui_embedding and hasattr(current_fig, "canvas"):
current_fig.canvas.draw_idle()
return current_fig, ax
def _get_fno(self):
"""Calculates the working F-number of the optical system for the
single defined field point and given wavelength.
Returns:
float: The working F-number.
"""
return get_working_FNO(
optic=self.optic,
field=(0, 0), # always calculate on-axis F/#
wavelength=self.resolved_wavelength,
)