Source code for surfaces.factories.interaction_model_factory

"""Interaction Model Factory

This module contains the InteractionModelFactory class, which is used to create
interaction model objects based on the given parameters.

Kramer Harrison, 2025
"""

from __future__ import annotations

from typing import TYPE_CHECKING

from optiland.interactions.diffractive_model import DiffractiveInteractionModel
from optiland.interactions.phase_interaction_model import PhaseInteractionModel
from optiland.interactions.refractive_reflective_model import RefractiveReflectiveModel
from optiland.interactions.thin_lens_interaction_model import ThinLensInteractionModel

if TYPE_CHECKING:
    from collections.abc import Callable  # pragma: no cover

    from optiland.coatings import BaseCoating
    from optiland.interactions.base import BaseInteractionModel
    from optiland.scatter import BaseBSDF
    from optiland.surfaces import Surface


def _build_refractive_reflective(
    parent_surface: Surface | None,
    is_reflective: bool,
    coating: BaseCoating | None,
    bsdf: BaseBSDF | None,
    **_,
) -> RefractiveReflectiveModel:
    return RefractiveReflectiveModel(
        parent_surface=parent_surface,
        is_reflective=is_reflective,
        coating=coating,
        bsdf=bsdf,
    )


def _build_thin_lens(
    parent_surface: Surface | None,
    is_reflective: bool,
    coating: BaseCoating | None,
    bsdf: BaseBSDF | None,
    focal_length: float | None = None,
    **_,
) -> ThinLensInteractionModel:
    if focal_length is None:
        raise ValueError("Focal length is required for thin lens.")
    return ThinLensInteractionModel(
        parent_surface=parent_surface,
        focal_length=focal_length,
        is_reflective=is_reflective,
        coating=coating,
        bsdf=bsdf,
    )


def _build_diffractive(
    parent_surface: Surface | None,
    is_reflective: bool,
    coating: BaseCoating | None,
    bsdf: BaseBSDF | None,
    **_,
) -> DiffractiveInteractionModel:
    return DiffractiveInteractionModel(
        parent_surface=parent_surface,
        is_reflective=is_reflective,
        coating=coating,
        bsdf=bsdf,
    )


def _build_phase(
    parent_surface: Surface | None,
    is_reflective: bool,
    coating: BaseCoating | None,
    bsdf: BaseBSDF | None,
    phase_profile=None,
    **_,
) -> PhaseInteractionModel:
    if phase_profile is None:
        raise ValueError("phase_profile is required for phase interaction.")
    return PhaseInteractionModel(
        parent_surface=parent_surface,
        phase_profile=phase_profile,
        is_reflective=is_reflective,
        coating=coating,
        bsdf=bsdf,
    )


_INTERACTION_REGISTRY: dict[str, Callable] = {
    "refractive_reflective": _build_refractive_reflective,
    "thin_lens": _build_thin_lens,
    "diffractive": _build_diffractive,
    "phase": _build_phase,
}


[docs] class InteractionModelFactory: """A factory class for creating interaction model objects."""
[docs] @classmethod def register( cls, name: str, builder: Callable, *, overwrite: bool = False, ) -> None: """Register a new interaction model builder. Args: name: The string key used when specifying interaction_type. builder: A callable with signature ``(parent_surface, is_reflective, coating, bsdf, **kwargs)`` that returns a ``BaseInteractionModel`` instance. overwrite: Allow replacing an existing registration. Raises: ValueError: If name is already registered and overwrite is False. """ if name in _INTERACTION_REGISTRY and not overwrite: raise ValueError( f"Interaction model '{name}' is already registered. " "Pass overwrite=True to replace it." ) _INTERACTION_REGISTRY[name] = builder
[docs] def create( self, parent_surface: Surface | None, interaction_type: str, is_reflective: bool, coating: BaseCoating | None, bsdf: BaseBSDF | None, **kwargs, ) -> BaseInteractionModel: """Creates an interaction model object based on the given parameters. Args: parent_surface: The parent surface (hooked up later in Surface.__init__). interaction_type (str): The type of interaction model to create. is_reflective (bool): Indicates whether the surface is reflective. coating (Optional[BaseCoating]): The coating of the surface. bsdf (Optional[BaseBSDF]): The BSDF of the surface. **kwargs: Additional keyword arguments forwarded to the builder (e.g. ``focal_length`` for thin_lens, ``phase_profile`` for phase). Returns: BaseInteractionModel: The created interaction model object. Raises: ValueError: If the interaction_type is unknown. """ if interaction_type not in _INTERACTION_REGISTRY: raise ValueError(f"Unknown interaction_type: {interaction_type!r}") builder = _INTERACTION_REGISTRY[interaction_type] return builder( parent_surface=parent_surface, is_reflective=is_reflective, coating=coating, bsdf=bsdf, **kwargs, )