Performance Tuning

This guide covers performance optimization techniques for XPCS Viewer, including JIT compilation, GPU acceleration, and connection pooling.

JIT Compilation (JAX Backend)

When the JAX backend is active (XPCS_USE_JAX=true or auto-detected), functions decorated with @jax.jit or called via backend.jit() are compiled to XLA on first invocation. Subsequent calls use the cached compiled version.

Typical speedups:

Operation

First Call

Subsequent Calls

Speedup

Q-map computation (1024x1024)

~200 ms

~5 ms

40x

Partition generation

~100 ms

~3 ms

33x

Model function evaluation

~50 ms

~0.5 ms

100x

Best practices for JIT:

  1. Avoid Python control flow that depends on array values. JAX traces functions abstractly. Use jax.lax.cond instead of if/else on traced values.

  2. Use ``static_argnums`` for non-array arguments. When a function takes integer or string arguments that affect control flow, mark them as static:

    @jax.jit
    def compute(data, mode):  # 'mode' changes code path
        ...
    
    # Better:
    @functools.partial(jax.jit, static_argnums=(1,))
    def compute(data, mode):
        ...
    
  3. Warm up JIT caches at startup. For interactive workflows, call JIT-compiled functions once with representative data during initialization to avoid compilation latency on the first user interaction.

  4. Monitor compilation via logging. Set PYXPCS_LOG_LEVEL=DEBUG to see @log_timing output for JIT-compiled functions. First-call times that are 10-100x longer than subsequent calls indicate JIT compilation overhead.

GPU Acceleration

Enable GPU computation for large datasets:

export XPCS_USE_JAX=true
export XPCS_USE_GPU=true

When GPU helps:

  • Large detector images (>1024x1024 pixels)

  • Batch fitting across many Q-bins

  • Two-time correlation computation on large frame sequences

When GPU does not help (or hurts):

  • Small arrays (<10,000 elements): Host-to-device transfer overhead dominates

  • Single-array operations: Not enough parallelism to saturate the GPU

  • I/O-bound workflows: Time is spent in HDF5 reads, not computation

Memory management:

Control GPU memory allocation via XPCS_GPU_MEMORY_FRACTION:

# Use only 50% of GPU memory (share with other processes)
export XPCS_GPU_MEMORY_FRACTION=0.5

The DeviceManager reports GPU status:

from xpcsviewer.backends import DeviceManager

dm = DeviceManager()
print(f"GPU available: {dm.gpu_available}")
print(f"GPU enabled: {dm.is_gpu_enabled}")
print(f"Current device: {dm.current_device}")
print(f"All devices: {dm.available_devices}")

Explicit device placement:

from xpcsviewer.backends import DeviceManager

dm = DeviceManager()
gpu_array = dm.place_on_device(numpy_array)  # Moves to configured device

HDF5 Connection Pooling

The HDF5ConnectionPool in fileIO/hdf_reader.py caches open file handles to avoid repeated open/close cycles during interactive analysis.

Monitoring pool performance:

from xpcsviewer.io import HDF5Facade

facade = HDF5Facade()
stats = facade.get_pool_stats()
# {
#     "pool_size": 3,
#     "cache_hits": 150,
#     "cache_misses": 5,
#     "cache_hit_ratio": 0.968,
# }

Target metrics:

  • Cache hit ratio: > 95% in typical interactive use

  • Average read latency: < 10 ms for cached connections

  • Pool size: 3–5 concurrent files for normal workflows

Clearing the pool:

Release all pooled connections before application shutdown or when switching to a different dataset directory:

facade.clear_pool()

I/O Boundary Optimization

Array conversion at I/O boundaries (ensure_numpy()) is the most frequent performance-sensitive operation. The I/O adapters provide monitoring to identify bottlenecks.

Enable monitoring:

from xpcsviewer.backends import get_backend, create_adapters

backend = get_backend()
pyqt_adapter, hdf5_adapter, mpl_adapter = create_adapters(
    backend,
    enable_monitoring=True,
)

# ... use adapters in analysis workflow ...

# Check statistics
stats = pyqt_adapter.get_stats()
print(f"Conversions: {stats['conversion_count']}")
print(f"Avg time: {stats['average_conversion_time_ms']:.3f} ms")

Typical conversion overhead:

Array Size

CPU (NumPy -> NumPy)

GPU (JAX -> NumPy)

100x100 (10K elements)

< 0.01 ms

~0.1 ms

1024x1024 (1M elements)

< 0.01 ms

~0.5 ms

4096x4096 (16M elements)

< 0.01 ms

~5 ms

When the backend is NumPy, ensure_numpy() is essentially free (returns the input array if it is already a writable NumPy array). When the backend is JAX, data must be copied from the device (GPU or CPU-backed JAX arrays) to a standard NumPy array.

Minimize unnecessary conversions:

  • Convert at the I/O boundary, not inside computation loops.

  • Avoid converting the same array multiple times; store the NumPy result.

  • For small arrays used only for display, the overhead is negligible.

Fitting Performance

NLSQ warm-start:

The NLSQ solver provides fast point estimates to initialize the Bayesian sampler. Use presets to control the speed/robustness tradeoff:

Preset

Typical Time

Use Case

fast

1–5 ms

Quick preview, known-good data

robust

5–50 ms

Default for production use

global

50–500 ms

Multi-modal or difficult optimization landscapes

large

varies

Large parameter spaces (>10 parameters)

NumPyro NUTS tuning:

Adjust SamplerConfig for the speed/accuracy tradeoff:

from xpcsviewer.fitting.results import SamplerConfig

# Fast exploration (fewer samples, fewer chains)
fast_config = SamplerConfig(
    num_warmup=200,
    num_samples=500,
    num_chains=2,
)

# Production quality
production_config = SamplerConfig(
    num_warmup=500,
    num_samples=1000,
    num_chains=4,
    target_accept_prob=0.8,
    random_seed=42,  # Reproducibility
)

Batch fitting tips:

  1. Fit multiple Q-bins in parallel using the threading manager.

  2. Check FitDiagnostics.converged to skip diagnostic plots for well-converged fits.

  3. Use nlsq_optimize() alone (without NUTS) for exploratory analysis where full Bayesian inference is not needed.

Profiling Workflow

  1. Identify the bottleneck using @log_timing output:

    export PYXPCS_LOG_LEVEL=DEBUG
    python -m xpcsviewer
    

    Look for methods that exceed their threshold_ms (logged at WARNING).

  2. Check I/O adapter stats for excessive conversions:

    stats = pyqt_adapter.get_stats()
    if stats['average_conversion_time_ms'] > 1.0:
        print("Conversion overhead detected")
    
  3. Check HDF5 pool stats for cache misses:

    pool_stats = facade.get_pool_stats()
    if pool_stats['cache_hit_ratio'] < 0.9:
        print("Poor cache hit ratio -- check file access patterns")
    
  4. Check JIT compilation by comparing first-call vs. subsequent-call times in DEBUG logs. First calls that are >100x slower indicate JIT overhead that may benefit from warm-up.