ADR-001: JAX Migration and Backend Abstraction

Status

Accepted

Context

XPCS Viewer originally used NumPy exclusively for all array computations. As datasets grew larger and fitting workloads became more demanding, several pain points emerged:

  1. Performance ceiling: NumPy computations on large 2D detector images (1024x1024+) and multi-Q-bin correlation analysis were becoming the bottleneck in interactive workflows.

  2. No GPU acceleration: Synchrotron beamlines increasingly provide GPU-equipped workstations, but NumPy cannot utilize them.

  3. No automatic differentiation: The Bayesian fitting pipeline requires gradients for NUTS sampling. With NumPy, gradient computation had to be done numerically (finite differences), which is slow and imprecise.

  4. No JIT compilation: Repeated Q-map computations and partition generation could not be compiled and cached.

The team evaluated JAX as the primary computational backend because:

  • JAX provides a NumPy-compatible API (jax.numpy), minimizing migration effort.

  • JAX supports JIT compilation via XLA, yielding 5-10x speedups on repeated computations.

  • JAX supports automatic differentiation (jax.grad, jax.value_and_grad), which NumPyro requires for Hamiltonian Monte Carlo (NUTS) sampling.

  • JAX supports GPU/TPU execution with the same code, via jax.device_put.

Initially, JAX was an optional dependency to support lightweight installations. As the project matured, JAX and its ecosystem (NumPyro, optimistix, interpax, optax) became core dependencies (listed in pyproject.toml under dependencies), since the Bayesian fitting pipeline and JIT-compiled Q-map computation are now integral to the application. The backend abstraction layer remains useful for testing (forcing NumPy behavior) and for graceful degradation when JAX is installed but GPU hardware is unavailable.

Decision

We introduced a backend abstraction layer (xpcsviewer/backends/) that provides a unified BackendProtocol interface for array operations, with two concrete implementations:

Architecture

xpcsviewer/backends/
  _base.py            # BackendProtocol (runtime_checkable Protocol)
  _jax_backend.py     # JAXBackend: full JIT, grad, GPU support
  _numpy_backend.py   # NumPyBackend: CPU-only fallback (no-op JIT, raises on grad)
  _device.py          # DeviceManager singleton: GPU/CPU selection, memory config
  _conversions.py     # ensure_numpy(), ensure_backend_array() at I/O boundaries
  io_adapter.py       # PyQtGraphAdapter, HDF5Adapter, MatplotlibAdapter
  __init__.py          # get_backend(), set_backend(), auto-detection

Key Design Choices

  1. Protocol-based interface (typing.Protocol with @runtime_checkable): Both backends implement BackendProtocol, which defines 60+ methods covering array creation, mathematical operations, statistical functions, JIT, grad, vmap, scan, and fori_loop. This uses structural subtyping – no base class inheritance required.

  2. Auto-detection with environment override: get_backend() auto-detects JAX availability. Users can force a backend via XPCS_USE_JAX=true|false|auto. When JAX is requested but not installed, a clear ImportError with install instructions is raised.

  3. float64 by default: Scientific computing requires double precision. JAX is configured with JAX_ENABLE_X64=true on first use to ensure float64 arrays are supported.

  4. Singleton DeviceManager: DeviceManager uses double-checked locking for thread-safe singleton creation. It reads XPCS_USE_GPU, XPCS_GPU_FALLBACK, and XPCS_GPU_MEMORY_FRACTION from the environment and configures the JAX runtime accordingly.

  5. ensure_numpy() at I/O boundaries: All external libraries (h5py, PyQtGraph, Matplotlib) require NumPy arrays. The ensure_numpy() function in _conversions.py handles all array types (JAX arrays, read-only NumPy arrays, lists) and guarantees a writable NumPy ndarray.

  6. I/O Adapters: PyQtGraphAdapter, HDF5Adapter, and MatplotlibAdapter wrap ensure_numpy() with optional performance monitoring (conversion count, timing statistics). These centralize boundary conversions instead of scattering ensure_numpy() calls across modules.

Environment Variables

Variable

Default

Description

XPCS_USE_JAX

auto

true, false, or auto

XPCS_USE_GPU

false

Enable GPU device

XPCS_GPU_FALLBACK

true

Fall back to CPU if GPU unavailable

XPCS_GPU_MEMORY_FRACTION

0.9

Max GPU memory fraction

JAX_PLATFORMS

(unset)

Force JAX platform (cpu, gpu)

Consequences

What became easier

  • Performance: JIT-compiled Q-map computation runs ~10x faster on repeated calls. GPU acceleration provides further speedups for large datasets.

  • Bayesian fitting: NumPyro NUTS sampling works natively with JAX arrays and automatic differentiation, eliminating the need for numerical gradients.

  • Backend flexibility: The NumPy backend remains available via XPCS_USE_JAX=false for debugging and deterministic testing, even though JAX is installed by default.

  • Testing: Both backends can be tested independently. The reset_backend() function allows tests to switch backends between test cases.

What became more difficult

  • Debugging: JAX’s tracing semantics (functional purity requirements) can produce confusing error messages when Python control flow depends on array values.

  • Dependency size: JAX + jaxlib adds ~500MB to the installation. Since JAX is now a core dependency, all installations carry this cost.

  • I/O boundary discipline: Every interaction with external libraries must go through ensure_numpy(). Missing conversions produce runtime errors (JAX arrays are not accepted by h5py or PyQtGraph).

  • NumPy backend limitations: NumPyBackend.grad() and value_and_grad() raise NotImplementedError. Code that requires gradients must check backend.supports_grad first.