"""JAX replacements for scipy.ndimage functions.
This module provides JAX-compatible implementations of scipy.ndimage
functions used in SimpleMask for smoothing and filtering.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import numpy as np
if TYPE_CHECKING:
from numpy.typing import ArrayLike
[docs]
def gaussian_filter(
input_array: ArrayLike,
sigma: float | tuple[float, ...],
mode: str = "reflect",
truncate: float = 4.0,
) -> np.ndarray:
"""Apply Gaussian filter to array.
This is a JAX-compatible implementation that uses convolution with
a Gaussian kernel. Falls back to scipy.ndimage.gaussian_filter when
JAX is not available.
Parameters
----------
input_array : array-like
Input array to filter
sigma : float or tuple of floats
Standard deviation for Gaussian kernel. Can be a single value
for isotropic filtering or a tuple for anisotropic filtering.
mode : str
Boundary mode: 'reflect', 'constant', 'nearest', 'wrap'
(default: 'reflect')
truncate : float
Truncate filter at this many standard deviations (default: 4.0)
Returns
-------
ndarray
Filtered array
"""
try:
import jax
import jax.numpy as jnp
from jax import lax
from xpcsviewer.backends import ensure_numpy
return ensure_numpy(_gaussian_filter_jax(input_array, sigma, mode, truncate))
except ImportError:
# Fall back to scipy
from scipy.ndimage import gaussian_filter as scipy_gaussian
return scipy_gaussian(input_array, sigma, mode=mode, truncate=truncate)
def _gaussian_filter_jax(
input_array: ArrayLike,
sigma: float | tuple[float, ...],
mode: str = "reflect",
truncate: float = 4.0,
) -> Any:
"""JAX implementation of Gaussian filter using separable convolution."""
import jax.numpy as jnp
arr = jnp.asarray(input_array)
ndim = arr.ndim
# Normalize sigma to tuple
if isinstance(sigma, (int, float)):
sigma = (float(sigma),) * ndim
else:
sigma = tuple(float(s) for s in sigma)
if len(sigma) != ndim:
raise ValueError(f"sigma must have {ndim} elements, got {len(sigma)}")
# Map mode to JAX padding mode
padding_modes = {
"reflect": "REFLECT",
"constant": "CONSTANT",
"nearest": "REPLICATE",
"wrap": "CIRCULAR",
}
if mode not in padding_modes:
raise ValueError(f"Unsupported mode: {mode}")
pad_mode = padding_modes[mode]
# Apply 1D Gaussian filter along each axis (separable)
result = arr
for axis in range(ndim):
if sigma[axis] > 0:
result = _gaussian_filter_1d_jax(
result, sigma[axis], axis, pad_mode, truncate
)
# Return JAX array -- ensure_numpy is deferred to the public API
return result
def _gaussian_filter_1d_jax(
arr,
sigma: float,
axis: int,
pad_mode: str,
truncate: float,
):
"""Apply 1D Gaussian filter along specified axis using JAX."""
import jax.numpy as jnp
# Compute kernel size
radius = int(truncate * sigma + 0.5)
if radius == 0:
return arr
# Create 1D Gaussian kernel
x = jnp.arange(-radius, radius + 1, dtype=arr.dtype)
kernel = jnp.exp(-0.5 * (x / sigma) ** 2)
kernel = kernel / jnp.sum(kernel)
# Reshape kernel for the target axis
shape = [1] * arr.ndim
shape[axis] = len(kernel)
kernel = kernel.reshape(shape)
# Pad array
pad_width = [(0, 0)] * arr.ndim
pad_width[axis] = (radius, radius)
if pad_mode == "CONSTANT":
padded = jnp.pad(arr, pad_width, mode="constant", constant_values=0)
elif pad_mode == "REFLECT":
padded = jnp.pad(arr, pad_width, mode="reflect")
elif pad_mode == "REPLICATE":
padded = jnp.pad(arr, pad_width, mode="edge")
elif pad_mode == "CIRCULAR":
padded = jnp.pad(arr, pad_width, mode="wrap")
else:
padded = jnp.pad(arr, pad_width, mode="constant", constant_values=0)
# Convolve
# For 2D case, we use a simple approach with sum over the kernel window
# This is less efficient than lax.conv but more general for N-D
result = _convolve_1d(padded, kernel.flatten(), axis)
return result
def _convolve_1d(arr, kernel, axis: int):
"""1D convolution along specified axis using JAX."""
import jax.numpy as jnp
from jax import lax
kernel_size = len(kernel)
ndim = arr.ndim
# For 2D arrays, use lax.conv_general_dilated
if ndim == 2:
# Add batch and channel dimensions for lax.conv
arr_4d = arr[jnp.newaxis, jnp.newaxis, :, :]
if axis == 0:
kernel_4d = kernel.reshape(1, 1, kernel_size, 1)
else:
kernel_4d = kernel.reshape(1, 1, 1, kernel_size)
result = lax.conv_general_dilated(
arr_4d,
kernel_4d,
window_strides=(1, 1),
padding="VALID",
dimension_numbers=("NCHW", "OIHW", "NCHW"),
)
return result[0, 0]
# Fallback for other dimensions: use jnp.convolve approach
# Move axis to last position, convolve, move back
arr_moved = jnp.moveaxis(arr, axis, -1)
original_shape = arr_moved.shape
# Flatten to 2D for vectorized convolution
flat = arr_moved.reshape(-1, arr_moved.shape[-1])
def convolve_row(row):
return jnp.convolve(row, kernel, mode="valid")
# Use vmap for vectorized convolution
from jax import vmap
result_flat = vmap(convolve_row)(flat)
result_shape = original_shape[:-1] + (result_flat.shape[-1],)
result = result_flat.reshape(result_shape)
return jnp.moveaxis(result, -1, axis)
[docs]
def gaussian_filter1d(
input_array: ArrayLike,
sigma: float,
axis: int = -1,
mode: str = "reflect",
truncate: float = 4.0,
) -> np.ndarray:
"""Apply 1D Gaussian filter along specified axis.
This is a JAX-compatible implementation that uses convolution with
a 1D Gaussian kernel. Falls back to scipy.ndimage.gaussian_filter1d
when JAX is not available.
Parameters
----------
input_array : array-like
Input array to filter
sigma : float
Standard deviation for Gaussian kernel
axis : int
Axis along which to apply filter (default: -1)
mode : str
Boundary mode: 'reflect', 'constant', 'nearest', 'wrap'
(default: 'reflect')
truncate : float
Truncate filter at this many standard deviations (default: 4.0)
Returns
-------
ndarray
Filtered array
"""
try:
import jax.numpy as jnp
arr = jnp.asarray(input_array)
# Map mode to JAX padding mode
padding_modes = {
"reflect": "REFLECT",
"constant": "CONSTANT",
"nearest": "REPLICATE",
"wrap": "CIRCULAR",
}
if mode not in padding_modes:
raise ValueError(f"Unsupported mode: {mode}")
pad_mode = padding_modes[mode]
result = _gaussian_filter_1d_jax(arr, sigma, axis, pad_mode, truncate)
from xpcsviewer.backends import ensure_numpy
return ensure_numpy(result) # ensure_numpy at public API boundary
except ImportError:
from scipy.ndimage import gaussian_filter1d as scipy_gaussian_1d
return scipy_gaussian_1d(
input_array, sigma, axis=axis, mode=mode, truncate=truncate
)
[docs]
def zoom(
input_array: ArrayLike,
zoom_factor: float | tuple[float, ...],
order: int = 1,
mode: str = "constant",
cval: float = 0.0,
) -> np.ndarray:
"""Zoom (resize) array using interpolation.
This is a JAX-compatible implementation using interpax for interpolation.
Falls back to scipy.ndimage.zoom when JAX is not available.
Parameters
----------
input_array : array-like
Input array to zoom
zoom_factor : float or tuple of floats
Zoom factor for each axis. If scalar, applied uniformly.
order : int
Interpolation order (0=nearest, 1=linear, 3=cubic). Default: 1.
mode : str
Boundary mode (default: 'constant')
cval : float
Fill value for constant mode (default: 0.0)
Returns
-------
ndarray
Zoomed array
"""
try:
import interpax
import jax.numpy as jnp
arr = jnp.asarray(input_array)
ndim = arr.ndim
# Normalize zoom_factor to tuple
if isinstance(zoom_factor, (int, float)):
zoom_tuple = (float(zoom_factor),) * ndim
else:
zoom_tuple = tuple(float(z) for z in zoom_factor)
if len(zoom_tuple) != ndim:
raise ValueError(f"zoom_factor must have {ndim} elements")
# Calculate new shape
new_shape = tuple(int(round(s * z)) for s, z in zip(arr.shape, zoom_tuple))
if ndim == 1:
# 1D case
x_old = jnp.arange(arr.shape[0])
x_new = jnp.linspace(0, arr.shape[0] - 1, new_shape[0])
method = "linear" if order == 1 else ("nearest" if order == 0 else "cubic")
result = interpax.interp1d(x_new, x_old, arr, method=method, extrap=cval)
elif ndim == 2:
# 2D case - use bilinear/bicubic interpolation
y_old = jnp.arange(arr.shape[0])
x_old = jnp.arange(arr.shape[1])
y_new = jnp.linspace(0, arr.shape[0] - 1, new_shape[0])
x_new = jnp.linspace(0, arr.shape[1] - 1, new_shape[1])
# Create meshgrid and flatten for interpax (expects 1D query points)
xq, yq = jnp.meshgrid(x_new, y_new)
xq_flat = xq.ravel()
yq_flat = yq.ravel()
method = "linear" if order == 1 else ("nearest" if order == 0 else "cubic")
result_flat = interpax.interp2d(
yq_flat, xq_flat, y_old, x_old, arr, method=method, extrap=cval
)
# Reshape back to 2D
result = result_flat.reshape(new_shape)
else:
# For higher dimensions, fall back to scipy
from scipy.ndimage import zoom as scipy_zoom
return scipy_zoom(
input_array, zoom_factor, order=order, mode=mode, cval=cval
)
from xpcsviewer.backends import ensure_numpy
return ensure_numpy(result)
except ImportError:
from scipy.ndimage import zoom as scipy_zoom
return scipy_zoom(input_array, zoom_factor, order=order, mode=mode, cval=cval)