JAX Backend Architecture

This page explains the design and performance model of XPCS Viewer’s backend abstraction layer, which provides transparent switching between JAX and NumPy computational backends.

Why a Backend Abstraction?

XPCS Viewer performs intensive numerical computation: Q-map generation over million-pixel detectors, correlation function fitting across dozens of Q-bins, and two-time correlation on thousands of frames. JAX provides JIT compilation, automatic differentiation, and optional GPU acceleration that can deliver 10-100x speedups for these workloads.

However, the codebase must also work in environments where JAX is available but GPU hardware is not, or where users want deterministic NumPy behavior for debugging. The backend abstraction makes this transparent: analysis code calls backend.exp(x) and gets the right implementation regardless of whether the active backend is JAX or NumPy.

Architecture Overview

Application Code
       |
       v
get_backend()  -->  BackendProtocol (60+ methods)
       |                   |                |
       v                   v                v
  JAXBackend         NumPyBackend     DeviceManager
  - jax.numpy        - numpy          - GPU/CPU selection
  - jax.jit          - no-op jit      - Memory config
  - jax.grad         - raises         - Device placement
       |
       v
ensure_numpy()  -->  I/O Adapters
                     - PyQtGraphAdapter
                     - HDF5Adapter
                     - MatplotlibAdapter

The Protocol Pattern

The abstraction uses Python’s typing.Protocol with @runtime_checkable rather than an abstract base class. This means:

  • No inheritance required: Each backend independently implements the same interface through structural subtyping (duck typing with contracts).

  • No JAX import at definition time: The protocol is defined using only standard Python types. JAXBackend imports JAX lazily.

  • Runtime checking: isinstance(backend, BackendProtocol) works at runtime, enabling defensive checks.

The protocol defines methods across eight categories: array creation, trigonometry, statistics, binning, boolean/masking, math, and functional transformations (jit, grad, vmap, scan, fori_loop).

Backend Selection

Backend selection follows a priority chain:

  1. Explicit environment variable: XPCS_USE_JAX=true forces JAX; XPCS_USE_JAX=false forces NumPy.

  2. Auto-detection (XPCS_USE_JAX=auto, the default): Attempts to import JAX. If successful, uses JAXBackend; otherwise falls back to NumPyBackend with no error.

  3. Programmatic override: set_backend("numpy") in code or tests.

Once selected, the backend is cached as a module-level singleton. reset_backend() clears the cache (used in tests to switch backends between test cases).

JAX Configuration

When JAX is first used, XPCS Viewer configures it for scientific computing:

  • JAX_ENABLE_X64=true: Enables float64 support. Without this, JAX defaults to float32, which is insufficient for the precision requirements of correlation function fitting.

  • XLA_PYTHON_CLIENT_MEM_FRACTION: Set from XPCS_GPU_MEMORY_FRACTION to control GPU memory allocation.

These environment variables must be set before JAX is imported, which is why the configuration happens in _configure_jax() called from get_backend().

JIT Compilation

JIT (Just-In-Time) compilation is the primary performance benefit of the JAX backend. When a function is decorated with @jax.jit or called via backend.jit(fn), JAX traces the function with abstract values to build an XLA computation graph, then compiles it to optimized machine code.

Typical JIT speedups in XPCS Viewer

Operation

First Call

Subsequent

Speedup

Q-map computation (1024x1024)

~200 ms

~5 ms

40x

Partition generation

~100 ms

~3 ms

33x

Model function evaluation

~50 ms

~0.5 ms

100x

Caching strategy: The Q-map module uses a dict-based JIT cache (_JIT_CACHE) because JAX arrays are not hashable and cannot be used as lru_cache keys. The first call for a given scattering geometry type creates and caches the JIT-compiled function; subsequent calls reuse it.

Tracing constraints: JIT-compiled functions must be functionally pure – they cannot use Python-level control flow that depends on array values (e.g., if x > 0). The codebase uses static_argnums for arguments that control code paths (like string mode selectors) and JAX primitives (jax.lax.cond, jax.lax.fori_loop) for array-dependent branching.

