"""Bayesian fitting module for XPCS correlation analysis.
This module provides Bayesian parameter estimation using NumPyro NUTS
sampler with JAX-accelerated NLSQ warm-start.
NLSQ 0.6.0 Enhanced Features:
- R², adjusted R², RMSE, MAE, AIC, BIC metrics on NLSQResult
- Confidence intervals for parameters
- Prediction intervals (accounting for observation noise)
- Automatic bounds inference (auto_bounds)
- Numerical stability checks (stability)
- Fallback strategies for difficult problems (fallback)
- Model health diagnostics (compute_diagnostics)
Public API:
fit_single_exp(x, y, yerr=None, **kwargs) -> FitResult
fit_double_exp(x, y, yerr=None, **kwargs) -> FitResult
fit_stretched_exp(x, y, yerr=None, **kwargs) -> FitResult
fit_power_law(q, tau, tau_err=None, **kwargs) -> FitResult
nlsq_fit(model_fn, x, y, yerr, p0, bounds, **kwargs) -> NLSQResult
SamplerConfig
FitResult
NLSQResult
FitDiagnostics
Models:
Single exponential: y = baseline + contrast * exp(-2 * x / tau)
Double exponential: y = baseline + c1*exp(-2x/tau1) + c2*exp(-2x/tau2)
Stretched exponential: y = baseline + contrast * exp(-(2x/tau)^beta)
Power law: tau = tau0 * q^(-alpha)
"""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from numpy.typing import ArrayLike
from .results import FitResult, NLSQResult
[docs]
def fit_single_exp(
x: ArrayLike,
y: ArrayLike,
yerr: ArrayLike | None = None,
**kwargs,
) -> FitResult:
"""Fit single exponential decay model with Bayesian inference.
Model: y = baseline + contrast * exp(-2 * x / tau)
Parameters
----------
x : array_like
Delay times
y : array_like
G2 correlation values
yerr : array_like, optional
Measurement uncertainties
**kwargs
Sampler configuration (see SamplerConfig)
Returns
-------
FitResult
Posterior samples for tau, baseline, contrast
"""
from .sampler import run_single_exp_fit
return run_single_exp_fit(x, y, yerr, **kwargs)
[docs]
def fit_double_exp(
x: ArrayLike,
y: ArrayLike,
yerr: ArrayLike | None = None,
**kwargs,
) -> FitResult:
"""Fit double exponential decay model with Bayesian inference.
Model: y = baseline + c1*exp(-2x/tau1) + c2*exp(-2x/tau2)
Parameters
----------
x : array_like
Delay times
y : array_like
G2 correlation values
yerr : array_like, optional
Measurement uncertainties
**kwargs
Sampler configuration (see SamplerConfig)
Returns
-------
FitResult
Posterior samples for tau1, tau2, baseline, contrast1, contrast2
"""
from .sampler import run_double_exp_fit
return run_double_exp_fit(x, y, yerr, **kwargs)
[docs]
def fit_stretched_exp(
x: ArrayLike,
y: ArrayLike,
yerr: ArrayLike | None = None,
**kwargs,
) -> FitResult:
"""Fit stretched exponential (KWW) model with Bayesian inference.
Model: y = baseline + contrast * exp(-(2 * x / tau)^beta)
Parameters
----------
x : array_like
Delay times
y : array_like
G2 correlation values
yerr : array_like, optional
Measurement uncertainties
**kwargs
Sampler configuration (see SamplerConfig)
Returns
-------
FitResult
Posterior samples for tau, baseline, contrast, beta
"""
from .sampler import run_stretched_exp_fit
return run_stretched_exp_fit(x, y, yerr, **kwargs)
[docs]
def fit_power_law(
q: ArrayLike,
tau: ArrayLike | FitResult,
tau_err: ArrayLike | None = None,
**kwargs,
) -> FitResult:
"""Fit power law Q-dependence of relaxation time.
Model: tau = tau0 * q^(-alpha)
Parameters
----------
q : array_like
Q values
tau : array_like or FitResult
Relaxation times (or FitResult with tau samples)
tau_err : array_like, optional
Measurement uncertainties on tau values from G2 fitting
**kwargs
Sampler configuration (see SamplerConfig)
Returns
-------
FitResult
Posterior samples for tau0, alpha
"""
from .sampler import run_power_law_fit
return run_power_law_fit(q, tau, tau_err=tau_err, **kwargs)
[docs]
def nlsq_fit(
model_fn,
x: ArrayLike,
y: ArrayLike,
yerr: ArrayLike | None,
p0: dict[str, float],
bounds: dict[str, tuple[float, float]],
workflow: str = "auto_global",
*,
auto_bounds: bool = False,
stability: str | bool = False,
fallback: bool = False,
compute_diagnostics: bool = False,
show_progress: bool = False,
) -> NLSQResult:
"""JAX-accelerated nonlinear least squares with NLSQ 0.6.0 features.
Parameters
----------
model_fn : callable
Model function taking x and parameter values. Uses JAX operations.
x : array_like
Independent variable
y : array_like
Dependent variable
yerr : array_like or None
Measurement uncertainties
p0 : dict
Initial parameter guess {name: value}
bounds : dict
Parameter bounds {name: (min, max)}
workflow : {'auto', 'auto_global', 'hpc'}, optional
NLSQ workflow configuration (default: 'auto_global'):
- 'auto': Fast single-start
- 'auto_global': Robust multi-start (default)
- 'hpc': Streaming for large datasets
auto_bounds : bool, optional
Enable automatic bounds inference (default: False)
stability : {'auto', 'check', False}, optional
Numerical stability checks (default: False):
- 'auto': Check and apply fixes
- 'check': Check and warn only
- False: Skip checks
fallback : bool, optional
Enable fallback strategies for difficult problems (default: False)
compute_diagnostics : bool, optional
Compute model health diagnostics (default: False)
show_progress : bool, optional
Display progress bar (default: False)
Returns
-------
NLSQResult
Enhanced result with R², RMSE, AIC, BIC, confidence intervals,
predictions, and optional model diagnostics.
Examples
--------
Basic usage:
>>> result = nlsq_fit(model_fn, x, y, yerr, p0, bounds)
>>> print(f"R² = {result.r_squared:.4f}")
>>> print(result.summary())
Robust fitting with diagnostics:
>>> result = nlsq_fit(
... model_fn, x, y, yerr, p0, bounds,
... workflow='auto_global',
... stability='auto',
... fallback=True,
... compute_diagnostics=True,
... )
"""
from typing import Any, cast
from .nlsq import nlsq_optimize
return nlsq_optimize(
model_fn,
x,
y,
yerr,
p0,
bounds,
workflow=cast(Any, workflow),
auto_bounds=auto_bounds,
stability=cast(Any, stability),
fallback=fallback,
compute_diagnostics=compute_diagnostics,
show_progress=show_progress,
)
# =============================================================================
# Gradient-Based Optimization API (FR-011: Auto-differentiation)
# =============================================================================
[docs]
def value_and_grad(func, argnums=0):
"""Create a function that computes both value and gradient.
Wraps JAX's value_and_grad for user-defined objective functions.
The wrapped function returns (value, gradient) tuple.
Parameters
----------
func : callable
Differentiable function that returns a scalar
argnums : int or tuple of int, optional
Arguments to differentiate with respect to (default: 0)
Returns
-------
callable
Function returning (value, gradient) tuple
Raises
------
RuntimeError
If JAX backend is not available
Examples
--------
>>> def loss(params, x, y):
... pred = params[0] * x + params[1]
... return jnp.sum((pred - y) ** 2)
>>> val_grad = value_and_grad(loss)
>>> value, gradient = val_grad(params, x, y)
"""
from xpcsviewer.backends import get_backend
backend = get_backend()
if backend.name != "jax":
raise RuntimeError(
"value_and_grad requires JAX backend. Set XPCS_USE_JAX=1 to enable."
)
return backend.value_and_grad(func, argnums=argnums)
[docs]
def grad(func, argnums=0):
"""Create a gradient function for a scalar-valued function.
Wraps JAX's grad for user-defined objective functions.
Parameters
----------
func : callable
Differentiable function that returns a scalar
argnums : int or tuple of int, optional
Arguments to differentiate with respect to (default: 0)
Returns
-------
callable
Function that computes gradients
Raises
------
RuntimeError
If JAX backend is not available
Examples
--------
>>> def loss(params, x, y):
... pred = params[0] * x + params[1]
... return jnp.sum((pred - y) ** 2)
>>> gradient_fn = grad(loss)
>>> gradient = gradient_fn(params, x, y)
"""
from xpcsviewer.backends import get_backend
backend = get_backend()
if backend.name != "jax":
raise RuntimeError("grad requires JAX backend. Set XPCS_USE_JAX=1 to enable.")
return backend.grad(func, argnums=argnums)
[docs]
def minimize_with_grad(
objective,
initial_params,
max_iterations: int = 500,
tolerance: float = 1e-8,
learning_rate: float = 0.01,
):
"""Minimize objective function using gradient descent.
Simple gradient descent optimizer for user-defined differentiable
objective functions. For more sophisticated optimization, consider
using optimistix or scipy.optimize.
Parameters
----------
objective : callable
Differentiable objective function: f(params) -> scalar
initial_params : array_like
Initial parameter values
max_iterations : int, optional
Maximum iterations (default: 500)
tolerance : float, optional
Convergence tolerance for loss change (default: 1e-8)
learning_rate : float, optional
Learning rate / step size (default: 0.01)
Returns
-------
tuple
(optimal_params, diagnostics_dict) where diagnostics contains:
- iterations: Number of iterations performed
- losses: Array of loss values at each iteration
- converged: Whether optimization converged
- final_loss: Final loss value
Raises
------
RuntimeError
If JAX backend is not available
Examples
--------
>>> def loss(params):
... return jnp.sum((params - target) ** 2)
>>> optimal, diag = minimize_with_grad(loss, initial_guess)
>>> print(f"Converged: {diag['converged']}, Loss: {diag['final_loss']}")
"""
from xpcsviewer.simplemask.calibration import minimize_with_grad as _minimize
return _minimize(
objective,
initial_params,
max_iterations=max_iterations,
tolerance=tolerance,
learning_rate=learning_rate,
)
# Bayesian batch assembly
from .bayesian_assembly import assemble_fit_summary
# Re-export legacy fitting functions for xpcs_file.py compatibility
from .legacy import (
double_exp,
double_exp_all,
fit_with_fixed,
fit_with_fixed_parallel,
fit_with_fixed_sequential,
robust_curve_fit,
sequential_fitting,
single_exp,
single_exp_all,
vectorized_parameter_estimation,
vectorized_residual_analysis,
)
# Re-export public classes
from .results import FitDiagnostics, FitResult, NLSQResult, SamplerConfig
# Re-export visualization functions (FR-013 to FR-021, NLSQ 0.6.0)
from .visualization import (
PUBLICATION_STYLE,
apply_publication_style,
compute_prediction_interval,
compute_uncertainty_band,
generate_arviz_diagnostics,
plot_comparison,
plot_nlsq_fit,
plot_posterior_predictive,
save_figure,
validate_pcov,
)
# Bayesian all-Q visualization and export
from .viz import export_bayesian_csv, export_bayesian_diagnostics, plot_bayesian_all_q
__all__ = [
# Fitting functions
"fit_single_exp",
"fit_double_exp",
"fit_stretched_exp",
"fit_power_law",
"nlsq_fit",
# Result classes
"SamplerConfig",
"FitResult",
"NLSQResult",
"FitDiagnostics",
# Gradient-based optimization API (FR-011)
"grad",
"value_and_grad",
"minimize_with_grad",
# Visualization functions (FR-013 to FR-021, NLSQ 0.6.0)
"PUBLICATION_STYLE",
"apply_publication_style",
"validate_pcov",
"compute_uncertainty_band",
"compute_prediction_interval",
"plot_nlsq_fit",
"generate_arviz_diagnostics",
"plot_posterior_predictive",
"plot_comparison",
"save_figure",
# Bayesian batch assembly
"assemble_fit_summary",
# Bayesian all-Q visualization and export
"plot_bayesian_all_q",
"export_bayesian_csv",
"export_bayesian_diagnostics",
# Legacy fitting functions
"single_exp",
"double_exp",
"single_exp_all",
"double_exp_all",
"fit_with_fixed",
"fit_with_fixed_parallel",
"fit_with_fixed_sequential",
"robust_curve_fit",
"sequential_fitting",
"vectorized_parameter_estimation",
"vectorized_residual_analysis",
]