from __future__ import annotations
import numpy as np
from scipy.ndimage import zoom
import optiland.backend as be
from .distortion_warper import DistortionWarper
from .psf_basis_generator import PSFBasisGenerator
from .simulator import SpatiallyVariableSimulator
[docs]
class ImageSimulationEngine:
"""
Master engine for performing full image simulation including spatially
variable blur, geometric distortion, and lateral color.
Args:
optic (Optic): The optical system model.
config (dict): Configuration dictionary.
- wavelength (list[float]): List of 3 wavelengths (um) for R, G, B.
- psf_grid_shape (tuple): (ny, nx) for PSF basis generation.
- psf_size (int): Pixel size for PSFs.
- num_rays (int): Number of rays for PSF generation.
- n_components (int): Number of EigenPSFs.
- oversample (int): Upsampling factor for simulation accuracy.
- padding (int): Pixel padding (guard band) to avoid edge artifacts.
"""
def __init__(self, optic, config=None):
self.optic = optic
self.source_image = None
self.simulated_image = None
self.config = {
"wavelengths": [0.65, 0.55, 0.45],
"psf_grid_shape": (5, 5),
"psf_size": 128,
"num_rays": 64,
"n_components": 3,
"oversample": 1,
"padding": 64,
}
if config:
self.config.update(config)
def _prepare_source_image(self, source_image):
if isinstance(source_image, str):
import matplotlib.image as mpimg
image = mpimg.imread(source_image)
if image.ndim == 3 and image.shape[-1] == 4:
image = image[:, :, :3]
else:
image = source_image
image = be.array(image)
if image.ndim == 2:
return image[None, None, :, :]
if image.ndim == 3:
if image.shape[-1] != 3:
raise ValueError("3D source_image must have shape (H, W, 3).")
return be.transpose(image, (2, 0, 1))[None, :, :, :]
if image.ndim == 4:
if image.shape[1] not in (1, 3):
raise ValueError("4D source_image must have shape (B, C, H, W).")
return image
raise ValueError(
"source_image must have shape (H, W), (H, W, 3), or (B, C, H, W)."
)
[docs]
def run(self, source_image):
"""
Executes the simulation pipeline.
Args:
source_image (ArrayLike): The input source image with shape (H, W),
(H, W, 3), or (B, C, H, W).
Returns:
be.ndarray: The simulated image batch with shape (B, C, H, W).
Values defined by input dynamic range.
"""
self.source_image = self._prepare_source_image(source_image)
# 1. Preprocessing
# Pad and Upsample
processed_input, pad_info = self._preprocess(self.source_image)
_, C, H, W = processed_input.shape
wavelengths = self.config["wavelengths"]
# Handle grayscale input with 3 wavelengths -> treat as RGB result
if C == 1 and len(wavelengths) == 3:
input_channels = [processed_input[:, 0, :, :]] * 3
else:
# If input is RGB, match wavelengths 1-to-1
input_channels = [
processed_input[:, c, :, :] for c in range(min(C, len(wavelengths)))
]
# 2. Simulation Loop per Channel
processed_channels = []
sim = SpatiallyVariableSimulator()
warper = DistortionWarper(self.optic)
for _i, (wave, channel_img) in enumerate(
zip(wavelengths, input_channels, strict=False)
):
# A. Basis Generation
gen = PSFBasisGenerator(
self.optic,
wavelength=wave,
grid_shape=self.config["psf_grid_shape"],
num_rays=self.config["num_rays"],
psf_grid_size=self.config["psf_size"],
)
eigen_psfs, coeffs, mean_psf = gen.generate_basis(
n_components=self.config["n_components"]
)
# Resize coeffs to image size
coeffs_resized = gen.resize_coefficient_map(coeffs, (H, W))
# B. Convolution (Blur)
blurred = sim.simulate(channel_img, eigen_psfs, coeffs_resized, mean_psf)
# C. Distortion (Warp)
# Generate map for current wavelength (handles lateral color)
dist_map = warper.generate_distortion_map(wave, (H, W))
distorted = warper.warp_image(blurred, dist_map)
processed_channels.append(distorted)
final_output = be.stack(processed_channels, axis=1)
# 3. Postprocessing
# Downsample and Crop
result = self._postprocess(final_output, pad_info)
self.simulated_image = result
return result
[docs]
def view(self, index: int = 0, *, show: bool = True):
"""
Visualizes one original and simulated image side-by-side from the batch.
Args:
index (int): Batch index to visualize.
show (bool): If True (default), calls plt.show(). Set False for
headless use.
"""
if self.source_image is None or self.simulated_image is None:
raise RuntimeError("Call run(source_image) before view().")
batch_size = self.source_image.shape[0]
if index < 0 or index >= batch_size:
raise IndexError(f"index must be between 0 and {batch_size - 1}.")
import matplotlib.pyplot as plt
# Prepare selected image in batch for display (C, H, W) -> (H, W, C)
src = self.source_image[index]
src = be.transpose(src, (1, 2, 0))
sim = self.simulated_image[index]
sim = be.transpose(sim, (1, 2, 0))
src_np = be.to_numpy(src)
sim_np = be.to_numpy(sim)
# Ensure correct range for display
if src_np.max() > 2.0:
src_np = src_np / 255.0
if sim_np.max() > 2.0:
sim_np = sim_np / 255.0
src_np = np.clip(src_np, 0, 1)
sim_np = np.clip(sim_np, 0, 1)
if src_np.shape[-1] == 1:
src_np = src_np[:, :, 0]
if sim_np.shape[-1] == 1:
sim_np = sim_np[:, :, 0]
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(src_np, cmap="gray" if src_np.ndim == 2 else None)
ax[0].set_title(f"Original Image [{index}]")
ax[0].axis("off")
ax[1].imshow(sim_np, cmap="gray" if sim_np.ndim == 2 else None)
ax[1].set_title(f"Simulated Image [{index}]")
ax[1].axis("off")
plt.tight_layout()
if show:
plt.show()
return fig, ax
def _preprocess(self, image):
# Padding
pad = self.config["padding"]
# Padding: ((0,0), (0,0), (pad, pad), (pad, pad)) for (B, C, H, W)
image = be.pad(
image,
((0, 0), (0, 0), (pad, pad), (pad, pad)),
mode="reflect",
)
# Upsampling
scale = self.config["oversample"]
if scale > 1:
image_np = be.to_numpy(image)
upsampled_np = zoom(image_np, (1, 1, scale, scale), order=1)
image = be.array(upsampled_np)
return image, (pad, scale)
def _postprocess(self, image, pad_info):
"""Downsamples and crops the image."""
pad, scale = pad_info
# Downsample
if scale > 1:
if be.get_backend().__class__.__name__ == "TorchBackend":
import torch.nn.functional as F
image = F.interpolate(
image,
scale_factor=1 / scale,
mode="bilinear",
align_corners=False,
)
else:
image_np = be.to_numpy(image)
downsampled_np = zoom(image_np, (1, 1, 1 / scale, 1 / scale), order=1)
image = be.array(downsampled_np)
target_h, target_w = self.source_image.shape[-2:]
start_y = pad
start_x = pad
crop = image[
:,
:,
start_y : start_y + target_h,
start_x : start_x + target_w,
]
# Ensure values are within valid range (prevent small negative values)
crop = be.maximum(crop, 0.0)
return crop