Source code for analysis.jones_pupil

"""Jones Pupil Analysis

This module provides a Jones pupil analysis for optical systems.

Kramer Harrison, 2025
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np

import optiland.backend as be
from optiland.analysis.base import BaseAnalysis
from optiland.rays import PolarizationState

if TYPE_CHECKING:
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure

    from optiland.optic import Optic


[docs] class JonesPupil(BaseAnalysis): """Generates and plots Jones pupil maps. This class computes the spatially resolved Jones matrix at the exit pupil (or image plane) as a function of normalized pupil coordinates. It visualizes the real and imaginary parts of the Jones matrix elements (Jxx, Jxy, Jyx, Jyy). Attributes: optic: Instance of the optic object to be assessed. Attributes: optic: Instance of the optic object to be assessed. field: Field at which data is generated (Hx, Hy). wavelengths: Wavelengths at which data is generated. grid_size: The side length of the square grid of rays (NxN). data: Contains Jones matrix data in a list, ordered by wavelength. """ def __init__( self, optic: Optic, field: tuple[float, float] = (0, 0), wavelengths: str | list = "all", grid_size: int = 65, ): """Initializes the JonesPupil analysis. Args: optic: An instance of the optic object to be assessed. field: The normalized field coordinates (Hx, Hy) at which to generate data. Defaults to (0, 0). wavelengths: Wavelengths at which to generate data. If 'all', all defined wavelengths are used. Defaults to "all". grid_size: The number of points along one dimension of the pupil grid. Defaults to 65. """ self.field = field self.grid_size = grid_size super().__init__(optic, wavelengths)
[docs] def view( self, fig_to_plot_on: Figure | None = None, figsize: tuple[float, float] = (16, 8), ) -> tuple[Figure, list[Axes]]: """Displays the Jones pupil plots. Args: fig_to_plot_on: An existing Matplotlib figure to plot on. If None, a new figure is created. Defaults to None. figsize: The figure size for the output window. Defaults to (16, 8). Returns: A tuple containing the Matplotlib figure and a list of its axes. """ # Select primary wavelength index wl_idx = 0 primary_wl = self.optic.primary_wavelength wl_values = [wp.value for wp in self.wavelengths] if primary_wl in wl_values: wl_idx = wl_values.index(primary_wl) data_fw = self.data[wl_idx] if fig_to_plot_on: fig = fig_to_plot_on fig.clear() else: fig = plt.figure(figsize=figsize) # 2 rows (Real, Imag), 4 columns (Jxx, Jxy, Jyx, Jyy) axs = fig.subplots(2, 4, sharex=True, sharey=True) # Elements to plot elements = [ ("Jxx", data_fw["J"][:, 0, 0]), ("Jxy", data_fw["J"][:, 0, 1]), ("Jyx", data_fw["J"][:, 1, 0]), ("Jyy", data_fw["J"][:, 1, 1]), ] px = be.to_numpy(data_fw["Px"]).reshape(self.grid_size, self.grid_size) py = be.to_numpy(data_fw["Py"]).reshape(self.grid_size, self.grid_size) mask = px**2 + py**2 <= 1.0 for col, (name, values) in enumerate(elements): val_np = be.to_numpy(values).reshape(self.grid_size, self.grid_size) val_np[~mask] = np.nan # Real part ax_real = axs[0, col] im_real = ax_real.pcolormesh( px, py, np.real(val_np), shading="nearest", cmap="viridis" ) ax_real.set_title(f"Re({name})") ax_real.set_aspect("equal") fig.colorbar(im_real, ax=ax_real, fraction=0.046, pad=0.04) # Imag part ax_imag = axs[1, col] im_imag = ax_imag.pcolormesh( px, py, np.imag(val_np), shading="nearest", cmap="viridis" ) ax_imag.set_title(f"Im({name})") ax_imag.set_aspect("equal") fig.colorbar(im_imag, ax=ax_imag, fraction=0.046, pad=0.04) # Labels for ax in axs[:, 0]: ax.set_ylabel("Py") for ax in axs[-1, :]: ax.set_xlabel("Px") field_val = self.field wl_val = self.wavelengths[wl_idx].value fig.suptitle(f"Jones Pupil - Field: {field_val}, Wavelength: {wl_val:.4f} µm") fig.tight_layout() return fig, fig.get_axes()
def _generate_data(self): """Generates Jones matrix data for all fields and wavelengths.""" # Generate pupil grid x = be.linspace(-1.0, 1.0, self.grid_size) y = be.linspace(-1.0, 1.0, self.grid_size) Px_grid, Py_grid = be.meshgrid(x, y) Px = Px_grid.flatten() Py = Py_grid.flatten() data = [] Hx, Hy = self.field for wp in self.wavelengths: data.append(self._generate_single_data(Hx, Hy, Px, Py, wp.value)) return data def _generate_single_data(self, Hx, Hy, Px, Py, wavelength): """Generates data for a single field and wavelength configuration.""" # Handle polarization state original_pol = self.optic.polarization if original_pol == "ignore": # Temporarily enable polarization to get PolarizedRays self.optic.updater.set_polarization(PolarizationState()) try: rays = self.optic.trace_generic( Hx=Hx, Hy=Hy, Px=Px, Py=Py, wavelength=wavelength ) finally: if original_pol == "ignore": self.optic.updater.set_polarization("ignore") if not hasattr(rays, "p"): # Fallback if rays are not polarized (should not happen w/ check above) raise RuntimeError("Ray tracing did not return polarized rays.") # Ray direction vectors (normalized) k = be.stack([rays.L, rays.M, rays.N], axis=1) # Normalize k (should be already, but to be safe) k_norm = be.linalg.norm(k, axis=1) k = k / be.unsqueeze_last(k_norm) # Construct local basis vectors (Standard Polar Projection / Dipole-like) # v ~ Y-axis: perpendicular to k and X=[1,0,0] x_axis = be.array([1.0, 0.0, 0.0]) # Broadcast x_axis to match k shape x_axis = be.broadcast_to(x_axis, k.shape) v = be.cross(k, x_axis) v_norm = be.linalg.norm(v, axis=1) # Avoid division by zero v = v / be.unsqueeze_last(v_norm + 1e-15) # u ~ X-axis: perpendicular to v and k u = be.cross(v, k) u_norm = be.linalg.norm(u, axis=1) # Avoid division by zero u = u / be.unsqueeze_last(u_norm + 1e-15) # Project global P onto local basis (u, v) # Jxx = u . (P . x_in) # Jxy = u . (P . y_in) # Jyx = v . (P . x_in) # Jyy = v . (P . y_in) # p has shape (N, 3, 3) # P . x_in is simply the first column of p # P . y_in is simply the second column of p P_x_in = rays.p[:, :, 0] # Shape (N, 3) P_y_in = rays.p[:, :, 1] # Shape (N, 3) # Dot products Jxx = be.sum(u * P_x_in, axis=1) Jxy = be.sum(u * P_y_in, axis=1) Jyx = be.sum(v * P_x_in, axis=1) Jyy = be.sum(v * P_y_in, axis=1) # Stack into (N, 2, 2) row1 = be.stack([Jxx, Jxy], axis=1) row2 = be.stack([Jyx, Jyy], axis=1) J = be.stack([row1, row2], axis=1) return {"Px": Px, "Py": Py, "J": J}