JAX Backend Architecture¶
This page explains the design and performance model of XPCS Viewer’s backend abstraction layer, which provides transparent switching between JAX and NumPy computational backends.
Why a Backend Abstraction?¶
XPCS Viewer performs intensive numerical computation: Q-map generation over million-pixel detectors, correlation function fitting across dozens of Q-bins, and two-time correlation on thousands of frames. JAX provides JIT compilation, automatic differentiation, and optional GPU acceleration that can deliver 10-100x speedups for these workloads.
However, the codebase must also work in environments where JAX is available but
GPU hardware is not, or where users want deterministic NumPy behavior for
debugging. The backend abstraction makes this transparent: analysis code calls
backend.exp(x) and gets the right implementation regardless of whether the
active backend is JAX or NumPy.
Architecture Overview¶
Application Code
|
v
get_backend() --> BackendProtocol (60+ methods)
| | |
v v v
JAXBackend NumPyBackend DeviceManager
- jax.numpy - numpy - GPU/CPU selection
- jax.jit - no-op jit - Memory config
- jax.grad - raises - Device placement
|
v
ensure_numpy() --> I/O Adapters
- PyQtGraphAdapter
- HDF5Adapter
- MatplotlibAdapter
The Protocol Pattern¶
The abstraction uses Python’s typing.Protocol with @runtime_checkable
rather than an abstract base class. This means:
No inheritance required: Each backend independently implements the same interface through structural subtyping (duck typing with contracts).
No JAX import at definition time: The protocol is defined using only standard Python types.
JAXBackendimports JAX lazily.Runtime checking:
isinstance(backend, BackendProtocol)works at runtime, enabling defensive checks.
The protocol defines methods across eight categories: array creation, trigonometry, statistics, binning, boolean/masking, math, and functional transformations (jit, grad, vmap, scan, fori_loop).
Backend Selection¶
Backend selection follows a priority chain:
Explicit environment variable:
XPCS_USE_JAX=trueforces JAX;XPCS_USE_JAX=falseforces NumPy.Auto-detection (
XPCS_USE_JAX=auto, the default): Attempts to import JAX. If successful, uses JAXBackend; otherwise falls back to NumPyBackend with no error.Programmatic override:
set_backend("numpy")in code or tests.
Once selected, the backend is cached as a module-level singleton.
reset_backend() clears the cache (used in tests to switch backends
between test cases).
JAX Configuration¶
When JAX is first used, XPCS Viewer configures it for scientific computing:
JAX_ENABLE_X64=true: Enables float64 support. Without this, JAX defaults to float32, which is insufficient for the precision requirements of correlation function fitting.XLA_PYTHON_CLIENT_MEM_FRACTION: Set fromXPCS_GPU_MEMORY_FRACTIONto control GPU memory allocation.
These environment variables must be set before JAX is imported, which is
why the configuration happens in _configure_jax() called from
get_backend().
JIT Compilation¶
JIT (Just-In-Time) compilation is the primary performance benefit of the JAX
backend. When a function is decorated with @jax.jit or called via
backend.jit(fn), JAX traces the function with abstract values to build an
XLA computation graph, then compiles it to optimized machine code.
Operation |
First Call |
Subsequent |
Speedup |
|---|---|---|---|
Q-map computation (1024x1024) |
~200 ms |
~5 ms |
40x |
Partition generation |
~100 ms |
~3 ms |
33x |
Model function evaluation |
~50 ms |
~0.5 ms |
100x |
Caching strategy: The Q-map module uses a dict-based JIT cache
(_JIT_CACHE) because JAX arrays are not hashable and cannot be used as
lru_cache keys. The first call for a given scattering geometry type
creates and caches the JIT-compiled function; subsequent calls reuse it.
Tracing constraints: JIT-compiled functions must be functionally pure
– they cannot use Python-level control flow that depends on array values
(e.g., if x > 0). The codebase uses static_argnums for arguments
that control code paths (like string mode selectors) and JAX primitives
(jax.lax.cond, jax.lax.fori_loop) for array-dependent branching.
NumPy Backend Behavior¶
When JAX is not active, the NumPy backend provides the same API with no-op transformations:
backend.jit(fn)returnsfnunchanged (no compilation).backend.grad()andbackend.value_and_grad()raiseNotImplementedError. Code that requires gradients must checkbackend.supports_gradfirst.backend.vmap(fn)falls back to a Python loop.backend.scan()andbackend.fori_loop()use standard Pythonforloops.
This means NumPy backend code is functionally identical but slower for repeated computations (no compilation cache) and cannot perform automatic differentiation (required for Bayesian NUTS sampling).
I/O Boundaries¶
External libraries (h5py, PyQtGraph, Matplotlib) require NumPy arrays. When the JAX backend is active, arrays must be converted at every I/O boundary.
The Boundary Problem¶
Missing ensure_numpy() calls produce runtime errors:
h5py raises
TypeErrorwhen given a JAX arrayPyQtGraph renders nothing or crashes silently
Matplotlib raises
ValueErroron non-NumPy input
These errors are easy to introduce during development because code works fine with the NumPy backend but fails with JAX.
I/O Adapters¶
To centralize and monitor boundary conversions, three adapter classes wrap
ensure_numpy():
PyQtGraphAdapter:
to_pyqtgraph(array)/from_pyqtgraph(array)for all GUI visualizationHDF5Adapter:
to_hdf5(array)/from_hdf5(array)for file I/OMatplotlibAdapter:
to_matplotlib(*arrays)for static plots
Each adapter optionally tracks conversion statistics (count, average time, slowest conversion) for performance profiling. When the backend is NumPy, conversions are essentially free (the input array is already NumPy).
Conversion overhead for JAX arrays:
Array Size |
CPU JAX |
GPU JAX |
|---|---|---|
100x100 |
< 0.01 ms |
~0.1 ms |
1024x1024 |
< 0.01 ms |
~0.5 ms |
4096x4096 |
< 0.01 ms |
~5 ms |
GPU data transfer dominates the cost for GPU arrays. CPU JAX arrays share memory with NumPy and convert with near-zero overhead.
Device Management¶
The DeviceManager is a thread-safe singleton that controls GPU/CPU
device selection:
Configuration: Reads
XPCS_USE_GPU,XPCS_GPU_FALLBACK, andXPCS_GPU_MEMORY_FRACTIONfrom the environment.Graceful fallback: If GPU is requested but unavailable and
XPCS_GPU_FALLBACK=true(the default), falls back to CPU with a warning instead of crashing.Device placement:
place_on_device(array)moves arrays to the configured device viajax.device_put.Thread safety: Uses double-checked locking for singleton creation to avoid race conditions in multi-threaded GUI applications.
When to Use GPU¶
GPU acceleration helps most for:
Large detector images (>1024x1024 pixels) during Q-map computation
Batch fitting across many Q-bins (parallelizable across GPU cores)
Two-time correlation on large frame sequences
GPU does not help (and may hurt due to transfer overhead) for:
Small arrays (<10,000 elements)
Single scalar operations
I/O-bound workflows (time spent in HDF5 reads, not computation)
Automatic Differentiation¶
JAX’s grad and value_and_grad transformations are essential for the
Bayesian fitting pipeline. NumPyro’s NUTS sampler requires gradients of the
log-posterior to perform Hamiltonian Monte Carlo sampling.
The backend exposes this capability through:
backend.supports_grad: Boolean capability checkbackend.grad(fn): Returns a function that computes the gradientbackend.value_and_grad(fn): Returns both the function value and its gradient in a single forward/backward pass
The fitting module’s model functions (single_exp_func,
double_exp_func, stretched_exp_func) are decorated with
@_maybe_jit, which applies jax.jit when JAX is available. These
JIT-compiled functions are automatically differentiable by JAX’s tracing
mechanism.