Source code for optic.extended_source_optic

"""Extended Source Optic Module

This module defines the ExtendedSourceOptic class, a wrapper around the core
Optic class that enables extended source ray tracing. This design keeps the
core Optic class unchanged while providing source-based ray tracing and
visualization functionality.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal

import optiland.backend as be

if TYPE_CHECKING:
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure

    from optiland.optic.optic import Optic
    from optiland.rays import RealRays
    from optiland.sources.base import BaseSource


[docs] class ExtendedSourceOptic: """Wrapper for Optic that enables extended source ray tracing. This class wraps a standard Optic instance and provides source-based ray tracing and visualization methods without modifying the core Optic API. It delegates all standard Optic attributes and methods to the underlying optic instance. Args: optic (Optic): The optical system to wrap. source (BaseSource): The extended source for ray generation. Attributes: optic (Optic): The underlying optical system. source (BaseSource): The extended source for generating rays. """ def __init__(self, optic: Optic, source: BaseSource): # Use object.__setattr__ to avoid triggering __getattr__ object.__setattr__(self, "optic", optic) object.__setattr__(self, "source", source) def __getattr__(self, name: str): """Delegate attribute access to the underlying optic. This allows transparent access to all Optic properties and methods that are not overridden by ExtendedSourceOptic. Args: name (str): The attribute name to look up. Returns: The attribute value from the underlying optic. Raises: AttributeError: If the attribute is not found on the optic. """ return getattr(self.optic, name) def __setattr__(self, name: str, value: Any): """Delegate attribute setting to the underlying optic. Attributes specifics to ExtendedSourceOptic ('optic', 'source') are set locally. All others are set on the underlying optic. """ if name in ("optic", "source"): object.__setattr__(self, name, value) else: setattr(self.optic, name, value) def __repr__(self) -> str: optic_name = self.optic.name or "Unnamed" source_type = type(self.source).__name__ return f"ExtendedSourceOptic(optic='{optic_name}', source={source_type})"
[docs] def trace(self, num_rays: int = 1000) -> tuple[RealRays, dict]: """Trace rays generated from the extended source through the optical system. This method generates rays using the attached source and traces them through the optical system. Args: num_rays (int): The number of rays to generate and trace. Defaults to 1000. Returns: tuple: A tuple containing: - RealRays: The traced rays (final positions and directions). - dict: Ray path data with 'x', 'y', 'z' arrays of shape (num_surfaces, num_rays). """ # Generate rays from the source rays = self.source.generate_rays(num_rays) # Trace the rays through the optical system traced_rays = self.optic.surfaces.trace(rays) # Get the full ray path through all surfaces ray_path = { "x": self.optic.surfaces.x, "y": self.optic.surfaces.y, "z": self.optic.surfaces.z, } return traced_rays, ray_path
[docs] def draw( self, num_rays: int = 1000, figsize: tuple[float, float] = (10, 4), xlim: tuple[float, float] | None = None, ylim: tuple[float, float] | None = None, title: str | None = None, projection: Literal["XY", "XZ", "YZ"] = "YZ", ax: Axes | None = None, ) -> tuple[Figure, Axes]: """Draw a 2D representation of the optical system with rays from the extended source. This method traces rays from the attached source and visualizes them through the optical system alongside the rendered surfaces. Args: num_rays (int, optional): The number of rays to generate and trace. Defaults to 1000. figsize (tuple[float, float], optional): The size of the figure. Defaults to (10, 4). xlim (tuple[float, float] | None, optional): The x-axis limits of the plot. Defaults to None. ylim (tuple[float, float] | None, optional): The y-axis limits of the plot. Defaults to None. title (str | None, optional): The title of the plot. Defaults to None, which auto-generates a title from the source type. projection (Literal["XY", "XZ", "YZ"], optional): The projection plane. Defaults to "YZ". ax (matplotlib.axes.Axes, optional): The axes to plot on. If None, a new figure and axes are created. Defaults to None. Returns: tuple[Figure, Axes]: A tuple containing the matplotlib Figure and Axes objects of the plot. """ import matplotlib.pyplot as plt from optiland.visualization.system.system import OpticalSystem from optiland.visualization.themes import get_active_theme # Generate and trace rays from the source traced_rays, ray_path = self.trace(num_rays) # Set up the figure theme = get_active_theme() params = theme.parameters if figsize is None: figsize = params["figure.figsize"] if ax is None: fig, ax = plt.subplots(figsize=figsize) fig.set_facecolor(params["figure.facecolor"]) else: fig = ax.get_figure() ax.set_facecolor(params["axes.facecolor"]) # We need surface extents for drawing. Trace a few traditional rays # if we have fields defined to establish proper surface extents. from optiland.visualization.system.rays import Rays2D rays_2d = Rays2D(self.optic) if len(self.optic.fields.fields) > 0: field = self.optic.fields.get_field_coords()[0] wavelength = self.optic.wavelengths.primary_wavelength.value rays_2d._trace(field, wavelength, 3, "line_y") # Draw the optical system surfaces system = OpticalSystem(self.optic, rays_2d, projection="2d") system.plot(ax, theme=theme, projection=projection) # Plot the rays from the source self._plot_source_rays(ax, ray_path, traced_rays, projection=projection) # Style the plot if projection == "YZ": ax.set_xlabel("Z [mm]", color=params["axes.labelcolor"]) ax.set_ylabel("Y [mm]", color=params["axes.labelcolor"]) elif projection == "XZ": ax.set_xlabel("Z [mm]", color=params["axes.labelcolor"]) ax.set_ylabel("X [mm]", color=params["axes.labelcolor"]) else: # XY ax.set_xlabel("X [mm]", color=params["axes.labelcolor"]) ax.set_ylabel("Y [mm]", color=params["axes.labelcolor"]) ax.tick_params(axis="x", colors=params["xtick.color"]) ax.tick_params(axis="y", colors=params["ytick.color"]) for spine in ax.spines.values(): spine.set_color(params["axes.edgecolor"]) ax.axis("image") if title: ax.set_title(title, color=params["text.color"]) else: ax.set_title( f"Optical System with {type(self.source).__name__}", color=params["text.color"], ) if xlim: ax.set_xlim(xlim) if ylim: ax.set_ylim(ylim) ax.grid( visible=True, color=params["grid.color"], alpha=params["grid.alpha"], ) plt.tight_layout() return fig, ax
def _plot_source_rays( self, ax, ray_path: dict, traced_rays: RealRays, projection: str = "YZ", ): """Plot rays from the extended source on the given axis. Args: ax: The matplotlib axis to plot on. ray_path: Dictionary with 'x', 'y', 'z' arrays of ray paths, each of shape (num_surfaces, num_rays). traced_rays: The traced rays object for intensity information. projection: The projection plane ('YZ', 'XZ', or 'XY'). """ x_coords = be.to_numpy(ray_path["x"]) y_coords = be.to_numpy(ray_path["y"]) z_coords = be.to_numpy(ray_path["z"]) # Get final ray intensities intensities = be.to_numpy(traced_rays.i) if len(y_coords.shape) != 2 or len(z_coords.shape) != 2: return num_surfaces, num_rays = y_coords.shape # Plot each ray path for k in range(num_rays): # Skip rays that were completely blocked if intensities[k] <= 0: continue if projection == "YZ": ax.plot( z_coords[:, k], y_coords[:, k], "b-", alpha=0.3, linewidth=0.5, ) elif projection == "XZ": ax.plot( z_coords[:, k], x_coords[:, k], "b-", alpha=0.3, linewidth=0.5, ) else: # XY ax.plot( x_coords[:, k], y_coords[:, k], "b-", alpha=0.3, linewidth=0.5, ) # --- Methods that don't apply for extended source optics ---
[docs] def trace_generic(self, *args, **kwargs): """Not available for ExtendedSourceOptic. Raises: NotImplementedError: Always raised. Use trace(num_rays) instead. """ raise NotImplementedError( "trace_generic() is not available for ExtendedSourceOptic. " "Use trace(num_rays) instead." )