Source code for ml.wrappers

"""Machine Learning Wrappers

This module contains wrappers for integrating Optiland with PyTorch. The
OpticalSystemModule class is a key component that allows users to define and optimize
optical systems within the PyTorch ecosystem. The `forward` method can be customized or
overridden to implement specific optical system behaviors.

Kramer Harrison, 2025
"""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

try:
    import torch
    import torch.nn as nn
except ImportError:
    torch = None
    nn = None

import optiland.backend as be

if TYPE_CHECKING:
    from collections.abc import Callable

    from optiland.optic import Optic
    from optiland.optimization.problem import OptimizationProblem


[docs] class OpticalSystemModule(nn.Module if nn is not None else object): """ A PyTorch nn.Module that wraps an Optiland OptimizationProblem. This class exposes the optical system's variables as trainable nn.Parameters, allowing the entire system to be integrated as a differentiable layer into larger machine learning models. Args: optic (Optic): The optical system definition. problem (OptimizationProblem): The optimization problem defining variables and objectives. objective_fn (Callable[[], torch.Tensor] | None): An optional callable tha takes no arguments and returns a scalar PyTorch tensor representing the loss or metric to be optimized. If None, problem.sum_squared() is used as default. """ def __init__( self, optic: Optic, problem: OptimizationProblem, objective_fn: Callable[[], torch.Tensor] | None = None, ): super().__init__() if torch is None: raise RuntimeError( "OpticalSystemModule requires the 'torch' package. " "Install PyTorch to use this class." ) if be.get_backend() != "torch": raise RuntimeError("OpticalSystemModule requires the 'torch' backend.") # Ensure gradients are enabled for PyTorch operations if not be.grad_mode.requires_grad: warnings.warn("Gradient tracking is enabled for PyTorch.", stacklevel=2) be.grad_mode.enable() self.optic = optic self.problem = problem # Initialize parameters as torch.nn.Parameter objects initial_params = [var.value for var in self.problem.variables] self.params = nn.ParameterList( [torch.nn.Parameter(be.array(p)) for p in initial_params] ) # Store the original variable definitions self._original_variables = self.problem.variables # Store the user-provided objective function or set a default self.objective_fn = ( objective_fn if objective_fn is not None else self._default_loss ) def _default_loss(self) -> torch.Tensor: """ The default loss function, which computes the sum of squared errors from the provided optimization problem. Returns: torch.Tensor: The computed loss value. """ return self.problem.sum_squared() def _sync_params_to_problem(self): """ Pushes the current tensor values from the nn.Parameters into the problem variables. This operation is part of the computation graph. """ for i, param in enumerate(self.params): var = self._original_variables[i] var.update(param)
[docs] def apply_bounds(self): """ Applies the defined bounds to the parameters in-place. This should be called after each optimizer step to enforce constraints. """ with torch.no_grad(): # Operations here shouldn't be part of the gradient graph for i, param in enumerate(self.params): var = self._original_variables[i] min_val, max_val = var.bounds # Inverse scale the parameter data min_val = ( var.variable.inverse_scale(min_val) if min_val is not None else None ) max_val = ( var.variable.inverse_scale(max_val) if max_val is not None else None ) # Clamp the parameter data to the defined bounds if min_val is not None and max_val is not None: param.data.clamp_(min_val, max_val) elif min_val is not None: param.data.clamp_(min=min_val) elif max_val is not None: param.data.clamp_(max=max_val)
[docs] def forward(self) -> torch.Tensor: """ Defines the forward pass of the optical system. This implementation synchronizes the PyTorch parameters with the Optiland problem variables, updates the optics, and then computes the loss using either the user-provided objective function or the default sum of squared errors. The output is a differentiable scalar tensor. Users are encouraged to customize or override this method to suit their specific optimization objectives. """ # 1. Synchronize the nn.Parameter values with the Optiland problem variables. self._sync_params_to_problem() # 2. Update dependent properties within the optical system self.problem.update_optics() # 3. Compute the objective using the stored objective_fn loss = self.objective_fn() return loss