"""Distribution Module
This module provides various classes representing 2D pupil distributions.
Kramer Harrison, 2024
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
import numpy as np
import optiland.backend as be
if TYPE_CHECKING:
from collections.abc import Callable
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from optiland._types import BEArray, DistributionType
[docs]
class BaseDistribution(ABC):
"""Base class for distributions.
This class provides a base implementation for generating points and
visualizing the distribution.
Attributes:
x: The x-coordinates of the generated points.
y: The y-coordinates of the generated points.
"""
def __init__(self):
self.x: BEArray = be.empty(0)
self.y: BEArray = be.empty(0)
[docs]
@abstractmethod
def generate_points(self, num_points: int):
"""Generate points based on the distribution.
Args:
num_points (int): The number of points to generate.
"""
# pragma: no cover
[docs]
def view(self) -> tuple[Figure, Axes]:
"""Visualize the distribution.
This method plots the distribution points and a unit circle for
reference.
Returns:
A tuple containing the figure and axes of the plot.
"""
fig, ax = plt.subplots()
ax.plot(be.to_numpy(self.x), be.to_numpy(self.y), "k*")
t = np.linspace(0, 2 * be.pi, 256)
x, y = np.cos(t), np.sin(t)
ax.plot(x, y, "r")
ax.set_xlabel("Normalized Pupil Coordinate X")
ax.set_ylabel("Normalized Pupil Coordinate Y")
ax.axis("equal")
return fig, ax
[docs]
class LineXDistribution(BaseDistribution):
"""A class representing a line distribution along the x-axis.
Generates `num_points` along the x-axis.
Attributes:
positive_only (bool): Flag indicating whether the distribution should
be limited to positive values only.
x: The x-coordinates of the generated points.
y: The y-coordinates of the generated points.
"""
def __init__(self, positive_only: bool = False):
self.positive_only = positive_only
[docs]
def generate_points(self, num_points: int):
"""Generates points along the x-axis based on the specified parameters.
Args:
num_points (int): The number of points to generate.
"""
if self.positive_only:
self.x = be.linspace(0, 1, num_points)
else:
self.x = be.linspace(-1, 1, num_points)
self.y = be.zeros([num_points])
[docs]
class LineYDistribution(BaseDistribution):
"""A class representing a line distribution along the y-axis.
Generates `num_points` along the y-axis.
Attributes:
positive_only: Flag indicating whether the distribution should
be positive-only.
x: The x-coordinates of the generated points.
y: The y-coordinates of the generated points.
"""
def __init__(self, positive_only: bool = False):
self.positive_only = positive_only
[docs]
def generate_points(self, num_points: int):
"""Generates points along the line distribution.
Args:
num_points (int): The number of points to generate.
"""
self.x = be.zeros([num_points])
if self.positive_only:
self.y = be.linspace(0, 1, num_points)
else:
self.y = be.linspace(-1, 1, num_points)
[docs]
class RandomDistribution(BaseDistribution):
"""A class representing a random distribution.
Generates `num_points` random points within the unit disk.
Attributes:
rng (be.Generator): The random number generator from the backend.
x: The x-coordinates of the generated points.
y: The y-coordinates of the generated points.
"""
def __init__(self, seed=None):
self.rng = be.default_rng(seed)
[docs]
def generate_points(self, num_points: int):
"""Generates random points.
Args:
num_points (int): The number of points to generate.
"""
r = be.random_uniform(size=num_points, generator=self.rng)
theta = be.random_uniform(0, 2 * be.pi, size=num_points, generator=self.rng)
self.x = be.sqrt(r) * be.cos(theta)
self.y = be.sqrt(r) * be.sin(theta)
[docs]
class HexagonalDistribution(BaseDistribution):
"""A class representing a hexagonal distribution.
Generates points in a hexagonal pattern. The total number of points is
`1 + 3 * num_rings * (num_rings + 1)`, including the center point.
Attributes:
x: Array of x-coordinates of the generated points.
y: Array of y-coordinates of the generated points.
"""
[docs]
def generate_points(self, num_rings: int = 6):
"""Generate points in a hexagonal distribution.
Args:
num_rings: Number of rings in the hexagonal distribution.
Defaults to 6.
"""
x = be.zeros([1])
y = be.zeros([1])
r = be.linspace(0, 1, num_rings + 1)
for i in range(num_rings):
num_theta = 6 * (i + 1)
theta = be.linspace(0, 2 * be.pi, num_theta + 1)[:-1]
x = be.concatenate([x, r[i + 1] * be.cos(theta)])
y = be.concatenate([y, r[i + 1] * be.sin(theta)])
self.x = x
self.y = y
[docs]
class CrossDistribution(BaseDistribution):
"""A class representing a cross-shaped distribution.
This distribution generates points in the shape of a cross,
with the x-axis and y-axis as the arms of the cross.
If `num_points` is odd, it generates `2 * num_points - 1` points.
If `num_points` is even and positive, it generates `2 * num_points` points.
`num_points` represents the number of points along the full extent of each
axis before potential origin merging. If `num_points` is 0, 0 points are
generated.
Attributes:
x: Array of x-coordinates of the generated points.
y: Array of y-coordinates of the generated points.
"""
[docs]
def generate_points(self, num_points: int):
"""Generate points in the shape of a cross.
Args:
num_points: The number of points to generate in each axis.
"""
# Generate points for the y-axis (vertical line)
y_line_x = be.zeros([num_points])
y_line_y = be.linspace(-1, 1, num_points)
# Generate points for the x-axis (horizontal line)
x_line_x = be.linspace(-1, 1, num_points)
x_line_y = be.zeros([num_points])
# If num_points is odd, linspace(-1, 1, num_points) includes 0 at the midpoint.
# This means (0,0) is part of y_line (x=0, y=0) and x_line (x=0, y=0).
# To avoid duplication, we remove the (0,0) point from the x_line set.
if num_points % 2 == 1:
mid_idx = num_points // 2
# Remove the middle element which corresponds to (0,0) for the x_line
x_line_x = be.concatenate((x_line_x[:mid_idx], x_line_x[mid_idx + 1 :]))
x_line_y = be.concatenate((x_line_y[:mid_idx], x_line_y[mid_idx + 1 :]))
self.x = be.concatenate((y_line_x, x_line_x))
self.y = be.concatenate((y_line_y, x_line_y))
[docs]
class GaussianQuadrature(BaseDistribution):
"""GaussianQuadrature class for generating points and weights for Gaussian
quadrature distribution.
Generates points for Gaussian quadrature. If `is_symmetric` is true,
`num_rings` points are generated. If `is_symmetric` is false,
`3 * num_rings` points are generated.
Attributes:
is_symmetric: Indicates whether the distribution is symmetric about y.
Defaults to False.
Reference:
G. W. Forbes, "Optical system assessment for design: numerical ray
tracing in the Gaussian pupil," J. Opt. Soc. Am. A 5, 1943-1956 (1988)
"""
def __init__(self, is_symmetric=False):
self.is_symmetric = is_symmetric
[docs]
def generate_points(self, num_rings: int):
"""Generate points for Gaussian quadrature distribution.
Args:
num_rings: Number of rings for Gaussian quadrature.
"""
radius = self._get_radius(num_rings)
if self.is_symmetric:
theta = be.array([0.0])
else:
theta = be.array([-1.04719755, 0.0, 1.04719755])
self.x = be.outer(radius, be.cos(theta)).flatten()
self.y = be.outer(radius, be.sin(theta)).flatten()
def _get_radius(self, num_rings: int) -> BEArray:
"""Get the radius values for the given number of rings.
Args:
num_rings: Number of rings for Gaussian quadrature.
Returns:
Radius values for the given number of rings.
Raises:
ValueError: If the number of rings is not between 1 and 6.
"""
radius_dict = {
1: be.array([0.70711]),
2: be.array([0.45970, 0.88807]),
3: be.array([0.33571, 0.70711, 0.94196]),
4: be.array([0.26350, 0.57446, 0.81853, 0.96466]),
5: be.array([0.21659, 0.48038, 0.70711, 0.87706, 0.97626]),
6: be.array([0.18375, 0.41158, 0.61700, 0.78696, 0.91138, 0.98300]),
}
if num_rings not in radius_dict:
raise ValueError("Gaussian quadrature must have between 1 and 6 rings.")
return radius_dict[num_rings]
[docs]
def get_weights(self, num_rings: int) -> BEArray:
"""Get weights for Gaussian quadrature distribution.
Args:
num_rings: Number of rings for Gaussian quadrature.
Returns:
Array of weights.
"""
weights_dict = {
1: be.array([0.5]),
2: be.array([0.25, 0.25]),
3: be.array([0.13889, 0.22222, 0.13889]),
4: be.array([0.08696, 0.16304, 0.16304, 0.08696]),
5: be.array([0.059231, 0.11966, 0.14222, 0.11966, 0.059231]),
6: be.array([0.04283, 0.09019, 0.11698, 0.11698, 0.09019, 0.04283]),
}
if num_rings not in weights_dict:
raise ValueError("Gaussian quadrature must have between 1 and 6 rings.")
weights = weights_dict[num_rings]
weights = weights * 6.0 if self.is_symmetric else weights * 2.0
return weights
[docs]
class RingDistribution(BaseDistribution):
"""RingDistribution class for generating points along a single ring.
Generates `num_points` along a single ring at the maximum aperture value
(radius 1).
"""
[docs]
def generate_points(self, num_points: int):
"""Generate points along a ring at the maximum aperture value.
Args:
num_points (int): The number of points to generate in each ring.
"""
theta = be.linspace(0, 2 * be.pi, num_points + 1)[:-1]
self.x = be.cos(theta)
self.y = be.sin(theta)
[docs]
class SobolDistribution(BaseDistribution):
"""A class representing a Sobol distribution.
Generates `num_points` points using a Sobol low-discrepancy sequence
within the unit disk.
Attributes:
seed (int | None): Seed for the Sobol sequence generator.
x: The x-coordinates of the generated points.
y: The y-coordinates of the generated points.
"""
def __init__(self, seed: int | None = None):
super().__init__()
self.seed = seed
[docs]
def generate_points(self, num_points: int):
"""Generates Sobol points.
Args:
num_points (int): The number of points to generate.
"""
sample = be.sobol_sampler(
dim=2, num_samples=num_points, scramble=True, seed=self.seed
)
u1 = sample[:, 0]
u2 = sample[:, 1]
r = be.sqrt(u1)
theta = 2 * be.pi * u2
self.x = r * be.cos(theta)
self.y = r * be.sin(theta)
[docs]
def create_distribution(distribution_type: DistributionType) -> BaseDistribution:
"""Create a distribution based on the given distribution type.
Args:
distribution_type: The type of distribution to create.
Returns:
An instance of the specified distribution type.
Raises:
ValueError: If an invalid distribution type is provided.
"""
distribution_classes: dict[
DistributionType, type[BaseDistribution] | Callable[[], BaseDistribution]
] = {
"line_x": LineXDistribution,
"line_y": LineYDistribution,
"positive_line_x": lambda: LineXDistribution(positive_only=True),
"positive_line_y": lambda: LineYDistribution(positive_only=True),
"random": RandomDistribution,
"uniform": UniformDistribution,
"hexapolar": HexagonalDistribution,
"cross": CrossDistribution,
"ring": RingDistribution,
"sobol": SobolDistribution,
}
if distribution_type not in distribution_classes:
raise ValueError("Invalid distribution type.")
return distribution_classes[distribution_type]()