Source code for optimization.optimizer.scipy.base

"""Optiland Scipy Optimization Module

This module contains classes for various optimization algorithms that can be
used to solve optimization problems defined in the OptimizationProblem class.
This module provides a generic optimizer class and several specific optimizers
that utilize different algorithms from the SciPy library.

Kramer Harrison, 2024
"""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any

import optiland.backend as be
from scipy import optimize

from ..base import BaseOptimizer
from ..live_plotter import LiveOptimizationPlotter

if TYPE_CHECKING:
    from collections.abc import Callable

    from ...problem import OptimizationProblem


[docs] class OptimizerGeneric(BaseOptimizer): """Generic optimizer class for solving optimization problems. Args: problem (OptimizationProblem): The optimization problem to be solved. Attributes: problem (OptimizationProblem): The optimization problem to be solved. _x (list): List to store the values of the variables during optimization. Methods: optimize(maxiter=1000, disp=True, tol=1e-3): Optimize the problem using the specified parameters. undo(): Undo the last optimization step. _fun(x): Internal function to evaluate the objective function. """ def __init__(self, problem: OptimizationProblem): super().__init__(problem) self._x = [] assert not any(isinstance(var.value, str) for var in self.problem.variables), ( "Glass material(s) have been declared as variable(s). " "Please use GlassExpert or remove them." ) if self.problem.initial_value == 0.0: self.problem.initial_value = self.problem.sum_squared()
[docs] def optimize( self, method: str | None = None, maxiter: int = 1000, disp: bool = True, tol: float = 1e-3, callback: Callable | None = None, plot: bool = False, ) -> optimize.OptimizeResult: """Optimize the problem using the specified parameters. Args: method (str, optional): The optimization method to use. Default is chosen to be one of BFGS, L-BFGS-B, SLSQP, depending on whether contraints or bounds given. Follows scipy.optimize.minimize method. maxiter (int, optional): Maximum number of iterations. Default is 1000. disp (bool, optional): Whether to display optimization information. Default is True. tol (float, optional): Tolerance for convergence. Default is 1e-3. callback (callable): A callable called after each iteration. plot: If True, update live plots during optimization. Returns: result (OptimizeResult): The optimization result. """ if method == "Default": method = None x0 = [var.value for var in self.problem.variables] self._x.append(x0) x0 = be.to_numpy(x0) bounds = tuple([var.bounds for var in self.problem.variables]) options = {"maxiter": maxiter, "disp": disp} live_plotter: LiveOptimizationPlotter | None = None if plot: live_plotter = LiveOptimizationPlotter(self) live_plotter.initialize() def _wrapped_callback(*args: Any, **kwargs: Any) -> None: if callback is not None: callback(*args, **kwargs) if live_plotter is not None: live_plotter.update() with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) result = optimize.minimize( self._fun, x0, method=method, bounds=bounds, options=options, tol=tol, callback=_wrapped_callback, ) # The last function evaluation is not necessarily the lowest. # Update all lens variables to their optimized values for idvar, var in enumerate(self.problem.variables): var.update(result.x[idvar]) self.problem.update_optics() if live_plotter is not None: live_plotter.update() live_plotter.finalize() return result
[docs] def undo(self): """Undo the last optimization step.""" if len(self._x) > 0: x0 = self._x[-1] for idvar, var in enumerate(self.problem.variables): var.update(x0[idvar]) self._x.pop(-1)
[docs] def _fun(self, x) -> float: """Internal function to evaluate the objective function. Args: x (array-like): The values of the variables. Returns: rss (float): The residual sum of squares. """ # Update all variables to their new values for idvar, var in enumerate(self.problem.variables): var.update(be.array(x[idvar])) # Update optics (e.g., pickups and solves) self.problem.update_optics() # Compute merit function value try: rss = self.problem.sum_squared() if be.isnan(rss): return 1e10 # --- Convert result back to float for SciPy --- return be.to_numpy(rss).item() except ValueError: return 1e10