Source code for distribution

"""Distribution Module

This module provides various classes representing 2D pupil distributions.

Kramer Harrison, 2024
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np

import optiland.backend as be

if TYPE_CHECKING:
    from collections.abc import Callable

    from matplotlib.axes import Axes
    from matplotlib.figure import Figure

    from optiland._types import BEArray, DistributionType


[docs] class BaseDistribution(ABC): """Base class for distributions. This class provides a base implementation for generating points and visualizing the distribution. Attributes: x: The x-coordinates of the generated points. y: The y-coordinates of the generated points. """ def __init__(self): self.x: BEArray = be.empty(0) self.y: BEArray = be.empty(0)
[docs] @abstractmethod def generate_points(self, num_points: int): """Generate points based on the distribution. Args: num_points (int): The number of points to generate. """
# pragma: no cover
[docs] def view(self) -> tuple[Figure, Axes]: """Visualize the distribution. This method plots the distribution points and a unit circle for reference. Returns: A tuple containing the figure and axes of the plot. """ fig, ax = plt.subplots() ax.plot(be.to_numpy(self.x), be.to_numpy(self.y), "k*") t = np.linspace(0, 2 * be.pi, 256) x, y = np.cos(t), np.sin(t) ax.plot(x, y, "r") ax.set_xlabel("Normalized Pupil Coordinate X") ax.set_ylabel("Normalized Pupil Coordinate Y") ax.axis("equal") return fig, ax
[docs] class LineXDistribution(BaseDistribution): """A class representing a line distribution along the x-axis. Generates `num_points` along the x-axis. Attributes: positive_only (bool): Flag indicating whether the distribution should be limited to positive values only. x: The x-coordinates of the generated points. y: The y-coordinates of the generated points. """ def __init__(self, positive_only: bool = False): self.positive_only = positive_only
[docs] def generate_points(self, num_points: int): """Generates points along the x-axis based on the specified parameters. Args: num_points (int): The number of points to generate. """ if self.positive_only: self.x = be.linspace(0, 1, num_points) else: self.x = be.linspace(-1, 1, num_points) self.y = be.zeros([num_points])
[docs] class LineYDistribution(BaseDistribution): """A class representing a line distribution along the y-axis. Generates `num_points` along the y-axis. Attributes: positive_only: Flag indicating whether the distribution should be positive-only. x: The x-coordinates of the generated points. y: The y-coordinates of the generated points. """ def __init__(self, positive_only: bool = False): self.positive_only = positive_only
[docs] def generate_points(self, num_points: int): """Generates points along the line distribution. Args: num_points (int): The number of points to generate. """ self.x = be.zeros([num_points]) if self.positive_only: self.y = be.linspace(0, 1, num_points) else: self.y = be.linspace(-1, 1, num_points)
[docs] class RandomDistribution(BaseDistribution): """A class representing a random distribution. Generates `num_points` random points within the unit disk. Attributes: rng (be.Generator): The random number generator from the backend. x: The x-coordinates of the generated points. y: The y-coordinates of the generated points. """ def __init__(self, seed=None): self.rng = be.default_rng(seed)
[docs] def generate_points(self, num_points: int): """Generates random points. Args: num_points (int): The number of points to generate. """ r = be.random_uniform(size=num_points, generator=self.rng) theta = be.random_uniform(0, 2 * be.pi, size=num_points, generator=self.rng) self.x = be.sqrt(r) * be.cos(theta) self.y = be.sqrt(r) * be.sin(theta)
[docs] class UniformDistribution(BaseDistribution): """Represents a uniform distribution of points within a square, which is masked to the unit disk. Generates points on a square grid of `num_points` x `num_points` and then masks them to the unit disk. The resulting number of points is approximately `num_points^2 * pi / 4`. Attributes: x: The x-coordinates of the generated points. y: The y-coordinates of the generated points. """
[docs] def generate_points(self, num_points: int): """Generates a grid of points within the unit disk. Args: num_points (int): The number of points along each axis to generate. """ x = be.linspace(-1, 1, num_points) x, y = be.meshgrid(x, x) r2 = x**2 + y**2 self.x = x[r2 <= 1] self.y = y[r2 <= 1]
[docs] class HexagonalDistribution(BaseDistribution): """A class representing a hexagonal distribution. Generates points in a hexagonal pattern. The total number of points is `1 + 3 * num_rings * (num_rings + 1)`, including the center point. Attributes: x: Array of x-coordinates of the generated points. y: Array of y-coordinates of the generated points. """
[docs] def generate_points(self, num_rings: int = 6): """Generate points in a hexagonal distribution. Args: num_rings: Number of rings in the hexagonal distribution. Defaults to 6. """ x = be.zeros([1]) y = be.zeros([1]) r = be.linspace(0, 1, num_rings + 1) for i in range(num_rings): num_theta = 6 * (i + 1) theta = be.linspace(0, 2 * be.pi, num_theta + 1)[:-1] x = be.concatenate([x, r[i + 1] * be.cos(theta)]) y = be.concatenate([y, r[i + 1] * be.sin(theta)]) self.x = x self.y = y
[docs] class CrossDistribution(BaseDistribution): """A class representing a cross-shaped distribution. This distribution generates points in the shape of a cross, with the x-axis and y-axis as the arms of the cross. If `num_points` is odd, it generates `2 * num_points - 1` points. If `num_points` is even and positive, it generates `2 * num_points` points. `num_points` represents the number of points along the full extent of each axis before potential origin merging. If `num_points` is 0, 0 points are generated. Attributes: x: Array of x-coordinates of the generated points. y: Array of y-coordinates of the generated points. """
[docs] def generate_points(self, num_points: int): """Generate points in the shape of a cross. Args: num_points: The number of points to generate in each axis. """ # Generate points for the y-axis (vertical line) y_line_x = be.zeros([num_points]) y_line_y = be.linspace(-1, 1, num_points) # Generate points for the x-axis (horizontal line) x_line_x = be.linspace(-1, 1, num_points) x_line_y = be.zeros([num_points]) # If num_points is odd, linspace(-1, 1, num_points) includes 0 at the midpoint. # This means (0,0) is part of y_line (x=0, y=0) and x_line (x=0, y=0). # To avoid duplication, we remove the (0,0) point from the x_line set. if num_points % 2 == 1: mid_idx = num_points // 2 # Remove the middle element which corresponds to (0,0) for the x_line x_line_x = be.concatenate((x_line_x[:mid_idx], x_line_x[mid_idx + 1 :])) x_line_y = be.concatenate((x_line_y[:mid_idx], x_line_y[mid_idx + 1 :])) self.x = be.concatenate((y_line_x, x_line_x)) self.y = be.concatenate((y_line_y, x_line_y))
[docs] class GaussianQuadrature(BaseDistribution): """GaussianQuadrature class for generating points and weights for Gaussian quadrature distribution. Generates points for Gaussian quadrature. If `is_symmetric` is true, `num_rings` points are generated. If `is_symmetric` is false, `3 * num_rings` points are generated. Attributes: is_symmetric: Indicates whether the distribution is symmetric about y. Defaults to False. Reference: G. W. Forbes, "Optical system assessment for design: numerical ray tracing in the Gaussian pupil," J. Opt. Soc. Am. A 5, 1943-1956 (1988) """ def __init__(self, is_symmetric=False): self.is_symmetric = is_symmetric
[docs] def generate_points(self, num_rings: int): """Generate points for Gaussian quadrature distribution. Args: num_rings: Number of rings for Gaussian quadrature. """ radius = self._get_radius(num_rings) if self.is_symmetric: theta = be.array([0.0]) else: theta = be.array([-1.04719755, 0.0, 1.04719755]) self.x = be.outer(radius, be.cos(theta)).flatten() self.y = be.outer(radius, be.sin(theta)).flatten()
def _get_radius(self, num_rings: int) -> BEArray: """Get the radius values for the given number of rings. Args: num_rings: Number of rings for Gaussian quadrature. Returns: Radius values for the given number of rings. Raises: ValueError: If the number of rings is not between 1 and 6. """ radius_dict = { 1: be.array([0.70711]), 2: be.array([0.45970, 0.88807]), 3: be.array([0.33571, 0.70711, 0.94196]), 4: be.array([0.26350, 0.57446, 0.81853, 0.96466]), 5: be.array([0.21659, 0.48038, 0.70711, 0.87706, 0.97626]), 6: be.array([0.18375, 0.41158, 0.61700, 0.78696, 0.91138, 0.98300]), } if num_rings not in radius_dict: raise ValueError("Gaussian quadrature must have between 1 and 6 rings.") return radius_dict[num_rings]
[docs] def get_weights(self, num_rings: int) -> BEArray: """Get weights for Gaussian quadrature distribution. Args: num_rings: Number of rings for Gaussian quadrature. Returns: Array of weights. """ weights_dict = { 1: be.array([0.5]), 2: be.array([0.25, 0.25]), 3: be.array([0.13889, 0.22222, 0.13889]), 4: be.array([0.08696, 0.16304, 0.16304, 0.08696]), 5: be.array([0.059231, 0.11966, 0.14222, 0.11966, 0.059231]), 6: be.array([0.04283, 0.09019, 0.11698, 0.11698, 0.09019, 0.04283]), } if num_rings not in weights_dict: raise ValueError("Gaussian quadrature must have between 1 and 6 rings.") weights = weights_dict[num_rings] weights = weights * 6.0 if self.is_symmetric else weights * 2.0 return weights
[docs] class RingDistribution(BaseDistribution): """RingDistribution class for generating points along a single ring. Generates `num_points` along a single ring at the maximum aperture value (radius 1). """
[docs] def generate_points(self, num_points: int): """Generate points along a ring at the maximum aperture value. Args: num_points (int): The number of points to generate in each ring. """ theta = be.linspace(0, 2 * be.pi, num_points + 1)[:-1] self.x = be.cos(theta) self.y = be.sin(theta)
[docs] class SobolDistribution(BaseDistribution): """A class representing a Sobol distribution. Generates `num_points` points using a Sobol low-discrepancy sequence within the unit disk. Attributes: seed (int | None): Seed for the Sobol sequence generator. x: The x-coordinates of the generated points. y: The y-coordinates of the generated points. """ def __init__(self, seed: int | None = None): super().__init__() self.seed = seed
[docs] def generate_points(self, num_points: int): """Generates Sobol points. Args: num_points (int): The number of points to generate. """ sample = be.sobol_sampler( dim=2, num_samples=num_points, scramble=True, seed=self.seed ) u1 = sample[:, 0] u2 = sample[:, 1] r = be.sqrt(u1) theta = 2 * be.pi * u2 self.x = r * be.cos(theta) self.y = r * be.sin(theta)
[docs] def create_distribution(distribution_type: DistributionType) -> BaseDistribution: """Create a distribution based on the given distribution type. Args: distribution_type: The type of distribution to create. Returns: An instance of the specified distribution type. Raises: ValueError: If an invalid distribution type is provided. """ distribution_classes: dict[ DistributionType, type[BaseDistribution] | Callable[[], BaseDistribution] ] = { "line_x": LineXDistribution, "line_y": LineYDistribution, "positive_line_x": lambda: LineXDistribution(positive_only=True), "positive_line_y": lambda: LineYDistribution(positive_only=True), "random": RandomDistribution, "uniform": UniformDistribution, "hexapolar": HexagonalDistribution, "cross": CrossDistribution, "ring": RingDistribution, "sobol": SobolDistribution, } if distribution_type not in distribution_classes: raise ValueError("Invalid distribution type.") return distribution_classes[distribution_type]()