Source code for sources.visualization

"""Sources Visualization Module

This module provides visualization tools for extended sources, allowing users
to validate and visualize their source definitions before running full
optical system traces.

The SourceViewer class creates a 3-panel plot showing:
1. XY spatial distribution with intensity color-coding
2. XZ ray propagation paths
3. YZ ray propagation paths

This helps users verify the spatial and angular properties of their sources.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np

import optiland.backend as be
from optiland.visualization.base import BaseViewer

if TYPE_CHECKING:
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure

    from optiland.sources.base import BaseSource


[docs] class SourceViewer(BaseViewer): """A class used to visualize extended sources. This viewer creates a comprehensive 3-panel visualization of an extended source showing the spatial distribution and propagation characteristics of the generated rays. Args: source (BaseSource): The extended source to be visualized. Attributes: source (BaseSource): The extended source being visualized. Methods: view(num_rays, propagation_distance, figsize): Creates the source visualization. """ def __init__(self, source: BaseSource): """Initialize the SourceViewer with a source. Args: source (BaseSource): The extended source to visualize. """ self.source = source
[docs] def view( self, num_rays: int = 5000, propagation_distance: float = 0.1, figsize: tuple[float, float] = (20, 8), cross_spatial: tuple[str, str] = ("x", "y"), cross_angular: tuple[str, str] = ("L", "M"), ) -> tuple[Figure, list[Axes]]: """Create a comprehensive visualization of the extended source. This method generates a multi-panel plot showing: 1. Column 1: XY spatial distribution and angular distribution 2. Column 2: Cross-sections (spatial and angular) 3. Column 3: XZ and YZ ray propagation paths Args: num_rays (int, optional): Number of rays to generate for visualization. Defaults to 5000. propagation_distance (float, optional): Distance in mm to propagate rays for path visualization. Defaults to 0.1. figsize (tuple[float, float], optional): Figure size (width, height) in inches. Defaults to (18, 8). cross_spatial (tuple[str, str], optional): Spatial cross-section axes. Defaults to ("x", "y"). cross_angular (tuple[str, str], optional): Angular cross-section axes. Defaults to ("L", "M"). Returns: tuple[Figure, list[Axes]]: Matplotlib figure and list of 6 axes objects. """ # Generate rays from the source rays = self.source.generate_rays(num_rays) fig, axes = plt.subplots(2, 3, figsize=figsize) x = be.to_numpy(rays.x) y = be.to_numpy(rays.y) z = be.to_numpy(rays.z) L = be.to_numpy(rays.L) M = be.to_numpy(rays.M) N = be.to_numpy(rays.N) # --- Calculate Theoretical Radiance for Visualization --- # This value is for color-coding only and is separate from rays.i (power) from optiland.sources.smf import SMFSource if isinstance(self.source, SMFSource): # SMFSource has both spatial and angular Gaussian profiles w_spatial = self.source.sigma_spatial_mm * 2 # 1/e² radius in mm w_angular = self.source.sigma_angular_rad * 2 # 1/e² half-angle term_x = -2.0 * (rays.x / w_spatial) ** 2 term_y = -2.0 * (rays.y / w_spatial) ** 2 term_L = -2.0 * (rays.L / w_angular) ** 2 term_M = -2.0 * (rays.M / w_angular) ** 2 radiance = be.exp(term_x + term_y + term_L + term_M) else: # Default for unknown source types: use ray power radiance = rays.i # Convert to numpy and normalize for plotting radiance_np = be.to_numpy(radiance) radiance_norm = ( radiance_np / np.max(radiance_np) if np.max(radiance_np) > 0 else radiance_np ) # Column 1: Spatial and Angular Distributions # Panel (0,0): XY Spatial Distribution scatter1 = axes[0, 0].scatter( x, y, c=radiance_norm, s=5, alpha=0.6, cmap="viridis" ) axes[0, 0].set_xlabel("X [mm]") axes[0, 0].set_ylabel("Y [mm]") axes[0, 0].set_title(f"{type(self.source).__name__}\nSpatial Distribution") axes[0, 0].set_aspect("equal", adjustable="box") axes[0, 0].grid(alpha=0.3) cbar1 = plt.colorbar(scatter1, ax=axes[0, 0], fraction=0.046, pad=0.04) cbar1.set_label("Normalized Radiance") # Panel (1,0): Angular Distribution (L vs M) scatter2 = axes[1, 0].scatter( L, M, c=radiance_norm, s=5, alpha=0.6, cmap="viridis" ) axes[1, 0].set_xlabel("L (Direction Cosine)") axes[1, 0].set_ylabel("M (Direction Cosine)") axes[1, 0].set_title("Angular Distribution") axes[1, 0].set_aspect("equal", adjustable="box") axes[1, 0].grid(alpha=0.3) cbar2 = plt.colorbar(scatter2, ax=axes[1, 0], fraction=0.046, pad=0.04) cbar2.set_label("Normalized Radiance") # Column 2: Cross-sections # Panel (0,1): Spatial Cross-sections self._plot_cross_sections( axes[0, 1], x, y, radiance_norm, cross_spatial, ["X [mm]", "Y [mm]"], "Spatial Cross-Sections", spatial=True, ) # Panel (1,1): Angular Cross-sections self._plot_cross_sections( axes[1, 1], L, M, radiance_norm, cross_angular, ["L", "M"], "Angular Cross-Sections", spatial=False, ) # Column 3: Propagation Views # Panel (0,2): XZ Propagation self._plot_ray_propagation( axes[0, 2], x, z, L, N, radiance_norm, propagation_distance, "Z [mm]", "X [mm]", "XZ Ray Propagation", ) # Panel (1,2): YZ Propagation self._plot_ray_propagation( axes[1, 2], y, z, M, N, radiance_norm, propagation_distance, "Z [mm]", "Y [mm]", "YZ Ray Propagation", ) plt.tight_layout() return fig, axes.flatten().tolist()
def _plot_cross_sections( self, ax: Axes, coord1: np.ndarray, coord2: np.ndarray, intensity: np.ndarray, axes_labels: tuple[str, str], axis_units: list[str], title: str, spatial: bool = True, num_bins: int = 50, ) -> None: """Plot cross-sections of spatial or angular distributions. Args: ax (Axes): Matplotlib axes to plot on. coord1 (np.ndarray): First coordinate array. coord2 (np.ndarray): Second coordinate array. intensity (np.ndarray): Intensity values for weighting. axes_labels (tuple[str, str]): Labels for the axes being plotted. axis_units (list[str]): Units for axis labels. title (str): Plot title. spatial (bool): Whether this is spatial (True) or angular (False). num_bins (int): Number of bins for histograms. """ # Handle spatial vs angular cross-sections differently if spatial: # For spatial coordinates (x, y): Use density=True to avoid double-weighting # Rays are already spatially distributed according to the beam profile hist1, bins1 = np.histogram(coord1, bins=num_bins, density=True) hist2, bins2 = np.histogram(coord2, bins=num_bins, density=True) else: # For angular coordinates (L, M): Use intensity weighting weights1 = intensity / np.sum(intensity) if np.sum(intensity) > 0 else None weights2 = intensity / np.sum(intensity) if np.sum(intensity) > 0 else None # Create bins bins1 = np.linspace(coord1.min(), coord1.max(), num_bins) bins2 = np.linspace(coord2.min(), coord2.max(), num_bins) # Calculate weighted histograms hist1, _ = np.histogram(coord1, bins=bins1, weights=weights1) hist2, _ = np.histogram(coord2, bins=bins2, weights=weights2) # Normalize histograms hist1 = hist1 / np.max(hist1) if np.max(hist1) > 0 else hist1 hist2 = hist2 / np.max(hist2) if np.max(hist2) > 0 else hist2 # Plot cross-sections bin_centers1 = (bins1[:-1] + bins1[1:]) / 2 bin_centers2 = (bins2[:-1] + bins2[1:]) / 2 ax.plot(bin_centers1, hist1, "b-", linewidth=2, label=f"{axes_labels[0]}") ax.plot(bin_centers2, hist2, "r-", linewidth=2, label=f"{axes_labels[1]}") ax.set_xlabel(f"{axes_labels[0]} / {axes_labels[1]} {axis_units[0]}") ax.set_ylabel("Normalized Intensity") ax.set_title(title) ax.grid(alpha=0.3) ax.legend() def _plot_ray_propagation( self, ax: Axes, coord1: np.ndarray, coord2: np.ndarray, dir1: np.ndarray, dir2: np.ndarray, intensity: np.ndarray, distance: float, xlabel: str, ylabel: str, title: str, ) -> None: """Plot ray propagation in a 2D plane with intensity-based coloring. Args: ax (Axes): Matplotlib axes to plot on. coord1 (np.ndarray): Starting coordinates for first dimension. coord2 (np.ndarray): Starting coordinates for second dimension. dir1 (np.ndarray): Direction cosines for first dimension. dir2 (np.ndarray): Direction cosines for second dimension. intensity (np.ndarray): Ray intensities for color mapping. distance (float): Propagation distance in mm. xlabel (str): Label for x-axis. ylabel (str): Label for y-axis. title (str): Plot title. """ # Calculate end points after propagation end_coord1 = coord1 + dir1 * distance end_coord2 = coord2 + dir2 * distance # Sample subset of rays for clearer visualization (max 1000 rays) num_rays = len(coord1) if num_rays > 1000: indices = np.linspace(0, num_rays - 1, 1000, dtype=int) coord1 = coord1[indices] coord2 = coord2[indices] end_coord1 = end_coord1[indices] end_coord2 = end_coord2[indices] intensity_subset = intensity[indices] else: intensity_subset = intensity # Create colormap for intensity visualization colors = plt.cm.viridis(intensity_subset) # Plot ray paths with intensity-based coloring for i in range(len(coord1)): ax.plot( [coord2[i], end_coord2[i]], [coord1[i], end_coord1[i]], color=colors[i], alpha=0.1, linewidth=0.5, ) # Plot end points with intensity coloring ax.scatter( end_coord2, end_coord1, c=intensity_subset, s=3, alpha=0.8, cmap="viridis", label="Ray Origins", ) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) ax.grid(alpha=0.3) ax.legend()