from __future__ import annotations
import numpy as np
from scipy.ndimage import zoom
import optiland.backend as be
from optiland.psf.fft import FFTPSF
[docs]
class PSFBasisGenerator:
"""
Generates a basis of EigenPSFs for an optical system using SVD.
This class computes a grid of PSFs and decomposes them to find the
EigenPSFs that efficiently represent the spatially varying blur.
Args:
optic (Optic): The optical system to analyze.
wavelength (float): The wavelength for PSF calculation.
grid_shape (tuple, optional): (ny, nx) size of sampling grid. Default: (5, 5).
num_rays (int, optional): Number of rays for pupil sampling. Default: 128.
psf_grid_size (int, optional): PSF grid size (e.g., 256 for 256x256).
If None, calculated from num_rays.
"""
def __init__(
self, optic, wavelength, grid_shape=(5, 5), num_rays=128, psf_grid_size=None
):
self.optic = optic
self.wavelength = wavelength
self.grid_shape = grid_shape
self.num_rays = num_rays
self.psf_grid_size = psf_grid_size
[docs]
def generate_basis(self, n_components=3):
"""
Computes the EigenPSFs and their corresponding coefficient maps.
Args:
n_components (int): Number of principal components (EigenPSFs) to keep.
Returns:
Returns:
tuple:
- eigen_psfs (be.ndarray): Basis PSFs, shape (n_components, H, W).
- coefficient_grid (be.ndarray): Coefficient maps on low-res grid,
shape (n_components, grid_ny, grid_nx).
- mean_psf (be.ndarray): Average PSF across field, shape (H, W).
"""
# 1. Generate Grid of PSFs
psf_stack = self._compute_psf_grid()
n_psfs, h, w = psf_stack.shape
# 2. Flatten for PCA
# X shape: (n_samples, n_features) -> (n_psfs, h*w)
X = be.reshape(psf_stack, (n_psfs, -1))
# 3. Center the data
mean_psf_flat = be.mean(X, axis=0)
X_centered = X - mean_psf_flat
# 4. SVD Decomposition
# X_centered = U * S * Vt
try:
U, S, Vt = be.linalg.svd(X_centered, full_matrices=False)
except AttributeError:
# Fallback if be.linalg is not directly exposed as expected
if be.get_backend() == "torch":
import torch
U, S, Vt = torch.linalg.svd(X_centered, full_matrices=False)
else:
import numpy as np
U, S, Vt = np.linalg.svd(X_centered, full_matrices=False)
# 5. Extract top N components
eigen_psfs_flat = Vt[:n_components]
coeffs_flat = (
U[:, :n_components] * S[:n_components]
) # (n_samples, n_components)
# 6. Reshape results
eigen_psfs = be.reshape(eigen_psfs_flat, (n_components, h, w))
mean_psf = be.reshape(mean_psf_flat, (h, w))
# Reshape coeffs back to the spatial grid (n_components, grid_ny, grid_nx)
coeffs_t = be.transpose(coeffs_flat, (1, 0))
coefficient_grid = be.reshape(
coeffs_t, (n_components, self.grid_shape[0], self.grid_shape[1])
)
return eigen_psfs, coefficient_grid, mean_psf
def _compute_psf_grid(self):
"""Generates the stack of PSFs across the field of view."""
psfs = []
ny, nx = self.grid_shape
# Iterate over normalized field coordinates [-1, 1]
ys = np.linspace(-1, 1, ny)
xs = np.linspace(-1, 1, nx)
for y in ys:
for x in xs:
field = (x, y)
# Compute PSF
psf_calc = FFTPSF(
optic=self.optic,
field=field,
wavelength=self.wavelength,
num_rays=self.num_rays,
grid_size=self.psf_grid_size,
)
# Normalize sum to 1 to treat as probability distribution
raw_psf = psf_calc.psf
norm_psf = raw_psf / be.sum(raw_psf)
psfs.append(norm_psf)
return be.stack(psfs)
[docs]
@staticmethod
def resize_coefficient_map(coeff_map, target_shape):
"""
Resizes the coefficient map to the target shape using bicubic interpolation.
Args:
coeff_map (be.ndarray): Input map (H_in, W_in) or (C, H_in, W_in).
target_shape (tuple): (H_out, W_out).
Returns:
be.ndarray: Resized map.
"""
# Check backend
backend = be.get_backend()
if backend == "torch":
import torch.nn.functional as F
is_3d = coeff_map.ndim == 3
if not is_3d:
coeff_map = coeff_map.unsqueeze(0) # (1, H, W)
inp = coeff_map.unsqueeze(0)
out = F.interpolate(
inp, size=target_shape, mode="bicubic", align_corners=False
)
out = out.squeeze(0)
if not is_3d:
out = out.squeeze(0)
return out
else:
input_shape = coeff_map.shape
if len(input_shape) == 3:
# (C, H, W) -> zoom last two dims only
h, w = input_shape[-2:]
zoom_h = target_shape[0] / h
zoom_w = target_shape[1] / w
return be.array(
zoom(be.to_numpy(coeff_map), (1, zoom_h, zoom_w), order=1)
)
else:
h, w = input_shape
zoom_h = target_shape[0] / h
zoom_w = target_shape[1] / w
return be.array(zoom(be.to_numpy(coeff_map), (zoom_h, zoom_w), order=1))