Source code for xpcsviewer.fitting.sampler

"""NumPyro NUTS sampler with NLSQ warm-start.

This module provides the MCMC sampling functionality using NumPyro's
NUTS sampler with JAX-accelerated NLSQ warm-start.
"""

from __future__ import annotations

import logging
import time
from typing import TYPE_CHECKING, Literal

import numpy as np

from xpcsviewer.utils.log_utils import log_timing

from .models import (
    double_exp_func,
    double_exp_model,
    power_law_func,
    power_law_model,
    single_exp_func,
    single_exp_model,
    stretched_exp_func,
    stretched_exp_model,
)
from .nlsq import nlsq_optimize
from .results import FitDiagnostics, FitResult, SamplerConfig

if TYPE_CHECKING:
    from numpy.typing import ArrayLike

logger = logging.getLogger(__name__)

# Check availability of optional dependencies independently (T-10).
# Splitting the monolithic try/except allows partial functionality when only
# some packages are missing (e.g., arviz absent but jax+numpyro present).
JAX_AVAILABLE = False
NUMPYRO_AVAILABLE = False
ARVIZ_AVAILABLE = False
try:
    import jax
    import jax.numpy as jnp

    JAX_AVAILABLE = True
except ImportError:
    pass
try:
    import numpyro
    import numpyro.distributions as dist
    from numpyro.infer import MCMC, NUTS
    from numpyro.infer.initialization import init_to_value

    NUMPYRO_AVAILABLE = True
except ImportError:
    pass
try:
    import arviz as az

    ARVIZ_AVAILABLE = True
except ImportError:
    pass


