Backend Selection: JAX vs NumPy¶
This tutorial explains the backend abstraction layer, how to switch between JAX and NumPy backends, and when to use each for optimal performance.
Architecture Overview¶
xpcsviewer uses a backend abstraction that provides a unified API for
array operations. All computation code uses the backend instead of
importing numpy or jax.numpy directly. This allows:
Automatic fallback: JAX if available, NumPy otherwise
GPU acceleration: JAX backend enables JIT compilation and GPU offloading
Gradient computation: Needed for advanced fitting and optimization
Same API: Switch backends without changing analysis code
The backend protocol is defined in BackendProtocol.
Checking the Active Backend¶
from xpcsviewer.backends import get_backend
backend = get_backend()
print(f"Backend: {backend.name}")
print(f"GPU support: {backend.supports_gpu}")
print(f"JIT support: {backend.supports_jit}")
print(f"Grad support: {backend.supports_grad}")
Switching Backends¶
Programmatic Selection¶
from xpcsviewer.backends import set_backend, get_backend
# Force NumPy backend
set_backend("numpy")
print(f"Now using: {get_backend().name}")
# Force JAX backend (requires JAX installation)
set_backend("jax")
print(f"Now using: {get_backend().name}")
Environment Variable Control¶
Set XPCS_USE_JAX=1 before importing xpcsviewer to enable JAX:
# Enable JAX backend
export XPCS_USE_JAX=1
# Force CPU-only JAX (useful for debugging)
export JAX_PLATFORMS=cpu
# Then run your script
python my_analysis.py
If XPCS_USE_JAX is not set, the backend auto-detects: JAX if importable,
NumPy as fallback.
Using the Backend API¶
Basic Array Operations¶
from xpcsviewer.backends import get_backend
backend = get_backend()
# Array creation
x = backend.linspace(0, 10, 100)
y = backend.zeros((3, 3))
z = backend.arange(0, 100, 0.5)
# Math operations
result = backend.sin(x)
result = backend.exp(-x)
result = backend.sqrt(backend.sum(x ** 2))
# Statistics
mean_val = backend.mean(x)
std_val = backend.std(x)
min_val = backend.min(x)
# Array manipulation
stacked = backend.stack([x, result])
reshaped = backend.reshape(x, (10, 10))
All backend methods mirror the NumPy/JAX API and return arrays in the active backend’s format.
JIT Compilation (JAX Only)¶
backend = get_backend()
if backend.supports_jit:
@backend.jit
def compute_g2(t, tau, baseline, contrast):
return baseline + contrast * backend.exp(-2 * t / tau)
# First call compiles; subsequent calls are fast
t = backend.linspace(0.01, 1000, 10000)
g2 = compute_g2(t, 5.0, 1.0, 0.3)
Gradient Computation (JAX Only)¶
backend = get_backend()
if backend.supports_grad:
def loss_fn(params, x, y_obs):
tau, baseline, contrast = params
y_pred = baseline + contrast * backend.exp(-2 * x / tau)
return backend.sum((y_pred - y_obs) ** 2)
# Get loss value and gradients simultaneously
value_and_grad_fn = backend.value_and_grad(loss_fn)
params = backend.array([5.0, 1.0, 0.3])
loss, grads = value_and_grad_fn(params, x, y_obs)
print(f"Loss: {loss}, Gradients: {grads}")
I/O Boundary Conversions¶
JAX arrays are not directly compatible with HDF5, PyQtGraph, or Matplotlib.
Always convert at I/O boundaries using ensure_numpy():
from xpcsviewer.backends import get_backend, ensure_numpy
import matplotlib.pyplot as plt
backend = get_backend()
# Compute with backend (may be JAX or NumPy)
x = backend.linspace(0, 10, 100)
y = backend.sin(x)
# Convert at the plotting boundary
plt.plot(ensure_numpy(x), ensure_numpy(y))
plt.show()
The ensure_numpy() function:
Returns NumPy arrays unchanged (no copy)
Copies JAX arrays to host memory
Handles lists and other array-like objects
Ensures the result is writable (important for HDF5)
I/O Adapters¶
For structured I/O, use the dedicated adapter classes:
from xpcsviewer.backends.io_adapter import PyQtGraphAdapter, HDF5Adapter
# For PyQtGraph plots
pg_adapter = PyQtGraphAdapter()
plot_data = pg_adapter.to_pyqtgraph(jax_array)
# For HDF5 writing
h5_adapter = HDF5Adapter()
h5_data = h5_adapter.to_hdf5(jax_array)
When to Use Each Backend¶
Use NumPy When:¶
Running on machines without JAX/GPU
Debugging (NumPy error messages are more descriptive)
Small datasets where JIT compilation overhead dominates
Interfacing with libraries that require NumPy arrays
Running on macOS without GPU support
Use JAX When:¶
Fitting many Q bins (JIT + vmap parallelism)
Running Bayesian inference (NumPyro requires JAX)
Needing gradient computation for optimization
Working with large datasets (> 10,000 delay points)
GPU hardware is available
Performance Comparison¶
import time
from xpcsviewer.backends import set_backend, get_backend
def benchmark_g2_compute(backend, n_points=100000, n_trials=10):
t = backend.linspace(0.01, 1000, n_points)
tau = backend.array(5.0)
# Warm-up (important for JIT)
_ = backend.exp(-2 * t / tau)
start = time.perf_counter()
for _ in range(n_trials):
_ = backend.exp(-2 * t / tau)
elapsed = (time.perf_counter() - start) / n_trials
return elapsed
# NumPy benchmark
set_backend("numpy")
np_time = benchmark_g2_compute(get_backend())
print(f"NumPy: {np_time*1000:.2f} ms")
# JAX benchmark (if available)
try:
set_backend("jax")
jax_time = benchmark_g2_compute(get_backend())
print(f"JAX: {jax_time*1000:.2f} ms")
print(f"Speedup: {np_time/jax_time:.1f}x")
except (ImportError, ValueError):
print("JAX not available for comparison")
Troubleshooting¶
JAX Not Detected¶
# Check if JAX is importable
try:
import jax
print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
except ImportError:
print("JAX is not installed. Install with: pip install jax jaxlib")
Memory Issues with JAX¶
# Limit JAX GPU memory pre-allocation
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.5
# Force CPU-only for debugging
export JAX_PLATFORMS=cpu
Float64 Precision¶
JAX defaults to float32. xpcsviewer configures float64 precision automatically
via jax.config.update("jax_enable_x64", True). Verify:
from xpcsviewer.backends import get_backend
backend = get_backend()
x = backend.linspace(0, 1, 10)
print(f"Default dtype: {x.dtype}") # Should be float64
Next Steps¶
Fitting G2 Correlation Functions – Use JAX acceleration for fitting
Cookbook: Common Patterns and Recipes – Backend-aware analysis patterns
Getting Started with XPCS Viewer – Loading data and basic analysis