"""Base Field Definition Module
This module defines the abstract base class for field types in optical systems.
Kramer Harrison, 2025
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, ClassVar
if TYPE_CHECKING:
from optiland import Optic
from optiland._types import BEArray, ScalarOrArray
[docs]
class BaseFieldDefinition(ABC):
"""Abstract base class for defining how fields map to ray properties."""
_registry: ClassVar[dict[str, type[BaseFieldDefinition]]] = {}
[docs]
@classmethod
def register(cls, name: str):
"""Class decorator to register a field type by name.
Args:
name: The string key used to look up this field type.
Returns:
A decorator that registers the subclass and returns it unchanged.
"""
def decorator(subclass: type[BaseFieldDefinition]) -> type[BaseFieldDefinition]:
cls._registry[name] = subclass
return subclass
return decorator
[docs]
@classmethod
def create(cls, field_type: str) -> BaseFieldDefinition:
"""Instantiate a field definition by its registered name.
Args:
field_type: The registered name of the field type.
Returns:
A new instance of the corresponding field definition.
Raises:
ValueError: If ``field_type`` is not in the registry.
"""
if field_type not in cls._registry:
raise ValueError(f"Invalid field type: {field_type}.")
return cls._registry[field_type]()
[docs]
@abstractmethod
def get_ray_origins(
self,
optic: Optic,
Hx: ScalarOrArray,
Hy: ScalarOrArray,
Px: ScalarOrArray,
Py: ScalarOrArray,
vx: ScalarOrArray,
vy: ScalarOrArray,
) -> tuple[ScalarOrArray, ScalarOrArray, ScalarOrArray]:
"""Calculate the initial positions for rays originating at the object.
Args:
Hx: Normalized x field coordinate.
Hy: Normalized y field coordinate.
Px: x-coordinate of the pupil point.
Py: y-coordinate of the pupil point.
vx: Vignetting factor in the x-direction.
vy: Vignetting factor in the y-direction.
Returns:
A tuple containing the x, y, and z coordinates of the
object position.
"""
pass # pragma: no cover
[docs]
@abstractmethod
def get_paraxial_object_position(
self, optic: Optic, Hy: ScalarOrArray, y1: ScalarOrArray, EPL: ScalarOrArray
) -> tuple[BEArray, BEArray]:
"""Calculate the position of the object in the paraxial optical system.
Args:
Hy: The normalized field height.
y1: The initial y-coordinate of the ray.
EPL: The entrance pupil location.
Returns:
A tuple containing the y and z coordinates of the object
position.
"""
pass # pragma: no cover
[docs]
@abstractmethod
def scale_chief_ray_for_field(
self,
optic: Optic,
y_obj_unit: ScalarOrArray,
u_obj_unit: ScalarOrArray,
y_img_unit: ScalarOrArray,
) -> ScalarOrArray:
"""Calculates the scaling factor for a unit chief ray based on the field
definition.
This is used in the paraxial chief_ray calculation. It uses the results
of a forward and backward "unit" trace from the stop to determine the
final scaling factor.
Args:
optic: The optical system.
y_obj_unit: The object-space height of the unit ray.
u_obj_unit: The object-space angle of the unit ray.
y_img_unit: The image-space height of the unit ray.
Returns:
The scaling factor.
"""
pass # pragma: no cover
[docs]
def to_dict(self) -> dict:
"""Convert the field definition to a dictionary.
Returns:
dict: A dictionary representation of the field definition.
"""
return {"field_type": self.__class__.__name__}
[docs]
@classmethod
def from_dict(cls, field_def_dict: dict) -> BaseFieldDefinition:
"""Create a field definition from a dictionary.
Args:
field_def_dict (dict): A dictionary representation of the field
definition.
Returns:
BaseFieldDefinition: A field definition object created from the
dictionary.
Raises:
ValueError: If ``field_type`` is missing or not in the registry.
"""
if "field_type" not in field_def_dict:
raise ValueError("Missing required keys: field_type")
# Ensure subclasses are imported so their @register decorators run.
from optiland.fields.field_types import ( # noqa: F401
AngleField,
ObjectHeightField,
ParaxialImageHeightField,
RealImageHeightField,
)
class_name = field_def_dict["field_type"]
# Registry keys are class names (e.g. "AngleField"); look up by name.
for _key, klass in cls._registry.items():
if klass.__name__ == class_name:
return klass()
raise ValueError(f"Unknown field definition: {class_name}")