Source code for xpcsviewer.backends

"""Backend abstraction layer for JAX/NumPy array operations.

This module provides a unified interface for array computations that can
run on either JAX (with GPU support) or NumPy (CPU fallback).

Public API:
    get_backend() -> BackendProtocol
    set_backend(name: str) -> None
    BackendProtocol
    DeviceManager
    DeviceConfig
    DeviceType
    ensure_numpy(array) -> np.ndarray
    ensure_backend_array(array) -> BackendArray

Environment Variables:
    XPCS_USE_JAX: 'true', 'false', or 'auto' (default: 'auto')
    XPCS_USE_GPU: 'true' or 'false' (default: 'false')
    XPCS_GPU_FALLBACK: 'true' or 'false' (default: 'true')
    XPCS_GPU_MEMORY_FRACTION: float 0.0-1.0 (default: 0.9)
"""

from __future__ import annotations

import os

# Ensure float64 is enabled before any JAX import elsewhere in the codebase.
# This must happen at module scope (import time) so that lazy JIT decoration
# in other modules picks up the correct precision setting.
os.environ.setdefault("JAX_ENABLE_X64", "true")

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from ._base import BackendProtocol

# Module-level state
_current_backend: BackendProtocol | None = None
_jax_configured: bool = False


def _configure_jax() -> None:
    """Configure JAX settings (float64, memory) on first use."""
    global _jax_configured
    if _jax_configured:
        return

    try:
        # Enable float64 for scientific computing precision
        os.environ.setdefault("JAX_ENABLE_X64", "true")

        import jax

        jax.config.update("jax_enable_x64", True)

        # Configure GPU memory fraction if specified
        memory_fraction = float(os.environ.get("XPCS_GPU_MEMORY_FRACTION", "0.9"))
        if 0.0 < memory_fraction < 1.0:
            os.environ.setdefault(
                "XLA_PYTHON_CLIENT_MEM_FRACTION", str(memory_fraction)
            )

        _jax_configured = True
    except ImportError:
        pass  # JAX not available


def _detect_backend() -> str:
    """Detect which backend to use based on environment and availability."""
    use_jax = os.environ.get("XPCS_USE_JAX", "auto").lower()

    if use_jax == "false":
        return "numpy"

    if use_jax == "true":
        try:
            _configure_jax()
            import jax

            return "jax"
        except ImportError as e:
            raise ImportError(
                "JAX requested but not installed. "
                "Install with: pip install 'xpcsviewer-gui[jax]'"
            ) from e

    # Auto-detect
    try:
        _configure_jax()
        import jax

        return "jax"
    except ImportError:
        return "numpy"


[docs] def get_backend() -> BackendProtocol: """Get the current computation backend. Returns the JAX backend if available and configured, otherwise falls back to NumPy. Returns ------- BackendProtocol The active backend instance. """ global _current_backend if _current_backend is None: backend_name = _detect_backend() set_backend(backend_name) return _current_backend # type: ignore[return-value]
[docs] def set_backend(name: str) -> None: """Set the computation backend. Parameters ---------- name : str Backend name: 'jax' or 'numpy' Raises ------ ValueError If backend name is not recognized. ImportError If JAX backend is requested but not available. """ global _current_backend name = name.lower() if name == "jax": _configure_jax() from ._jax_backend import JAXBackend _current_backend = JAXBackend() # type: ignore[assignment] elif name == "numpy": from ._numpy_backend import NumPyBackend _current_backend = NumPyBackend() # type: ignore[assignment] else: raise ValueError(f"Unknown backend: {name}. Use 'jax' or 'numpy'.")
[docs] def reset_backend() -> None: """Reset backend to trigger re-detection on next get_backend() call. Also resets legacy fitting closures that capture a stale backend (JAX-N-07). """ global _current_backend _current_backend = None # Invalidate legacy closures that captured the old backend (JAX-N-07). # Import is deferred to avoid circular import at module load time. try: from xpcsviewer.fitting.legacy import reset_legacy_closures reset_legacy_closures() except ImportError: pass # fitting module may not be installed
# Alias for testing _reset_backend = reset_backend def _parse_bool_env(name: str, default: bool = False) -> bool: """Parse boolean environment variable. Accepts: 'true', '1', 'yes' (case-insensitive) for True 'false', '0', 'no' (case-insensitive) for False Parameters ---------- name : str Environment variable name default : bool Default value if not set or invalid Returns ------- bool Parsed boolean value """ value = os.environ.get(name, "").lower() if value in ("true", "1", "yes"): return True if value in ("false", "0", "no"): return False return default # Convenience re-exports from ._base import BackendProtocol from ._conversions import ensure_backend_array, ensure_numpy from ._device import DeviceConfig, DeviceManager, DeviceType from .io_adapter import ( HDF5Adapter, MatplotlibAdapter, PyQtGraphAdapter, create_adapters, ) __all__ = [ "BackendProtocol", "DeviceConfig", "DeviceManager", "DeviceType", "HDF5Adapter", "MatplotlibAdapter", "PyQtGraphAdapter", "create_adapters", "ensure_backend_array", "ensure_numpy", "get_backend", "reset_backend", "set_backend", ]