NumPy Backend Behavior

When JAX is not active, the NumPy backend provides the same API with no-op transformations:

  • backend.jit(fn) returns fn unchanged (no compilation).

  • backend.grad() and backend.value_and_grad() raise NotImplementedError. Code that requires gradients must check backend.supports_grad first.

  • backend.vmap(fn) falls back to a Python loop.

  • backend.scan() and backend.fori_loop() use standard Python for loops.

This means NumPy backend code is functionally identical but slower for repeated computations (no compilation cache) and cannot perform automatic differentiation (required for Bayesian NUTS sampling).

I/O Boundaries

External libraries (h5py, PyQtGraph, Matplotlib) require NumPy arrays. When the JAX backend is active, arrays must be converted at every I/O boundary.

The Boundary Problem

Missing ensure_numpy() calls produce runtime errors:

  • h5py raises TypeError when given a JAX array

  • PyQtGraph renders nothing or crashes silently

  • Matplotlib raises ValueError on non-NumPy input

These errors are easy to introduce during development because code works fine with the NumPy backend but fails with JAX.

I/O Adapters

To centralize and monitor boundary conversions, three adapter classes wrap ensure_numpy():

  • PyQtGraphAdapter: to_pyqtgraph(array) / from_pyqtgraph(array) for all GUI visualization

  • HDF5Adapter: to_hdf5(array) / from_hdf5(array) for file I/O

  • MatplotlibAdapter: to_matplotlib(*arrays) for static plots

Each adapter optionally tracks conversion statistics (count, average time, slowest conversion) for performance profiling. When the backend is NumPy, conversions are essentially free (the input array is already NumPy).

Conversion overhead for JAX arrays:

Array Size

CPU JAX

GPU JAX

100x100

< 0.01 ms

~0.1 ms

1024x1024

< 0.01 ms

~0.5 ms

4096x4096

< 0.01 ms

~5 ms

GPU data transfer dominates the cost for GPU arrays. CPU JAX arrays share memory with NumPy and convert with near-zero overhead.

Device Management

The DeviceManager is a thread-safe singleton that controls GPU/CPU device selection:

  • Configuration: Reads XPCS_USE_GPU, XPCS_GPU_FALLBACK, and XPCS_GPU_MEMORY_FRACTION from the environment.

  • Graceful fallback: If GPU is requested but unavailable and XPCS_GPU_FALLBACK=true (the default), falls back to CPU with a warning instead of crashing.

  • Device placement: place_on_device(array) moves arrays to the configured device via jax.device_put.

  • Thread safety: Uses double-checked locking for singleton creation to avoid race conditions in multi-threaded GUI applications.

When to Use GPU

GPU acceleration helps most for:

  • Large detector images (>1024x1024 pixels) during Q-map computation

  • Batch fitting across many Q-bins (parallelizable across GPU cores)

  • Two-time correlation on large frame sequences

GPU does not help (and may hurt due to transfer overhead) for:

  • Small arrays (<10,000 elements)

  • Single scalar operations

  • I/O-bound workflows (time spent in HDF5 reads, not computation)

Automatic Differentiation

JAX’s grad and value_and_grad transformations are essential for the Bayesian fitting pipeline. NumPyro’s NUTS sampler requires gradients of the log-posterior to perform Hamiltonian Monte Carlo sampling.

The backend exposes this capability through:

  • backend.supports_grad: Boolean capability check

  • backend.grad(fn): Returns a function that computes the gradient

  • backend.value_and_grad(fn): Returns both the function value and its gradient in a single forward/backward pass

The fitting module’s model functions (single_exp_func, double_exp_func, stretched_exp_func) are decorated with @_maybe_jit, which applies jax.jit when JAX is available. These JIT-compiled functions are automatically differentiable by JAX’s tracing mechanism.