Backend Selection: JAX vs NumPy

This tutorial explains the backend abstraction layer, how to switch between JAX and NumPy backends, and when to use each for optimal performance.

Architecture Overview

xpcsviewer uses a backend abstraction that provides a unified API for array operations. All computation code uses the backend instead of importing numpy or jax.numpy directly. This allows:

  • Automatic fallback: JAX if available, NumPy otherwise

  • GPU acceleration: JAX backend enables JIT compilation and GPU offloading

  • Gradient computation: Needed for advanced fitting and optimization

  • Same API: Switch backends without changing analysis code

The backend protocol is defined in BackendProtocol.

Checking the Active Backend

from xpcsviewer.backends import get_backend

backend = get_backend()
print(f"Backend: {backend.name}")
print(f"GPU support: {backend.supports_gpu}")
print(f"JIT support: {backend.supports_jit}")
print(f"Grad support: {backend.supports_grad}")

Switching Backends

Programmatic Selection

from xpcsviewer.backends import set_backend, get_backend

# Force NumPy backend
set_backend("numpy")
print(f"Now using: {get_backend().name}")

# Force JAX backend (requires JAX installation)
set_backend("jax")
print(f"Now using: {get_backend().name}")

Environment Variable Control

Set XPCS_USE_JAX=1 before importing xpcsviewer to enable JAX:

# Enable JAX backend
export XPCS_USE_JAX=1

# Force CPU-only JAX (useful for debugging)
export JAX_PLATFORMS=cpu

# Then run your script
python my_analysis.py

If XPCS_USE_JAX is not set, the backend auto-detects: JAX if importable, NumPy as fallback.

Using the Backend API

Basic Array Operations

from xpcsviewer.backends import get_backend

backend = get_backend()

# Array creation
x = backend.linspace(0, 10, 100)
y = backend.zeros((3, 3))
z = backend.arange(0, 100, 0.5)

# Math operations
result = backend.sin(x)
result = backend.exp(-x)
result = backend.sqrt(backend.sum(x ** 2))

# Statistics
mean_val = backend.mean(x)
std_val = backend.std(x)
min_val = backend.min(x)

# Array manipulation
stacked = backend.stack([x, result])
reshaped = backend.reshape(x, (10, 10))

All backend methods mirror the NumPy/JAX API and return arrays in the active backend’s format.

JIT Compilation (JAX Only)

backend = get_backend()

if backend.supports_jit:
    @backend.jit
    def compute_g2(t, tau, baseline, contrast):
        return baseline + contrast * backend.exp(-2 * t / tau)

    # First call compiles; subsequent calls are fast
    t = backend.linspace(0.01, 1000, 10000)
    g2 = compute_g2(t, 5.0, 1.0, 0.3)

Gradient Computation (JAX Only)

backend = get_backend()

if backend.supports_grad:
    def loss_fn(params, x, y_obs):
        tau, baseline, contrast = params
        y_pred = baseline + contrast * backend.exp(-2 * x / tau)
        return backend.sum((y_pred - y_obs) ** 2)

    # Get loss value and gradients simultaneously
    value_and_grad_fn = backend.value_and_grad(loss_fn)
    params = backend.array([5.0, 1.0, 0.3])
    loss, grads = value_and_grad_fn(params, x, y_obs)
    print(f"Loss: {loss}, Gradients: {grads}")

I/O Boundary Conversions

JAX arrays are not directly compatible with HDF5, PyQtGraph, or Matplotlib. Always convert at I/O boundaries using ensure_numpy():

from xpcsviewer.backends import get_backend, ensure_numpy
import matplotlib.pyplot as plt

backend = get_backend()

# Compute with backend (may be JAX or NumPy)
x = backend.linspace(0, 10, 100)
y = backend.sin(x)

# Convert at the plotting boundary
plt.plot(ensure_numpy(x), ensure_numpy(y))
plt.show()

The ensure_numpy() function:

  • Returns NumPy arrays unchanged (no copy)

  • Copies JAX arrays to host memory

  • Handles lists and other array-like objects

  • Ensures the result is writable (important for HDF5)

I/O Adapters

For structured I/O, use the dedicated adapter classes:

from xpcsviewer.backends.io_adapter import PyQtGraphAdapter, HDF5Adapter

# For PyQtGraph plots
pg_adapter = PyQtGraphAdapter()
plot_data = pg_adapter.to_pyqtgraph(jax_array)

# For HDF5 writing
h5_adapter = HDF5Adapter()
h5_data = h5_adapter.to_hdf5(jax_array)

When to Use Each Backend

Use NumPy When:

  • Running on machines without JAX/GPU

  • Debugging (NumPy error messages are more descriptive)

  • Small datasets where JIT compilation overhead dominates

  • Interfacing with libraries that require NumPy arrays

  • Running on macOS without GPU support

Use JAX When:

  • Fitting many Q bins (JIT + vmap parallelism)

  • Running Bayesian inference (NumPyro requires JAX)

  • Needing gradient computation for optimization

  • Working with large datasets (> 10,000 delay points)

  • GPU hardware is available

Performance Comparison

import time
from xpcsviewer.backends import set_backend, get_backend

def benchmark_g2_compute(backend, n_points=100000, n_trials=10):
    t = backend.linspace(0.01, 1000, n_points)
    tau = backend.array(5.0)

    # Warm-up (important for JIT)
    _ = backend.exp(-2 * t / tau)

    start = time.perf_counter()
    for _ in range(n_trials):
        _ = backend.exp(-2 * t / tau)
    elapsed = (time.perf_counter() - start) / n_trials

    return elapsed

# NumPy benchmark
set_backend("numpy")
np_time = benchmark_g2_compute(get_backend())
print(f"NumPy: {np_time*1000:.2f} ms")

# JAX benchmark (if available)
try:
    set_backend("jax")
    jax_time = benchmark_g2_compute(get_backend())
    print(f"JAX:   {jax_time*1000:.2f} ms")
    print(f"Speedup: {np_time/jax_time:.1f}x")
except (ImportError, ValueError):
    print("JAX not available for comparison")

Troubleshooting

JAX Not Detected

# Check if JAX is importable
try:
    import jax
    print(f"JAX version: {jax.__version__}")
    print(f"JAX devices: {jax.devices()}")
except ImportError:
    print("JAX is not installed. Install with: pip install jax jaxlib")

Memory Issues with JAX

# Limit JAX GPU memory pre-allocation
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.5

# Force CPU-only for debugging
export JAX_PLATFORMS=cpu

Float64 Precision

JAX defaults to float32. xpcsviewer configures float64 precision automatically via jax.config.update("jax_enable_x64", True). Verify:

from xpcsviewer.backends import get_backend

backend = get_backend()
x = backend.linspace(0, 1, 10)
print(f"Default dtype: {x.dtype}")  # Should be float64

Next Steps