Source code for optiland.analysis.image_simulation.engine

from __future__ import annotations

import numpy as np
from scipy.ndimage import zoom

import optiland.backend as be

from .distortion_warper import DistortionWarper
from .psf_basis_generator import PSFBasisGenerator
from .simulator import SpatiallyVariableSimulator


[docs] class ImageSimulationEngine: """ Master engine for performing full image simulation including spatially variable blur, geometric distortion, and lateral color. Args: optic (Optic): The optical system model. source_image (ArrayLike): The input source image (H, W, 3) or (H, W). Expected to be in RGB format if 3 channels. config (dict): Configuration dictionary. - wavelength (list[float]): List of 3 wavelengths (um) for R, G, B. - psf_grid_shape (tuple): (ny, nx) for PSF basis generation. - psf_size (int): Pixel size for PSFs. - num_rays (int): Number of rays for PSF generation. - n_components (int): Number of EigenPSFs. - oversample (int): Upsampling factor for simulation accuracy. - padding (int): Pixel padding (guard band) to avoid edge artifacts. """ def __init__(self, optic, source_image, config=None): self.optic = optic self.simulated_image = None # Load image if path string if isinstance(source_image, str): import matplotlib.image as mpimg img = mpimg.imread(source_image) # Handle alpha channel if present (remove it for now) if img.ndim == 3 and img.shape[2] == 4: img = img[:, :, :3] else: img = source_image # Ensure source is (C, H, W) or (H, W) backend array img = be.array(img) if img.ndim == 3 and img.shape[2] == 3: # (H, W, 3) -> (3, H, W) img = be.transpose(img, (2, 0, 1)) elif img.ndim == 2: # Monochromatic/Grayscale -> (1, H, W) img = img[None, :, :] self.source_image = img # Default config self.config = { "wavelengths": [0.65, 0.55, 0.45], # R, G, B standard approx "psf_grid_shape": (5, 5), "psf_size": 128, "num_rays": 64, # Optimized for performance (was 128) "n_components": 3, "oversample": 1, "padding": 64, } if config: self.config.update(config)
[docs] def run(self): """ Executes the simulation pipeline. Returns: be.ndarray: The simulated image (H, W, C) or (H, W). Values defined by input dynamic range. """ # 1. Preprocessing # Pad and Upsample processed_input, pad_info = self._preprocess(self.source_image) C, H, W = processed_input.shape final_output = be.zeros_like(processed_input) wavelengths = self.config["wavelengths"] # Handle grayscale input with 3 wavelengths -> treat as RGB result if C == 1 and len(wavelengths) == 3: final_output = be.zeros((3, H, W), dtype=processed_input.dtype) input_channels = [processed_input[0]] * 3 else: # If input is RGB, match wavelengths 1-to-1 input_channels = [ processed_input[c] for c in range(min(C, len(wavelengths))) ] # 2. Simulation Loop per Channel processed_channels = [] for _i, (wave, channel_img) in enumerate( zip(wavelengths, input_channels, strict=False) ): # A. Basis Generation gen = PSFBasisGenerator( self.optic, wavelength=wave, grid_shape=self.config["psf_grid_shape"], num_rays=self.config["num_rays"], psf_grid_size=self.config["psf_size"], ) eigen_psfs, coeffs, mean_psf = gen.generate_basis( n_components=self.config["n_components"] ) # Resize coeffs to image size coeffs_resized = gen.resize_coefficient_map(coeffs, (H, W)) # B. Convolution (Blur) sim = SpatiallyVariableSimulator() blurred = sim.simulate(channel_img, eigen_psfs, coeffs_resized, mean_psf) # C. Distortion (Warp) warper = DistortionWarper(self.optic) # Generate map for current wavelength (handles lateral color) dist_map = warper.generate_distortion_map(wave, (H, W)) distorted = warper.warp_image(blurred, dist_map) processed_channels.append(distorted) final_output = be.stack(processed_channels, axis=0) # 3. Postprocessing # Downsample and Crop result = self._postprocess(final_output, pad_info) # Return (H, W, C) for image format compatibility if result.ndim == 3: result = be.transpose(result, (1, 2, 0)) self.simulated_image = result return result
[docs] def view(self, force_rerun=False): """ Visualizes the original and simulated images side-by-side. Runs the simulation if it hasn't been run yet or if force_rerun is True. """ if self.simulated_image is None or force_rerun: self.run() import matplotlib.pyplot as plt # Prepare source for display (C, H, W) -> (H, W, C) using backend generic src = self.source_image if src.ndim == 3: src = be.transpose(src, (1, 2, 0)) src_np = be.to_numpy(src) sim_np = be.to_numpy(self.simulated_image) # Ensure correct range for display if src_np.max() > 2.0: src_np = src_np / 255.0 if sim_np.max() > 2.0: sim_np = sim_np / 255.0 src_np = np.clip(src_np, 0, 1) sim_np = np.clip(sim_np, 0, 1) fig, ax = plt.subplots(1, 2, figsize=(10, 5)) ax[0].imshow(src_np, cmap="gray" if src_np.ndim == 2 else None) ax[0].set_title("Original Image") ax[0].axis("off") ax[1].imshow(sim_np, cmap="gray" if sim_np.ndim == 2 else None) ax[1].set_title("Simulated Image") ax[1].axis("off") plt.tight_layout() plt.show() return fig, ax
def _preprocess(self, image): # Padding pad = self.config["padding"] # Padding: ((0,0), (pad, pad), (pad, pad)) for (C, H, W) image_np = be.to_numpy(image) padded_np = np.pad(image_np, ((0, 0), (pad, pad), (pad, pad)), mode="reflect") # Upsampling scale = self.config["oversample"] if scale > 1: upsampled_np = zoom(padded_np, (1, scale, scale), order=1) else: upsampled_np = padded_np return be.array(upsampled_np), (pad, scale) def _postprocess(self, image, pad_info): """Downsamples and crops the image.""" pad, scale = pad_info # Downsample if scale > 1: image_np = be.to_numpy(image) downsampled_np = zoom(image_np, (1, 1 / scale, 1 / scale), order=1) image = be.array(downsampled_np) target_h, target_w = self.source_image.shape[-2:] start_y = pad start_x = pad crop = image[:, start_y : start_y + target_h, start_x : start_x + target_w] # Ensure values are within valid range (prevent small negative values) crop = be.maximum(crop, 0.0) return crop