Source code for xpcsviewer.fitting.legacy

"""Legacy fitting utilities for backward compatibility.

Provides the same API as the old xpcsviewer.helper.fitting module,
migrated to use the new JAX-accelerated backend where possible.

These functions maintain the same interface for xpcs_file.py compatibility
while leveraging the new fitting infrastructure internally.
"""

from __future__ import annotations

from collections.abc import Callable
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from typing import Any, cast

import numpy as np
from nlsq import curve_fit
from numpy.typing import NDArray

from xpcsviewer.backends._conversions import ensure_numpy
from xpcsviewer.utils.log_utils import log_timing
from xpcsviewer.utils.logging_config import get_logger

logger = get_logger(__name__)


# Model factory functions for nlsq.curve_fit.
# Each factory imports jax.numpy directly because NLSQ JIT-traces model
# functions with JAX regardless of the xpcsviewer backend setting.


def make_single_exp() -> Callable[..., NDArray[np.floating[Any]]]:
    """Factory: create a single exponential closure.

    Uses ``jax.numpy`` directly because ``nlsq.curve_fit`` JIT-traces
    model functions with JAX regardless of the xpcsviewer backend setting.
    """
    import jax.numpy as jnp

    def _single_exp(
        x: NDArray[np.floating[Any]], tau: float, bkg: float, cts: float
    ) -> NDArray[np.floating[Any]]:
        return cts * jnp.exp(-2 * jnp.asarray(x) / tau) + bkg  # type: ignore[return-value]

    return _single_exp


def make_double_exp() -> Callable[..., NDArray[np.floating[Any]]]:
    """Factory: create a double exponential closure.

    Uses ``jax.numpy`` directly because ``nlsq.curve_fit`` JIT-traces
    model functions with JAX regardless of the xpcsviewer backend setting.
    """
    import jax.numpy as jnp

    def _double_exp(
        x: NDArray[np.floating[Any]],
        tau1: float,
        bkg: float,
        cts1: float,
        tau2: float,
        cts2: float,
    ) -> Any:
        xa: Any = jnp.asarray(x)
        return cts1 * jnp.exp(-2 * xa / tau1) + cts2 * jnp.exp(-2 * xa / tau2) + bkg

    return _double_exp


def make_single_exp_all() -> Callable[..., NDArray[np.floating[Any]]]:
    """Factory: create a single_exp_all closure.

    Uses ``jax.numpy`` directly because ``nlsq.curve_fit`` JIT-traces
    model functions with JAX regardless of the xpcsviewer backend setting.
    """
    import jax.numpy as jnp

    def _single_exp_all(
        x: NDArray[np.floating[Any]], a: float, b_: float, c: float, d: float
    ) -> NDArray[np.floating[Any]]:
        return a * jnp.exp(-2 * jnp.asarray(x) / b_) + c + d  # type: ignore[return-value]

    return _single_exp_all


def make_double_exp_all() -> Callable[..., NDArray[np.floating[Any]]]:
    """Factory: create a double_exp_all closure.

    Uses ``jax.numpy`` directly because ``nlsq.curve_fit`` JIT-traces
    model functions with JAX regardless of the xpcsviewer backend setting.
    """
    import jax.numpy as jnp

    def _double_exp_all(
        x: NDArray[np.floating[Any]],
        a: float,
        b_: float,
        c: float,
        d: float,
        e: float,
        f: float,
    ) -> Any:
        xa: Any = jnp.asarray(x)
        return a * jnp.exp(-2 * xa / b_) + c * jnp.exp(-2 * xa / d) + e + f

    return _double_exp_all


# ---------------------------------------------------------------------------
# Backward-compatible module-level functions (kept for callers that import them
# directly by name). These delegate to per-call factory closures.
# ---------------------------------------------------------------------------
_single_exp_fn: Callable[..., NDArray[np.floating[Any]]] | None = None
_double_exp_fn: Callable[..., NDArray[np.floating[Any]]] | None = None
_single_exp_all_fn: Callable[..., NDArray[np.floating[Any]]] | None = None
_double_exp_all_fn: Callable[..., NDArray[np.floating[Any]]] | None = None


