"""JAX replacements for scipy.interpolate functions using interpax.
This module provides JAX-compatible implementations of scipy.interpolate
functions used in SimpleMask for interpolation operations, using the
interpax library for GPU-accelerated interpolation.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Literal
import numpy as np
from xpcsviewer.backends import ensure_numpy
if TYPE_CHECKING:
from numpy.typing import ArrayLike
# Supported interpolation methods
InterpolationKind = Literal["linear", "nearest", "cubic", "quadratic"]
[docs]
class Interp1d:
"""1D interpolation class compatible with scipy.interpolate.interp1d.
This class provides JAX-compatible 1D interpolation using interpax.
It stores the interpolation data and provides a callable interface.
Parameters
----------
x : array-like
1D array of x coordinates (must be monotonically increasing)
y : array-like
1D or ND array of y values. If ND, interpolation is performed
along the last axis.
kind : str
Interpolation method: 'linear', 'nearest', 'quadratic', 'cubic'.
Default is 'linear'.
bounds_error : bool
If True, raise ValueError for out-of-bounds values.
If False, use fill_value for out-of-bounds. Default: True.
fill_value : float or tuple
Value for out-of-bounds points. Can be 'extrapolate' for
linear extrapolation, or a tuple (below, above) for different
values below and above the range.
Examples
--------
>>> x = np.array([0, 1, 2, 3])
>>> y = np.array([0, 1, 4, 9])
>>> f = Interp1d(x, y, kind='linear')
>>> f(1.5)
2.5
"""
[docs]
def __init__(
self,
x: ArrayLike,
y: ArrayLike,
kind: str = "linear",
bounds_error: bool = True,
fill_value: float | tuple | str = np.nan,
):
"""Initialize interpolator."""
self._use_interpax = False
self._use_jax = False
try:
import interpax
import jax.numpy as jnp
self._x = jnp.asarray(x)
self._y = jnp.asarray(y)
self._use_interpax = True
self._use_jax = True
self._jnp = jnp
self._interpax = interpax
except ImportError:
self._x = np.asarray(x) # type: ignore
self._y = np.asarray(y) # type: ignore
self._kind = kind
self._bounds_error = bounds_error
self._fill_value = fill_value
self._extrapolate = fill_value == "extrapolate"
# Validate inputs
if self._x.ndim != 1:
raise ValueError("x must be 1-dimensional")
if len(self._x) < 2:
raise ValueError("x must have at least 2 elements")
# Check monotonicity
xp = np if not self._use_jax else self._jnp
if not xp.all(xp.diff(self._x) > 0):
raise ValueError("x must be strictly increasing")
[docs]
def __call__(self, x_new: ArrayLike) -> np.ndarray:
"""Evaluate interpolation at new x values.
Parameters
----------
x_new : array-like
New x values at which to interpolate
Returns
-------
ndarray
Interpolated y values
"""
if self._use_interpax:
return self._interp_interpax(x_new)
return self._interp_numpy(x_new)
def _interp_interpax(self, x_new: ArrayLike) -> np.ndarray:
"""Interpolation using interpax library."""
jnp = self._jnp
interpax = self._interpax
x_new_arr = jnp.asarray(x_new)
x_new_shape = x_new_arr.shape
x_new_flat = x_new_arr.flatten()
x_min, x_max = self._x[0], self._x[-1]
# Handle bounds error check
if self._bounds_error:
out_of_bounds = jnp.logical_or(x_new_flat < x_min, x_new_flat > x_max)
if bool(jnp.any(out_of_bounds)):
raise ValueError("x_new values out of interpolation range")
# Map kind to interpax method
# interpax supports: "nearest", "linear", "cubic", "cubic2", "cardinal", "catmull-rom"
method_map = {
"linear": "linear",
"nearest": "nearest",
"cubic": "cubic",
"quadratic": "cubic", # Use cubic as approximation
"slinear": "linear",
"zero": "nearest",
}
method = method_map.get(self._kind, "linear")
# Determine extrapolation behavior
if self._extrapolate:
extrap = True # interpax will extrapolate
elif self._bounds_error:
extrap = False
else:
extrap = False # We'll handle fill values manually
# Use interpax.interp1d for 1D interpolation
result = interpax.interp1d(
x_new_flat,
self._x,
self._y,
method=method,
extrap=extrap,
)
# Handle fill values for out-of-bounds when not extrapolating
if not self._bounds_error and not self._extrapolate:
below = x_new_flat < x_min
above = x_new_flat > x_max
if isinstance(self._fill_value, tuple):
fill_below, fill_above = self._fill_value
else:
fill_below = fill_above = self._fill_value
result = jnp.where(below, fill_below, result)
result = jnp.where(above, fill_above, result)
return ensure_numpy(result.reshape(x_new_shape))
def _interp_numpy(self, x_new: ArrayLike) -> np.ndarray:
"""NumPy/SciPy fallback implementation."""
from scipy.interpolate import interp1d as scipy_interp1d
f = scipy_interp1d(
self._x,
self._y,
kind=self._kind,
bounds_error=self._bounds_error,
fill_value=self._fill_value if not self._extrapolate else "extrapolate",
)
return f(x_new)
[docs]
def interp1d(
x: ArrayLike,
y: ArrayLike,
kind: str = "linear",
bounds_error: bool = True,
fill_value: float | tuple | str = np.nan,
) -> Interp1d:
"""Create 1D interpolation function.
Factory function that returns an Interp1d instance using interpax
when JAX is available.
Parameters
----------
x : array-like
1D array of x coordinates (must be monotonically increasing)
y : array-like
Array of y values
kind : str
Interpolation method: 'linear', 'nearest', 'cubic', etc.
bounds_error : bool
If True, raise ValueError for out-of-bounds values
fill_value : float or tuple or 'extrapolate'
Value for out-of-bounds points
Returns
-------
Interp1d
Callable interpolation function
"""
return Interp1d(x, y, kind=kind, bounds_error=bounds_error, fill_value=fill_value)
[docs]
def interp2d_jax(
xq: ArrayLike,
yq: ArrayLike,
x: ArrayLike,
y: ArrayLike,
f: ArrayLike,
method: str = "linear",
extrap: bool = False,
fill_value: float = np.nan,
) -> np.ndarray:
"""2D interpolation using interpax.
Parameters
----------
xq, yq : array-like
Query points for interpolation
x, y : array-like
Original grid coordinates (1D arrays)
f : array-like
Values on original grid (2D array)
method : str
Interpolation method: 'linear', 'cubic', etc.
extrap : bool
Whether to extrapolate outside bounds
fill_value : float
Value for out-of-bounds points when extrap=False
Returns
-------
ndarray
Interpolated values at query points
"""
try:
import interpax
import jax.numpy as jnp
xq = jnp.asarray(xq)
yq = jnp.asarray(yq)
x = jnp.asarray(x)
y = jnp.asarray(y)
f = jnp.asarray(f)
result = interpax.interp2d(xq, yq, x, y, f, method=method, extrap=extrap)
if not extrap:
# Apply fill value for out-of-bounds
x_oob = jnp.logical_or(xq < x[0], xq > x[-1])
y_oob = jnp.logical_or(yq < y[0], yq > y[-1])
oob = jnp.logical_or(x_oob, y_oob)
result = jnp.where(oob, fill_value, result)
return ensure_numpy(result)
except ImportError:
from scipy.interpolate import RegularGridInterpolator
interp = RegularGridInterpolator(
(np.asarray(x), np.asarray(y)),
np.asarray(f),
method=method if method in ("linear", "nearest") else "linear",
bounds_error=False,
fill_value=fill_value,
)
points = np.stack([np.asarray(xq).ravel(), np.asarray(yq).ravel()], axis=-1)
return interp(points).reshape(np.asarray(xq).shape)