Source code for xpcsviewer.fitting.results

"""Result dataclasses for Bayesian fitting.

This module defines the data structures for fit results, diagnostics,
and sampler configuration.
"""

from __future__ import annotations

import logging
import sys
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

logger = logging.getLogger(__name__)

import numpy as np

if TYPE_CHECKING:
    import pandas as pd
    from nlsq.diagnostics import ModelHealthReport
    from nlsq.result import CurveFitResult
    from numpy.typing import ArrayLike
    from xarray import DataTree


def safe_version(package_name: str) -> str:
    """Safely retrieve package version for reproducibility tracking.

    Per Technical Guidelines, fit artifacts must include software versions.
    This function provides robust version retrieval that never raises exceptions.

    Parameters
    ----------
    package_name : str
        Name of the package to get version for

    Returns
    -------
    str
        Version string, or "unknown" if version cannot be determined
    """
    try:
        from importlib.metadata import version

        return version(package_name)
    except Exception:
        # Fallback: try __version__ attribute
        try:
            import importlib

            module = importlib.import_module(package_name)
            return getattr(module, "__version__", "unknown")
        except Exception:
            return "unknown"


[docs] @dataclass class SamplerConfig: """Configuration for NUTS sampler. Attributes ---------- num_warmup : int Number of warmup (burn-in) samples (default: 500) num_samples : int Number of posterior samples per chain (default: 1000) num_chains : int Number of MCMC chains (default: 4) target_accept_prob : float Target acceptance probability for NUTS (default: 0.8) max_tree_depth : int Maximum tree depth for NUTS (default: 10) random_seed : int or None Random seed for reproducibility (default: None) """ num_warmup: int = 500 num_samples: int = 1000 num_chains: int = 4 target_accept_prob: float = 0.8 max_tree_depth: int = 10 random_seed: int | None = None
[docs] def __post_init__(self): """Validate sampler configuration.""" if self.num_warmup <= 0: raise ValueError(f"num_warmup must be positive, got {self.num_warmup}") if self.num_samples <= 0: raise ValueError(f"num_samples must be positive, got {self.num_samples}") if self.num_chains <= 0: raise ValueError(f"num_chains must be positive, got {self.num_chains}") if not (0 < self.target_accept_prob < 1): raise ValueError( f"target_accept_prob must be in (0, 1), got {self.target_accept_prob}" ) if self.max_tree_depth <= 0: raise ValueError( f"max_tree_depth must be positive, got {self.max_tree_depth}" )
[docs] @dataclass class FitDiagnostics: """Convergence diagnostics for MCMC sampling. Attributes ---------- r_hat : dict[str, float] Gelman-Rubin statistic per parameter ess_bulk : dict[str, int] Bulk ESS per parameter ess_tail : dict[str, int] Tail ESS per parameter divergences : int Number of divergent transitions max_treedepth_reached : int Count of max treedepth events bfmi : float | None Bayesian Fraction of Missing Information (mean across chains). Added per Technical Guidelines for Bayesian inference compliance. Properties (computed) --------------------- converged : bool True if all diagnostics pass thresholds (see below). This is a computed ``@property``, not a stored field. Convergence Thresholds ---------------------- r_hat < 1.01 : All parameters must converge ess_bulk > 400 : Sufficient effective samples ess_tail > 400 : Sufficient tail samples divergences == 0 : No divergent transitions bfmi >= 0.2 : Adequate exploration (if computed) """ r_hat: dict[str, float] = field(default_factory=dict) ess_bulk: dict[str, int] = field(default_factory=dict) ess_tail: dict[str, int] = field(default_factory=dict) divergences: int = 0 max_treedepth_reached: int = 0 bfmi: float | None = None # NEW: Per Technical Guidelines
[docs] def __post_init__(self) -> None: """Validate diagnostic invariants (BUG-025). Enforces: - divergences >= 0 (negative divergences are physically meaningless) - all ess_bulk values are non-negative - all ess_tail values are non-negative """ if self.divergences < 0: raise ValueError( f"divergences must be non-negative, got {self.divergences}" ) for param, ess in self.ess_bulk.items(): if ess < 0: raise ValueError(f"ess_bulk['{param}'] must be non-negative, got {ess}") for param, ess in self.ess_tail.items(): if ess < 0: raise ValueError(f"ess_tail['{param}'] must be non-negative, got {ess}")
@property def converged(self) -> bool: """Check if all diagnostics pass thresholds.""" # R-hat threshold if any(r > 1.01 for r in self.r_hat.values()): return False # ESS thresholds if any(e < 400 for e in self.ess_bulk.values()): return False if any(e < 400 for e in self.ess_tail.values()): return False # No divergences if self.divergences > 0: return False # BFMI check (per Technical Guidelines) if self.bfmi is not None and self.bfmi < 0.2: return False return True
[docs] @dataclass class NLSQResult: """Result from nonlinear least squares fitting. This class wraps NLSQ 0.6.0's CurveFitResult and delegates statistical properties to the native result for accuracy and consistency. Attributes ---------- params : dict[str, float] Point estimates for each parameter converged : bool Whether optimization converged chi_squared : float Reduced chi-squared statistic pcov_valid : bool Covariance validity flag (FR-021) pcov_message : str Validation message describing covariance status native_result : CurveFitResult, optional NLSQ 0.6.0 native result object for property delegation _param_names : list[str] Ordered parameter names for covariance indexing Properties (delegated to native_result when available) ------------------------------------------------------ r_squared : float Coefficient of determination (R²). Range: (-∞, 1], where 1 is perfect fit. adj_r_squared : float Adjusted R² accounting for number of parameters. rmse : float Root mean squared error. Lower is better. mae : float Mean absolute error. Robust to outliers. aic : float Akaike Information Criterion. Lower is better for model selection. bic : float Bayesian Information Criterion. Penalizes complexity more than AIC. residuals : ndarray Fit residuals as numpy array. predictions : ndarray Model predictions at input x values. covariance : ndarray Parameter covariance matrix (n_params x n_params). confidence_intervals : dict[str, tuple[float, float]] Parameter confidence intervals at 95% level. diagnostics : ModelHealthReport or None NLSQ model health diagnostics (if compute_diagnostics=True). is_healthy : bool Whether the fit passes all health checks. health_score : int Health score (0-100). condition_number : float Condition number from identifiability diagnostics. """ # Core fields (required) params: dict[str, float] chi_squared: float converged: bool # Covariance validation pcov_valid: bool = True pcov_message: str = "" # Native result for delegation (NEW - T018) native_result: CurveFitResult | None = None # Sentinel flag: True when this result is a fallback due to fitting failure (BUG-003) is_fallback: bool = False # Parameter names for covariance indexing (NEW - T019) _param_names: list[str] = field(default_factory=list) # Legacy storage fields (used when native_result is None for backward compat) _covariance: np.ndarray | None = field(default=None, repr=False) _residuals: np.ndarray | None = field(default=None, repr=False) # BUG-023: default to NaN so failed fits are distinguishable from poor-but-converged fits. # Callers can detect failure with math.isnan(result.r_squared) instead of checking == 0.0. _r_squared: float = field(default=float("nan"), repr=False) _adj_r_squared: float = field(default=float("nan"), repr=False) _rmse: float = field(default=float("nan"), repr=False) _mae: float = field(default=float("nan"), repr=False) _aic: float = field(default=float("nan"), repr=False) _bic: float = field(default=float("nan"), repr=False) _confidence_intervals: dict[str, tuple[float, float]] = field( default_factory=dict, repr=False ) _predictions: np.ndarray | None = field(default=None, repr=False) # Backward compatibility aliases for __init__
[docs] def __post_init__(self) -> None: """Initialize param names from params dict if not provided.""" if not self._param_names: self._param_names = list(self.params.keys())
# T020: r_squared property delegation @property def r_squared(self) -> float: """Coefficient of determination (R²).""" if self.native_result is not None and hasattr(self.native_result, "r_squared"): return self.native_result.r_squared return self._r_squared @r_squared.setter def r_squared(self, value: float) -> None: """Set r_squared (for backward compat initialization). Note: This setter is a no-op when native_result is active (BUG-039). The value is delegated to native_result.r_squared instead. """ if self.native_result is not None: logger.warning( "NLSQResult.r_squared setter called while delegation to native_result " "is active; this assignment is a no-op. The getter will return " "native_result.r_squared." ) return self._r_squared = value # T021: adj_r_squared property delegation @property def adj_r_squared(self) -> float: """Adjusted R² accounting for number of parameters.""" if self.native_result is not None and hasattr( self.native_result, "adj_r_squared" ): return self.native_result.adj_r_squared return self._adj_r_squared @adj_r_squared.setter def adj_r_squared(self, value: float) -> None: """Set adj_r_squared (for backward compat initialization). Note: This setter is a no-op when native_result is active (BUG-039). """ if self.native_result is not None: logger.warning( "NLSQResult.adj_r_squared setter called while delegation to " "native_result is active; this assignment is a no-op." ) return self._adj_r_squared = value # T022: rmse property delegation @property def rmse(self) -> float: """Root mean squared error.""" if self.native_result is not None and hasattr(self.native_result, "rmse"): return self.native_result.rmse return self._rmse @rmse.setter def rmse(self, value: float) -> None: """Set rmse (for backward compat initialization). Note: This setter is a no-op when native_result is active (BUG-039). """ if self.native_result is not None: logger.warning( "NLSQResult.rmse setter called while delegation to " "native_result is active; this assignment is a no-op." ) return self._rmse = value # T023: mae property delegation @property def mae(self) -> float: """Mean absolute error.""" if self.native_result is not None and hasattr(self.native_result, "mae"): return self.native_result.mae return self._mae @mae.setter def mae(self, value: float) -> None: """Set mae (for backward compat initialization). Note: This setter is a no-op when native_result is active (BUG-039). """ if self.native_result is not None: logger.warning( "NLSQResult.mae setter called while delegation to " "native_result is active; this assignment is a no-op." ) return self._mae = value # T024: aic property delegation @property def aic(self) -> float: """Akaike Information Criterion.""" if self.native_result is not None and hasattr(self.native_result, "aic"): return self.native_result.aic return self._aic @aic.setter def aic(self, value: float) -> None: """Set aic (for backward compat initialization). Note: This setter is a no-op when native_result is active (BUG-039). """ if self.native_result is not None: logger.warning( "NLSQResult.aic setter called while delegation to " "native_result is active; this assignment is a no-op." ) return self._aic = value # T025: bic property delegation @property def bic(self) -> float: """Bayesian Information Criterion.""" if self.native_result is not None and hasattr(self.native_result, "bic"): return self.native_result.bic return self._bic @bic.setter def bic(self, value: float) -> None: """Set bic (for backward compat initialization). Note: This setter is a no-op when native_result is active (BUG-039). """ if self.native_result is not None: logger.warning( "NLSQResult.bic setter called while delegation to " "native_result is active; this assignment is a no-op." ) return self._bic = value # T026: residuals property delegation @property def residuals(self) -> np.ndarray: """Fit residuals as numpy array.""" if self.native_result is not None and hasattr(self.native_result, "residuals"): return np.asarray(self.native_result.residuals) if self._residuals is not None: return self._residuals return np.array([]) @residuals.setter def residuals(self, value: np.ndarray) -> None: """Set residuals (for backward compat initialization). Note: This setter is a no-op when native_result is active (BUG-039). """ if self.native_result is not None: logger.warning( "NLSQResult.residuals setter called while delegation to " "native_result is active; this assignment is a no-op." ) return self._residuals = value # T027: predictions property delegation @property def predictions(self) -> np.ndarray | None: """Model predictions at input x values.""" if self.native_result is not None and hasattr( self.native_result, "predictions" ): return np.asarray(self.native_result.predictions) return self._predictions @predictions.setter def predictions(self, value: np.ndarray | None) -> None: """Set predictions (for backward compat initialization). Note: This setter is a no-op when native_result is active (BUG-039). """ if self.native_result is not None: logger.warning( "NLSQResult.predictions setter called while delegation to " "native_result is active; this assignment is a no-op." ) return self._predictions = value # T035: covariance property delegation @property def covariance(self) -> np.ndarray: """Parameter covariance matrix.""" if self.native_result is not None and hasattr(self.native_result, "pcov"): return np.asarray(self.native_result.pcov) if self._covariance is not None: return self._covariance n = len(self.params) return np.zeros((n, n)) @covariance.setter def covariance(self, value: np.ndarray) -> None: """Set covariance (for backward compat initialization). Note: This setter is a no-op when native_result is active (BUG-039). """ if self.native_result is not None: logger.warning( "NLSQResult.covariance setter called while delegation to " "native_result is active; this assignment is a no-op." ) return self._covariance = value # T029: confidence_intervals property delegation @property def confidence_intervals(self) -> dict[str, tuple[float, float]]: """Parameter confidence intervals at 95% level.""" if self.native_result is not None and hasattr( self.native_result, "confidence_intervals" ): ci = self.native_result.confidence_intervals # Handle both property (dict) and method (callable) on native_result if callable(ci): ci = ci() if isinstance(ci, dict): return ci return self._confidence_intervals @confidence_intervals.setter def confidence_intervals(self, value: dict[str, tuple[float, float]]) -> None: """Set confidence_intervals (for backward compat initialization). Note: This setter is a no-op when native_result is active (BUG-039). """ if self.native_result is not None: logger.warning( "NLSQResult.confidence_intervals setter called while delegation to " "native_result is active; this assignment is a no-op." ) return self._confidence_intervals = value # T31: diagnostics property delegation @property def diagnostics(self) -> ModelHealthReport | None: """NLSQ model health diagnostics.""" if self.native_result is not None and hasattr( self.native_result, "diagnostics" ): return self.native_result.diagnostics return None # T032: is_healthy property @property def is_healthy(self) -> bool: """Whether the fit passes all health checks. Uses an attribute-based check against the NLSQ diagnostics API (BUG-040). Avoids the fragile str(status) == "healthy" string comparison by: 1. Checking diagnostics.is_healthy attribute if available (NLSQ native), or 2. Using health_score >= 70 as a numeric fallback, or 3. Checking enum value via .value or .name attribute comparison. """ if self.diagnostics is None: return False # Unknown health != healthy # Primary: use native is_healthy attribute if available (NLSQ 0.6.0) if hasattr(self.diagnostics, "is_healthy"): return bool(self.diagnostics.is_healthy) # Secondary: use numeric health_score (avoids string comparison) if hasattr(self.diagnostics, "health_score"): return int(self.diagnostics.health_score) >= 70 # Fallback: use enum value/name comparison (more robust than str()) status = self.diagnostics.status if hasattr(status, "value"): return status.value == "healthy" # type: ignore[comparison-overlap] if hasattr(status, "name"): return status.name.lower() == "healthy" # Last resort: string comparison (kept for extreme backward compat) return str(status).lower() in ("healthy", "modelstatus.healthy") # T033: health_score property @property def health_score(self) -> int: """Health score (0-100).""" if self.diagnostics is not None: return int(self.diagnostics.health_score) return 0 # Unknown health != perfect health # T034: condition_number property @property def condition_number(self) -> float: """Condition number from identifiability diagnostics.""" if ( self.diagnostics is not None and self.diagnostics.identifiability is not None ): return self.diagnostics.identifiability.condition_number return 1.0 # Default to well-conditioned
[docs] def get_param_uncertainty(self, param: str) -> float: """Get standard error for a parameter from the covariance matrix. Parameters ---------- param : str Parameter name Returns ------- float Standard error (sqrt of diagonal covariance element) """ param_names = list(self.params.keys()) if param not in param_names: raise KeyError(f"Parameter '{param}' not found in params") idx = param_names.index(param) return float(np.sqrt(self.covariance[idx, idx]))
[docs] def get_confidence_interval( self, param: str, alpha: float = 0.05 ) -> tuple[float, float]: """Get confidence interval for a parameter. Parameters ---------- param : str Parameter name alpha : float Significance level (default: 0.05 for 95% CI) Returns ------- tuple[float, float] (lower, upper) bounds of confidence interval """ if param in self.confidence_intervals: return self.confidence_intervals[param] # Compute from covariance if not cached from scipy import stats param_value = self.params[param] std_err = self.get_param_uncertainty(param) z = stats.norm.ppf(1 - alpha / 2) return (param_value - z * std_err, param_value + z * std_err)
# T030: get_prediction_interval method delegation
[docs] def get_prediction_interval( self, x: ArrayLike, alpha: float = 0.05 ) -> tuple[np.ndarray, np.ndarray]: """Get prediction interval at new x values. Delegates to native_result.prediction_interval() when available. Parameters ---------- x : array_like X values at which to compute prediction intervals alpha : float Significance level (default: 0.05 for 95% PI) Returns ------- tuple[ndarray, ndarray] (lower, upper) bounds of prediction interval as numpy arrays """ x = np.asarray(x) if self.native_result is not None: pi = self.native_result.prediction_interval(x=x, alpha=alpha) pi = np.asarray(pi) # Handle both (n, 2) array and (2, n) array formats if pi.ndim == 2 and pi.shape[-1] == 2: return pi[:, 0], pi[:, 1] if pi.ndim == 2 and pi.shape[0] == 2: return pi[0], pi[1] # Fallback: try direct unpack return pi[0], pi[1] # Fallback: return simple prediction +/- 2*rmse (rough approximation) predictions = self.predictions if predictions is not None and len(predictions) == len(x): margin = 2 * self.rmse return predictions - margin, predictions + margin # No prediction interval available return np.full_like(x, np.nan), np.full_like(x, np.nan)
[docs] def summary(self) -> str: """Generate a formatted summary of the fit results. Returns ------- str Formatted summary string """ lines = ["NLSQ Fit Results", "=" * 50] # Convergence status status = "Converged" if self.converged else "Did not converge" lines.append(f"Status: {status}") lines.append("") # Fit quality metrics lines.append("Fit Quality:") lines.append(f" R²: {self.r_squared:.6f}") lines.append(f" Adjusted R²: {self.adj_r_squared:.6f}") lines.append(f" RMSE: {self.rmse:.6e}") lines.append(f" MAE: {self.mae:.6e}") lines.append(f" χ² (reduced): {self.chi_squared:.4f}") lines.append("") # Model selection criteria lines.append("Model Selection:") lines.append(f" AIC: {self.aic:.2f}") lines.append(f" BIC: {self.bic:.2f}") lines.append("") # Parameters with uncertainties lines.append("Parameters:") for name, value in self.params.items(): std_err = self.get_param_uncertainty(name) ci = self.get_confidence_interval(name) lines.append(f" {name}: {value:.6e} ± {std_err:.6e}") lines.append(f" 95% CI: [{ci[0]:.6e}, {ci[1]:.6e}]") # Covariance validity lines.append("") if self.pcov_valid: lines.append("Covariance: Valid") else: lines.append(f"Covariance: Invalid - {self.pcov_message}") return "\n".join(lines)
[docs] def plot(self, model, x_data, y_data, **kwargs): """Plot fit with uncertainty band. Parameters ---------- model : callable Model function x_data, y_data : array-like Original data **kwargs Additional arguments for visualization Returns ------- matplotlib.axes.Axes """ from .visualization import plot_nlsq_fit return plot_nlsq_fit(self, model, x_data, y_data, **kwargs)
[docs] @dataclass class FitResult: """Container for Bayesian fitting results. Attributes ---------- samples : dict[str, ndarray] Posterior samples {param_name: (n_samples,)} param_names : list[str] Canonical parameter order matching the model function signature. Used by visualization to ensure correct positional argument mapping. summary : DataFrame Summary statistics per parameter diagnostics : FitDiagnostics Convergence diagnostics nlsq_init : dict[str, float] NLSQ warm-start point estimates arviz_data : DataTree ArviZ-compatible data for plotting (FR-015) config : SamplerConfig | None Sampler configuration used for this fit (per Technical Guidelines) x : ndarray | None Input x data (for reproducibility metadata) """ samples: dict[str, np.ndarray] param_names: list[str] = field(default_factory=list) summary: pd.DataFrame | None = None diagnostics: FitDiagnostics = field(default_factory=FitDiagnostics) nlsq_init: dict[str, float] = field(default_factory=dict) arviz_data: DataTree | None = None config: SamplerConfig | None = None # Per Technical Guidelines x: np.ndarray | None = None # For data_metadata
[docs] def __post_init__(self) -> None: """Validate samples invariants at construction (BUG-024). Enforces: - samples dict is non-empty (no results without actual posterior draws) - all sample arrays have consistent shapes (same number of draws) """ if not self.samples: raise ValueError( "FitResult.samples must be non-empty: " "posterior samples dict cannot be empty" ) shapes = {name: np.asarray(arr).shape for name, arr in self.samples.items()} shape_values = list(shapes.values()) if len(shape_values) > 1 and not all( s == shape_values[0] for s in shape_values ): raise ValueError( f"FitResult.samples arrays have inconsistent shapes: {shapes}. " "All sample arrays must have the same shape." )
[docs] def get_mean(self, param: str) -> float: """Get posterior mean for parameter. Parameters ---------- param : str Parameter name Returns ------- float Posterior mean """ if param not in self.samples: raise KeyError(f"Parameter '{param}' not found in samples") return float(np.mean(self.samples[param]))
[docs] def get_std(self, param: str) -> float: """Get posterior standard deviation for parameter. Parameters ---------- param : str Parameter name Returns ------- float Posterior standard deviation """ if param not in self.samples: raise KeyError(f"Parameter '{param}' not found in samples") return float(np.std(self.samples[param]))
[docs] def get_hdi(self, param: str, prob: float = 0.94) -> tuple[float, float]: """Get highest density interval for parameter. Uses ArviZ's ``az.hdi()`` for the true HDI (shortest interval containing ``prob`` probability mass). Falls back to an equal-tailed percentile interval when ArviZ is unavailable. .. note:: For skewed posteriors (e.g. LogNormal tau), the true HDI can be significantly narrower than the equal-tailed interval. The previous implementation always used percentiles, which over-estimated uncertainty for skewed parameters (BUG-C fix). Parameters ---------- param : str Parameter name prob : float Probability mass for HDI (default: 0.94) Returns ------- tuple[float, float] (lower, upper) bounds of HDI """ if param not in self.samples: raise KeyError(f"Parameter '{param}' not found in samples") samples = self.samples[param] # BUG-C fix: Use ArviZ for true HDI computation instead of # equal-tailed percentile interval. For symmetric distributions # they are identical; for skewed distributions (LogNormal tau), # the HDI is up to ~35% narrower. try: import arviz as az hdi_result = az.hdi(samples, hdi_prob=prob) return (float(hdi_result[0]), float(hdi_result[1])) except (ImportError, Exception): # Fallback: equal-tailed interval (wider for skewed posteriors) alpha = 1 - prob lower_pct = 100 * (alpha / 2) upper_pct = 100 * (1 - alpha / 2) return ( float(np.percentile(samples, lower_pct)), float(np.percentile(samples, upper_pct)), )
[docs] def get_samples(self, param: str) -> np.ndarray: """Get posterior samples for parameter. Parameters ---------- param : str Parameter name Returns ------- ndarray Posterior samples array """ if param not in self.samples: raise KeyError(f"Parameter '{param}' not found in samples") return self.samples[param]
[docs] def predict(self, x: ArrayLike) -> tuple[np.ndarray, np.ndarray]: """Generate posterior predictive samples. .. note:: This method requires a model function which is not stored in ``FitResult``. Use :meth:`plot_posterior_predictive` with an explicit model function, or compute predictions manually from :meth:`get_samples`. Parameters ---------- x : array-like X values for prediction Returns ------- tuple[ndarray, ndarray] (mean prediction, std prediction) Raises ------ NotImplementedError Always raised. Use ``plot_posterior_predictive(model, x, y)`` or compute predictions from posterior samples directly. """ raise NotImplementedError( "FitResult.predict() requires a model function. Use " "plot_posterior_predictive() with an explicit model, or compute " "predictions from get_samples() directly." )
[docs] def to_dict(self) -> dict: """Convert to serializable dictionary. Per Technical Guidelines, exports include: - versions: Package versions for reproducibility - sampler_config: Sampler parameters used - data_metadata: Data characteristics - diagnostics: Including BFMI Returns ------- dict Dictionary representation with full reproducibility metadata """ # T026: Build versions dictionary for reproducibility versions = { "xpcsviewer": safe_version("xpcs-toolkit"), "numpyro": safe_version("numpyro"), "jax": safe_version("jax"), "arviz": safe_version("arviz"), "nlsq": safe_version("nlsq"), "python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", } # T027: Build sampler_config dictionary sampler_config: dict[str, Any] = {} if self.config is not None: sampler_config = { "num_warmup": self.config.num_warmup, "num_samples": self.config.num_samples, "num_chains": self.config.num_chains, "target_accept_prob": self.config.target_accept_prob, "max_tree_depth": self.config.max_tree_depth, "random_seed": self.config.random_seed, } # T028: Build data_metadata dictionary data_metadata: dict[str, Any] = {} if self.x is not None: x_arr = np.asarray(self.x) data_metadata = { "n_points": len(x_arr), "x_range": [float(x_arr.min()), float(x_arr.max())], } return { "samples": {k: v.tolist() for k, v in self.samples.items()}, "nlsq_init": self.nlsq_init, "diagnostics": { "r_hat": self.diagnostics.r_hat, "ess_bulk": self.diagnostics.ess_bulk, "ess_tail": self.diagnostics.ess_tail, "divergences": self.diagnostics.divergences, "max_treedepth_reached": self.diagnostics.max_treedepth_reached, "bfmi": self.diagnostics.bfmi, # Per Technical Guidelines "converged": self.diagnostics.converged, }, "versions": versions, "sampler_config": sampler_config, "data_metadata": data_metadata, }
[docs] def plot_posterior_predictive(self, model, x_data, y_data, **kwargs): """Plot posterior predictive with credible interval. Parameters ---------- model : callable Model function x_data, y_data : array-like Original data **kwargs Additional arguments for visualization Returns ------- matplotlib.figure.Figure """ from .visualization import plot_posterior_predictive return plot_posterior_predictive(self, model, x_data, y_data, **kwargs)
[docs] def generate_diagnostics(self, output_dir=None, formats=("pdf", "png")) -> dict: """Generate ArviZ diagnostic plots. Parameters ---------- output_dir : str or Path, optional Directory for output files formats : tuple Output formats Returns ------- dict Mapping of plot type to figure or file path """ from .visualization import generate_arviz_diagnostics return generate_arviz_diagnostics( self.arviz_data, output_dir=output_dir, formats=formats )