Source code for backend.utils

"""
Utility functions for working with different backends.

To add support for a new backend, add a conversion function to the CONVERTERS
list.

Kramer Harrison, 2024
"""

from __future__ import annotations

import importlib
from typing import TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:
    from numpy.typing import NDArray
    from torch import Tensor

    from optiland._types import ScalarOrArrayT


# Conversion functions for backends
[docs] def torch_to_numpy(obj: Tensor) -> NDArray: if importlib.util.find_spec("torch"): import torch if isinstance(obj, torch.Tensor): return obj.detach().cpu().numpy() raise TypeError
CONVERTERS = [torch_to_numpy]
[docs] def to_numpy(obj: ScalarOrArrayT) -> NDArray: """Converts input scalar or array to NumPy array, regardless of backend.""" if isinstance(obj, np.ndarray): return obj elif isinstance(obj, int | float | np.number): return np.array(obj) # Handle lists: Iterate and convert elements individually elif isinstance(obj, list | tuple): # Recursively call to_numpy on each element to handle tensors correctly # This will use the CONVERTERS loop for tensor elements within the list # Then, construct a 1D numpy array from the processed scalar elements. processed_elements = [] for item in obj: converted = to_numpy( item ) # Handles tensor detach, returns ndarray or scalar # Extract scalar value if it's a 0-dim or 1-element array if isinstance(converted, np.ndarray) and converted.size == 1: processed_elements.append(converted.item()) # Handle if it was already converted to a Python/Numpy scalar elif isinstance(converted, int | float | np.number): processed_elements.append(converted) else: raise TypeError( f"List element conversion resulted in non-scalar " f"type: {type(converted)}" ) return np.array(processed_elements, dtype=float) # Ensure 1D float array for converter in CONVERTERS: try: return converter(obj) except TypeError: continue raise TypeError(f"Unsupported object type: {type(obj)}")
[docs] def is_torch_tensor(obj) -> bool: """Checks if an object is a PyTorch tensor. Args: obj: The object to check. Returns: bool: True if the object is a PyTorch tensor, False otherwise. """ if importlib.util.find_spec("torch"): import torch return isinstance(obj, torch.Tensor) return False