"""Encircled Energy Analysis
This module provides an encircled energy analysis for optical systems.
Kramer Harrison, 2024
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
import optiland.backend as be
from .spot_diagram import SpotData, SpotDiagram
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
[docs]
class EncircledEnergy(SpotDiagram):
"""Class representing the Encircled Energy analysis of a given optic.
Args:
optic (Optic): The optic for which the Encircled Energy analysis is
performed.
fields (str or tuple, optional): The fields for which the analysis is
performed. Defaults to 'all'.
wavelength (str or float, optional): The wavelength at which the
analysis is performed. Defaults to 'primary'.
num_rays (int, optional): The number of rays used for the analysis.
Defaults to 100000.
distribution (str, optional): The distribution of rays.
Defaults to 'random'.
num_points (int, optional): The number of points used for plotting the
Encircled Energy curve. Defaults to 256.
"""
def __init__(
self,
optic,
fields="all",
wavelength="primary",
num_rays=100_000,
distribution="random",
num_points=256,
):
self.num_points = num_points
if isinstance(wavelength, float | int):
# If a number is passed, wrap it in a list for the base classes.
processed_wavelengths = [wavelength]
elif isinstance(wavelength, str) and wavelength in ["primary", "all"]:
# If 'primary' or 'all' is passed, let the base class handle it.
processed_wavelengths = wavelength
else:
# Catch any other invalid input.
raise TypeError(
f"Unsupported wavelength: {wavelength}. "
"Expected 'primary', 'all', or a number."
)
super().__init__(
optic,
fields=fields,
wavelengths=processed_wavelengths, # Pass the formatted value
num_rings=num_rays,
distribution=distribution,
)
[docs]
def view(
self,
fig_to_plot_on: Figure | None = None,
figsize: tuple[float, float] = (7, 4.5),
*,
show: bool = True,
) -> tuple[Figure, Axes]:
"""Plot the Encircled Energy curve.
Args:
fig_to_plot_on (plt.Figure, optional): The figure to plot on.
If None, a new figure is created. Defaults to None.
figsize (tuple, optional): The size of the figure if a new one is
created. Defaults to (7, 4.5).
show (bool): If True (default), calls plt.show(). Set False for
headless use.
Returns:
tuple: A tuple containing the figure and axes objects.
"""
is_gui_embedding = fig_to_plot_on is not None
if is_gui_embedding:
current_fig = fig_to_plot_on
current_fig.clear()
ax = current_fig.add_subplot(111)
else:
current_fig, ax = plt.subplots(figsize=figsize)
data = self._center_spots(self.data)
geometric_size = self.geometric_spot_radius()
axis_lim = be.max(geometric_size)
for k, field_data in enumerate(data):
self._plot_field(ax, field_data, self.fields[k], axis_lim, self.num_points)
ax.legend(bbox_to_anchor=(1.05, 0.5), loc="center left")
ax.set_xlabel("Radius (mm)")
ax.set_ylabel("Encircled Energy (-)")
ax.set_title(f"Wavelength: {self.wavelengths[0].value:.4f} µm")
ax.set_xlim((0, None))
ax.set_ylim((0, None))
ax.grid(True)
current_fig.tight_layout()
if is_gui_embedding and hasattr(current_fig, "canvas"):
current_fig.canvas.draw_idle()
if show and not is_gui_embedding:
plt.show()
return current_fig, ax
[docs]
def centroid(self):
"""Calculate the centroid of the Encircled Energy.
Returns:
list: A list of tuples representing the centroid coordinates for
each field.
"""
centroid = []
for field_data in self.data:
spot_data_item = field_data[0]
centroid_x = be.mean(spot_data_item.x)
centroid_y = be.mean(spot_data_item.y)
centroid.append((centroid_x, centroid_y))
return centroid
def _plot_field(self, ax, field_data, field, axis_lim, num_points, buffer=1.2):
"""Plot the Encircled Energy curve for a specific field.
Args:
ax (matplotlib.axes.Axes): The axes on which to plot the curve.
field_data (list): List of field data.
field (tuple): Tuple representing the normalized field coordinates.
axis_lim (float): Maximum axis limit.
num_points (int): Number of points for plotting the curve.
buffer (float, optional): Buffer factor for the axis limit.
Defaults to 1.2.
"""
r_max = axis_lim * buffer
r_step = be.linspace(0, r_max, num_points)
for points in field_data:
x = points.x
y = points.y
# energy and intensity are used interchangeably here
energy = points.intensity
radii = be.sqrt(x**2 + y**2)
def vectorized_ee(r):
return be.nansum(energy[radii <= r]) # noqa: B023
# element‑wise encircled energy (Tensor)
ee = be.vectorize(vectorized_ee)(r_step)
# convert both to plain numpy for plotting
r_np = be.to_numpy(r_step)
ee_np = be.to_numpy(ee)
Hx, Hy = field.coord
ax.plot(r_np, ee_np, label=f"Hx: {Hx:.3f}, Hy: {Hy:.3f}")
def _generate_field_data(
self,
field,
wavelength,
num_rays=100,
distribution="hexapolar",
coordinates="local",
):
"""Generate the field data for a specific field and wavelength.
Args:
field (tuple): Tuple representing the field coordinates.
wavelength (float): The wavelength.
num_rays (int, optional): The number of rays. Defaults to 100.
distribution (str, optional): The distribution of rays.
Defaults to 'hexapolar'.
coordinates (str): Coordinate system choice (ignored).
Returns:
SpotData: SpotData object containing x, y, and intensity arrays.
"""
self.optic.trace(*field, wavelength, num_rays, distribution)
x = self.optic.surfaces.x[-1, :]
y = self.optic.surfaces.y[-1, :]
intensity = self.optic.surfaces.intensity[-1, :]
return SpotData(x=x, y=y, intensity=intensity)