Source code for optimization.optimizer.torch.base

"""Torch Base Optimizer

This module contains a base class for all PyTorch-based optimizers.

Kramer Harrison, 2025
"""

from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from types import SimpleNamespace
from typing import TYPE_CHECKING

try:
    import torch
    import torch.optim as optim
    from torch.optim.lr_scheduler import ExponentialLR, LRScheduler
except (ImportError, ModuleNotFoundError):
    torch = None
    optim = None
    ExponentialLR = None
    LRScheduler = None

import optiland.backend as be

from ..base import BaseOptimizer
from ..live_plotter import LiveOptimizationPlotter

if TYPE_CHECKING:
    from collections.abc import Callable

    from ...problem import OptimizationProblem


[docs] class TorchBaseOptimizer(BaseOptimizer, ABC): """ A base class for all PyTorch-based optimizers. This class handles the common setup and optimization loop logic for any optimizer using the PyTorch backend. """ def __init__(self, problem: OptimizationProblem): super().__init__(problem) if be.get_backend() != "torch": raise RuntimeError( f"{self.__class__.__name__} requires the 'torch' backend." ) # Ensure gradients are enabled for PyTorch operations if not be.grad_mode.requires_grad: warnings.warn("Gradient tracking is enabled for PyTorch.", stacklevel=2) be.grad_mode.enable() # Initialize parameters as torch.nn.Parameter objects # Use var.value (scaled) to match the scaled bounds from var.bounds initial_params = [var.value for var in self.problem.variables] self.params = [torch.nn.Parameter(be.array(p)) for p in initial_params] @abstractmethod def _create_optimizer_and_scheduler( self, lr: float, gamma: float ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: """ Creates and returns the specific PyTorch optimizer and learning rate scheduler. Subclasses must implement this method. Args: lr (float): The learning rate. gamma (float): The decay factor for the learning rate. Returns: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: The optimizer and learning rate scheduler. """ raise NotImplementedError def _apply_bounds(self): """ Applies the defined bounds to the parameters in-place. This is called after each optimizer step to enforce constraints. """ with torch.no_grad(): for i, param in enumerate(self.params): var = self.problem.variables[i] min_val, max_val = var.bounds # Clamp the parameter data to the defined bounds if min_val is not None and max_val is not None: param.data.clamp_(min_val, max_val) elif min_val is not None: param.data.clamp_(min=min_val) elif max_val is not None: param.data.clamp_(max=max_val)
[docs] def optimize( self, n_steps: int = 100, lr: float = 1e-2, gamma: float = 0.99, disp: bool = True, plot: bool = False, callback: Callable[[int, float], None] | None = None, ): """ Runs the optimization loop. Args: n_steps (int): The number of optimization steps. lr (float): The learning rate. gamma (float): The decay factor for the learning rate. disp (bool): Whether to display progress. plot: If True, update live plots during optimization. callback (Callable[[int, float], None] | None): A callback function to be called after each step with the current step and loss value. """ optimizer, scheduler = self._create_optimizer_and_scheduler(lr, gamma) live_plotter: LiveOptimizationPlotter | None = None if plot: live_plotter = LiveOptimizationPlotter(self) live_plotter.initialize() with be.grad_mode.temporary_enable(): for i in range(n_steps): optimizer.zero_grad() # 1. Update the model state from the current nn.Parameter values. # Use var.update() which inverse-scales from scaled space for k, param in enumerate(self.params): self.problem.variables[k].update(param) # 2. Update any dependent properties. self.problem.update_optics() # 3. Compute loss from the updated model. loss = self.problem.sum_squared() # 4. Backpropagate and step. loss.backward() optimizer.step() # 5. Enforce constraints on the scaled parameters. self._apply_bounds() # 6. Step the learning rate scheduler. scheduler.step() # 7. Call the user-provided callback if callback: callback(i, loss.item()) if live_plotter is not None: live_plotter.update() # 8. Print loss if display is enabled. if disp and (i % 10 == 0 or i == n_steps - 1): print(f" Step {i + 1:04d}/{n_steps}, Loss: {loss.item():.6f}") # Final update to ensure the model reflects the last optimized state for k, param in enumerate(self.params): self.problem.variables[k].update(param) self.problem.update_optics() if live_plotter is not None: live_plotter.update() live_plotter.finalize() final_loss = self.problem.sum_squared().item() return SimpleNamespace(fun=final_loss, x=[p.item() for p in self.params])