def reset_legacy_closures() -> None:
    """Reset cached model function closures (call after reset_backend()).

    The module-level singletons capture the backend at first use. If the
    backend is reset (e.g., switching from NumPy to JAX), these stale
    closures must be invalidated so the next call re-creates them with
    the new backend.
    """
    global _single_exp_fn, _double_exp_fn, _single_exp_all_fn, _double_exp_all_fn
    _single_exp_fn = None
    _double_exp_fn = None
    _single_exp_all_fn = None
    _double_exp_all_fn = None


[docs] def single_exp( x: NDArray[np.floating[Any]], tau: float, bkg: float, cts: float ) -> NDArray[np.floating[Any]]: """Single exponential model for G2 correlation function. Delegates to a module-level closure created once via make_single_exp(). Uses jax.numpy directly for NLSQ JIT compatibility. """ global _single_exp_fn if _single_exp_fn is None: _single_exp_fn = make_single_exp() return _single_exp_fn(x, tau, bkg, cts)
[docs] def double_exp( x: NDArray[np.floating[Any]], tau1: float, bkg: float, cts1: float, tau2: float, cts2: float, ) -> NDArray[np.floating[Any]]: """Double exponential model for G2 correlation function. Delegates to a module-level closure created once via make_double_exp(). Uses jax.numpy directly for NLSQ JIT compatibility. """ global _double_exp_fn if _double_exp_fn is None: _double_exp_fn = make_double_exp() return _double_exp_fn(x, tau1, bkg, cts1, tau2, cts2)
[docs] def single_exp_all( x: NDArray[np.floating[Any]], a: float, b_: float, c: float, d: float ) -> NDArray[np.floating[Any]]: """Single exponential with all parameters. Delegates to a module-level closure created once via make_single_exp_all(). Uses jax.numpy directly for NLSQ JIT compatibility. """ global _single_exp_all_fn if _single_exp_all_fn is None: _single_exp_all_fn = make_single_exp_all() return _single_exp_all_fn(x, a, b_, c, d)
[docs] def double_exp_all( x: NDArray[np.floating[Any]], a: float, b_: float, c: float, d: float, e: float, f: float, ) -> NDArray[np.floating[Any]]: """Double exponential with all parameters. Delegates to a module-level closure created once via make_double_exp_all(). Uses jax.numpy directly for NLSQ JIT compatibility. """ global _double_exp_all_fn if _double_exp_all_fn is None: _double_exp_all_fn = make_double_exp_all() return _double_exp_all_fn(x, a, b_, c, d, e, f)
[docs] @log_timing(threshold_ms=100) def fit_with_fixed( base_func: Callable[..., NDArray[np.floating[Any]]], x: NDArray[np.floating[Any]], y: NDArray[np.floating[Any]], sigma: NDArray[np.floating[Any]], bounds: NDArray[np.floating[Any]], fit_flag: NDArray[np.bool_], fit_x: NDArray[np.floating[Any]], p0: NDArray[np.floating[Any]] | None = None, **kwargs: Any, ) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]: """Fitting with fixed parameters using nlsq.curve_fit. Parameters ---------- base_func : callable Function to fit x : array Input data y : array Output data sigma : array Error bars bounds : tuple (lower_bounds, upper_bounds) fit_flag : array Boolean array indicating which parameters to fit fit_x : array X values for output curve p0 : array, optional Initial parameter values Returns ------- tuple (fit_line, fit_params) """ # Ensure numpy arrays at nlsq boundary x = ensure_numpy(x) y = ensure_numpy(y) sigma = ensure_numpy(sigma) fit_x = ensure_numpy(fit_x) fit_flag = np.asarray(fit_flag) fix_flag = np.logical_not(fit_flag) bounds = np.asarray(bounds) num_args = len(fit_flag) # Process boundaries for fitting parameters only bounds_fit = bounds[:, fit_flag] # Preserve full initial guess before slicing to free params. # Fixed parameters use their p0 value (data-driven) or the midpoint # of bounds when p0 is absent — NOT the upper bound. if p0 is not None: full_p0 = np.array(p0) else: full_p0 = np.mean(bounds, axis=0) # Slice to free parameters for optimizer p0 = full_p0[fit_flag] fit_val = np.zeros((y.shape[1], 2, num_args)) # Fixed parameter values: use initial guess (semantically correct) fixed_values = full_p0.tolist() fit_indices = [i for i in range(num_args) if fit_flag[i]] # Create wrapper function for fixed parameters # Uses a Python list (not np.array) to avoid TracerArrayConversionError # when nlsq JIT-traces this function with JAX tracers as fit_params. def wrapper_func(x_data, *fit_params): full_params = list(fixed_values) for idx, val in zip(fit_indices, fit_params): full_params[idx] = val return base_func(x_data, *full_params) # Fit each column for n in range(y.shape[1]): try: sigma_col = sigma[:, n] if sigma.ndim > 1 else sigma popt, pcov = curve_fit( wrapper_func, x, y[:, n], sigma=sigma_col, p0=p0, bounds=(bounds_fit[0], bounds_fit[1]), max_nfev=5000, ) fit_val[n, 0, fit_flag] = popt pcov_diag = np.diag(pcov) errors = np.sqrt(pcov_diag) if np.any(pcov_diag < 0): logger.warning( f"Column {n}: Negative diagonal elements in covariance matrix" ) if np.any(~np.isfinite(errors)): errors = np.where(np.isfinite(errors), errors, np.abs(popt) * 0.1) fit_val[n, 1, fit_flag] = errors fit_val[n, 0, fix_flag] = full_p0[fix_flag] fit_val[n, 1, fix_flag] = 0 except Exception as e: logger.warning(f"Fitting failed for column {n}: {e}") fit_val[n, 0, :] = np.mean(bounds, axis=0) fit_val[n, 1, :] = 0 # Generate fit lines fit_line = np.zeros((y.shape[1], len(fit_x))) for n in range(y.shape[1]): fit_line[n] = ensure_numpy(base_func(fit_x, *fit_val[n, 0, :])) return fit_line, fit_val
def _fit_single_qvalue( args: tuple[Any, ...], ) -> tuple[ int, NDArray[np.floating[Any]] | None, NDArray[np.floating[Any]] | None, bool ]: """Worker function for parallel fitting of a single q-value.""" col_idx, x, y_col, sigma_col, wrapper_func, p0, bounds_fit = args try: popt, pcov = curve_fit( wrapper_func, x, y_col, sigma=sigma_col, p0=p0, bounds=(bounds_fit[0], bounds_fit[1]), method="trf", max_nfev=5000, ) pcov_diag = np.diag(pcov) errors = np.sqrt(np.abs(pcov_diag)) if np.any(~np.isfinite(errors)): errors = np.where(np.isfinite(errors), errors, np.abs(popt) * 0.1) return col_idx, popt, errors, True except Exception as e: logger.warning(f"Fitting failed for q-value {col_idx}: {e}") return col_idx, None, None, False
[docs] @log_timing(threshold_ms=500) def fit_with_fixed_parallel( base_func: Callable[..., NDArray[np.floating[Any]]], x: NDArray[np.floating[Any]], y: NDArray[np.floating[Any]], sigma: NDArray[np.floating[Any]], bounds: NDArray[np.floating[Any]], fit_flag: NDArray[np.bool_], fit_x: NDArray[np.floating[Any]], p0: NDArray[np.floating[Any]] | None = None, max_workers: int | None = None, use_threads: bool = True, **kwargs: Any, ) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]: """Parallel version of fit_with_fixed for processing multiple q-values simultaneously.""" # Ensure numpy arrays at nlsq boundary x = ensure_numpy(x) y = ensure_numpy(y) sigma = ensure_numpy(sigma) fit_x = ensure_numpy(fit_x) fit_flag = np.asarray(fit_flag) fix_flag = np.logical_not(fit_flag) bounds = np.asarray(bounds) num_args = len(fit_flag) num_qvals = y.shape[1] bounds_fit = bounds[:, fit_flag] # Preserve full initial guess for fixed parameter values if p0 is not None: full_p0 = np.array(p0) else: full_p0 = np.mean(bounds, axis=0) p0 = full_p0[fit_flag] fit_val = np.zeros((num_qvals, 2, num_args)) # Fixed parameter values: use initial guess (not upper bound) fixed_values_par = full_p0.tolist() fit_indices_par = [i for i in range(num_args) if fit_flag[i]] def wrapper_func(x_data, *fit_params): full_params = list(fixed_values_par) for idx, val in zip(fit_indices_par, fit_params): full_params[idx] = val return base_func(x_data, *full_params) fit_args = [] for n in range(num_qvals): sigma_col = sigma[:, n] if sigma.ndim > 1 else sigma fit_args.append((n, x, y[:, n], sigma_col, wrapper_func, p0, bounds_fit)) if max_workers is None: import os max_workers = min(num_qvals, os.cpu_count() or 1) logger.info( f"Starting parallel G2 fitting for {num_qvals} q-values using {max_workers} workers" ) executor_class = ThreadPoolExecutor if use_threads else ProcessPoolExecutor with executor_class(max_workers=max_workers) as executor: future_to_col = { executor.submit(_fit_single_qvalue, args): args[0] for args in fit_args } for completed_fits, future in enumerate(as_completed(future_to_col), start=1): col_idx, popt, errors, success = future.result() if success: fit_val[col_idx, 0, fit_flag] = popt fit_val[col_idx, 1, fit_flag] = errors fit_val[col_idx, 0, fix_flag] = full_p0[fix_flag] fit_val[col_idx, 1, fix_flag] = 0 else: fit_val[col_idx, 0, :] = np.mean(bounds, axis=0) fit_val[col_idx, 1, :] = 0 if completed_fits % max(1, num_qvals // 10) == 0: progress = (completed_fits / num_qvals) * 100 logger.debug( f"Parallel fitting progress: {progress:.1f}% ({completed_fits}/{num_qvals})" ) def generate_fit_line(n): return n, ensure_numpy(base_func(fit_x, *fit_val[n, 0, :])) fit_line = np.zeros((num_qvals, len(fit_x))) with executor_class(max_workers=max_workers) as executor: line_futures = { executor.submit(generate_fit_line, n): n for n in range(num_qvals) } for future in as_completed(line_futures): from typing import cast result_tuple = cast(tuple[int, NDArray[np.floating[Any]]], future.result()) n, line_data = result_tuple fit_line[n] = line_data logger.info(f"Parallel G2 fitting completed for {num_qvals} q-values") return fit_line, fit_val
[docs] def sequential_fitting( func: Callable[..., NDArray[np.floating[Any]]], x: NDArray[np.floating[Any]], y: NDArray[np.floating[Any]], sigma: NDArray[np.floating[Any]] | None = None, p0: NDArray[np.floating[Any]] | None = None, bounds: tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]] | None = None, **kwargs: Any, ) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]], str]: """Robust fitting using NLSQ 0.6.0 with multistart and fallback strategies. Replaces the legacy TRF → LM → DE chain with a single NLSQ call that handles fallback automatically. """ x = ensure_numpy(x) y = ensure_numpy(y) sigma = ensure_numpy(sigma) if sigma is not None else None # Determine n_params from bounds or p0 (more reliable than inspecting wrapper func) if bounds is not None: n_params = len(bounds[0]) elif p0 is not None: n_params = len(p0) else: # Last resort: try to infer from function signature # Note: This may fail for wrapper functions with *args try: n_params = func.__code__.co_argcount - 1 if n_params <= 0: n_params = 1 # Minimum 1 parameter except AttributeError: n_params = 1 # Filter kwargs for nlsq compatibility safe_kwargs = {k: v for k, v in kwargs.items() if k not in ["max_nfev", "maxfev"]} popt: NDArray[np.floating[Any]] | None = None pcov: NDArray[np.floating[Any]] | None = None try: from typing import Any as TypingAny result: TypingAny = curve_fit( func, x, y, sigma=sigma, p0=p0, bounds=bounds, multistart=True, fallback=True, stability="auto", **safe_kwargs, ) if hasattr(result, "popt"): popt, pcov = np.asarray(result.popt), np.asarray(result.pcov) else: popt, pcov = np.asarray(result[0]), np.asarray(result[1]) # Validate that popt has the expected shape if popt.shape == (n_params,) and np.all(np.isfinite(popt)): if pcov.shape == (n_params, n_params) and np.all(np.isfinite(pcov)): logger.debug("NLSQ multistart fitting succeeded") return popt, pcov, "nlsq_multistart" # popt is valid but pcov is not - use popt with fallback covariance logger.debug( "NLSQ fitting: popt valid but pcov invalid, using fallback pcov" ) return popt, np.eye(n_params) * 1e6, "nlsq_partial" except Exception as e: logger.debug(f"NLSQ fitting failed: {e}") # Fallback: construct valid parameters from bounds or p0 logger.warning("All fitting methods failed, using fallback parameters") if p0 is not None: fallback_popt = np.asarray(p0).flatten() if fallback_popt.shape[0] != n_params: # p0 shape mismatch, use bounds mean if bounds is not None: fallback_popt = np.mean(bounds, axis=0) else: fallback_popt = np.ones(n_params) elif bounds is not None: fallback_popt = np.mean(bounds, axis=0) else: fallback_popt = np.ones(n_params) # Final shape validation fallback_popt = np.atleast_1d(fallback_popt) if fallback_popt.shape[0] != n_params: logger.warning( f"Fallback popt shape mismatch: got {fallback_popt.shape}, expected ({n_params},)" ) fallback_popt = np.ones(n_params) return fallback_popt, np.eye(n_params) * 1e6, "fallback"
[docs] @log_timing(threshold_ms=500) def fit_with_fixed_sequential( base_func: Callable[..., NDArray[np.floating[Any]]], x: NDArray[np.floating[Any]], y: NDArray[np.floating[Any]], sigma: NDArray[np.floating[Any]], bounds: NDArray[np.floating[Any]], fit_flag: NDArray[np.bool_], fit_x: NDArray[np.floating[Any]], p0: NDArray[np.floating[Any]] | None = None, **kwargs: Any, ) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]], list[str]]: """Enhanced fitting with sequential method approach.""" # Ensure numpy arrays x = ensure_numpy(x) y = ensure_numpy(y) sigma = ensure_numpy(sigma) fit_x = ensure_numpy(fit_x) fit_flag = np.asarray(fit_flag) fix_flag = np.logical_not(fit_flag) bounds = np.asarray(bounds) num_args = len(fit_flag) bounds_fit = bounds[:, fit_flag] # Preserve full initial guess for fixed parameter values if p0 is not None: full_p0 = np.array(p0) else: full_p0 = np.mean(bounds, axis=0) p0 = full_p0[fit_flag] fit_val = np.zeros((y.shape[1], 2, num_args)) fit_methods = [] # Fixed parameter values: use initial guess (not upper bound) fixed_values_seq = full_p0.tolist() fit_indices_seq = [i for i in range(num_args) if fit_flag[i]] def wrapper_func(x_data, *fit_params): full_params = list(fixed_values_seq) for idx, val in zip(fit_indices_seq, fit_params): full_params[idx] = val return base_func(x_data, *full_params) # Compute expected number of fit parameters (True values in fit_flag) n_fit_params = int(np.sum(fit_flag)) for n in range(y.shape[1]): try: sigma_col = sigma[:, n] if sigma.ndim > 1 else sigma popt, pcov, method_used = sequential_fitting( wrapper_func, x, y[:, n], sigma=sigma_col, p0=p0, bounds=(bounds_fit[0], bounds_fit[1]), max_nfev=5000, ) # Validate popt shape before assignment popt = np.atleast_1d(popt) if popt.shape[0] != n_fit_params: logger.warning( f"Column {n}: popt shape mismatch - got {popt.shape}, " f"expected ({n_fit_params},). Using fallback." ) popt = p0.copy() if p0 is not None else np.mean(bounds_fit, axis=0) pcov = np.eye(n_fit_params) * 1e6 method_used = "fallback_shape_mismatch" fit_methods.append(method_used) logger.debug(f"Column {n}: fitted using {method_used}") fit_val[n, 0, fit_flag] = popt # Safely extract diagonal from pcov pcov_diag = np.diag(pcov) if pcov.ndim == 2 else np.abs(pcov) errors = np.sqrt(np.abs(pcov_diag)) # Handle non-finite errors if not np.all(np.isfinite(errors)): errors = np.where(np.isfinite(errors), errors, np.abs(popt) * 0.1) fit_val[n, 1, fit_flag] = errors fit_val[n, 0, fix_flag] = full_p0[fix_flag] fit_val[n, 1, fix_flag] = 0 except Exception as e: logger.warning(f"Sequential fitting failed for column {n}: {e}") fit_methods.append("fallback_error") fit_val[n, 0, :] = np.mean(bounds, axis=0) fit_val[n, 1, :] = 0 fit_line = np.zeros((y.shape[1], len(fit_x))) for n in range(y.shape[1]): fit_line[n] = ensure_numpy(base_func(fit_x, *fit_val[n, 0, :])) method_counts: dict[str, int] = {} for method in fit_methods: method_counts[method] = method_counts.get(method, 0) + 1 logger.info(f"Fitting methods used: {method_counts}") return fit_line, fit_val, fit_methods
[docs] def robust_curve_fit( func: Callable[..., NDArray[np.floating[Any]]], x: NDArray[np.floating[Any]], y: NDArray[np.floating[Any]], **kwargs: Any, ) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]: """Simple wrapper around nlsq.curve_fit with error handling.""" x = ensure_numpy(x) y = ensure_numpy(y) try: result: Any = curve_fit(func, x, y, **kwargs) if hasattr(result, "popt"): return np.asarray(result.popt), np.asarray(result.pcov) return np.asarray(result[0]), np.asarray(result[1]) except Exception as e: logger.warning(f"Curve fitting failed: {e}") n_params = func.__code__.co_argcount - 1 return np.ones(n_params), np.eye(n_params)
[docs] def vectorized_parameter_estimation( x: NDArray[np.floating[Any]], y: NDArray[np.floating[Any]], model_type: str = "exponential", ) -> tuple | None: """Vectorized parameter estimation.""" x, y = ensure_numpy(x), ensure_numpy(y) if model_type != "exponential": return None try: y_min, y_max = np.min(y), np.max(y) amp = y_max - y_min idx = np.argmin(np.abs(y - (y_min + amp / np.e))) tau = x[idx] if idx > 0 else x[len(x) // 2] # NLSQ JIT-traces model functions with JAX internally, so the # model must use jax.numpy ops regardless of the xpcsviewer backend. import jax.numpy as jnp def _model_func(x_val, tau_val, bkg_val, cts_val): return cts_val * jnp.exp(-2 * jnp.asarray(x_val) / tau_val) + bkg_val result = curve_fit( _model_func, x, y, p0=[tau, y_min, amp], bounds=( [x[1] * 0.1, -np.abs(y_max), amp * 0.1], [x[-1] * 10, y_max * 1.1, amp * 10], ), method="trf", maxfev=5000, ) if hasattr(result, "popt"): popt = cast(Any, result).popt else: popt = result[0] return tuple(popt) except Exception: return None
[docs] def vectorized_residual_analysis( x: NDArray[np.floating[Any]], y_true: NDArray[np.floating[Any]], y_pred: NDArray[np.floating[Any]], ) -> dict[str, float | NDArray[np.floating[Any]]]: """Vectorized residual analysis.""" y_true = ensure_numpy(y_true) y_pred = ensure_numpy(y_pred) residuals = y_true - y_pred return { "mean_residual": np.mean(residuals), "std_residual": np.std(residuals), "rmse": np.sqrt(np.mean(residuals**2)), "mae": np.mean(np.abs(residuals)), }