[docs] def check_numpyro() -> None: """Raise error if JAX, NumPyro, or ArviZ are not available.""" missing = [] if not JAX_AVAILABLE: missing.append("jax") if not NUMPYRO_AVAILABLE: missing.append("numpyro") if not ARVIZ_AVAILABLE: missing.append("arviz") if missing: raise ImportError( f"Bayesian fitting requires JAX, NumPyro, and ArviZ. " f"Missing: {', '.join(missing)}. " f"Install with: pip install {' '.join(missing)}" )
def _extract_config(kwargs: dict) -> SamplerConfig: """Extract SamplerConfig from kwargs or use defaults.""" if "sampler_config" in kwargs: return kwargs["sampler_config"] return SamplerConfig( num_warmup=kwargs.get("num_warmup", 500), num_samples=kwargs.get("num_samples", 1000), num_chains=kwargs.get("num_chains", 4), target_accept_prob=kwargs.get("target_accept_prob", 0.8), max_tree_depth=kwargs.get("max_tree_depth", 10), random_seed=kwargs.get("random_seed"), ) def _run_mcmc( model, model_args: tuple, config: SamplerConfig, init_params: dict[str, float] | None = None, ) -> tuple[MCMC, dict]: """Run MCMC sampling with optional warm-start initialization. Parameters ---------- model : callable NumPyro model function model_args : tuple Positional arguments for the model config : SamplerConfig Sampler configuration init_params : dict, optional Initial parameter values in **constrained** (natural) space, e.g. from NLSQ warm-start. Passed to NumPyro's ``init_to_value`` which handles the constrained→unconstrained transform internally and falls back to ``init_to_uniform`` for any sample sites not in ``init_params``. """ check_numpyro() # Set random seed if config.random_seed is not None: seed = config.random_seed rng_key = jax.random.PRNGKey(seed) logger.info(f"MCMC PRNG seed (user-specified): {seed}") else: # Use time-based seed for non-deterministic runs (BUG-020). # BUG-E fix: Store actual seed in config for reproducibility. # Previously, the time-based seed was only logged but not stored, # so FitResult.to_dict() exported random_seed=null. seed = int(time.time_ns() % 2**31) rng_key = jax.random.PRNGKey(seed) logger.info(f"MCMC PRNG seed (time-based): {seed}") config.random_seed = seed # BUG-A fix: Use init_to_value() for warm-start initialization. # This accepts constrained (natural) values from NLSQ and handles # the transform to unconstrained space internally. It also falls # back to init_to_uniform for any sample sites not in init_params # (e.g. 'sigma' when yerr is None). # # Previously, constrained NLSQ values were passed directly as # init_params to mcmc.run(), but NUTS expects unconstrained values. # This caused NUTS to start at exp(value) instead of value for # LogNormal priors, making the warm-start ineffective. if init_params is not None: init_strategy = init_to_value(values=init_params) else: init_strategy = None # Use NUTS default (init_to_uniform) # Configure NUTS sampler kernel_kwargs = { "target_accept_prob": config.target_accept_prob, "max_tree_depth": config.max_tree_depth, } if init_strategy is not None: kernel_kwargs["init_strategy"] = init_strategy kernel = NUTS(model, **kernel_kwargs) # Create MCMC instance — use vectorized (vmap) on single-device systems # to run chains in parallel, or pmap when multiple devices are available. chain_method = ( "vectorized" if jax.local_device_count() < config.num_chains else "parallel" ) logger.info( "MCMC config: warmup=%d, samples=%d, chains=%d, chain_method=%s", config.num_warmup, config.num_samples, config.num_chains, chain_method, ) mcmc = MCMC( kernel, num_warmup=config.num_warmup, num_samples=config.num_samples, num_chains=config.num_chains, chain_method=chain_method, ) # BUG-D fix: Request num_steps to compute max_treedepth_reached. # Previously hardcoded to 0 with the comment "NumPyro doesn't track # this directly", but num_steps IS available via extra_fields. mcmc.run(rng_key, *model_args, extra_fields=("num_steps", "energy")) samples = mcmc.get_samples() first_key = next(iter(samples)) n_draws = len(samples[first_key]) logger.info( "MCMC completed: %d total draws (%d chains x %d samples), chain_method=%s", n_draws, config.num_chains, config.num_samples, chain_method, ) return mcmc, samples
[docs] def compute_bfmi(arviz_data) -> float | None: """Compute BFMI from ArviZ DataTree. Parameters ---------- arviz_data : DataTree ArviZ DataTree object from MCMC sampling Returns ------- float | None Mean BFMI across chains, or None if computation fails Notes ----- Uses az.bfmi() which returns per-chain values. Returns mean across all chains. Logs warning if BFMI < 0.2 per Technical Guidelines. """ try: bfmi_values = az.bfmi(arviz_data) bfmi_mean = float(np.mean(bfmi_values)) if bfmi_mean < 0.2: logger.warning( f"Low BFMI ({bfmi_mean:.3f}) indicates poor posterior exploration. " f"Consider reparameterization or increasing warmup." ) return bfmi_mean except Exception as e: logger.warning(f"Failed to compute BFMI: {e}") return None
def _build_fit_result( mcmc: MCMC, samples: dict, nlsq_init: dict[str, float], param_names: list[str], config: SamplerConfig | None = None, x: np.ndarray | None = None, ) -> FitResult: """Build FitResult from MCMC output.""" # Convert samples to numpy, preserving param_names order so that # dict key iteration matches the model function's signature. samples_np = {k: np.asarray(samples[k]) for k in param_names if k in samples} # Convert to ArviZ DataTree first (needed for summary and BFMI) arviz_data = az.from_numpyro(mcmc) # Get diagnostics summary = az.summary(arviz_data, var_names=param_names, round_to="none") # Extract diagnostics r_hat = {} ess_bulk = {} ess_tail = {} for name in param_names: if name in summary.index: r_hat[name] = float(summary.loc[name, "r_hat"]) ess_bulk[name] = int(summary.loc[name, "ess_bulk"]) ess_tail[name] = int(summary.loc[name, "ess_tail"]) # Count divergences extra = mcmc.get_extra_fields() num_divergent = int(np.sum(extra["diverging"])) # BUG-D fix: Compute max_treedepth_reached from num_steps. # Previously hardcoded to 0. NumPyro provides num_steps via # extra_fields — steps hitting 2^max_tree_depth - 1 indicate # the sampler exhausted its tree depth budget. max_treedepth_reached = 0 if "num_steps" in extra and config is not None: max_steps = 2**config.max_tree_depth - 1 max_treedepth_reached = int(np.sum(extra["num_steps"] >= max_steps)) if max_treedepth_reached > 0: total_samples = len(extra["num_steps"]) pct = 100.0 * max_treedepth_reached / max(total_samples, 1) msg = ( f"{max_treedepth_reached}/{total_samples} samples " f"({pct:.1f}%) hit max_tree_depth={config.max_tree_depth}" ) if pct > 1.0: logger.warning(f"{msg}. Consider increasing max_tree_depth.") else: logger.debug(msg) # Compute BFMI per Technical Guidelines bfmi = compute_bfmi(arviz_data) diagnostics = FitDiagnostics( r_hat=r_hat, ess_bulk=ess_bulk, ess_tail=ess_tail, divergences=num_divergent, max_treedepth_reached=max_treedepth_reached, bfmi=bfmi, ) return FitResult( samples=samples_np, param_names=param_names, summary=summary, diagnostics=diagnostics, nlsq_init=nlsq_init, arviz_data=arviz_data, config=config, x=x, )
[docs] @log_timing(threshold_ms=60_000) def run_single_exp_fit( x: ArrayLike, y: ArrayLike, yerr: ArrayLike | None = None, stability: Literal["auto", "check", False] = "auto", auto_bounds: bool = False, **kwargs, ) -> FitResult: """Run single exponential fit with NLSQ warm-start. Parameters ---------- x : array_like Delay times y : array_like G2 correlation values yerr : array_like, optional Measurement uncertainties stability : str, optional NLSQ stability mode: 'auto', 'check', or False (default: 'auto') auto_bounds : bool, optional Use NLSQ auto-bounds inference (default: False) **kwargs Sampler configuration Returns ------- FitResult Posterior samples for tau, baseline, contrast """ check_numpyro() x = np.asarray(x) y = np.asarray(y) if yerr is not None: yerr = np.asarray(yerr) config = _extract_config(kwargs) param_names = ["tau", "baseline", "contrast"] # NLSQ warm-start with NLSQ 0.6.0 features logger.info("Running NLSQ warm-start for single exponential fit") p0 = {"tau": 1.0, "baseline": 1.0, "contrast": 0.3} bounds = { "tau": (1e-6, 1e6), "baseline": (0.0, 2.0), "contrast": (0.0, 1.0), } nlsq_result = nlsq_optimize( single_exp_func, x, y, yerr, p0, bounds, stability=stability, auto_bounds=auto_bounds, compute_diagnostics=True, # Enable for health checking ) nlsq_init = nlsq_result.params # Determine warm-start viability use_warm_start = nlsq_result.is_healthy if not use_warm_start: health_score = nlsq_result.health_score logger.warning( f"NLSQ warm-start unreliable (health_score={health_score}); " "falling back to init_to_uniform with boosted warmup" ) config.num_warmup = int(config.num_warmup * 1.5) # Convert to JAX arrays x_jax = jnp.asarray(x) y_jax = jnp.asarray(y) yerr_jax = jnp.asarray(yerr) if yerr is not None else None # Run MCMC with warm-start (or uniform init if NLSQ failed) logger.info("Running NUTS sampling") mcmc, samples = _run_mcmc( single_exp_model, (x_jax, y_jax, yerr_jax), config, init_params=nlsq_init if use_warm_start else None, ) return _build_fit_result(mcmc, samples, nlsq_init, param_names, config=config, x=x)
[docs] @log_timing(threshold_ms=60_000) def run_double_exp_fit( x: ArrayLike, y: ArrayLike, yerr: ArrayLike | None = None, stability: Literal["auto", "check", False] = "auto", auto_bounds: bool = False, **kwargs, ) -> FitResult: """Run double exponential fit with NLSQ warm-start. Parameters ---------- x : array_like Delay times y : array_like G2 correlation values yerr : array_like, optional Measurement uncertainties stability : str, optional NLSQ stability mode: 'auto', 'check', or False (default: 'auto') auto_bounds : bool, optional Use NLSQ auto-bounds inference (default: False) **kwargs Sampler configuration Returns ------- FitResult Posterior samples for tau1, tau2, baseline, contrast1, contrast2 """ check_numpyro() x = np.asarray(x) y = np.asarray(y) if yerr is not None: yerr = np.asarray(yerr) config = _extract_config(kwargs) param_names = ["tau1", "tau2", "baseline", "contrast1", "contrast2"] # NLSQ warm-start with NLSQ 0.6.0 features logger.info("Running NLSQ warm-start for double exponential fit") p0 = { "tau1": 0.1, "tau2": 10.0, "baseline": 1.0, "contrast1": 0.15, "contrast2": 0.15, } bounds = { "tau1": (1e-6, 1e6), "tau2": (1e-6, 1e6), "baseline": (0.0, 2.0), "contrast1": (0.0, 1.0), "contrast2": (0.0, 1.0), } nlsq_result = nlsq_optimize( double_exp_func, x, y, yerr, p0, bounds, stability=stability, auto_bounds=auto_bounds, compute_diagnostics=True, ) nlsq_init = nlsq_result.params # Determine warm-start viability use_warm_start = nlsq_result.is_healthy if not use_warm_start: health_score = nlsq_result.health_score logger.warning( f"NLSQ warm-start unreliable (health_score={health_score}); " "falling back to init_to_uniform with boosted warmup" ) config.num_warmup = int(config.num_warmup * 1.5) # Convert to JAX arrays x_jax = jnp.asarray(x) y_jax = jnp.asarray(y) yerr_jax = jnp.asarray(yerr) if yerr is not None else None # Run MCMC with warm-start (or uniform init if NLSQ failed) logger.info("Running NUTS sampling") init_params = None if use_warm_start: # BUG-022: Sort tau1/tau2 before computing tau2_factor. # NLSQ may return tau1 > tau2 which would make tau2_factor negative, # causing invalid init params for the double_exp_model parameterization # (which enforces tau2 = tau1 * (1 + tau2_factor) with tau2_factor > 0). tau_vals = sorted([nlsq_init["tau1"], nlsq_init["tau2"]]) tau1_sorted = tau_vals[0] tau2_sorted = tau_vals[1] # Clamp tau2_factor to avoid extreme values from NLSQ warm-start tau2_factor = max(0.01, min(tau2_sorted / tau1_sorted - 1, 1000.0)) init_params = { "tau1": tau1_sorted, # BUG-022: use sorted tau1 (always the smaller value) "tau2_factor": tau2_factor, "baseline": nlsq_init["baseline"], "contrast1": nlsq_init["contrast1"], "contrast2": nlsq_init["contrast2"], } mcmc, samples = _run_mcmc( double_exp_model, (x_jax, y_jax, yerr_jax), config, init_params=init_params, ) return _build_fit_result(mcmc, samples, nlsq_init, param_names, config=config, x=x)
[docs] @log_timing(threshold_ms=60_000) def run_stretched_exp_fit( x: ArrayLike, y: ArrayLike, yerr: ArrayLike | None = None, stability: Literal["auto", "check", False] = "auto", auto_bounds: bool = False, **kwargs, ) -> FitResult: """Run stretched exponential fit with NLSQ warm-start. Parameters ---------- x : array_like Delay times y : array_like G2 correlation values yerr : array_like, optional Measurement uncertainties stability : str, optional NLSQ stability mode: 'auto', 'check', or False (default: 'auto') auto_bounds : bool, optional Use NLSQ auto-bounds inference (default: False) **kwargs Sampler configuration Returns ------- FitResult Posterior samples for tau, baseline, contrast, beta """ check_numpyro() x = np.asarray(x) y = np.asarray(y) if yerr is not None: yerr = np.asarray(yerr) config = _extract_config(kwargs) param_names = ["tau", "baseline", "contrast", "beta"] # NLSQ warm-start with NLSQ 0.6.0 features logger.info("Running NLSQ warm-start for stretched exponential fit") p0 = {"tau": 1.0, "baseline": 1.0, "contrast": 0.3, "beta": 0.8} bounds = { "tau": (1e-6, 1e6), "baseline": (0.0, 2.0), "contrast": (0.0, 1.0), "beta": (0.01, 0.99), } nlsq_result = nlsq_optimize( stretched_exp_func, x, y, yerr, p0, bounds, stability=stability, auto_bounds=auto_bounds, compute_diagnostics=True, ) nlsq_init = nlsq_result.params # Determine warm-start viability use_warm_start = nlsq_result.is_healthy if not use_warm_start: health_score = nlsq_result.health_score logger.warning( f"NLSQ warm-start unreliable (health_score={health_score}); " "falling back to init_to_uniform with boosted warmup" ) config.num_warmup = int(config.num_warmup * 1.5) # Convert to JAX arrays x_jax = jnp.asarray(x) y_jax = jnp.asarray(y) yerr_jax = jnp.asarray(yerr) if yerr is not None else None init_params = None if use_warm_start: # JAX-N-06: Clamp beta from NLSQ init to avoid boundary issues in NUTS. # Beta near 0 or 1 causes numerical instability in the stretched exp model. if "beta" in nlsq_init: nlsq_init["beta"] = max(0.05, min(0.95, nlsq_init["beta"])) init_params = nlsq_init # Run MCMC (uniform init if NLSQ failed, warm-start otherwise) logger.info("Running NUTS sampling") mcmc, samples = _run_mcmc( stretched_exp_model, (x_jax, y_jax, yerr_jax), config, init_params=init_params, ) return _build_fit_result(mcmc, samples, nlsq_init, param_names, config=config, x=x)
[docs] @log_timing(threshold_ms=60_000) def run_power_law_fit( q: ArrayLike, tau: ArrayLike | FitResult, tau_err: ArrayLike | None = None, stability: Literal["auto", "check", False] = "auto", auto_bounds: bool = False, bounds: dict[str, tuple[float, float]] | None = None, **kwargs, ) -> FitResult: """Run power law fit with NLSQ warm-start. Parameters ---------- q : array_like Q values tau : array_like or FitResult Relaxation times (or FitResult with tau samples) tau_err : array_like, optional Measurement uncertainties on tau values from G2 fitting stability : str, optional NLSQ stability mode: 'auto', 'check', or False (default: 'auto') auto_bounds : bool, optional Use NLSQ auto-bounds inference (default: False) bounds : dict, optional Override NLSQ warm-start bounds. Keys are parameter names (``"tau0"``, ``"alpha"``), values are ``(min, max)`` tuples. If ``None``, uses defaults: ``tau0=(1e-6, 1e6)``, ``alpha=(0.0, 10.0)``. **kwargs Sampler configuration Returns ------- FitResult Posterior samples for tau0, alpha """ check_numpyro() q = np.asarray(q) # Handle FitResult input if isinstance(tau, FitResult): raise TypeError( "run_power_law_fit requires per-Q tau values as an array, " "not a single FitResult. Pass an array of tau values from " "individual per-Q fits instead." ) tau = np.asarray(tau) if tau_err is not None: tau_err = np.asarray(tau_err) config = _extract_config(kwargs) param_names = ["tau0", "alpha"] # NLSQ warm-start with NLSQ 0.6.0 features logger.info("Running NLSQ warm-start for power law fit") default_bounds = {"tau0": (1e-6, 1e6), "alpha": (0.0, 10.0)} if bounds is not None: default_bounds.update(bounds) bounds = default_bounds # Data-driven initial guess via log-log linear regression. # power_law: tau = tau0 * q^(-alpha) → log(tau) = log(tau0) - alpha*log(q) pos = (q > 0) & (tau > 0) if np.sum(pos) >= 2: log_q = np.log(q[pos]) log_tau = np.log(tau[pos]) slope, intercept = np.polyfit(log_q, log_tau, 1) p0_alpha = float(np.clip(-slope, bounds["alpha"][0], bounds["alpha"][1])) p0_tau0 = float( np.clip(np.exp(intercept), bounds["tau0"][0], bounds["tau0"][1]) ) logger.info( "Power law p0 from log-log regression: tau0=%.3g, alpha=%.3f", p0_tau0, p0_alpha, ) else: p0_tau0 = 1.0 p0_alpha = float(np.clip(2.0, bounds["alpha"][0], bounds["alpha"][1])) p0 = {"tau0": p0_tau0, "alpha": p0_alpha} nlsq_result = nlsq_optimize( power_law_func, q, tau, tau_err, p0, bounds, stability=stability, auto_bounds=auto_bounds, compute_diagnostics=True, ) nlsq_init = nlsq_result.params # Determine warm-start viability. # For sparse data (typical of power-law / diffusion fits), NLSQ # diagnostics are often unavailable (health_score=0), even when # the optimizer converged. Accept warm-start when NLSQ converged # and chi² is finite, falling back only on true failure. use_warm_start = nlsq_result.is_healthy if ( not use_warm_start and nlsq_result.converged and np.isfinite(nlsq_result.chi_squared) ): logger.info( "NLSQ diagnostics unavailable (health_score=%d) but fit converged " "(chi2=%.3g); accepting warm-start", nlsq_result.health_score, nlsq_result.chi_squared, ) use_warm_start = True if not use_warm_start: health_score = nlsq_result.health_score logger.warning( f"NLSQ warm-start unreliable (health_score={health_score}, " f"converged={nlsq_result.converged}); " "falling back to init_to_uniform with boosted warmup" ) config.num_warmup = int(config.num_warmup * 1.5) # Convert to JAX arrays q_jax = jnp.asarray(q) tau_jax = jnp.asarray(tau) tau_err_jax = jnp.asarray(tau_err) if tau_err is not None else None # Run MCMC (uniform init if NLSQ failed, warm-start otherwise) logger.info("Running NUTS sampling") mcmc, samples = _run_mcmc( power_law_model, (q_jax, tau_jax, tau_err_jax), config, init_params=nlsq_init if use_warm_start else None, ) return _build_fit_result(mcmc, samples, nlsq_init, param_names, config=config, x=q)