"""System Visualization Module
This module contains the OpticalSystem class for visualizing optical systems.
Kramer Harrison, 2024
"""
from __future__ import annotations
import optiland.backend as be
from optiland.visualization.system.lens import Lens2D, Lens3D
from optiland.visualization.system.mirror import Mirror3D
from optiland.visualization.system.surface import Surface2D, Surface3D
from optiland.visualization.system.utils import transform
[docs]
class OpticalSystem:
"""A class to represent an optical system for visualization. The optical
system contains surfaces and lenses.
Args:
optic (Optic): The optical system to be used for plotting.
rays (Rays): The rays interacting with the optical system.
projection (str): The type of projection for visualization.
Must be '2d' or '3d'.
Attributes:
optic (Optic): The optical system to be used for plotting.
rays (Rays): The rays interacting with the optical system.
projection (str): The type of projection for visualization.
Must be '2d' or '3d'.
components (list): A list to store the components of the optical
system.
component_registry (dict): A registry mapping component names to their
respective classes for 2D and 3D projections.
Methods:
plot(ax):
Identifies and plots the components of the optical system on the
given axis (or renderer for 3D plotting).
"""
def __init__(self, optic, rays, projection="2d"):
self.optic = optic
self.rays = rays
self.projection = projection
self.components = [] # initialize empty list of components
if self.projection not in ["2d", "3d"]:
raise ValueError("Invalid projection type. Must be '2d' or '3d'.")
self.component_registry = {
"lens": {"2d": Lens2D, "3d": Lens3D},
"mirror": {"2d": Surface2D, "3d": Mirror3D},
"surface": {"2d": Surface2D, "3d": Surface3D},
}
[docs]
def plot(self, ax, theme=None, projection="YZ", show_apertures=True):
"""Plots the components of the optical system on the given
axis (or renderer for 3D plotting).
"""
self._identify_components()
artists = {}
for component in self.components:
component_artists = component.plot(ax, theme=theme, projection=projection)
if component_artists:
artists.update(component_artists)
if show_apertures and self.projection == "2d":
aperture_artists = self._plot_apertures(ax, projection=projection)
artists.update(aperture_artists)
return artists
def _identify_components(self):
"""Identifies the components of the optical system and adds them to the
list of components.
"""
self.components = []
n = self.optic.surfaces.n(self.optic.primary_wavelength) # refractive indices
num_surf = self.optic.surfaces.num_surfaces
lens_surfaces = []
for k, surf in enumerate(self.optic.surfaces):
# Get the surface extent
extent = self.rays.r_extent[k]
# Object surface
if k == 0:
if not surf.is_infinite:
self._add_component("surface", surf, extent)
# Image surface or paraxial surface
elif k == num_surf - 1 or surf.surface_type == "paraxial":
self._add_component("surface", surf, extent)
# Surface is a mirror
elif surf.interaction_model.is_reflective:
if lens_surfaces: # Second surface mirror (lens + mirror)
surface = self._get_lens_surface(surf, extent)
lens_surfaces.append(surface)
self._add_component("lens", lens_surfaces)
lens_surfaces = []
else:
self._add_component("mirror", surf, extent)
# Front surface of a lens
elif n[k] > 1:
surface = self._get_lens_surface(surf, extent)
lens_surfaces.append(surface)
# Back surface of a lens
elif n[k] == 1 and n[k - 1] > 1 and lens_surfaces:
surface = self._get_lens_surface(surf, extent)
lens_surfaces.append(surface)
self._add_component("lens", lens_surfaces)
lens_surfaces = []
# Standalone phase surface
elif surf.interaction_model.interaction_type == "phase":
self._add_component("surface", surf, extent)
# add final lens, if any
if lens_surfaces:
self._add_component("lens", lens_surfaces)
def _add_component(self, component_name, *args):
"""Adds a component to the list of components."""
if component_name in self.component_registry:
component_class = self.component_registry[component_name][self.projection]
else:
raise ValueError(f"Component {component_name} not found in registry.")
self.components.append(component_class(*args))
def _get_lens_surface(self, surface, *args):
"""Gets the lens surface based on the projection type."""
surface_class = self.component_registry["surface"][self.projection]
return surface_class(surface, *args)
def _plot_apertures(self, ax, projection="YZ"):
if projection == "XY":
return {}
if projection not in ("XZ", "YZ"):
raise ValueError("Invalid projection type. Must be 'XY', 'XZ', or 'YZ'.")
stop_color = "black" # arrow color for stop apertures
aperture_color = "grey" # arrow color for other apertures
artists = {}
n = self.optic.surfaces.n(self.optic.primary_wavelength)
for idx, surface in enumerate(self.optic.surfaces):
if idx > 0:
is_lens_surface = n[idx] > 1 or (n[idx] == 1 and n[idx - 1] > 1)
else:
is_lens_surface = n[idx] > 1
if is_lens_surface and not surface.is_stop:
continue
# Skip surfaces without apertures (unless stop)
if surface.aperture is None and not surface.is_stop:
continue
# Determine aperture extent
if surface.aperture is not None:
x_min, x_max, y_min, y_max = surface.aperture.extent
elif surface.semi_aperture is not None:
r = surface.semi_aperture
x_min, x_max, y_min, y_max = -r, r, -r, r
elif (
surface.is_stop
and self.optic.aperture is not None
and self.optic.aperture.ap_type == "float_by_stop_size"
):
r = 0.5 * self.optic.aperture.value
x_min, x_max, y_min, y_max = -r, r, -r, r
elif surface.is_stop and self.rays is not None:
r = be.to_numpy(self.rays.r_extent[idx]).item()
if r <= 0:
continue
x_min, x_max, y_min, y_max = -r, r, -r, r
else:
continue
# Define local coordinates based on projection
x_local = be.array([x_min, x_max])
y_local = be.array([y_min, y_max])
z_local = be.array([0.0, 0.0])
x_global, y_global, z_global = transform(
x_local, y_local, z_local, surface, is_global=False
)
x_global = be.to_numpy(x_global)
y_global = be.to_numpy(y_global)
z_global = be.to_numpy(z_global)
# Draw line for aperture edge
axis_vals = x_global if projection == "XZ" else y_global
(line,) = ax.plot(
z_global,
axis_vals,
color="black",
linewidth=0.3,
)
artists[line] = surface
# Add arrows to indicate aperture extent
eps = 1e-6
facecolor = stop_color if surface.is_stop else aperture_color
arrowprops = {"arrowstyle": "-|>", "facecolor": facecolor, "linewidth": 0}
axis_vals = x_global if projection == "XZ" else y_global
for z_val, axis_val, sign in (
(z_global[1], axis_vals[1], 1), # top
(z_global[0], axis_vals[0], -1), # bottom
):
ax.annotate(
"",
xy=(z_val, axis_val),
xytext=(z_val, axis_val + sign * eps),
arrowprops=arrowprops,
)
return artists