"""Thin film optics stack class with inlined TMM.
This class encapsulates both the stack structure (incident/substrate, layers)
and the numerical Transfer Matrix Method (TMM) to compute complex amplitude
coefficients (r, t) and power coefficients (R, T, A) for s, p and unpolarized
cases.
Corentin Nannini, 2025
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
import optiland.backend as be
from optiland.materials import IdealMaterial
from .core import _tmm_coh
from .layer import Layer
if TYPE_CHECKING:
from optiland.materials import BaseMaterial
import re
import matplotlib.pyplot as plt
Pol = Literal["s", "p", "u"]
PlotType = Literal["R", "T", "A"]
Array: TypeAlias = Any # be.ndarray
[docs]
@dataclass
class ThinFilmStack:
"""Multilayer thin-film stack with inlined TMM calculations.
This class encapsulates both the stack structure (incident/substrate, layers)
and the numerical Transfer Matrix Method (TMM) to compute complex amplitude
coefficients (r, t) and power coefficients (R, T, A) for s, p and unpolarized
cases.
Units and conventions:
- Wavelength in microns (µm) internally; convenience helpers accept nm.
- AOI in radians internally; convenience helpers accept degrees.
- Layers are ordered from the incident side to the substrate side.
Args:
incident_material (BaseMaterial): Incident medium (e.g., air).
substrate_material (BaseMaterial): Substrate medium (e.g., glass).
layers (list[Layer], optional): Ordered layers between incident and
substrate. Defaults to None.
reference_wl_um (float | None, optional): Reference wavelength for
thickness quarter-wave calculations. Defaults to None.
reference_AOI_deg (float | None, optional): Reference angle of
incidence in degrees for thickness quarter-wave calculations.
Defaults to 0 (normal incidence).
Examples:
>>> from optiland.materials import IdealMaterial, Material
>>> from optiland.thin_film import ThinFilmStack
>>> air, glass = IdealMaterial(1.0), IdealMaterial(1.52)
>>> tf = ThinFilmStack(incident_material=air, substrate_material=glass)
>>> # 100 nm SiO2 on glass
>>> SiO2 = Material("SiO2", reference="Gao")
>>> tf.add_layer_nm(SiO2, 100.0)
>>> R = tf.reflectance_nm_deg([550.0], [0.0], polarization="s")
>>> T = tf.transmittance_nm_deg([550.0], [0.0], polarization="s")
>>> A = tf.absorptance_nm_deg([550.0], [0.0], polarization="s")
"""
incident_material: BaseMaterial
substrate_material: BaseMaterial
layers: list[Layer] = field(default_factory=list)
reference_wl_um: float | None = None
reference_AOI_deg: float | None = 0
def __str__(self):
"""Return a concise summary of the stack structure."""
inc_name = getattr(
self.incident_material, "name", self.incident_material.__class__.__name__
)
sub_name = getattr(
self.substrate_material, "name", self.substrate_material.__class__.__name__
)
if not self.layers:
return (
f"ThinFilmStack(incident={inc_name}, substrate={sub_name}, layers=[])"
)
layer_lines: list[str] = []
for i, layer in enumerate(self.layers, start=1):
material_name = getattr(
layer.material, "name", layer.material.__class__.__name__
)
layer_lines.append(
f" {i}. {material_name} ({layer.thickness_um * 1000:.1f} nm)"
)
layers_str = "\n".join(layer_lines)
total_th = sum(layer.thickness_um for layer in self.layers)
return (
f"ThinFilmStack Summary\n"
f"---------------------\n"
f"Incident: {inc_name}\n"
f"Substrate: {sub_name}\n"
f"Layers:\n{layers_str}\n"
f"---------------------\n"
f"Total Thickness: {total_th * 1000:.1f} nm"
)
[docs]
def copy(
self,
incident: BaseMaterial | None = None,
substrate: BaseMaterial | None = None,
):
"""Creates a copy of the stack with optionally new surrounding materials."""
return ThinFilmStack(
incident_material=incident if incident else self.incident_material,
substrate_material=substrate if substrate else self.substrate_material,
layers=self.layers.copy(),
reference_wl_um=self.reference_wl_um,
reference_AOI_deg=self.reference_AOI_deg,
)
# ----- structure helpers -----
[docs]
def add_layer(
self, material: BaseMaterial, thickness_um: float, name: str | None = None
) -> ThinFilmStack:
"""Append a layer to the stack.
Args:
material: Optiland material providing n(λ), k(λ).
thickness_um: Thickness in microns (µm).
name: Optional label.
Returns:
self for chaining.
"""
self.layers.append(Layer(material, thickness_um, name))
return self
[docs]
def add_layer_nm(
self, material: BaseMaterial, thickness_nm: float, name: str | None = None
) -> ThinFilmStack:
"""Append a layer, thickness in nm.
Args:
material: Optiland material providing n(λ), k(λ).
thickness_nm: Thickness in nanometers.
name: Optional label.
"""
return self.add_layer(material, thickness_nm / 1000.0, name)
[docs]
def add_layer_qwot(
self,
material: BaseMaterial,
qwot_thickness: float = 1.0,
name: str | None = None,
) -> ThinFilmStack:
"""Append a quarter-wave optical thickness (QWOT) layer at the reference
wavelength and angle of incidence.
Args:
material: Optiland material providing n(λ), k(λ).
name: Optional label.
Raises:
ValueError: If reference_wl_um is not set.
"""
if self.reference_wl_um is None:
raise ValueError("reference_wl_um must be set for adding QWOT layer")
wl_um = self.reference_wl_um
th_rad = 0.0
if self.reference_AOI_deg is not None:
th_rad = be.deg2rad(self.reference_AOI_deg)
n = float(be.atleast_1d(material.n(wl_um))[0]) # to ensure scalar float
thickness_um = qwot_thickness * wl_um / (4 * n * be.cos(th_rad))
return self.add_layer(thickness_um=thickness_um, material=material, name=name)
# ----- units helpers -----
@staticmethod
def _to_um(wavelength_um_or_nm: float | Array, assume_nm: bool = False):
arr = be.atleast_1d(wavelength_um_or_nm)
return arr / 1000.0 if assume_nm else arr
@staticmethod
def _deg_to_rad(angle_deg: float | Array):
return be.atleast_1d(angle_deg) * (be.pi / 180.0)
# ----- public API: coefficients -----
[docs]
def compute_rtRTA(
self,
wavelength_um: float | Array,
aoi_rad: float | Array = 0.0,
polarization: Pol = "u",
) -> dict[str, Any]:
"""Compute complex and power coefficients over λ×θ grids.
Args:
wavelength_um: Wavelength(s) in microns (scalar or array). Use helpers
for nm.
aoi_rad: Angle(s) of incidence in radians (scalar or array). Use helpers
for degrees.
polarization: 's', 'p' or 'u' (unpolarized averages powers of s and p).
default 'u'.
Returns:
Dict with keys 'r','t','R','T','A'. Shapes are (Nλ, Nθ).
Note:
- For unpolarized 'u', r, t are s-polarization amplitudes; R, T, A are
averaged powers.
"""
wl = be.atleast_1d(wavelength_um)
th = be.atleast_1d(aoi_rad)
if polarization in ("s", "p"):
r, t, R, T, A = _tmm_coh(self, wl[:, None], th[None, :], polarization)
return {"r": r, "t": t, "R": R, "T": T, "A": A}
elif polarization == "u":
rs, ts, Rs, Ts, As = _tmm_coh(self, wl[:, None], th[None, :], "s")
rp, tp, Rp, Tp, Ap = _tmm_coh(self, wl[:, None], th[None, :], "p")
R = 0.5 * (Rs + Rp)
T = 0.5 * (Ts + Tp)
A = 0.5 * (As + Ap)
# Return s-amplitudes for reference; intensities are averaged
return {"r": rs, "t": ts, "R": R, "T": T, "A": A}
else:
raise ValueError("polarization must be 's', 'p' or 'u'")
[docs]
def compute_rtRTA_elementwise(
self,
wavelength_um: float | Array,
aoi_rad: float | Array = 0.0,
polarization: Pol = "u",
) -> dict[str, Any]:
"""Compute complex and power coefficients element-wise (no grid).
Use this when wavelength and aoi have matching shapes (e.g. per-ray).
"""
wl = be.atleast_1d(wavelength_um)
th = be.atleast_1d(aoi_rad)
if polarization in ("s", "p"):
r, t, R, T, A = _tmm_coh(self, wl, th, polarization)
return {"r": r, "t": t, "R": R, "T": T, "A": A}
elif polarization == "u":
rs, ts, Rs, Ts, As = _tmm_coh(self, wl, th, "s")
rp, tp, Rp, Tp, Ap = _tmm_coh(self, wl, th, "p")
R = 0.5 * (Rs + Rp)
T = 0.5 * (Ts + Tp)
A = 0.5 * (As + Ap)
return {"r": rs, "t": ts, "R": R, "T": T, "A": A}
else:
raise ValueError("polarization must be 's', 'p' or 'u'")
[docs]
def compute_rtRAT_nm_deg(
self,
wavelength_nm: float | Array,
aoi_deg: float | Array = 0.0,
polarization: Pol = "u",
) -> dict[str, float | Array]:
"""Same as coefficients() but inputs in nm and degrees."""
wl_um = self._to_um(wavelength_nm, assume_nm=True)
th_rad = self._deg_to_rad(aoi_deg)
return self.compute_rtRTA(wl_um, th_rad, polarization)
# ----- convenience getters -----
[docs]
def reflectance(
self,
wavelength_um: float | Array,
aoi_rad: float | Array = 0.0,
polarization: Pol = "u",
) -> Array:
return self.compute_rtRTA(wavelength_um, aoi_rad, polarization)["R"]
[docs]
def transmittance(
self,
wavelength_um: float | Array,
aoi_rad: float | Array = 0.0,
polarization: Pol = "u",
) -> Array:
return self.compute_rtRTA(wavelength_um, aoi_rad, polarization)["T"]
[docs]
def absorptance(
self,
wavelength_um: float | Array,
aoi_rad: float | Array = 0.0,
polarization: Pol = "u",
) -> Array:
return self.compute_rtRTA(wavelength_um, aoi_rad, polarization)["A"]
[docs]
def reflectance_nm_deg(
self,
wavelength_nm: float | Array,
aoi_deg: float | Array = 0.0,
polarization: Pol = "u",
) -> Array:
return self.compute_rtRAT_nm_deg(wavelength_nm, aoi_deg, polarization)["R"]
[docs]
def transmittance_nm_deg(
self,
wavelength_nm: float | Array,
aoi_deg: float | Array = 0.0,
polarization: Pol = "u",
) -> Array:
return self.compute_rtRAT_nm_deg(wavelength_nm, aoi_deg, polarization)["T"]
[docs]
def absorptance_nm_deg(
self,
wavelength_nm: float | Array,
aoi_deg: float | Array = 0.0,
polarization: Pol = "u",
) -> Array:
return self.compute_rtRAT_nm_deg(wavelength_nm, aoi_deg, polarization)["A"]
[docs]
def RTA(
self,
wavelength_um: float | Array,
aoi_rad: float | Array = 0.0,
polarization: Pol = "u",
) -> tuple[Array, Array, Array]:
"""Return (R, T, A) for given wavelength(s) in µm and AOI(s) in radians."""
rta_data = self.compute_rtRTA(wavelength_um, aoi_rad, polarization)
return (
rta_data["R"],
rta_data["T"],
rta_data["A"],
)
[docs]
def RTA_nm_deg(
self,
wavelength_nm: float | Array,
aoi_deg: float | Array = 0.0,
polarization: Pol = "u",
) -> tuple[Array, Array, Array]:
"""Return (R, T, A) for given wavelength(s) in nm and AOI(s) in degrees."""
rta_data = self.compute_rtRAT_nm_deg(wavelength_nm, aoi_deg, polarization)
return (
rta_data["R"],
rta_data["T"],
rta_data["A"],
)
# ----- insertion / removal helpers -----
[docs]
def insert_layer(
self,
index: int,
material: BaseMaterial,
thickness_um: float,
name: str | None = None,
) -> ThinFilmStack:
"""Insert a layer at an arbitrary position.
Args:
index: Position to insert at (0 = closest to incident).
material: Optiland material providing n(λ), k(λ).
thickness_um: Thickness in microns (µm).
name: Optional label.
Returns:
self for chaining.
"""
self.layers.insert(index, Layer(material, thickness_um, name))
return self
[docs]
def insert_layer_nm(
self,
index: int,
material: BaseMaterial,
thickness_nm: float,
name: str | None = None,
) -> ThinFilmStack:
"""Insert a layer at an arbitrary position, thickness in nm.
Args:
index: Position to insert at (0 = closest to incident).
material: Optiland material providing n(λ), k(λ).
thickness_nm: Thickness in nanometers.
name: Optional label.
Returns:
self for chaining.
"""
return self.insert_layer(index, material, thickness_nm / 1000.0, name)
[docs]
def remove_layer(self, index: int) -> Layer:
"""Remove and return the layer at *index*.
Args:
index: Index of the layer to remove.
Returns:
The removed Layer.
"""
return self.layers.pop(index)
[docs]
def split_layer(self, layer_index: int, position_fraction: float) -> ThinFilmStack:
"""Split a layer into two layers of the same material.
The original layer at *layer_index* is replaced by two layers whose
combined thickness equals the original. Useful for needle insertion
*within* a layer.
Args:
layer_index: Index of the layer to split.
position_fraction: Fraction (0..1) at which to split. 0.3 means
the first sub-layer gets 30 % of the original thickness.
Returns:
self for chaining.
"""
if not 0.0 < position_fraction < 1.0:
raise ValueError("position_fraction must be strictly between 0 and 1")
layer = self.layers[layer_index]
t1 = layer.thickness_um * position_fraction
t2 = layer.thickness_um * (1.0 - position_fraction)
self.layers[layer_index] = Layer(layer.material, t1, layer.name)
self.layers.insert(layer_index + 1, Layer(layer.material, t2, layer.name))
return self
[docs]
def deep_copy(self) -> ThinFilmStack:
"""Create a deep copy with new Layer instances (materials are shared).
Returns:
A new ThinFilmStack with independent layers.
"""
new_layers = [
Layer(layer.material, layer.thickness_um, layer.name)
for layer in self.layers
]
return ThinFilmStack(
incident_material=self.incident_material,
substrate_material=self.substrate_material,
layers=new_layers,
reference_wl_um=self.reference_wl_um,
reference_AOI_deg=self.reference_AOI_deg,
)
def __len__(self):
return len(self.layers)
def __repr__(self):
parts = [layer.name or f"Layer({i})" for i, layer in enumerate(self.layers)]
return f"ThinFilmStack({len(self.layers)} layers: " + " -> ".join(parts) + ")"
[docs]
def plot_structure(self, ax: plt.Axes = None) -> tuple[plt.Figure, plt.Axes]:
"""Plots a schematic representation of the thin film stack structure.
This method visualizes the stack as a series of colored rectangles, each
representing a material layer, the substrate, and the incident medium.
Each rectangle's height corresponds to the physical thickness of the
layer (in micrometers), and colors are assigned uniquely to each material.
The substrate is plotted at the bottom, followed by the stack layers,
and the incident medium at the top. Material names or refractive indices
are used as labels in the legend.
Args:
ax (plt.Axes, optional): The axes on which to plot the structure.
If None, a new figure and axes are created.
Returns:
tuple[plt.Figure, plt.Axes]: The matplotlib Figure and Axes objects
containing the plot.
"""
if ax is None:
fig, ax = plt.subplots()
import matplotlib.colors as mcolors
color_cycle = list(mcolors.TABLEAU_COLORS.values())
def _to_float(value) -> float:
if hasattr(value, "detach") and hasattr(value, "cpu"):
value = value.detach().cpu()
if hasattr(value, "item"):
return float(value.item())
return float(value)
def _get_name(obj):
"""
Get the name of a material or layer, or its refractive index if it's an
IdealMaterial. Because IdealMaterial may not have a name, we use its
refractive index for labeling.
"""
name = getattr(obj, "name", "") or ""
if isinstance(obj, IdealMaterial):
name = f"$n$ = {_to_float(obj.index[0])}"
return name
def _add_rect(y, height, color, label, text=None):
y = _to_float(y)
height = _to_float(height)
ax.add_patch(
plt.Rectangle((0, y), 1, height, color=color, label=label, alpha=0.7)
)
if text is not None:
ax.text(
0.5,
y + height / 2,
text,
ha="center",
va="center",
fontsize=10,
rotation=0,
)
material_names = (
[_get_name(self.incident_material)]
+ [_get_name(layer.material) for layer in self.layers]
+ [_get_name(self.substrate_material)]
)
unique_materials = {
name: color_cycle[i % len(color_cycle)]
for i, name in enumerate(dict.fromkeys(material_names))
}
total_layer_thickness = sum(layer.thickness_um for layer in self.layers)
total_layer_thickness = _to_float(total_layer_thickness)
# Ensure minimum thickness for visualization (avoid singular ylim
# on empty stacks)
if total_layer_thickness == 0:
total_layer_thickness = 1.0
incident_thickness = 0.08 * total_layer_thickness
substrate_thickness = 0.08 * total_layer_thickness
y = -substrate_thickness
# Substrate (bottom, negative y)
_add_rect(
y,
substrate_thickness,
unique_materials[_get_name(self.substrate_material)],
label=_get_name(self.substrate_material),
text=_get_name(self.substrate_material),
)
y = 0
# Layers (middle, positive y)
for _, layer in enumerate(self.layers):
color = unique_materials[_get_name(layer.material)]
label = layer.name or _get_name(layer.material)
if label:
label = re.sub(r"\d+", lambda m: str(int(m.group())), label)
_add_rect(
y,
layer.thickness_um,
color,
label=label,
text=None,
)
y += _to_float(layer.thickness_um)
# Incident medium (top)
_add_rect(
y,
incident_thickness,
unique_materials[_get_name(self.incident_material)],
label=_get_name(self.incident_material),
text=_get_name(self.incident_material),
)
ax.set_xlim(0, 1)
ax.set_ylim(-substrate_thickness, y + incident_thickness)
ax.set_ylabel("Thickness (µm)")
ax.set_xticks([])
handles, labels = ax.get_legend_handles_labels()
by_label = dict(zip(labels, handles, strict=False))
ax.legend(
by_label.values(),
by_label.keys(),
loc="center left",
bbox_to_anchor=(1.05, 0.5),
borderaxespad=0.0,
ncol=1,
)
fig = ax.figure
return fig, ax
[docs]
def plot_structure_thickness(
self, ax: plt.Axes = None
) -> tuple[plt.Figure, plt.Axes]:
"""
Plots the thickness of each layer in the thin film stack as a bar chart.
Each bar represents a layer, with its height corresponding to the layer's
thickness in nanometers.
Bars are colored according to the material of each layer, and a legend is
provided to identify materials.
Args:
ax (plt.Axes, optional): The matplotlib Axes object to plot on.
If None, a new figure and axes will be created.
Returns:
tuple[plt.Figure, plt.Axes]: The matplotlib Figure and Axes objects
containing the plot.
"""
if ax is None:
fig, ax = plt.subplots()
import matplotlib.colors as mcolors
ax.grid(True, alpha=0.3)
color_cycle = list(mcolors.TABLEAU_COLORS.values())
def _to_float(value) -> float:
if hasattr(value, "detach") and hasattr(value, "cpu"):
value = value.detach().cpu()
if hasattr(value, "item"):
return float(value.item())
return float(value)
def _get_name(obj):
name = getattr(obj, "name", "") or ""
if isinstance(obj, IdealMaterial):
name = f"$n$ = {_to_float(obj.index[0])}"
return name
material_names = [_get_name(layer.material) for layer in self.layers]
unique_materials = {
name: color_cycle[i % len(color_cycle)]
for i, name in enumerate(dict.fromkeys(material_names))
}
colors = [unique_materials[_get_name(layer.material)] for layer in self.layers]
thicknesses_nm = [_to_float(layer.thickness_um * 1000) for layer in self.layers]
labels = [layer.name or _get_name(layer.material) for layer in self.layers]
indices = list(range(len(self.layers)))
bars = ax.bar(
indices,
thicknesses_nm,
color=colors,
edgecolor=None,
alpha=0.7,
width=1,
)
ax.set_xlabel("Layer index")
ax.set_ylabel("Thickness (nm)")
# Legend
by_label = {}
for bar, label in zip(bars, labels, strict=False):
if label not in by_label:
by_label[label] = bar
ax.legend(
by_label.values(),
by_label.keys(),
loc="center left",
bbox_to_anchor=(1.05, 0.5),
borderaxespad=0.0,
ncol=1,
)
ax.set_xlim(0.5, len(self.layers) - 0.5)
fig = ax.figure
return fig, ax