Source code for visualization.system.rays

"""Rays Visualization Module

This module contains classes for visualizing rays in an optical system.

Kramer Harrison, 2024
"""

from __future__ import annotations

import numpy as np
import vtk

import optiland.backend as be
from optiland.utils import resolve_fields, resolve_wavelengths
from optiland.visualization.system.ray_bundle import RayBundle
from optiland.visualization.system.utils import transform


[docs] class Rays2D: """A class to represent and visualize 2D rays in an optical system. Args: optic (Optic): The optical system to be visualized. Attributes: optic (Optic): The optical system containing surfaces and fields. x (be.ndarray): X-coordinates of the rays. y (be.ndarray): Y-coordinates of the rays. z (be.ndarray): Z-coordinates of the rays. i (be.ndarray): Intensities of the rays. x_extent (be.ndarray): Extents of the x-coordinates for each surface. y_extent (be.ndarray): Extents of the y-coordinates for each surface. r_extent (be.ndarray): Extents of the radii for each surface. Methods: plot(ax, fields='all', wavelengths='primary', num_rays=3, distribution='line_y'): """ def __init__(self, optic): self.optic = optic self.x = None self.y = None self.z = None self.i = None n = optic.surfaces.num_surfaces self.r_extent = be.zeros(n)
[docs] def plot( self, ax, fields="all", wavelengths="primary", num_rays=3, distribution="line_y", reference=None, theme=None, projection="YZ", hide_vignetted=False, ): """Plots the rays for the given fields and wavelengths. Args: ax: The matplotlib axis to plot on. fields: The fields at which to trace the rays. Default is 'all'. wavelengths: The wavelengths at which to trace the rays. Default is 'primary'. num_rays: The number of rays to trace for each field and wavelength. Default is 3. distribution: The distribution of the rays. Default is 'line_y'. reference (str, optional): The reference rays to plot. Options include "chief" and "marginal". Defaults to None. theme (Theme, optional): The theme to apply. Defaults to None. hide_vignetted (bool, optional): If True, rays that vignette at any surface are not shown. Defaults to False. """ # Visualization ignores weights; extract coords and values directly field_points = resolve_fields(self.optic, fields) wl_points = resolve_wavelengths(self.optic, wavelengths) fields_coords = [fp.coord for fp in field_points] wavelengths_vals = [wp.value for wp in wl_points] artists = {} for i, field in enumerate(fields_coords): for j, wavelength in enumerate(wavelengths_vals): # if only one field, use different colors for each wavelength color_idx = i if len(fields_coords) > 1 else j if distribution is None: # trace only for surface extents self._trace(field, wavelength, num_rays, "line_y") else: # trace rays and plot lines self._trace(field, wavelength, num_rays, distribution) artists.update( self._plot_lines( ax, color_idx, field, theme=theme, projection=projection, hide_vignetted=hide_vignetted, ) ) # trace reference rays and plot lines if reference is not None: self._trace_reference(field, wavelength, reference) artists.update( self._plot_lines( ax, color_idx, field, linewidth=1.5, theme=theme, projection=projection, hide_vignetted=hide_vignetted, ) ) return artists
def _process_traced_rays(self): """Processes the traced rays and updates the surface extents.""" self.x = self.optic.surfaces.x self.y = self.optic.surfaces.y self.z = self.optic.surfaces.z self.i = self.optic.surfaces.intensity # update surface extents self._update_surface_extents() def _trace(self, field, wavelength, num_rays, distribution): """Traces rays through the optical system and updates the surface extents. Args: field (tuple): The field coordinates for the ray tracing. wavelength (float): The wavelength of the rays. num_rays (int): The number of rays to trace. distribution (str): The distribution pattern of the rays. Returns: None """ self.optic.trace(*field, wavelength, num_rays, distribution) self._process_traced_rays() def _trace_reference(self, field, wavelength, reference): """Traces reference rays through the optical system. Args: field (tuple): The field coordinates for the ray tracing. wavelength (float): The wavelength of the rays. reference (str): The type of reference rays to trace. Returns: None """ if reference == "chief": self.optic.trace_generic(*field, Px=0, Py=0, wavelength=wavelength) elif reference == "marginal": self.optic.trace_generic(*field, Px=0, Py=1, wavelength=wavelength) else: raise ValueError(f"Invalid ray reference type: {reference}") self._process_traced_rays() def _update_surface_extents(self): """Updates the extents of the surfaces in the optic's surface group.""" r_extent_new = be.copy(be.zeros_like(self.r_extent)) for i, surf in enumerate(self.optic.surfaces): x_surf = self.x[i] y_surf = self.y[i] z_surf = self.z[i] # Convert to local coordinate system x, y, _ = transform(x_surf, y_surf, z_surf, surf, is_global=True) r_extent_new[i] = be.nanmax(be.hypot(x, y)) self.r_extent = be.fmax(self.r_extent, r_extent_new) def _plot_lines( self, ax, color_idx, field, linewidth=1, theme=None, projection="YZ", hide_vignetted=False, ): """Plots multiple lines on the given axis. This method iterates through the rays stored in the object's attributes (self.x, self.y, self.z, self.i) and plots each valid ray on the provided axis. Rays that are outside the aperture (where self.i == 0) are excluded from the plot. Args: ax (matplotlib.axes.Axes): The axis on which to plot the lines. color_idx (int): The index used to determine the color of the lines. field (tuple): The field coordinates for the ray. linewidth (float): The width of the line. theme (Theme, optional): The theme to apply. Defaults to None. projection (str, optional): The projection plane. Must be 'XY', 'XZ', or 'YZ'. Defaults to 'YZ'. Returns: None """ artists = {} bundle_id = f"bundle_{color_idx}" # loop through rays for k in range(self.z.shape[1]): xk = be.to_numpy(self.x[:, k]) yk = be.to_numpy(self.y[:, k]) zk = be.to_numpy(self.z[:, k]) ik = be.to_numpy(self.i[:, k]) if np.any(ik == 0): if hide_vignetted: continue first_zero_idx = np.where(ik == 0)[0][0] xk[first_zero_idx + 1 :] = np.nan yk[first_zero_idx + 1 :] = np.nan zk[first_zero_idx + 1 :] = np.nan artist, ray_bundle = self._plot_single_line( ax, xk, yk, zk, color_idx, field, linewidth, theme=theme, projection=projection, ) ray_bundle.bundle_id = bundle_id artists[artist] = ray_bundle return artists def _plot_single_line( self, ax, x, y, z, color_idx, field, linewidth=1, theme=None, projection="YZ" ): """Plots a single line on the given axes. Args: ax (matplotlib.axes.Axes): The axes on which to plot the line. x (array-like): The x-coordinates of the line. y (array-like): The y-coordinates of the line. z (array-like): The z-coordinates of the line. color_idx (int): The index for the color to use for the line. field (tuple): The field coordinates for the ray. linewidth (float): The width of the line. Default is 1. theme (Theme, optional): The theme to apply. Defaults to None. projection (str, optional): The projection plane. Must be 'XY', 'XZ', or 'YZ'. Defaults to 'YZ'. Returns: None """ if theme: ray_cycle = theme.parameters.get("ray_cycle") color = ray_cycle[color_idx % len(ray_cycle)] else: color = f"C{color_idx}" if projection == "XY": (line,) = ax.plot(x, y, color=color, linewidth=linewidth) elif projection == "XZ": (line,) = ax.plot(z, x, color=color, linewidth=linewidth) else: # YZ (line,) = ax.plot(z, y, color=color, linewidth=linewidth) return line, RayBundle(x, y, z, field)
[docs] class Rays3D(Rays2D): """A class to represent 3D rays for visualization using VTK. Inherits from Rays2D and extends functionality to 3D. Methods: plot(ax, fields='all', wavelengths='primary', num_rays=3, distribution='line_y'): Args: optic: The optical system to be visualized. """
[docs] def plot( self, ax, fields="all", wavelengths="primary", num_rays=3, distribution="line_y", reference=None, theme=None, hide_vignetted=False, ): """Plots the rays for the given fields and wavelengths. Args: ax: The matplotlib axis to plot on. fields: The fields at which to trace the rays. Default is 'all'. wavelengths: The wavelengths at which to trace the rays. Default is 'primary'. num_rays: The number of rays to trace for each field and wavelength. Default is 3. distribution: The distribution of the rays. Default is 'line_y'. reference (str, optional): The reference rays to plot. Options include "chief" and "marginal". Defaults to None. theme (Theme, optional): The theme to apply. Defaults to None. hide_vignetted (bool, optional): If True, rays that vignette at any surface are not shown. Defaults to False. """ # Visualization ignores weights; extract coords and values directly field_points = resolve_fields(self.optic, fields) wl_points = resolve_wavelengths(self.optic, wavelengths) fields_coords = [fp.coord for fp in field_points] wavelengths_vals = [wp.value for wp in wl_points] for i, field in enumerate(fields_coords): for j, wavelength in enumerate(wavelengths_vals): # if only one field, use different colors for each wavelength color_idx = i if len(fields_coords) > 1 else j if distribution is None: # trace only for surface extents self._trace(field, wavelength, num_rays, "line_y") else: # trace rays and plot lines self._trace(field, wavelength, num_rays, distribution) self._plot_lines( ax, color_idx, field, theme=theme, hide_vignetted=hide_vignetted, ) # trace reference rays and plot lines if reference is not None: self._trace_reference(field, wavelength, reference) self._plot_lines( ax, color_idx, field, linewidth=1.5, theme=theme, hide_vignetted=hide_vignetted, )
def __init__(self, optic): super().__init__(optic) # matplotlib default colors converted to RGB self._rgb_colors = [ (0.122, 0.467, 0.706), (1.000, 0.498, 0.055), (0.173, 0.627, 0.173), (0.839, 0.153, 0.157), (0.580, 0.404, 0.741), (0.549, 0.337, 0.294), (0.890, 0.467, 0.761), (0.498, 0.498, 0.498), (0.737, 0.741, 0.133), (0.090, 0.745, 0.812), ] def _plot_lines( self, ax, color_idx, field, linewidth=1, theme=None, hide_vignetted=False ): # loop through rays for k in range(self.z.shape[1]): xk = be.to_numpy(self.x[:, k]) yk = be.to_numpy(self.y[:, k]) zk = be.to_numpy(self.z[:, k]) ik = be.to_numpy(self.i[:, k]) if np.any(ik == 0): if hide_vignetted: continue first_zero_idx = np.where(ik == 0)[0][0] xk[first_zero_idx + 1 :] = np.nan yk[first_zero_idx + 1 :] = np.nan zk[first_zero_idx + 1 :] = np.nan self._plot_single_line( ax, xk, yk, zk, color_idx, field, linewidth, theme=theme ) def _plot_single_line( self, renderer, x, y, z, color_idx, field, linewidth=1, theme=None ): """Plots a single line in 3D space using VTK with the specified coordinates and color index. Args: renderer (vtkRenderer): The VTK renderer to add the line actor to. x (list of float): The x-coordinates of the line. y (list of float): The y-coordinates of the line. z (list of float): The z-coordinates of the line. color_idx (int): The index of the color to use from the _rgb_colors list. field (tuple): The field coordinates for the ray. linewidth (float): The width of the line. Default is 1. theme (Theme, optional): The theme to apply. Defaults to None. """ if theme: from matplotlib.colors import to_rgb ray_cycle = theme.parameters.get("ray_cycle") color = to_rgb(ray_cycle[color_idx % len(ray_cycle)]) else: color = self._rgb_colors[color_idx % 10] for k in range(1, len(x)): if np.isnan(x[k - 1]) or np.isnan(x[k]): continue p0 = [x[k - 1], y[k - 1], z[k - 1]] p1 = [x[k], y[k], z[k]] line_source = vtk.vtkLineSource() line_source.SetPoint1(p0) line_source.SetPoint2(p1) line_mapper = vtk.vtkPolyDataMapper() line_mapper.SetInputConnection(line_source.GetOutputPort()) line_actor = vtk.vtkActor() line_actor.SetMapper(line_mapper) line_actor.GetProperty().SetLineWidth(linewidth) line_actor.GetProperty().SetColor(color) renderer.AddActor(line_actor)