Source code for optiland.rays.polarized_rays

"""Polarized Rays

This module contains the `PolarizedRays` class, which represents a class for
polarized rays in three-dimensional space. The class inherits from the
`RealRays` class.

Kramer Harrison, 2024
"""

from __future__ import annotations

import optiland.backend as be
from optiland.rays.polarization_state import PolarizationState
from optiland.rays.real_rays import RealRays


[docs] class PolarizedRays(RealRays): """Represents a class for polarized rays in three-dimensional space. Inherits from the `RealRays` class. Attributes: x (ndarray): The x-coordinates of the rays. y (ndarray): The y-coordinates of the rays. z (ndarray): The z-coordinates of the rays. L (ndarray): The x-components of the direction vectors of the rays. M (ndarray): The y-components of the direction vectors of the rays. N (ndarray): The z-components of the direction vectors of the rays. i (ndarray): The intensity of the rays. w (ndarray): The wavelength of the rays. opd (ndarray): The optical path length of the rays. p (be.ndarray): Array of polarization matrices of the rays. Methods: get_output_field(E: be.ndarray) -> be.ndarray: Compute the output electric field given the input electric field. update_intensity(state: PolarizationState): Update the ray intensity based on the polarization state. update(jones_matrix: be.ndarray = None): Update the polarization matrices after interaction with a surface. _get_3d_electric_field(state: PolarizationState) -> be.ndarray: Get the 3D electric fields given the polarization state and initial rays. """ def __init__(self, x, y, z, L, M, N, intensity, wavelength): super().__init__(x, y, z, L, M, N, intensity, wavelength) self.p = be.tile(be.eye(3), (be.size(self.x), 1, 1)) self._i0 = be.copy(intensity) self._L0 = be.copy(L) self._M0 = be.copy(M) self._N0 = be.copy(N)
[docs] def get_output_field(self, E: be.ndarray) -> be.ndarray: """Compute the output electric field given the input electric field. Args: E (be.ndarray): The input electric field as a numpy array. Returns: be.ndarray: The computed output electric field as a numpy array. """ return be.mult_p_E(self.p, E)
def _compute_unscaled_exit_fields( self, state: PolarizationState | None ) -> list[be.ndarray]: """Compute the unscaled exit electric field(s) for the rays. Args: state (PolarizationState | None): The polarization state. Returns: list[be.ndarray]: A list of unscaled 3D electric field arrays. """ if state is not None and state.is_polarized: E0 = self._get_3d_electric_field(state) E1 = self.get_output_field(E0) return [E1] else: state_x = PolarizationState( is_polarized=True, Ex=1.0, Ey=0.0, phase_x=0.0, phase_y=0.0, ) E0_x = self._get_3d_electric_field(state_x) E1_x = self.get_output_field(E0_x) state_y = PolarizationState( is_polarized=True, Ex=0.0, Ey=1.0, phase_x=0.0, phase_y=0.0, ) E0_y = self._get_3d_electric_field(state_y) E1_y = self.get_output_field(E0_y) return [E1_x, E1_y]
[docs] def get_exit_fields(self, state: PolarizationState | None) -> list[be.ndarray]: """Compute the exit electric field(s) for the rays. Args: state (PolarizationState | None): The polarization state. Returns: list[be.ndarray]: A list of 3D electric field arrays. For polarized light, the list contains a single array. For unpolarized light, the list contains two orthogonal, incoherently superimposed arrays, each scaled down by 1/sqrt(2). """ fields = self._compute_unscaled_exit_fields(state) scale_factor = be.unsqueeze_last(be.sqrt(self._i0 / len(fields))) return [E1 * scale_factor for E1 in fields]
[docs] def update_intensity(self, state: PolarizationState): """Update ray intensity based on polarization state. Args: state (PolarizationState): The polarization state of the ray. """ fields = self._compute_unscaled_exit_fields(state) intensity = be.zeros_like(self.i) for E1 in fields: intensity = intensity + be.sum(be.abs(E1) ** 2, axis=1) self.i = intensity * self._i0 / len(fields)
[docs] @staticmethod def get_local_basis( k0: be.ndarray, k1: be.ndarray ) -> tuple[be.ndarray, be.ndarray, be.ndarray, be.ndarray]: """Get the local s, p0, p1 vectors and transforming matrices. Args: k0: (N, 3) array of pre-interaction ray directions. k1: (N, 3) array of post-interaction ray directions. Returns: tuple: (s, p0, p1, o_in, o_out) where s, p0, p1 are (N, 3) vectors and o_in, o_out are the projection matrices. """ # find s-component s = be.cross(k0, k1) mag = be.linalg.norm(s, axis=1) # handle case when mag = 0 (i.e., k0 parallel to k1) mask = mag == 0 if be.any(mask): x = be.broadcast_to(be.array([1.0, 0.0, 0.0]), k0[mask].shape) p_fallback = be.cross(k0[mask], x) p_norms = be.linalg.norm(p_fallback, axis=1) y = be.broadcast_to(be.array([0.0, 1.0, 0.0]), k0[mask].shape) p_fallback = be.where( be.unsqueeze_last(p_norms == 0), be.cross(k0[mask], y), p_fallback ) s[mask] = be.cross(p_fallback, k0[mask]) mag = be.linalg.norm(s, axis=1) s = s / be.unsqueeze_last(mag) # find p-component pre and post surface p0 = be.cross(k0, s) p1 = be.cross(k1, s) # othogonal transformation matrices o_in = be.stack((s, p0, k0), axis=1) o_out = be.stack((s, p1, k1), axis=2) return s, p0, p1, o_in, o_out
[docs] def update(self, jones_matrix: be.ndarray = None): """Update polarization matrices after interaction with surface. Args: jones_matrix (be.ndarray, optional): Jones matrix representing the interaction with the surface. If not provided, the polarization matrix is computed assuming an identity matrix. """ # merge k-vector components into matrix for speed k0 = be.stack([self.L0, self.M0, self.N0]).T k1 = be.stack([self.L, self.M, self.N]).T s, p0, p1, o_in, o_out = self.get_local_basis(k0, k1) # compute polarization matrix for surface if jones_matrix is None: p = be.matmul(o_out, o_in) else: p = be.batched_chain_matmul3(o_out, jones_matrix, o_in) # update polarization matrices of rays self.p = be.matmul(p, self.p)
def _get_3d_electric_field(self, state: PolarizationState) -> be.ndarray: """Get 3D electric fields given polarization state and initial rays. Args: state (PolarizationState): The polarization state of the rays. Returns: be.ndarray: The 3D electric fields. """ k = be.stack([self._L0, self._M0, self._N0]).T # TODO - efficiently handle case when k parallel to x-axis x = be.broadcast_to(be.array([1.0, 0.0, 0.0]), k.shape) p = be.cross(k, x) norms = be.linalg.norm(p, axis=1) if be.any(norms == 0): raise ValueError("k-vector parallel to x-axis is not currently supported.") p = p / be.unsqueeze_last(norms) s = be.cross(p, k) E = ( state.Ex * be.exp(1j * state.phase_x) * s + state.Ey * be.exp(1j * state.phase_y) * p ) return E