"""Backend protocol interface for JAX/NumPy array operations.
This module defines the abstract interface that both NumPyBackend and
JAXBackend must implement, ensuring consistent API across backends.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable
if TYPE_CHECKING:
from collections.abc import Callable
import numpy as np
# Generic type for backend-specific arrays
ArrayType = TypeVar("ArrayType")
[docs]
@runtime_checkable
class BackendProtocol(Protocol):
"""Protocol defining the backend interface for array operations.
Both NumPyBackend and JAXBackend implement this protocol, providing
a unified API for array computations that can run on CPU or GPU.
Attributes
----------
name : str
Backend identifier ("numpy" or "jax")
supports_gpu : bool
Whether backend supports GPU computation
supports_jit : bool
Whether backend supports JIT compilation
supports_grad : bool
Whether backend supports automatic differentiation
pi : float
Mathematical constant π
"""
@property
def name(self) -> str:
"""Backend identifier ('numpy' or 'jax')."""
...
@property
def supports_gpu(self) -> bool:
"""Whether backend supports GPU computation."""
...
@property
def supports_jit(self) -> bool:
"""Whether backend supports JIT compilation."""
...
@property
def supports_grad(self) -> bool:
"""Whether backend supports automatic differentiation."""
...
@property
def pi(self) -> float:
"""Mathematical constant π."""
...
# =========================================================================
# Array Creation
# =========================================================================
[docs]
def zeros(self, shape: tuple[int, ...], dtype: Any = None) -> ArrayType: # type: ignore[type-var]
"""Create array filled with zeros."""
...
[docs]
def ones(self, shape: tuple[int, ...], dtype: Any = None) -> ArrayType: # type: ignore[type-var]
"""Create array filled with ones."""
...
[docs]
def arange(
self,
start: float,
stop: float | None = None,
step: float = 1,
dtype: Any = None,
) -> ArrayType: # type: ignore[type-var]
"""Create array with evenly spaced values."""
...
[docs]
def linspace(self, start: float, stop: float, num: int) -> ArrayType: # type: ignore[type-var]
"""Create array with linearly spaced values."""
...
[docs]
def logspace(self, start: float, stop: float, num: int) -> ArrayType: # type: ignore[type-var]
"""Create array with logarithmically spaced values."""
...
[docs]
def meshgrid(self, *xi: ArrayType, indexing: str = "xy") -> tuple[ArrayType, ...]:
"""Create coordinate matrices from coordinate vectors."""
...
[docs]
def zeros_like(self, x: ArrayType, dtype: Any = None) -> ArrayType:
"""Create zero-filled array with same shape as input."""
...
[docs]
def ones_like(self, x: ArrayType, dtype: Any = None) -> ArrayType:
"""Create ones-filled array with same shape as input."""
...
[docs]
def full(
self, shape: tuple[int, ...], fill_value: float, dtype: Any = None
) -> ArrayType: # type: ignore[type-var]
"""Create array filled with specified value."""
...
[docs]
def array(self, data: Any, dtype: Any = None) -> ArrayType: # type: ignore[type-var]
"""Create array from data."""
...
# =========================================================================
# Trigonometric Functions
# =========================================================================
[docs]
def sin(self, x: ArrayType) -> ArrayType:
"""Element-wise sine."""
...
[docs]
def cos(self, x: ArrayType) -> ArrayType:
"""Element-wise cosine."""
...
[docs]
def arctan(self, x: ArrayType) -> ArrayType:
"""Element-wise arctangent."""
...
[docs]
def arctan2(self, y: ArrayType, x: ArrayType) -> ArrayType:
"""Element-wise arctangent of y/x, handling quadrants."""
...
[docs]
def hypot(self, x: ArrayType, y: ArrayType) -> ArrayType:
"""Element-wise sqrt(x^2 + y^2)."""
...
[docs]
def deg2rad(self, x: ArrayType) -> ArrayType:
"""Convert degrees to radians."""
...
[docs]
def rad2deg(self, x: ArrayType) -> ArrayType:
"""Convert radians to degrees."""
...
[docs]
def mod(self, x: ArrayType, y: ArrayType | float) -> ArrayType:
"""Element-wise modulo."""
...
[docs]
def floor(self, x: ArrayType) -> ArrayType:
"""Element-wise floor."""
...
[docs]
def ceil(self, x: ArrayType) -> ArrayType:
"""Element-wise ceiling."""
...
[docs]
def round(self, x: ArrayType, decimals: int = 0) -> ArrayType:
"""Round to given number of decimals."""
...
# =========================================================================
# Statistical Functions
# =========================================================================
[docs]
def mean(self, x: ArrayType, axis: int | None = None) -> ArrayType:
"""Compute mean along axis."""
...
[docs]
def std(self, x: ArrayType, axis: int | None = None) -> ArrayType:
"""Compute standard deviation along axis."""
...
[docs]
def nanmean(self, x: ArrayType, axis: int | None = None) -> ArrayType:
"""Compute mean, ignoring NaN values."""
...
[docs]
def nanmin(self, x: ArrayType, axis: int | None = None) -> ArrayType:
"""Compute minimum, ignoring NaN values."""
...
[docs]
def nanmax(self, x: ArrayType, axis: int | None = None) -> ArrayType:
"""Compute maximum, ignoring NaN values."""
...
[docs]
def percentile(self, x: ArrayType, q: float, axis: int | None = None) -> ArrayType:
"""Compute percentile along axis."""
...
[docs]
def sum(self, x: ArrayType, axis: int | None = None) -> ArrayType:
"""Compute sum along axis."""
...
[docs]
def min(self, x: ArrayType, axis: int | None = None) -> ArrayType:
"""Compute minimum along axis."""
...
[docs]
def max(self, x: ArrayType, axis: int | None = None) -> ArrayType:
"""Compute maximum along axis."""
...
# =========================================================================
# Binning Functions
# =========================================================================
[docs]
def digitize(self, x: ArrayType, bins: ArrayType) -> ArrayType:
"""Return indices of bins to which each value belongs."""
...
[docs]
def bincount(
self,
x: ArrayType,
weights: ArrayType | None = None,
minlength: int = 0,
) -> ArrayType:
"""Count number of occurrences of each value."""
...
[docs]
def unique(
self,
x: ArrayType,
return_inverse: bool = False,
size: int | None = None,
) -> ArrayType | tuple[ArrayType, ...]:
"""Find unique elements of array."""
...
# =========================================================================
# Boolean/Masking Functions
# =========================================================================
[docs]
def logical_and(self, x: ArrayType, y: ArrayType) -> ArrayType:
"""Element-wise logical AND."""
...
[docs]
def logical_or(self, x: ArrayType, y: ArrayType) -> ArrayType:
"""Element-wise logical OR."""
...
[docs]
def logical_not(self, x: ArrayType) -> ArrayType:
"""Element-wise logical NOT."""
...
[docs]
def where(self, condition: ArrayType, x: ArrayType, y: ArrayType) -> ArrayType:
"""Return elements chosen from x or y depending on condition."""
...
[docs]
def nonzero(self, x: ArrayType, size: int | None = None) -> tuple[ArrayType, ...]:
"""Return indices of non-zero elements."""
...
[docs]
def isnan(self, x: ArrayType) -> ArrayType:
"""Test element-wise for NaN."""
...
[docs]
def isfinite(self, x: ArrayType) -> ArrayType:
"""Test element-wise for finite values."""
...
# =========================================================================
# Array Manipulation
# =========================================================================
[docs]
def clip(self, x: ArrayType, a_min: float, a_max: float) -> ArrayType:
"""Clip array values to specified range."""
...
[docs]
def stack(self, arrays: list[ArrayType], axis: int = 0) -> ArrayType:
"""Stack arrays along new axis."""
...
[docs]
def concatenate(self, arrays: list[ArrayType], axis: int = 0) -> ArrayType:
"""Concatenate arrays along existing axis."""
...
[docs]
def copy(self, x: ArrayType) -> ArrayType:
"""Return copy of array."""
...
[docs]
def reshape(self, x: ArrayType, shape: tuple[int, ...]) -> ArrayType:
"""Reshape array to specified shape."""
...
[docs]
def transpose(self, x: ArrayType, axes: tuple[int, ...] | None = None) -> ArrayType:
"""Permute array dimensions."""
...
[docs]
def flatten(self, x: ArrayType) -> ArrayType:
"""Flatten array to 1D."""
...
# =========================================================================
# Mathematical Functions
# =========================================================================
[docs]
def exp(self, x: ArrayType) -> ArrayType:
"""Element-wise exponential."""
...
[docs]
def log(self, x: ArrayType) -> ArrayType:
"""Element-wise natural logarithm."""
...
[docs]
def log10(self, x: ArrayType) -> ArrayType:
"""Element-wise base-10 logarithm."""
...
[docs]
def sqrt(self, x: ArrayType) -> ArrayType:
"""Element-wise square root."""
...
[docs]
def abs(self, x: ArrayType) -> ArrayType:
"""Element-wise absolute value."""
...
[docs]
def power(self, x: ArrayType, y: float | ArrayType) -> ArrayType:
"""Element-wise power."""
...
# =========================================================================
# Type Conversion
# =========================================================================
[docs]
def to_numpy(self, x: ArrayType) -> np.ndarray:
"""Convert array to NumPy ndarray."""
...
[docs]
def from_numpy(self, x: np.ndarray) -> ArrayType: # type: ignore[type-var]
"""Convert NumPy ndarray to backend array."""
...
[docs]
def astype(self, x: ArrayType, dtype: Any) -> ArrayType:
"""Cast array to specified dtype."""
...
# =========================================================================
# JIT Compilation
# =========================================================================
[docs]
def jit(
self,
func: Callable,
static_argnums: tuple[int, ...] | None = None,
) -> Callable:
"""JIT compile function (no-op for NumPy)."""
...
# =========================================================================
# Gradient Computation (JAX only)
# =========================================================================
[docs]
def grad(
self,
func: Callable,
argnums: int | tuple[int, ...] = 0,
) -> Callable:
"""Return gradient function (raises for NumPy)."""
...
[docs]
def value_and_grad(
self,
func: Callable,
argnums: int | tuple[int, ...] = 0,
) -> Callable:
"""Return function computing both value and gradient."""
...
# =========================================================================
# Batch Processing
# =========================================================================
[docs]
def vmap(
self,
func: Callable,
in_axes: int | tuple[int | None, ...] = 0,
out_axes: int = 0,
) -> Callable:
"""Vectorize function over batch dimension."""
...
[docs]
def scan(
self,
func: Callable,
init: ArrayType,
xs: ArrayType,
length: int | None = None,
) -> tuple[ArrayType, ArrayType]:
"""Scan over leading array dimension while carrying along state."""
...
[docs]
def fori_loop(
self,
lower: int,
upper: int,
body_fun: Callable,
init_val: ArrayType,
) -> ArrayType:
"""Execute body function in a loop from lower to upper."""
...