Source code for backend.base

"""
AbstractBackend ABC, @passthrough decorator, and BackendCapabilityError.

Kramer Harrison, 2025
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Literal

if TYPE_CHECKING:
    from collections.abc import Callable, Sequence
    from types import ModuleType


[docs] class BackendCapabilityError(Exception): """Raised when an operation is not supported by the current backend. Example: >>> be.grad_mode.enable() # on numpy backend BackendCapabilityError: grad_mode requires a backend that supports gradients. Current backend: 'numpy'. Try: be.set_backend('torch') """
[docs] def passthrough(*func_names: str): """Inject concrete passthrough methods into the decorated class. For each name in func_names, adds a method that calls ``self._lib.<name>(*args, **kwargs)``. Only injected if the class does not already define the method — explicit overrides always take priority. Args: *func_names: Names of functions to inject from the backend library. Returns: A class decorator that injects the passthrough methods. """ def decorator(cls: type) -> type: for name in func_names: if not hasattr(cls, name): def _make(n: str) -> Callable[..., Any]: def method(self: AbstractBackend, *args: Any, **kwargs: Any) -> Any: return getattr(self._lib, n)(*args, **kwargs) method.__name__ = n method.__qualname__ = f"{cls.__name__}.{n}" return method setattr(cls, name, _make(name)) return cls return decorator
[docs] @passthrough( # Trigonometric "sin", "cos", "tan", "arcsin", "arccos", "arctan", "arctan2", "sinh", "cosh", "tanh", # Math "exp", "log", "log2", "log10", "sqrt", "abs", "sign", "floor", "ceil", "hypot", # Angle conversion (np.deg2rad / torch.deg2rad share same name) "deg2rad", "rad2deg", # Checks "isnan", "isinf", "isfinite", # Logic "logical_and", "logical_or", "logical_not", # Complex "conj", "real", # Linear algebra (identical API in np and torch) "outer", "einsum", "dot", # NaN-safe reductions (np.nansum/torch.nansum, np.nanmean/torch.nanmean) "nansum", "nanmean", # Stack variants with matching names "vstack", "column_stack", # Sorting / searching "searchsorted", "round", # Array info — torch backend overrides these with its own implementations "shape", "size", "copy", "isscalar", "load", # dtype / machine-epsilon info "finfo", # Closeness / comparison (np.allclose / torch.allclose) "allclose", # Complex helpers "imag", # Sign copying "copysign", ) class AbstractBackend(ABC): """Abstract base class that defines the full backend contract. All backends must subclass this class and implement every abstract method. Concrete passthrough methods (injected by @passthrough) delegate to ``self._lib.<name>(...)``; subclasses may override them. Attributes: _lib: The underlying library module (``numpy`` or ``torch``). """ _lib: ModuleType # Each subclass sets: _lib = np or _lib = torch # ------------------------------------------------------------------ # Identity # ------------------------------------------------------------------ @property @abstractmethod def name(self) -> str: """Return the backend name (e.g. 'numpy' or 'torch').""" # ------------------------------------------------------------------ # Submodule proxies # ------------------------------------------------------------------ @property def linalg(self) -> Any: """Expose the linear-algebra submodule of the underlying library.""" return self._lib.linalg @property def fft(self) -> Any: """Expose the FFT submodule of the underlying library.""" return self._lib.fft @property def random(self) -> Any: """Expose the random submodule of the underlying library.""" return self._lib.random # ------------------------------------------------------------------ # Capability flags # ------------------------------------------------------------------ @property def supports_gradients(self) -> bool: """Return True if this backend supports automatic differentiation.""" return False @property def supports_gpu(self) -> bool: """Return True if this backend can use GPU acceleration.""" return False # ------------------------------------------------------------------ # Precision # ------------------------------------------------------------------
[docs] @abstractmethod def set_precision(self, precision: Literal["float32", "float64"]) -> None: """Set the floating-point precision used by this backend. Args: precision: Either ``'float32'`` or ``'float64'``. """
[docs] @abstractmethod def get_precision(self) -> int: """Return the current precision as an integer (32 or 64)."""
# ------------------------------------------------------------------ # Capability-gated torch-only features # Default implementations raise BackendCapabilityError. # TorchBackend overrides these. # ------------------------------------------------------------------ @property def grad_mode(self) -> Any: """Control object for gradient computation (torch only).""" raise BackendCapabilityError( f"grad_mode requires a backend that supports gradients. " f"Current backend: '{self.name}'. Try: be.set_backend('torch')" ) @property def autograd(self) -> Any: """The autograd submodule (torch only).""" raise BackendCapabilityError( f"autograd is not supported by backend '{self.name}'." )
[docs] def set_device(self, device: str) -> None: """Set the compute device (torch only). Args: device: Device string (e.g. ``'cpu'`` or ``'cuda'``). Raises: BackendCapabilityError: Always, on non-torch backends. """ raise BackendCapabilityError( f"set_device is not supported by backend '{self.name}'." )
[docs] def get_device(self) -> str: """Return the current compute device (torch only). Raises: BackendCapabilityError: Always, on non-torch backends. """ raise BackendCapabilityError( f"get_device is not supported by backend '{self.name}'." )
[docs] def get_complex_precision(self) -> Any: """Return the complex dtype matching the current precision (torch only). Raises: BackendCapabilityError: Always, on non-torch backends. """ raise BackendCapabilityError( f"get_complex_precision is not supported by backend '{self.name}'." )
[docs] def to_tensor(self, data: Any, device: Any = None) -> Any: """Convert data to a backend tensor with current precision (torch only). Raises: BackendCapabilityError: Always, on non-torch backends. """ raise BackendCapabilityError( f"to_tensor is not supported by backend '{self.name}'." )
# ------------------------------------------------------------------ # Array creation # ------------------------------------------------------------------
[docs] @abstractmethod def array(self, x: Any) -> Any: """Create a backend array/tensor from x."""
[docs] @abstractmethod def zeros(self, shape: Sequence[int], dtype: Any = None) -> Any: """Return a new array of the given shape filled with zeros."""
[docs] @abstractmethod def ones(self, shape: Sequence[int], dtype: Any = None) -> Any: """Return a new array of the given shape filled with ones."""
[docs] @abstractmethod def full(self, shape: Sequence[int], fill_value: Any, dtype: Any = None) -> Any: """Return a new array of the given shape filled with fill_value."""
[docs] @abstractmethod def linspace(self, start: float, stop: float, num: int = 50) -> Any: """Return evenly spaced numbers over the specified interval."""
[docs] @abstractmethod def arange(self, *args: Any, **kwargs: Any) -> Any: """Return evenly spaced values within a given interval."""
[docs] @abstractmethod def zeros_like(self, x: Any) -> Any: """Return an array of zeros with the same shape and type as x."""
[docs] @abstractmethod def ones_like(self, x: Any) -> Any: """Return an array of ones with the same shape and type as x."""
[docs] @abstractmethod def full_like(self, x: Any, fill_value: Any) -> Any: """Return a full array with the same shape as x."""
[docs] @abstractmethod def empty(self, shape: Sequence[int]) -> Any: """Return a new uninitialized array of the given shape."""
[docs] @abstractmethod def empty_like(self, x: Any) -> Any: """Return an uninitialized array with the same shape as x."""
[docs] @abstractmethod def eye(self, n: int) -> Any: """Return a 2D identity matrix of size n."""
[docs] @abstractmethod def asarray(self, x: Any, **kwargs: Any) -> Any: """Convert x to a backend array without copying if possible."""
# ------------------------------------------------------------------ # Array utilities # ------------------------------------------------------------------
[docs] @abstractmethod def cast(self, x: Any) -> Any: """Cast x to the current floating-point precision."""
[docs] @abstractmethod def is_array_like(self, x: Any) -> bool: """Return True if x is a list, tuple, or backend array."""
[docs] @abstractmethod def arange_indices(self, start: Any, stop: Any = None, step: int = 1) -> Any: """Return an integer array of indices."""
[docs] @abstractmethod def ravel(self, x: Any) -> Any: """Return a 1D float array of x."""
# ------------------------------------------------------------------ # Shape and indexing # ------------------------------------------------------------------
[docs] @abstractmethod def transpose(self, x: Any, axes: Sequence[int] | None = None) -> Any: """Permute the dimensions of x."""
[docs] @abstractmethod def reshape(self, x: Any, shape: Sequence[int]) -> Any: """Return x reshaped to the given shape."""
[docs] @abstractmethod def atleast_1d(self, x: Any) -> Any: """Return x as an array with at least one dimension."""
[docs] @abstractmethod def atleast_2d(self, x: Any) -> Any: """Return x as an array with at least two dimensions."""
[docs] @abstractmethod def as_array_1d(self, data: Any) -> Any: """Force conversion to a 1D array."""
[docs] @abstractmethod def stack(self, xs: Sequence[Any], axis: int = 0) -> Any: """Join a sequence of arrays along a new axis."""
[docs] @abstractmethod def concatenate(self, arrays: Sequence[Any], axis: int = 0) -> Any: """Join arrays along an existing axis."""
[docs] @abstractmethod def flip(self, x: Any) -> Any: """Reverse the order of elements in x along axis 0."""
[docs] @abstractmethod def roll(self, x: Any, shift: Any, axis: Any = ()) -> Any: """Roll x elements along the given axis."""
[docs] @abstractmethod def repeat(self, x: Any, repeats: int) -> Any: """Repeat elements of x."""
[docs] @abstractmethod def broadcast_to(self, x: Any, shape: Sequence[int]) -> Any: """Broadcast x to the given shape."""
[docs] @abstractmethod def tile(self, x: Any, dims: Any) -> Any: """Construct an array by tiling x."""
[docs] @abstractmethod def expand_dims(self, x: Any, axis: int) -> Any: """Expand the shape of x by inserting a new axis."""
[docs] @abstractmethod def meshgrid(self, *arrays: Any) -> tuple[Any, ...]: """Return coordinate matrices from coordinate vectors."""
[docs] @abstractmethod def unsqueeze_last(self, x: Any) -> Any: """Add a trailing dimension to x."""
# ------------------------------------------------------------------ # Reductions and math with semantic mismatches # ------------------------------------------------------------------
[docs] @abstractmethod def sum(self, x: Any, axis: int | None = None) -> Any: """Sum array elements over a given axis."""
[docs] @abstractmethod def mean(self, x: Any, axis: int | None = None, keepdims: bool = False) -> Any: """Compute the arithmetic mean, ignoring NaNs."""
[docs] @abstractmethod def std(self, x: Any, axis: int | None = None) -> Any: """Compute the standard deviation along the given axis."""
[docs] @abstractmethod def max(self, x: Any) -> Any: """Return the maximum value of x."""
[docs] @abstractmethod def min(self, x: Any) -> Any: """Return the minimum value of x."""
[docs] @abstractmethod def argmin(self, x: Any, axis: int | None = None) -> Any: """Return indices of the minimum values along an axis."""
[docs] @abstractmethod def argwhere(self, x: Any) -> Any: """Return indices of non-zero elements."""
[docs] @abstractmethod def clip(self, x: Any, a_min: Any, a_max: Any) -> Any: """Clip the values in x to [a_min, a_max]."""
[docs] @abstractmethod def where(self, condition: Any, x: Any, y: Any) -> Any: """Return elements chosen from x or y depending on condition."""
[docs] @abstractmethod def maximum(self, a: Any, b: Any) -> Any: """Element-wise maximum of a and b."""
[docs] @abstractmethod def minimum(self, a: Any, b: Any) -> Any: """Element-wise minimum of a and b."""
[docs] @abstractmethod def fmax(self, a: Any, b: Any) -> Any: """Element-wise maximum, ignoring NaNs."""
[docs] @abstractmethod def power(self, x: Any, y: Any) -> Any: """Return x raised to the power y."""
[docs] @abstractmethod def diff(self, x: Any, n: int = 1, axis: int = -1, **kwargs: Any) -> Any: """Calculate the n-th discrete difference along the given axis."""
[docs] @abstractmethod def all(self, x: Any) -> bool: """Return True if all elements of x are True."""
[docs] @abstractmethod def any(self, x: Any) -> bool: """Return True if any element of x is True."""
[docs] @abstractmethod def nanmax(self, x: Any, axis: int | None = None, keepdim: bool = False) -> Any: """Return the maximum, ignoring NaNs."""
[docs] @abstractmethod def sort(self, x: Any, axis: int = -1) -> Any: """Return a sorted copy of x."""
[docs] @abstractmethod def isclose(self, a: Any, b: Any, rtol: float = 1e-5, atol: float = 1e-8) -> Any: """Return a boolean array where elements are close."""
[docs] @abstractmethod def radians(self, x: Any) -> Any: """Convert angles from degrees to radians."""
[docs] @abstractmethod def degrees(self, x: Any) -> Any: """Convert angles from radians to degrees."""
# ------------------------------------------------------------------ # Linear algebra # ------------------------------------------------------------------
[docs] @abstractmethod def matmul(self, a: Any, b: Any) -> Any: """Matrix product of two arrays."""
[docs] @abstractmethod def cross( self, a: Any, b: Any, axisa: int = -1, axisb: int = -1, axisc: int = -1, axis: int | None = None, ) -> Any: """Return the cross product of two vectors."""
[docs] @abstractmethod def batched_chain_matmul3(self, a: Any, b: Any, c: Any) -> Any: """Compute a @ b @ c with promoted dtype."""
[docs] @abstractmethod def matrix_vector_multiply_and_squeeze(self, p: Any, E: Any) -> Any: """Multiply p @ E[..., newaxis] and squeeze the trailing dimension."""
[docs] @abstractmethod def mult_p_E(self, p: Any, E: Any) -> Any: """Complex matrix-vector multiply used for polarized fields."""
[docs] @abstractmethod def lstsq(self, a: Any, b: Any) -> Any: """Return the least-squares solution to a @ x = b."""
[docs] @abstractmethod def to_complex(self, x: Any) -> Any: """Cast x to complex128."""
# ------------------------------------------------------------------ # Interpolation # ------------------------------------------------------------------
[docs] @abstractmethod def nearest_nd_interpolator(self, points: Any, values: Any, x: Any, y: Any) -> Any: """Nearest-neighbour interpolation on an N-D dataset."""
[docs] @abstractmethod def interp(self, x: Any, xp: Any, fp: Any) -> Any: """1-D linear interpolation."""
[docs] @abstractmethod def grid_sample( self, input: Any, grid: Any, mode: str = "bilinear", padding_mode: str = "zeros", align_corners: bool = False, ) -> Any: """Sample input using bilinear/nearest interpolation on a grid."""
# ------------------------------------------------------------------ # Polynomial # ------------------------------------------------------------------
[docs] @abstractmethod def polyfit(self, x: Any, y: Any, degree: int) -> Any: """Least-squares polynomial fit."""
[docs] @abstractmethod def polyval(self, coeffs: Any, x: Any) -> Any: """Evaluate a polynomial at specific values."""
# ------------------------------------------------------------------ # Signal processing # ------------------------------------------------------------------
[docs] @abstractmethod def fftconvolve(self, in1: Any, in2: Any, mode: str = "full") -> Any: """FFT-based convolution."""
# ------------------------------------------------------------------ # Random number generation # ------------------------------------------------------------------
[docs] @abstractmethod def default_rng(self, seed: int | None = None) -> Any: """Return a random number generator seeded with seed."""
[docs] @abstractmethod def random_uniform( self, low: float = 0.0, high: float = 1.0, size: Any = None, generator: Any = None, ) -> Any: """Uniform random samples in [low, high)."""
[docs] @abstractmethod def rand(self, *size: int) -> Any: """Random values from a uniform distribution on [0, 1)."""
[docs] @abstractmethod def random_normal( self, loc: float = 0.0, scale: float = 1.0, size: Any = None, generator: Any = None, ) -> Any: """Random samples from a normal (Gaussian) distribution."""
[docs] @abstractmethod def sobol_sampler( self, dim: int, num_samples: int, scramble: bool = True, seed: int | None = None, ) -> Any: """Generate quasi-random samples using Sobol sequences."""
[docs] @abstractmethod def erfinv(self, x: Any) -> Any: """Inverse error function."""
# ------------------------------------------------------------------ # Miscellaneous # ------------------------------------------------------------------
[docs] @abstractmethod def factorial(self, n: Any) -> Any: """Compute the factorial of n."""
[docs] @abstractmethod def path_contains_points(self, vertices: Any, points: Any) -> Any: """Return a boolean mask of points inside the polygon."""
[docs] @abstractmethod def pad( self, tensor: Any, pad_width: Any, mode: str = "constant", constant_values: float | None = 0, ) -> Any: """Pad an array."""
[docs] @abstractmethod def vectorize(self, pyfunc: Callable[..., Any]) -> Callable[..., Any]: """Vectorize a scalar function over array inputs."""
[docs] @abstractmethod def errstate(self, **kwargs: Any) -> Any: """Context manager for floating-point error state."""
[docs] @abstractmethod def histogram(self, x: Any, bins: Any = 10) -> tuple[Any, Any]: """Compute a histogram of x."""
[docs] @abstractmethod def histogram2d( self, x: Any, y: Any, bins: Any, weights: Any = None, ) -> tuple[Any, Any, Any]: """Compute a 2-D histogram."""