Source code for analysis.spot_diagram.reference

"""Spot Diagram Reference Strategies

This module provides configurable reference point strategies for centering
spot diagram data. The center of calculation can be either the chief ray
intersection or the geometric centroid of the traced rays.

Kramer Harrison, 2025
"""

from __future__ import annotations

import abc
from enum import Enum
from typing import TYPE_CHECKING

import optiland.backend as be

if TYPE_CHECKING:
    from optiland._types import BEArray
    from optiland.optic import Optic

    from .core import SpotData


# TODO: Update to StrEnum when Python 3.10 support is dropped
[docs] class SpotReferenceType(str, Enum): """Defines the available reference point types for spot centering.""" CHIEF_RAY = "chief_ray" CENTROID = "centroid"
class SpotReferenceStrategy(abc.ABC): """Abstract base class for spot diagram reference strategies.""" @abc.abstractmethod def get_centers( self, data: list[list[SpotData]], ref_wavelength_index: int, optic: Optic, fields: list[tuple[float, float]], wavelength: float, coordinates: str, ) -> list[tuple[BEArray, BEArray]]: """Computes the (x, y) center for each field. Args: data: Nested list of spot data, ordered [field][wavelength]. ref_wavelength_index: Index of the reference wavelength. optic: The optical system being analyzed. fields: List of (Hx, Hy) field coordinates. wavelength: Reference wavelength value in micrometers. coordinates: Coordinate system ('global' or 'local'). Returns: A list of (x, y) center tuples, one per field. """ class CentroidReference(SpotReferenceStrategy): """Centers spots using the geometric centroid of rays at the reference wavelength.""" def get_centers( self, data: list[list[SpotData]], ref_wavelength_index: int, optic: Optic, fields: list[tuple[float, float]], wavelength: float, coordinates: str, ) -> list[tuple[BEArray, BEArray]]: """Computes centroids from the mean of ray intersections.""" return [ ( be.mean(field_data[ref_wavelength_index].x), be.mean(field_data[ref_wavelength_index].y), ) for field_data in data ] class ChiefRayReference(SpotReferenceStrategy): """Centers spots using the chief ray intersection at the reference wavelength.""" def get_centers( self, data: list[list[SpotData]], ref_wavelength_index: int, optic: Optic, fields: list[tuple[float, float]], wavelength: float, coordinates: str, ) -> list[tuple[BEArray, BEArray]]: """Computes centers from chief ray (Px=0, Py=0) intersections.""" from optiland.visualization.system.utils import transform centers = [] for H_x, H_y in fields: ray = optic.trace_generic(Hx=H_x, Hy=H_y, Px=0, Py=0, wavelength=wavelength) if coordinates == "local": x, y, _ = transform( ray.x, ray.y, ray.z, optic.image_surface, is_global=True ) centers.append((x.ravel()[0], y.ravel()[0])) else: centers.append((ray.x.ravel()[0], ray.y.ravel()[0])) return centers def create_reference_strategy( reference: str | SpotReferenceType, ) -> SpotReferenceStrategy: """Factory function to create a reference strategy from a type string. Args: reference: The reference type, either a SpotReferenceType enum value or a string ('chief_ray' or 'centroid'). Returns: The corresponding SpotReferenceStrategy instance. Raises: ValueError: If the reference type is not recognized. """ ref = SpotReferenceType(reference) strategies = { SpotReferenceType.CHIEF_RAY: ChiefRayReference, SpotReferenceType.CENTROID: CentroidReference, } return strategies[ref]()