ADR-004: Protocol-Based Backend Abstraction Layer¶
Status¶
Accepted
Context¶
When introducing JAX as an optional computational backend (see ADR-001), we needed an abstraction layer that:
Deferred JAX initialization: JAX configuration (float64, device selection) must happen before the first JAX operation, so lazy initialization is still important even though JAX is now a core dependency.
Provides identical API for both backends: Analysis code should not contain
if backend == "jax"conditionals.Supports capabilities introspection: Code that requires gradients or JIT must be able to check backend capabilities at runtime.
Handles I/O boundary conversions: External libraries (h5py, PyQtGraph, Matplotlib) require NumPy arrays; conversions must be centralized.
Manages device placement: GPU-enabled workflows need explicit control over which device holds the data.
We considered three approaches:
Abstract base class: Would require JAX imports at class definition time (even for the base class docstrings or type hints).
Duck typing: No enforced contract; easy to drift between backends.
Protocol (structural subtyping): Defines the contract without requiring inheritance; works with
isinstance()checks via@runtime_checkable.
Decision¶
We chose a typing.Protocol-based design with the following components:
BackendProtocol (_base.py)¶
A @runtime_checkable Protocol defining 60+ abstract methods across 8 categories:
Category |
Methods |
Purpose |
|---|---|---|
Properties |
|
Capability introspection |
Array Creation |
|
Factory methods |
Trigonometry |
|
Angle computations for Q-map |
Statistics |
|
Data analysis |
Binning |
|
Q-bin partition |
Boolean/Masking |
|
Mask operations |
Math |
|
Model evaluation |
Functional |
|
JAX transformations |
NumPyBackend (_numpy_backend.py)¶
Implements all
BackendProtocolmethods using NumPy.supports_gpu = False,supports_jit = False,supports_grad = False.jit()returns the function unchanged (no-op).grad()andvalue_and_grad()raiseNotImplementedError.vmap()uses a simple Python loop fallback.scan()uses a Pythonforloop.fori_loop()usesrange().
JAXBackend (_jax_backend.py)¶
Implements all
BackendProtocolmethods usingjax.numpy.supports_gpu = True(if GPU device available),supports_jit = True,supports_grad = True.jit()delegates tojax.jitwith optionalstatic_argnums.grad()andvalue_and_grad()delegate to JAX’s automatic differentiation.vmap()delegates tojax.vmap.scan()delegates tojax.lax.scan.fori_loop()delegates tojax.lax.fori_loop.JAX is imported lazily via module-level variables (
_jax,_jnp) and an_ensure_jax()function, avoiding import-time failures.
DeviceManager (_device.py)¶
A thread-safe singleton for device lifecycle management:
class DeviceManager:
_instance: DeviceManager | None = None
_lock = threading.RLock()
def __new__(cls) -> DeviceManager:
# Double-checked locking pattern
if cls._instance is None:
with cls._lock:
if cls._instance is None:
instance = super().__new__(cls)
instance._initialized = False
cls._instance = instance
return cls._instance
Key features:
DeviceConfigdataclass: Configurable viaDeviceConfig.from_environment(), readsXPCS_USE_GPU,XPCS_GPU_FALLBACK,XPCS_GPU_MEMORY_FRACTION.DeviceInfodataclass: Stores device type, ID, name, and memory info. Validates non-negative fields (BUG-044).Graceful GPU fallback: If GPU is requested but unavailable and
allow_gpu_fallback=True, falls back to CPU with a warning.place_on_device(array): Moves arrays to the configured device viajax.device_put.reset()classmethod: Destroys the singleton for testing.
Conversion Utilities (_conversions.py)¶
ensure_numpy(array) # Any array -> writable np.ndarray
ensure_backend_array(array) # Any array -> current backend's array type
is_jax_array(array) # Type check
is_numpy_array(array) # Type check
get_array_backend(array) # Returns "numpy", "jax", or "unknown"
arrays_compatible(a, b) # Same backend check
ensure_numpy() handles:
NumPy arrays (fast path; copies if read-only).
JAX arrays (copies to CPU via
np.array()).Objects with
__array__method.Lists and scalars (via
np.array()).
I/O Adapters (io_adapter.py)¶
Three adapter classes wrap ensure_numpy() with domain-specific semantics and optional performance monitoring:
Adapter |
Methods |
Use Case |
|---|---|---|
|
|
GUI visualization |
|
|
File I/O |
|
|
Static plots |
Each adapter tracks conversion statistics (get_stats(), reset_stats()) and logs slow conversions (>10ms) at DEBUG level.
The create_adapters() factory function creates all three adapters from a single backend instance.
Usage Pattern¶
from xpcsviewer.backends import get_backend, ensure_numpy
backend = get_backend() # Auto-detects JAX or NumPy
# Computation (backend-agnostic)
x = backend.linspace(0, 10, 100)
y = backend.exp(-x)
# I/O boundary
import pyqtgraph as pg
plot.setData(ensure_numpy(x), ensure_numpy(y))
Consequences¶
What became easier¶
Backend-agnostic code: Analysis modules call
backend.method()and work identically on JAX and NumPy.Capability checks:
if backend.supports_grad: result = backend.value_and_grad(fn)(x)enables graceful degradation.Testing:
set_backend("numpy")forces NumPy for deterministic tests.reset_backend()restores auto-detection.I/O monitoring: Adapters track conversion counts and timing, enabling performance profiling of I/O boundaries.
Device management: A single
DeviceManagerinstance coordinates all GPU/CPU decisions.
What became more difficult¶
Protocol maintenance: Adding a new mathematical operation requires updating
BackendProtocol,NumPyBackend, andJAXBackendsimultaneously.No polymorphic dispatch on arrays: Unlike PyTorch’s
torch.Tensormethods, backend arrays do not carry their backend. Operations likex.mean()are not available; callers must usebackend.mean(x).JAX semantic constraints: JIT-compiled functions must be functionally pure. Mutable state (e.g., counters, caches) must be refactored into explicit state passing or use JAX’s
scan/fori_loop.Fan-in coupling: 8+ modules depend on
get_backend()and 9+ depend onensure_numpy(). Changes to these functions cascade widely.