Performance Tuning¶
This guide covers performance optimization techniques for XPCS Viewer, including JIT compilation, GPU acceleration, and connection pooling.
JIT Compilation (JAX Backend)¶
When the JAX backend is active (XPCS_USE_JAX=true or auto-detected),
functions decorated with @jax.jit or called via backend.jit() are
compiled to XLA on first invocation. Subsequent calls use the cached compiled
version.
Typical speedups:
Operation |
First Call |
Subsequent Calls |
Speedup |
|---|---|---|---|
Q-map computation (1024x1024) |
~200 ms |
~5 ms |
40x |
Partition generation |
~100 ms |
~3 ms |
33x |
Model function evaluation |
~50 ms |
~0.5 ms |
100x |
Best practices for JIT:
Avoid Python control flow that depends on array values. JAX traces functions abstractly. Use
jax.lax.condinstead ofif/elseon traced values.Use ``static_argnums`` for non-array arguments. When a function takes integer or string arguments that affect control flow, mark them as static:
@jax.jit def compute(data, mode): # 'mode' changes code path ... # Better: @functools.partial(jax.jit, static_argnums=(1,)) def compute(data, mode): ...
Warm up JIT caches at startup. For interactive workflows, call JIT-compiled functions once with representative data during initialization to avoid compilation latency on the first user interaction.
Monitor compilation via logging. Set
PYXPCS_LOG_LEVEL=DEBUGto see@log_timingoutput for JIT-compiled functions. First-call times that are 10-100x longer than subsequent calls indicate JIT compilation overhead.
GPU Acceleration¶
Enable GPU computation for large datasets:
export XPCS_USE_JAX=true
export XPCS_USE_GPU=true
When GPU helps:
Large detector images (>1024x1024 pixels)
Batch fitting across many Q-bins
Two-time correlation computation on large frame sequences
When GPU does not help (or hurts):
Small arrays (<10,000 elements): Host-to-device transfer overhead dominates
Single-array operations: Not enough parallelism to saturate the GPU
I/O-bound workflows: Time is spent in HDF5 reads, not computation
Memory management:
Control GPU memory allocation via XPCS_GPU_MEMORY_FRACTION:
# Use only 50% of GPU memory (share with other processes)
export XPCS_GPU_MEMORY_FRACTION=0.5
The DeviceManager reports GPU status:
from xpcsviewer.backends import DeviceManager
dm = DeviceManager()
print(f"GPU available: {dm.gpu_available}")
print(f"GPU enabled: {dm.is_gpu_enabled}")
print(f"Current device: {dm.current_device}")
print(f"All devices: {dm.available_devices}")
Explicit device placement:
from xpcsviewer.backends import DeviceManager
dm = DeviceManager()
gpu_array = dm.place_on_device(numpy_array) # Moves to configured device
HDF5 Connection Pooling¶
The HDF5ConnectionPool in fileIO/hdf_reader.py caches open file
handles to avoid repeated open/close cycles during interactive analysis.
Monitoring pool performance:
from xpcsviewer.io import HDF5Facade
facade = HDF5Facade()
stats = facade.get_pool_stats()
# {
# "pool_size": 3,
# "cache_hits": 150,
# "cache_misses": 5,
# "cache_hit_ratio": 0.968,
# }
Target metrics:
Cache hit ratio: > 95% in typical interactive use
Average read latency: < 10 ms for cached connections
Pool size: 3–5 concurrent files for normal workflows
Clearing the pool:
Release all pooled connections before application shutdown or when switching to a different dataset directory:
facade.clear_pool()
I/O Boundary Optimization¶
Array conversion at I/O boundaries (ensure_numpy()) is the most frequent
performance-sensitive operation. The I/O adapters provide monitoring to
identify bottlenecks.
Enable monitoring:
from xpcsviewer.backends import get_backend, create_adapters
backend = get_backend()
pyqt_adapter, hdf5_adapter, mpl_adapter = create_adapters(
backend,
enable_monitoring=True,
)
# ... use adapters in analysis workflow ...
# Check statistics
stats = pyqt_adapter.get_stats()
print(f"Conversions: {stats['conversion_count']}")
print(f"Avg time: {stats['average_conversion_time_ms']:.3f} ms")
Typical conversion overhead:
Array Size |
CPU (NumPy -> NumPy) |
GPU (JAX -> NumPy) |
|---|---|---|
100x100 (10K elements) |
< 0.01 ms |
~0.1 ms |
1024x1024 (1M elements) |
< 0.01 ms |
~0.5 ms |
4096x4096 (16M elements) |
< 0.01 ms |
~5 ms |
When the backend is NumPy, ensure_numpy() is essentially free (returns the
input array if it is already a writable NumPy array). When the backend is JAX,
data must be copied from the device (GPU or CPU-backed JAX arrays) to a
standard NumPy array.
Minimize unnecessary conversions:
Convert at the I/O boundary, not inside computation loops.
Avoid converting the same array multiple times; store the NumPy result.
For small arrays used only for display, the overhead is negligible.
Fitting Performance¶
NLSQ warm-start:
The NLSQ solver provides fast point estimates to initialize the Bayesian sampler. Use presets to control the speed/robustness tradeoff:
Preset |
Typical Time |
Use Case |
|---|---|---|
|
1–5 ms |
Quick preview, known-good data |
|
5–50 ms |
Default for production use |
|
50–500 ms |
Multi-modal or difficult optimization landscapes |
|
varies |
Large parameter spaces (>10 parameters) |
NumPyro NUTS tuning:
Adjust SamplerConfig for the speed/accuracy tradeoff:
from xpcsviewer.fitting.results import SamplerConfig
# Fast exploration (fewer samples, fewer chains)
fast_config = SamplerConfig(
num_warmup=200,
num_samples=500,
num_chains=2,
)
# Production quality
production_config = SamplerConfig(
num_warmup=500,
num_samples=1000,
num_chains=4,
target_accept_prob=0.8,
random_seed=42, # Reproducibility
)
Batch fitting tips:
Fit multiple Q-bins in parallel using the threading manager.
Check
FitDiagnostics.convergedto skip diagnostic plots for well-converged fits.Use
nlsq_optimize()alone (without NUTS) for exploratory analysis where full Bayesian inference is not needed.
Profiling Workflow¶
Identify the bottleneck using
@log_timingoutput:export PYXPCS_LOG_LEVEL=DEBUG python -m xpcsviewer
Look for methods that exceed their
threshold_ms(logged at WARNING).Check I/O adapter stats for excessive conversions:
stats = pyqt_adapter.get_stats() if stats['average_conversion_time_ms'] > 1.0: print("Conversion overhead detected")
Check HDF5 pool stats for cache misses:
pool_stats = facade.get_pool_stats() if pool_stats['cache_hit_ratio'] < 0.9: print("Poor cache hit ratio -- check file access patterns")
Check JIT compilation by comparing first-call vs. subsequent-call times in DEBUG logs. First calls that are >100x slower indicate JIT overhead that may benefit from warm-up.