"""Thin film analysis class.
This provides thin film analysis class for optical response calculations
using the transfer matrix method (TMM).
Corentin Nannini, 2025
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
import optiland.backend as be
from optiland.colorimetry import core as color_core
from optiland.colorimetry.plotting import plot_cie_1931_chromaticity_diagram
if TYPE_CHECKING:
from .stack import ThinFilmStack
import matplotlib.pyplot as plt
# Physical constants
SPEED_OF_LIGHT = 299792458.0 # m/s
PLANCK_CONSTANT = 6.62607015e-34 # J⋅s
ELEMENTARY_CHARGE = 1.602176634e-19 # C
PLANCK_EV = PLANCK_CONSTANT / ELEMENTARY_CHARGE # eV⋅s ≈ 4.135667696e-15
# Type definitions
Pol = Literal["s", "p", "u"]
PolInput = Pol | list[Pol]
PlotType = Literal["R", "T", "A"]
Array: TypeAlias = Any # be.ndarray
WavelengthUnit = Literal[
"um", "nm", "frequency", "energy", "wavenumber", "relative_wavenumber"
]
AngleUnit = Literal["deg", "rad"]
[docs]
class SpectralAnalyzer:
"""Class for analyzing thin film stacks optical response (R/T/A).
Attributes:
stack (ThinFilmStack): The thin film stack to be analyzed.
"""
def __init__(self, stack: ThinFilmStack) -> None:
"""Initialize the SpectralAnalyzer with a ThinFilmStack.
Args:
stack (ThinFilmStack): The thin film stack to be analyzed.
"""
self.stack = stack
def _normalize_polarizations(self, polarization: PolInput) -> list[Pol]:
"""Convert polarization input to a list of polarizations."""
if isinstance(polarization, str):
return [polarization]
return polarization
def _get_line_style(self, pol_idx: int) -> dict[str, Any]:
"""Get line style for different polarizations."""
styles = [
{"linestyle": "-", "alpha": 0.8}, # solid for first
{"linestyle": "--", "alpha": 0.8}, # dashed for second
{"linestyle": ":", "alpha": 0.8}, # dotted for third
]
return styles[pol_idx % len(styles)]
@staticmethod
def _plot_array(values):
"""Convert backend arrays/tensors to matplotlib-safe array-like values."""
try:
return be.to_numpy(values)
except Exception:
if hasattr(values, "detach") and hasattr(values, "cpu"):
return values.detach().cpu().tolist()
return values
@staticmethod
def _plot_scalar(value) -> float:
"""Convert backend scalar/tensor to plain python float for plotting."""
try:
return float(value)
except Exception:
arr = SpectralAnalyzer._plot_array(value)
if isinstance(arr, list):
return float(arr[0])
return float(arr)
def _convert_to_wavelength_um(
self, values: float | Array, unit: WavelengthUnit
) -> Array:
"""Convert input values to wavelength in micrometers.
Args:
values: Input values in the specified unit
unit: Unit of the input values
Returns:
Wavelength values in micrometers
Raises:
ValueError: If relative_wavenumber is requested but no reference
wavelength is set
"""
values = be.atleast_1d(values)
if unit == "um":
return values
elif unit == "nm":
return values / 1000.0
elif unit == "frequency": # Hz to um
# λ (m) = c / ν, then convert to um
return (SPEED_OF_LIGHT / values) * 1e6
elif unit == "energy": # eV to um
# E = hf = hc/λ, so λ = hc/E
# λ (m) = (h * c) / E, then convert to um
return (PLANCK_EV * SPEED_OF_LIGHT / values) * 1e6
elif unit == "wavenumber": # cm⁻¹ to um
# k = 1/λ (cm⁻¹), so λ (cm) = 1/k, then convert to um
return 1e4 / values
elif unit == "relative_wavenumber":
if self.stack.reference_wl_um is None:
raise ValueError("reference_wl_um must be set for relative_wavenumber")
# k_rel = k / k_ref, so k = k_rel * k_ref = k_rel / λ_ref
# λ = 1/k = λ_ref / k_rel
return self.stack.reference_wl_um / values
else:
raise ValueError(f"Unknown wavelength unit: {unit}")
def _get_wavelength_axis_label(self, unit: WavelengthUnit) -> str:
"""Get the appropriate axis label for wavelength unit."""
labels = {
"um": r"$\lambda$ ($\mu$m)",
"nm": r"$\lambda$ (nm)",
"frequency": r"$\nu$ (Hz)",
"energy": r"$E$ (eV)",
"wavenumber": r"$k$ (cm$^{-1}$)",
"relative_wavenumber": r"$k/k_{\mathrm{ref}}$",
}
return labels[unit]
def _convert_wavelength_for_plotting(
self, wavelength_um: Array, unit: WavelengthUnit
) -> Array:
"""Convert wavelength in um to the desired unit for plotting."""
if unit == "um":
return wavelength_um
elif unit == "nm":
return wavelength_um * 1000.0
elif unit == "frequency": # um to Hz
return SPEED_OF_LIGHT / (wavelength_um * 1e-6)
elif unit == "energy": # um to eV
return (PLANCK_EV * SPEED_OF_LIGHT) / (wavelength_um * 1e-6)
elif unit == "wavenumber": # um to cm⁻¹
return 1e4 / wavelength_um
elif unit == "relative_wavenumber":
if self.stack.reference_wl_um is None:
raise ValueError("reference_wl_um must be set for relative_wavenumber")
return self.stack.reference_wl_um / wavelength_um
else:
raise ValueError(f"Unknown wavelength unit: {unit}")
def _convert_angle_to_radians(
self, angles: float | Array, unit: AngleUnit
) -> Array:
"""Convert angles to radians."""
angles = be.atleast_1d(angles)
if unit == "rad":
return angles
elif unit == "deg":
return be.deg2rad(angles)
else:
raise ValueError(f"Unknown angle unit: {unit}")
[docs]
def wavelength_view(
self,
wavelength_values: float | Array,
wavelength_unit: WavelengthUnit = "um",
aoi: float = 0.0,
aoi_unit: AngleUnit = "deg",
polarization: PolInput = "u",
to_plot: PlotType | list[PlotType] = "R",
ax: plt.Axes = None,
) -> tuple[plt.Figure, plt.Axes]:
"""Plot R/T/A vs wavelength (or equivalent units).
Args:
wavelength_values: Wavelength values in the specified unit
wavelength_unit: Unit of wavelength values
aoi: Angle of incidence (scalar)
aoi_unit: Unit of the angle
polarization: Polarization type(s) - single string or list
to_plot: Quantity(ies) to plot
ax: Optional matplotlib Axes
Returns:
Tuple of (figure, axes)
"""
# Convert inputs
wl_um = self._convert_to_wavelength_um(wavelength_values, wavelength_unit)
aoi_rad = float(self._convert_angle_to_radians(aoi, aoi_unit).item())
# Normalize inputs to lists
polarizations = self._normalize_polarizations(polarization)
if isinstance(to_plot, str):
to_plot = [to_plot]
# Convert wavelength back for plotting x-axis
x_values = self._convert_wavelength_for_plotting(wl_um, wavelength_unit)
x_plot = self._plot_array(x_values)
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.figure
# Plot for each polarization and quantity combination
for pol_idx, pol in enumerate(polarizations):
# Compute R/T/A for this polarization
rta_data = self.stack.compute_rtRTA(wl_um, aoi_rad, pol)
for quantity in to_plot:
if quantity not in ("R", "T", "A"):
raise ValueError("to_plot must be 'R', 'T', 'A' or a list of these")
# Get line style for this polarization
line_style = self._get_line_style(pol_idx)
ax.plot(
x_plot,
self._plot_array(rta_data[quantity].flatten()),
label=f"{quantity}, {pol}-pol, AOI={aoi}{aoi_unit}",
**line_style,
)
ax.set_xlabel(self._get_wavelength_axis_label(wavelength_unit))
ax.set_ylabel("Power fraction")
ax.set_xlim(self._plot_scalar(min(x_plot)), self._plot_scalar(max(x_plot)))
ax.set_ylim(0, 1)
ax.grid(True, alpha=0.3)
ax.legend()
return fig, ax
[docs]
def angular_view(
self,
aoi_values: float | Array,
aoi_unit: AngleUnit = "deg",
wavelength: float = 0.55,
wavelength_unit: WavelengthUnit = "um",
polarization: PolInput = "u",
to_plot: PlotType | list[PlotType] = "R",
ax: plt.Axes = None,
) -> tuple[plt.Figure, plt.Axes]:
"""Plot R/T/A vs angle of incidence.
Args:
aoi_values: Angle of incidence values in the specified unit
aoi_unit: Unit of the angle values
wavelength: Wavelength value (scalar)
wavelength_unit: Unit of the wavelength
polarization: Polarization type(s) - single string or list
to_plot: Quantity(ies) to plot
ax: Optional matplotlib Axes
Returns:
Tuple of (figure, axes)
"""
# Convert inputs
aoi_rad = self._convert_angle_to_radians(aoi_values, aoi_unit)
wl_um = float(
self._convert_to_wavelength_um(wavelength, wavelength_unit).item()
)
# Normalize inputs to lists
polarizations = self._normalize_polarizations(polarization)
if isinstance(to_plot, str):
to_plot = [to_plot]
# Convert angles back for plotting x-axis
x_values = be.atleast_1d(aoi_values)
x_plot = self._plot_array(x_values)
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.figure
# Plot for each polarization and quantity combination
for pol_idx, pol in enumerate(polarizations):
# Compute R/T/A for this polarization
rta_data = self.stack.compute_rtRTA(wl_um, aoi_rad, pol)
for quantity in to_plot:
if quantity not in ("R", "T", "A"):
raise ValueError("to_plot must be 'R', 'T', 'A' or a list of these")
# Get line style for this polarization
line_style = self._get_line_style(pol_idx)
# Create label parts
wl_axis_label = self._get_wavelength_axis_label(wavelength_unit)
wl_symbol = wl_axis_label.split("(")[0].strip()
ax.plot(
x_plot,
self._plot_array(rta_data[quantity].flatten()),
label=f"{quantity}, {pol}-pol, {wl_symbol}={wavelength}",
**line_style,
)
xlabel = r"AOI (°)" if aoi_unit == "deg" else r"AOI (rad)"
ax.set_xlabel(xlabel)
ax.set_ylabel("Power fraction")
ax.set_xlim(self._plot_scalar(min(x_plot)), self._plot_scalar(max(x_plot)))
ax.set_ylim(0, 1)
ax.grid(True, alpha=0.3)
ax.legend()
return fig, ax
[docs]
def map_view(
self,
wavelength_values: float | Array,
wavelength_unit: WavelengthUnit = "um",
aoi_values: float | Array = None,
aoi_unit: AngleUnit = "deg",
polarization: PolInput = "u",
to_plot: PlotType | list[PlotType] = "R",
colormap: str = "viridis",
fig: plt.Figure = None,
axs: plt.Axes | list[plt.Axes] = None,
) -> tuple[plt.Figure, plt.Axes | list[plt.Axes]]:
"""Plot 2D maps of R/T/A vs wavelength and angle of incidence.
Args:
wavelength_values: Wavelength values in the specified unit
wavelength_unit: Unit of wavelength values
aoi_values: Angle of incidence values in the specified unit
aoi_unit: Unit of the angle values
polarization: Polarization type(s) - single string or list
to_plot: Quantity(ies) to plot
fig: Optional matplotlib Figure
axs: Optional matplotlib Axes (single or list)
Returns:
Tuple of (figure, axes or list of axes)
"""
# Default AOI range if not provided
if aoi_values is None:
aoi_values = (
be.linspace(0, 80, 81)
if aoi_unit == "deg"
else be.linspace(0, be.deg2rad(80), 81)
)
# Convert inputs
wl_um = self._convert_to_wavelength_um(wavelength_values, wavelength_unit)
aoi_rad = self._convert_angle_to_radians(aoi_values, aoi_unit)
# Normalize inputs to lists
polarizations = self._normalize_polarizations(polarization)
if isinstance(to_plot, str):
to_plot = [to_plot]
# Convert back for plotting axes
wl_plot = self._convert_wavelength_for_plotting(wl_um, wavelength_unit)
aoi_plot = be.atleast_1d(aoi_values)
# Create meshgrid for plotting with backend-agnostic ij indexing
_aoi_grid, _wl_grid = be.meshgrid(aoi_plot, wl_plot)
WL, AOI = _wl_grid, _aoi_grid
WL_plot = self._plot_array(WL)
AOI_plot = self._plot_array(AOI)
# Create figure and axes
if fig is None or axs is None:
# Organize subplots: polarizations as columns, quantities as rows
nrows = len(to_plot)
ncols = len(polarizations)
fig, axs = plt.subplots(nrows, ncols, figsize=(6 * ncols, 4 * nrows))
# Ensure axs is always 2D array for consistent indexing
if nrows == 1 and ncols == 1:
axs = [[axs]]
elif nrows == 1:
axs = [axs]
elif ncols == 1:
axs = [[ax] for ax in axs]
else:
# If axs is provided, assume it's properly formatted
if not isinstance(axs, list):
axs = [[axs]]
elif not isinstance(axs[0], list):
axs = [axs]
# Plot each quantity and polarization combination
for qty_idx, quantity in enumerate(to_plot):
if quantity not in ("R", "T", "A"):
raise ValueError("to_plot must be 'R', 'T', 'A' or a list of these")
for pol_idx, pol in enumerate(polarizations):
# Compute R/T/A for this polarization
rta_data = self.stack.compute_rtRTA(wl_um, aoi_rad, pol)
ax_i = axs[qty_idx][pol_idx]
im = ax_i.pcolormesh(
WL_plot,
AOI_plot,
self._plot_array(rta_data[quantity]),
shading="auto",
vmin=0,
vmax=1,
cmap=colormap,
)
ax_i.set_xlabel(self._get_wavelength_axis_label(wavelength_unit))
ylabel = r"AOI (°)" if aoi_unit == "deg" else r"AOI (rad)"
ax_i.set_ylabel(ylabel)
ax_i.set_title(f"{quantity}, {pol}-pol")
# Add colorbar
cbar = fig.colorbar(im, ax=ax_i, label="Power fraction")
cbar.set_label("Power fraction")
fig.tight_layout()
# Return format depends on the number of plots
if len(to_plot) == 1 and len(polarizations) == 1:
return fig, axs[0][0]
elif len(to_plot) == 1:
return fig, axs[0] # Return list of axes for different polarizations
elif len(polarizations) == 1:
return fig, [
row[0] for row in axs
] # Return list of axes for different quantities
else:
return fig, axs # Return 2D array of axes
def _get_single_polarization(self, polarization: PolInput) -> Pol:
"""Normalize and validate a single polarization selection."""
polarizations = self._normalize_polarizations(polarization)
if len(polarizations) != 1:
raise ValueError("Color analysis requires a single polarization")
return polarizations[0]
def _get_rt_spectrum(
self,
wavelength_values: float | Array,
wavelength_unit: WavelengthUnit = "um",
aoi: float = 0.0,
aoi_unit: AngleUnit = "deg",
polarization: PolInput = "u",
quantity: Literal["R", "T"] = "R",
) -> tuple[list[float], list[float]]:
"""Return normalized power spectrum (R or T) in nm."""
if quantity not in ("R", "T"):
raise ValueError("quantity must be 'R' or 'T'")
wl_um = self._convert_to_wavelength_um(wavelength_values, wavelength_unit)
aoi_rad = float(self._convert_angle_to_radians(aoi, aoi_unit).item())
pol = self._get_single_polarization(polarization)
rta_data = self.stack.compute_rtRTA(wl_um, aoi_rad, pol)
wl_nm = self._convert_wavelength_for_plotting(wl_um, "nm")
values = rta_data[quantity].flatten()
return wl_nm.tolist(), values.tolist()
[docs]
def spectrum_to_xyY(
self,
wavelength_values: float | Array,
wavelength_unit: WavelengthUnit = "um",
aoi: float = 0.0,
aoi_unit: AngleUnit = "deg",
polarization: PolInput = "u",
quantity: Literal["R", "T"] = "R",
observer: Literal["2deg", "10deg"] = "2deg",
illuminant: list[float] | None = None,
) -> tuple[float, float, float]:
"""Compute xyY chromaticity from a normalized power spectrum.
Args:
wavelength_values: Wavelength values in the specified unit.
wavelength_unit: Unit of wavelength values.
aoi: Angle of incidence (scalar).
aoi_unit: Unit of the angle.
polarization: Polarization type(s).
quantity: Quantity to analyze ('R' or 'T').
observer: CIE standard observer ('2deg' or '10deg').
illuminant: Optional custom illuminant spectrum.
Returns:
tuple[float, float, float]: (x, y, Y) chromaticity coordinates.
"""
wavelengths_nm, values = self._get_rt_spectrum(
wavelength_values=wavelength_values,
wavelength_unit=wavelength_unit,
aoi=aoi,
aoi_unit=aoi_unit,
polarization=polarization,
quantity=quantity,
)
X, Y, Z = color_core.spectrum_to_xyz(
wavelengths=wavelengths_nm,
values=values,
illuminant=illuminant,
observer=observer,
)
x, y, Y = color_core.xyz_to_xyY(X, Y, Z)
return float(x), float(y), float(Y)
[docs]
def analyze_color(
self,
wavelength_values: float | Array,
wavelength_unit: WavelengthUnit = "um",
aoi: float = 0.0,
aoi_unit: AngleUnit = "deg",
polarization: PolInput = "u",
quantity: Literal["R", "T"] = "R",
observer: Literal["2deg", "10deg"] = "2deg",
illuminant: list[float] | None = None,
) -> dict[str, tuple[float, float, float] | tuple[int, int, int]]:
"""Return XYZ, xyY, and sRGB for a thin-film spectrum.
Args:
wavelength_values: Wavelength values in the specified unit.
wavelength_unit: Unit of wavelength values.
aoi: Angle of incidence (scalar).
aoi_unit: Unit of the angle.
polarization: Polarization type(s).
quantity: Quantity to analyze ('R' or 'T').
observer: CIE standard observer ('2deg' or '10deg').
illuminant: Optional custom illuminant spectrum.
Returns:
dict: Dictionary containing 'xyz', 'xyY', and 'sRGB' values.
"""
wavelengths_nm, values = self._get_rt_spectrum(
wavelength_values=wavelength_values,
wavelength_unit=wavelength_unit,
aoi=aoi,
aoi_unit=aoi_unit,
polarization=polarization,
quantity=quantity,
)
X, Y, Z = color_core.spectrum_to_xyz(
wavelengths=wavelengths_nm,
values=values,
illuminant=illuminant,
observer=observer,
)
x, y, Y = color_core.xyz_to_xyY(X, Y, Z)
r, g, b = color_core.xyz_to_srgb(X, Y, Z)
return {
"xyz": (float(X), float(Y), float(Z)),
"xyY": (float(x), float(y), float(Y)),
"sRGB": (int(r), int(g), int(b)),
}
[docs]
def plot_color_on_cie_1931(
self,
wavelength_values: float | Array,
wavelength_unit: WavelengthUnit = "um",
aoi: float = 0.0,
aoi_unit: AngleUnit = "deg",
polarization: PolInput = "u",
quantity: Literal["R", "T"] = "R",
observer: Literal["2deg", "10deg"] = "2deg",
illuminant: list[float] | None = None,
ax: plt.Axes | None = None,
color: Literal["no", "contour", "fill"] = "contour",
marker: str = "o",
marker_size: float = 6.0,
marker_color: str = "black",
) -> tuple[plt.Figure, plt.Axes]:
"""Plot the chromaticity point on a CIE 1931 diagram."""
fig, ax = plot_cie_1931_chromaticity_diagram(ax=ax, color=color)
x, y, _ = self.spectrum_to_xyY(
wavelength_values=wavelength_values,
wavelength_unit=wavelength_unit,
aoi=aoi,
aoi_unit=aoi_unit,
polarization=polarization,
quantity=quantity,
observer=observer,
illuminant=illuminant,
)
x_plot = self._plot_scalar(x)
y_plot = self._plot_scalar(y)
ax.plot(
x_plot, y_plot, marker=marker, markersize=marker_size, color=marker_color
)
ax.text(
x_plot + 0.02,
y_plot + 0.02,
f"x={x_plot:.4f}, y={y_plot:.4f}",
fontsize=9,
ha="left",
va="bottom",
)
return fig, ax