ADR-001: JAX Migration and Backend Abstraction¶
Status¶
Accepted
Context¶
XPCS Viewer originally used NumPy exclusively for all array computations. As datasets grew larger and fitting workloads became more demanding, several pain points emerged:
Performance ceiling: NumPy computations on large 2D detector images (1024x1024+) and multi-Q-bin correlation analysis were becoming the bottleneck in interactive workflows.
No GPU acceleration: Synchrotron beamlines increasingly provide GPU-equipped workstations, but NumPy cannot utilize them.
No automatic differentiation: The Bayesian fitting pipeline requires gradients for NUTS sampling. With NumPy, gradient computation had to be done numerically (finite differences), which is slow and imprecise.
No JIT compilation: Repeated Q-map computations and partition generation could not be compiled and cached.
The team evaluated JAX as the primary computational backend because:
JAX provides a NumPy-compatible API (
jax.numpy), minimizing migration effort.JAX supports JIT compilation via XLA, yielding 5-10x speedups on repeated computations.
JAX supports automatic differentiation (
jax.grad,jax.value_and_grad), which NumPyro requires for Hamiltonian Monte Carlo (NUTS) sampling.JAX supports GPU/TPU execution with the same code, via
jax.device_put.
Initially, JAX was an optional dependency to support lightweight installations. As the project matured, JAX and its ecosystem (NumPyro, optimistix, interpax, optax) became core dependencies (listed in pyproject.toml under dependencies), since the Bayesian fitting pipeline and JIT-compiled Q-map computation are now integral to the application. The backend abstraction layer remains useful for testing (forcing NumPy behavior) and for graceful degradation when JAX is installed but GPU hardware is unavailable.
Decision¶
We introduced a backend abstraction layer (xpcsviewer/backends/) that provides a unified BackendProtocol interface for array operations, with two concrete implementations:
Architecture¶
xpcsviewer/backends/
_base.py # BackendProtocol (runtime_checkable Protocol)
_jax_backend.py # JAXBackend: full JIT, grad, GPU support
_numpy_backend.py # NumPyBackend: CPU-only fallback (no-op JIT, raises on grad)
_device.py # DeviceManager singleton: GPU/CPU selection, memory config
_conversions.py # ensure_numpy(), ensure_backend_array() at I/O boundaries
io_adapter.py # PyQtGraphAdapter, HDF5Adapter, MatplotlibAdapter
__init__.py # get_backend(), set_backend(), auto-detection
Key Design Choices¶
Protocol-based interface (
typing.Protocolwith@runtime_checkable): Both backends implementBackendProtocol, which defines 60+ methods covering array creation, mathematical operations, statistical functions, JIT, grad, vmap, scan, and fori_loop. This uses structural subtyping – no base class inheritance required.Auto-detection with environment override:
get_backend()auto-detects JAX availability. Users can force a backend viaXPCS_USE_JAX=true|false|auto. When JAX is requested but not installed, a clearImportErrorwith install instructions is raised.float64 by default: Scientific computing requires double precision. JAX is configured with
JAX_ENABLE_X64=trueon first use to ensure float64 arrays are supported.Singleton DeviceManager:
DeviceManageruses double-checked locking for thread-safe singleton creation. It readsXPCS_USE_GPU,XPCS_GPU_FALLBACK, andXPCS_GPU_MEMORY_FRACTIONfrom the environment and configures the JAX runtime accordingly.ensure_numpy()at I/O boundaries: All external libraries (h5py, PyQtGraph, Matplotlib) require NumPy arrays. Theensure_numpy()function in_conversions.pyhandles all array types (JAX arrays, read-only NumPy arrays, lists) and guarantees a writable NumPy ndarray.I/O Adapters:
PyQtGraphAdapter,HDF5Adapter, andMatplotlibAdapterwrapensure_numpy()with optional performance monitoring (conversion count, timing statistics). These centralize boundary conversions instead of scatteringensure_numpy()calls across modules.
Environment Variables¶
Variable |
Default |
Description |
|---|---|---|
|
|
|
|
|
Enable GPU device |
|
|
Fall back to CPU if GPU unavailable |
|
|
Max GPU memory fraction |
|
(unset) |
Force JAX platform ( |
Consequences¶
What became easier¶
Performance: JIT-compiled Q-map computation runs ~10x faster on repeated calls. GPU acceleration provides further speedups for large datasets.
Bayesian fitting: NumPyro NUTS sampling works natively with JAX arrays and automatic differentiation, eliminating the need for numerical gradients.
Backend flexibility: The NumPy backend remains available via
XPCS_USE_JAX=falsefor debugging and deterministic testing, even though JAX is installed by default.Testing: Both backends can be tested independently. The
reset_backend()function allows tests to switch backends between test cases.
What became more difficult¶
Debugging: JAX’s tracing semantics (functional purity requirements) can produce confusing error messages when Python control flow depends on array values.
Dependency size: JAX + jaxlib adds ~500MB to the installation. Since JAX is now a core dependency, all installations carry this cost.
I/O boundary discipline: Every interaction with external libraries must go through
ensure_numpy(). Missing conversions produce runtime errors (JAX arrays are not accepted by h5py or PyQtGraph).NumPy backend limitations:
NumPyBackend.grad()andvalue_and_grad()raiseNotImplementedError. Code that requires gradients must checkbackend.supports_gradfirst.