"""Through Focus MTF
This module provides a class for performing through-focus MTF
analysis, calculating the MTF at various focal planes for a given
spatial frequency, wavelength, and fields.
Kramer Harrison, 2025
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import make_interp_spline
import optiland.backend as be
from optiland.analysis.through_focus import ThroughFocusAnalysis
from optiland.mtf.sampled import SampledMTF
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
[docs]
class ThroughFocusMTF(ThroughFocusAnalysis):
"""
Performs Modulation Transfer Function (MTF) analysis across a range of focal
positions.
This class calculates the MTF at a specified spatial frequency for both
tangential and sagittal orientations at multiple focal planes around the
nominal focus of an optical system.
The results include tangential and sagittal MTF values for each analyzed
field at each focal step.
Args:
optic: The optiland.optic.Optic object to analyze.
spatial_frequency (float): The single spatial frequency (in cycles/mm)
at which to calculate MTF. The calculation will be performed
for both tangential (fx = spatial_frequency, fy = 0) and
sagittal (fx = 0, fy = spatial_frequency) orientations.
delta_focus (float, optional): The increment of focal shift in mm.
Defaults to 0.1.
num_steps (int, optional): The number of focal planes to analyze
before and after the nominal focus. Must be an odd integer.
Defaults to 5.
fields (list[tuple[float, float]] | str, optional): Fields for
analysis. If "all", uses all fields from `optic.fields`.
Defaults to "all".
wavelength (float | str, optional): The wavelength (in µm) for
analysis. If "primary", uses the primary wavelength from
`optic.primary_wavelength`. Defaults to "primary".
num_rays (int, optional): The number of rays across the pupil in 1D
for the SampledMTF calculation. Defaults to 64.
"""
MAX_STEPS = 15
MIN_STEPS = 1
def __init__(
self,
optic,
spatial_frequency: float,
delta_focus: float = 0.1,
num_steps: int = 5,
fields: list[tuple[float, float]] | str = "all",
wavelength: float | str = "primary",
num_rays: int = 128,
):
self.spatial_frequency = spatial_frequency
self.num_rays = num_rays
if wavelength == "primary":
self.wavelength = optic.primary_wavelength
else:
self.wavelength = wavelength
super().__init__(
optic,
delta_focus=delta_focus,
num_steps=num_steps,
fields=fields,
wavelengths=[self.wavelength],
)
def _perform_analysis_at_focus(self):
"""
Performs the MTF analysis at the current focal position for all fields.
This method is called by the base class for each focal step. It
calculates the tangential and sagittal MTF values for the specified
spatial frequency for each field defined in `self.fields`.
Returns:
list[dict[str, float]]: A list of dictionaries, where each
dictionary corresponds to a field and contains the tangential
and sagittal MTF values, e.g.,
[{'tangential': 0.5, 'sagittal': 0.45}, # Field 1
{'tangential': 0.3, 'sagittal': 0.28}] # Field 2
"""
results_at_this_focus = []
for fp in self.fields:
field_coord = fp.coord
sampled_mtf = SampledMTF(
optic=self.optic,
field=field_coord,
wavelength=self.wavelength,
num_rays=self.num_rays,
distribution="uniform",
zernike_terms=37,
zernike_type="fringe",
)
freq_tan = (self.spatial_frequency, 0.0)
freq_sag = (0.0, self.spatial_frequency)
mtf_t = sampled_mtf.calculate_mtf([freq_tan])[0]
mtf_s = sampled_mtf.calculate_mtf([freq_sag])[0]
results_at_this_focus.append({"tangential": mtf_t, "sagittal": mtf_s})
return results_at_this_focus
[docs]
def view(
self,
fig_to_plot_on: Figure | None = None,
figsize: tuple[float, float] = (12, 4),
) -> tuple[Figure, Axes]:
"""
Visualizes the through-focus Modulation Transfer Function (MTF) results for
each analyzed field.
This method generates a plot of tangential and sagittal MTF values as a function
of defocus for each field position.
If enough data points are available, spline smoothing is applied to the MTF
data to produce smoother curves.
Otherwise, raw data points are plotted. The plot displays MTF at the spatial
frequency specified during initialization.
Parameters
----------
fig_to_plot_on : plt.Figure, optional
An existing matplotlib Figure to plot on. If provided, the plot will be
embedded in this figure.
If None (default), a new figure will be created.
figsize : tuple of float, optional
Size of the figure to create if `fig_to_plot_on` is None.
Default is (12, 4).
Returns
-------
tuple[Figure, Axes]
The matplotlib Figure and Axes objects containing the plot.
Notes
-----
- Spline smoothing uses cubic splines if at least 4 data points are available,
linear splines for 2-3 points, and raw data is plotted if fewer points
are present.
- The legend displays the field coordinates (Hx, Hy) for each curve.
- The plot includes grid lines and is formatted for clarity.
"""
is_gui_embedding = fig_to_plot_on is not None
if is_gui_embedding:
current_fig = fig_to_plot_on
current_fig.clear()
ax = current_fig.add_subplot(111)
else:
current_fig, ax = plt.subplots(figsize=figsize)
np_positions = be.to_numpy(be.asarray(self.positions))
np_nominal_focus = be.to_numpy(be.asarray(self.nominal_focus))
defocus_values_np = np_positions - np_nominal_focus
for i_field, fp in enumerate(self.fields):
field_coord = fp.coord
mtf_t_values = be.to_numpy(
be.asarray(
[
self.results[i_pos][i_field]["tangential"]
for i_pos in range(self.num_steps)
]
)
)
mtf_s_values = be.to_numpy(
be.asarray(
[
self.results[i_pos][i_field]["sagittal"]
for i_pos in range(self.num_steps)
]
)
)
num_data_points = len(defocus_values_np)
Hx, Hy = field_coord
# Determine spline order k based on number of points
if num_data_points >= 4: # Need at least k+1 points for spline of degree k
k = 3 # Cubic spline
elif num_data_points >= 2:
k = 1 # Linear spline
else:
k = 0 # No spline, just plot points
if k == 0:
# Plot raw data if spline conditions not met or k was initially 0
ax.plot(
defocus_values_np,
np.clip(mtf_t_values, 0, 1),
linestyle="-",
marker="o",
markersize=4,
color=f"C{i_field}",
label=f"Hx: {Hx:.1f}, Hy: {Hy:.1f}, Tangential (raw)",
)
ax.plot(
defocus_values_np,
np.clip(mtf_s_values, 0, 1),
linestyle="--",
marker="x",
markersize=4,
color=f"C{i_field}",
label=f"Hx: {Hx:.1f}, Hy: {Hy:.1f}, Sagittal (raw)",
)
else:
defocus_smooth = np.linspace(
defocus_values_np.min(), defocus_values_np.max(), 256
)
spl_t = make_interp_spline(
defocus_values_np, mtf_t_values, k=k, check_finite=False
)
mtf_t_smooth = spl_t(defocus_smooth)
spl_s = make_interp_spline(
defocus_values_np, mtf_s_values, k=k, check_finite=False
)
mtf_s_smooth = spl_s(defocus_smooth)
ax.plot(
defocus_smooth,
np.clip(mtf_t_smooth, 0, 1),
linestyle="-",
color=f"C{i_field}",
label=f"Hx: {Hx:.1f}, Hy: {Hy:.1f}, Tangential",
)
ax.plot(
defocus_smooth,
np.clip(mtf_s_smooth, 0, 1),
linestyle="--",
color=f"C{i_field}",
label=f"Hx: {Hx:.1f}, Hy: {Hy:.1f}, Sagittal",
)
ax.set_title(
f"Through-Focus MTF at {self.spatial_frequency} "
f"cycles/mm, λ={self.wavelength:.3f} µm"
)
ax.set_xlabel("Defocus (mm)")
ax.set_ylabel("MTF")
ax.set_xlim([np.min(defocus_values_np), np.max(defocus_values_np)])
ax.set_ylim([0, 1.05])
ax.legend(bbox_to_anchor=(1.05, 0.5), loc="center left")
ax.grid(True, linestyle=":", alpha=0.5)
current_fig.tight_layout()
if is_gui_embedding and hasattr(current_fig, "canvas"):
current_fig.canvas.draw_idle()
return current_fig, ax