"""Wavelength Module
This module defines the `Wavelength` and `WavelengthGroup` classes for
managing wavelengths in optical simulations. The `Wavelength` class represents
a single wavelength, allowing for its value to be defined in various units and
converted to microns for internal consistency. The `WavelengthGroup` class
manages collections of `Wavelength` objects, providing functionality to work
with multiple wavelengths simultaneously.
Kramer Harrison, 2024
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import optiland.backend as be
if TYPE_CHECKING:
from optiland._types import ScalarOrArray
[docs]
class Wavelength:
"""Represents a wavelength value with support for unit conversion.
Methods:
_convert_to_um(): Converts the wavelength value to microns.
"""
def __init__(
self,
value: ScalarOrArray,
is_primary: bool = True,
unit: str = "um",
weight: float = 1.0,
):
"""Initializes a Wavelength instance
Args:
value (float): The value of the wavelength.
is_primary (bool): Indicates whether the wavelength is a primary
wavelength.
unit (str): The unit of the wavelength value. Defaults to 'um'.
weight (float): Non-negative relative importance scalar. A weight
of 0.0 excludes the wavelength from optimization and weighted
analysis. Defaults to 1.0.
"""
self._value = value
self.is_primary = is_primary
self._unit = unit.lower()
self._value_in_um = self._convert_to_um()
self.weight = weight # uses the validated setter
@property
def weight(self) -> float:
"""float: Non-negative relative importance scalar for this wavelength."""
return self._weight
@weight.setter
def weight(self, value: float) -> None:
if value < 0:
raise ValueError(f"Wavelength weight must be non-negative, got {value}.")
self._weight = float(value)
@property
def value(self) -> float:
"""float: the value of the wavelength"""
return self._value_in_um
@property
def unit(self) -> str:
"""str: the unit of the wavelength"""
return "um"
@unit.setter
def unit(self, new_unit: str):
"""Sets the unit of the wavelength.
Args:
new_unit: The new unit to set for the wavelength.
"""
self._unit = new_unit.lower()
self._value_in_um = self._convert_to_um()
[docs]
def _convert_to_um(self) -> float:
"""Converts the wavelength value to micrometers (um) based on the
current unit.
Returns:
float: The converted wavelength value in micrometers.
Raises:
ValueError: If the current unit is not supported for conversion to
micrometers. Supported units: 'nm', 'um', 'mm', 'cm', 'm'.
"""
unit_conversion = {"nm": 0.001, "um": 1, "mm": 1000, "cm": 10000, "m": 1000000}
if self._unit in unit_conversion:
conversion_factor = unit_conversion[self._unit]
return self._value * conversion_factor
raise ValueError("Unsupported unit for conversion to microns.")
[docs]
def to_dict(self) -> dict:
"""Get a dictionary representation of the wavelength.
Returns:
A dictionary representation of the wavelength.
"""
return {
"value": self._value,
"is_primary": self.is_primary,
"unit": self._unit,
"weight": self.weight,
}
[docs]
@classmethod
def from_dict(cls, data: dict) -> Wavelength:
"""Create a Wavelength instance from a dictionary representation.
Args:
data: A dictionary containing the wavelength data.
Returns:
A new Wavelength instance created from the data.
"""
required_keys = {"value", "is_primary", "unit"}
if not required_keys.issubset(data):
missing = required_keys - data.keys()
raise ValueError(f"Missing required keys: {missing}")
return cls(
value=data["value"],
is_primary=data["is_primary"],
unit=data["unit"],
weight=data.get("weight", 1.0),
)
[docs]
class WavelengthGroup:
"""Represents a group of wavelengths, each with an optional weight.
Attributes:
wavelengths (list): A list of Wavelength objects.
Methods:
num_wavelengths(): Returns the number of wavelengths in the group.
primary_index(): Returns the index of the primary wavelength.
primary_wavelength(): Returns the primary wavelength.
add_wavelength(value, is_primary=True, unit='um'): Adds a new
wavelength to the group.
get_wavelength(wavelength_number): Returns the value of a specific
wavelength.
get_wavelengths(): Returns a list of all the wavelength values in the
group.
"""
def __init__(self):
self.wavelengths: list[Wavelength] = []
@property
def weights(self) -> tuple[float, ...]:
"""The weights of the wavelengths"""
return tuple(wave.weight for wave in self.wavelengths)
@property
def num_wavelengths(self) -> int:
"""The number of wavelengths"""
return len(self.wavelengths)
def __getitem__(self, index):
return self.wavelengths[index]
def __iter__(self):
return iter(self.wavelengths)
def __len__(self):
return len(self.wavelengths)
@property
def primary_index(self) -> int:
"""The index of the primary wavelength
raises:
StopIteration: If no primary wavelength is found
"""
return next(i for i, w in enumerate(self.wavelengths) if w.is_primary)
@primary_index.setter
def primary_index(self, index: int):
"""set the wavelength indexed by `index` as primary"""
if not 0 <= index < len(self.wavelengths):
raise ValueError("Index out of range")
for idx, wavelength in enumerate(self.wavelengths):
wavelength.is_primary = idx == index
@property
def primary_wavelength(self) -> Wavelength:
"""The primary wavelength"""
return self.wavelengths[self.primary_index]
[docs]
def add(
self,
value: float,
is_primary: bool = False,
unit: str = "um",
weight: float = 1.0,
):
"""Adds a new wavelength to the list of wavelengths.
Args:
value: The value of the wavelength.
is_primary: Indicates if the wavelength is primary. Default is True.
unit: The unit of the wavelength. Default is 'um'.
weight: The weight of the wavelength. Default is 1.0.
"""
if is_primary:
for wavelength in self.wavelengths:
wavelength.is_primary = False
if self.num_wavelengths == 0:
is_primary = True
self.wavelengths.append(Wavelength(value, is_primary, unit, weight))
[docs]
def remove(self, index: int) -> None:
"""Remove a wavelength from the group.
Args:
index: The index of the wavelength to remove.
"""
self.wavelengths.pop(index)
[docs]
def get_wavelength(self, wavelength_number: int) -> float:
"""Get the value of a specific wavelength.
Args:
wavelength_number: The index of the desired wavelength.
Returns:
The value of the specified wavelength.
"""
return self.wavelengths[wavelength_number].value
[docs]
def get_wavelengths(self) -> list[float]:
"""Returns a list of wavelength values.
Returns:
A list of wavelength values.
"""
return [wave.value for wave in self.wavelengths]
[docs]
def to_dict(self) -> dict:
"""Get a dictionary representation of the wavelength group.
Returns:
A dictionary representation of the wavelength group.
"""
return {"wavelengths": [wave.to_dict() for wave in self.wavelengths]}
[docs]
@classmethod
def from_dict(cls, data) -> WavelengthGroup:
"""Create a WavelengthGroup instance from a dictionary representation.
Args:
data: A dictionary containing the wavelength group data.
Returns:
A new WavelengthGroup instance created from the
data.
"""
if "wavelengths" not in data:
raise ValueError('Missing required key: "wavelengths"')
new_group = cls()
for wave_data in data["wavelengths"]:
new_group.add(**wave_data)
return new_group
[docs]
def add_wavelengths(
wavelength_group: WavelengthGroup,
min_value: float,
max_value: float,
num_wavelengths: int,
unit: str = "um",
*,
sampling: str = "chebyshev",
scale: str = "log",
):
"""Add new wavelengths corresponding to the geometrically-spaced Chebyshev nodes
Args:
min_value: Minimum wavelength value.
max_value: Maximum wavelength value.
num_wavelengths: The number of wavelengths to be added.
Has to be an odd integer.
unit: The unit of the wavelength. Default is 'um'.
sampling: The sampling algorithm used. Defaults to 'chebyshev'.
Currently supported options are:
'chebyshev' - chebyshev nodes of the first type
'uniform' - uniformly spaced nodes across the specified range
scale: space in which the nodes are sampled. Defaults to 'log'.
Currently supported options are:
'log' - nodes are sampled in the logarithms of wavelength.
'frequency' - nodes sampled in the frequency domain.
'wavelength' - nodes sampled in the frequency domain. Not recommended.
"""
if (
not isinstance(num_wavelengths, int)
or num_wavelengths % 2 == 0
or num_wavelengths <= 0
):
raise ValueError("num_wavelengths must be an odd positive integer")
if min_value <= 0 or max_value <= 0:
raise ValueError("min_value and max_value must be positive")
scale = scale.lower()
if scale in {"freq", "frequency"}:
scale = "frequency"
elif scale in {"wavelength"}:
scale = "wavelength"
elif scale in {"log", "logarithmic"}:
scale = "log"
else:
raise ValueError(f"Unknown scale: {scale!r}")
if scale == "frequency":
power = -1.0
elif scale == "wavelength":
power = 1.0
nodes = be.arange(1.0, num_wavelengths + 1.0)
if sampling == "chebyshev":
nodes = 0.5 * (1.0 - be.cos((2 * nodes - 1) * be.pi / (2 * num_wavelengths)))
elif sampling == "uniform":
nodes = (nodes - 0.5) / num_wavelengths
if scale == "log":
span = be.log2(max_value / min_value)
for i, node in enumerate(nodes):
is_primary = i == num_wavelengths // 2
value = min_value * 2 ** (span * node)
wavelength_group.wavelengths.append(
Wavelength(value, is_primary, unit, 1.0)
)
else:
min_value = min_value**power
max_value = max_value**power
span = max_value - min_value
for i, node in enumerate(nodes):
is_primary = i == num_wavelengths // 2
value = min_value + (span * node)
wavelength_group.wavelengths.append(
Wavelength(value ** (1.0 / power), is_primary, unit, 1.0)
)