"""Result dataclasses for Bayesian fitting.
This module defines the data structures for fit results, diagnostics,
and sampler configuration.
"""
from __future__ import annotations
import logging
import sys
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
logger = logging.getLogger(__name__)
import numpy as np
if TYPE_CHECKING:
import pandas as pd
from nlsq.diagnostics import ModelHealthReport
from nlsq.result import CurveFitResult
from numpy.typing import ArrayLike
from xarray import DataTree
def safe_version(package_name: str) -> str:
"""Safely retrieve package version for reproducibility tracking.
Per Technical Guidelines, fit artifacts must include software versions.
This function provides robust version retrieval that never raises exceptions.
Parameters
----------
package_name : str
Name of the package to get version for
Returns
-------
str
Version string, or "unknown" if version cannot be determined
"""
try:
from importlib.metadata import version
return version(package_name)
except Exception:
# Fallback: try __version__ attribute
try:
import importlib
module = importlib.import_module(package_name)
return getattr(module, "__version__", "unknown")
except Exception:
return "unknown"
[docs]
@dataclass
class SamplerConfig:
"""Configuration for NUTS sampler.
Attributes
----------
num_warmup : int
Number of warmup (burn-in) samples (default: 500)
num_samples : int
Number of posterior samples per chain (default: 1000)
num_chains : int
Number of MCMC chains (default: 4)
target_accept_prob : float
Target acceptance probability for NUTS (default: 0.8)
max_tree_depth : int
Maximum tree depth for NUTS (default: 10)
random_seed : int or None
Random seed for reproducibility (default: None)
"""
num_warmup: int = 500
num_samples: int = 1000
num_chains: int = 4
target_accept_prob: float = 0.8
max_tree_depth: int = 10
random_seed: int | None = None
[docs]
def __post_init__(self):
"""Validate sampler configuration."""
if self.num_warmup <= 0:
raise ValueError(f"num_warmup must be positive, got {self.num_warmup}")
if self.num_samples <= 0:
raise ValueError(f"num_samples must be positive, got {self.num_samples}")
if self.num_chains <= 0:
raise ValueError(f"num_chains must be positive, got {self.num_chains}")
if not (0 < self.target_accept_prob < 1):
raise ValueError(
f"target_accept_prob must be in (0, 1), got {self.target_accept_prob}"
)
if self.max_tree_depth <= 0:
raise ValueError(
f"max_tree_depth must be positive, got {self.max_tree_depth}"
)
[docs]
@dataclass
class FitDiagnostics:
"""Convergence diagnostics for MCMC sampling.
Attributes
----------
r_hat : dict[str, float]
Gelman-Rubin statistic per parameter
ess_bulk : dict[str, int]
Bulk ESS per parameter
ess_tail : dict[str, int]
Tail ESS per parameter
divergences : int
Number of divergent transitions
max_treedepth_reached : int
Count of max treedepth events
bfmi : float | None
Bayesian Fraction of Missing Information (mean across chains).
Added per Technical Guidelines for Bayesian inference compliance.
Properties (computed)
---------------------
converged : bool
True if all diagnostics pass thresholds (see below).
This is a computed ``@property``, not a stored field.
Convergence Thresholds
----------------------
r_hat < 1.01 : All parameters must converge
ess_bulk > 400 : Sufficient effective samples
ess_tail > 400 : Sufficient tail samples
divergences == 0 : No divergent transitions
bfmi >= 0.2 : Adequate exploration (if computed)
"""
r_hat: dict[str, float] = field(default_factory=dict)
ess_bulk: dict[str, int] = field(default_factory=dict)
ess_tail: dict[str, int] = field(default_factory=dict)
divergences: int = 0
max_treedepth_reached: int = 0
bfmi: float | None = None # NEW: Per Technical Guidelines
[docs]
def __post_init__(self) -> None:
"""Validate diagnostic invariants (BUG-025).
Enforces:
- divergences >= 0 (negative divergences are physically meaningless)
- all ess_bulk values are non-negative
- all ess_tail values are non-negative
"""
if self.divergences < 0:
raise ValueError(
f"divergences must be non-negative, got {self.divergences}"
)
for param, ess in self.ess_bulk.items():
if ess < 0:
raise ValueError(f"ess_bulk['{param}'] must be non-negative, got {ess}")
for param, ess in self.ess_tail.items():
if ess < 0:
raise ValueError(f"ess_tail['{param}'] must be non-negative, got {ess}")
@property
def converged(self) -> bool:
"""Check if all diagnostics pass thresholds."""
# R-hat threshold
if any(r > 1.01 for r in self.r_hat.values()):
return False
# ESS thresholds
if any(e < 400 for e in self.ess_bulk.values()):
return False
if any(e < 400 for e in self.ess_tail.values()):
return False
# No divergences
if self.divergences > 0:
return False
# BFMI check (per Technical Guidelines)
if self.bfmi is not None and self.bfmi < 0.2:
return False
return True
[docs]
@dataclass
class NLSQResult:
"""Result from nonlinear least squares fitting.
This class wraps NLSQ 0.6.0's CurveFitResult and delegates statistical
properties to the native result for accuracy and consistency.
Attributes
----------
params : dict[str, float]
Point estimates for each parameter
converged : bool
Whether optimization converged
chi_squared : float
Reduced chi-squared statistic
pcov_valid : bool
Covariance validity flag (FR-021)
pcov_message : str
Validation message describing covariance status
native_result : CurveFitResult, optional
NLSQ 0.6.0 native result object for property delegation
_param_names : list[str]
Ordered parameter names for covariance indexing
Properties (delegated to native_result when available)
------------------------------------------------------
r_squared : float
Coefficient of determination (R²). Range: (-∞, 1], where 1 is perfect fit.
adj_r_squared : float
Adjusted R² accounting for number of parameters.
rmse : float
Root mean squared error. Lower is better.
mae : float
Mean absolute error. Robust to outliers.
aic : float
Akaike Information Criterion. Lower is better for model selection.
bic : float
Bayesian Information Criterion. Penalizes complexity more than AIC.
residuals : ndarray
Fit residuals as numpy array.
predictions : ndarray
Model predictions at input x values.
covariance : ndarray
Parameter covariance matrix (n_params x n_params).
confidence_intervals : dict[str, tuple[float, float]]
Parameter confidence intervals at 95% level.
diagnostics : ModelHealthReport or None
NLSQ model health diagnostics (if compute_diagnostics=True).
is_healthy : bool
Whether the fit passes all health checks.
health_score : int
Health score (0-100).
condition_number : float
Condition number from identifiability diagnostics.
"""
# Core fields (required)
params: dict[str, float]
chi_squared: float
converged: bool
# Covariance validation
pcov_valid: bool = True
pcov_message: str = ""
# Native result for delegation (NEW - T018)
native_result: CurveFitResult | None = None
# Sentinel flag: True when this result is a fallback due to fitting failure (BUG-003)
is_fallback: bool = False
# Parameter names for covariance indexing (NEW - T019)
_param_names: list[str] = field(default_factory=list)
# Legacy storage fields (used when native_result is None for backward compat)
_covariance: np.ndarray | None = field(default=None, repr=False)
_residuals: np.ndarray | None = field(default=None, repr=False)
# BUG-023: default to NaN so failed fits are distinguishable from poor-but-converged fits.
# Callers can detect failure with math.isnan(result.r_squared) instead of checking == 0.0.
_r_squared: float = field(default=float("nan"), repr=False)
_adj_r_squared: float = field(default=float("nan"), repr=False)
_rmse: float = field(default=float("nan"), repr=False)
_mae: float = field(default=float("nan"), repr=False)
_aic: float = field(default=float("nan"), repr=False)
_bic: float = field(default=float("nan"), repr=False)
_confidence_intervals: dict[str, tuple[float, float]] = field(
default_factory=dict, repr=False
)
_predictions: np.ndarray | None = field(default=None, repr=False)
# Backward compatibility aliases for __init__
[docs]
def __post_init__(self) -> None:
"""Initialize param names from params dict if not provided."""
if not self._param_names:
self._param_names = list(self.params.keys())
# T020: r_squared property delegation
@property
def r_squared(self) -> float:
"""Coefficient of determination (R²)."""
if self.native_result is not None and hasattr(self.native_result, "r_squared"):
return self.native_result.r_squared
return self._r_squared
@r_squared.setter
def r_squared(self, value: float) -> None:
"""Set r_squared (for backward compat initialization).
Note: This setter is a no-op when native_result is active (BUG-039).
The value is delegated to native_result.r_squared instead.
"""
if self.native_result is not None:
logger.warning(
"NLSQResult.r_squared setter called while delegation to native_result "
"is active; this assignment is a no-op. The getter will return "
"native_result.r_squared."
)
return
self._r_squared = value
# T021: adj_r_squared property delegation
@property
def adj_r_squared(self) -> float:
"""Adjusted R² accounting for number of parameters."""
if self.native_result is not None and hasattr(
self.native_result, "adj_r_squared"
):
return self.native_result.adj_r_squared
return self._adj_r_squared
@adj_r_squared.setter
def adj_r_squared(self, value: float) -> None:
"""Set adj_r_squared (for backward compat initialization).
Note: This setter is a no-op when native_result is active (BUG-039).
"""
if self.native_result is not None:
logger.warning(
"NLSQResult.adj_r_squared setter called while delegation to "
"native_result is active; this assignment is a no-op."
)
return
self._adj_r_squared = value
# T022: rmse property delegation
@property
def rmse(self) -> float:
"""Root mean squared error."""
if self.native_result is not None and hasattr(self.native_result, "rmse"):
return self.native_result.rmse
return self._rmse
@rmse.setter
def rmse(self, value: float) -> None:
"""Set rmse (for backward compat initialization).
Note: This setter is a no-op when native_result is active (BUG-039).
"""
if self.native_result is not None:
logger.warning(
"NLSQResult.rmse setter called while delegation to "
"native_result is active; this assignment is a no-op."
)
return
self._rmse = value
# T023: mae property delegation
@property
def mae(self) -> float:
"""Mean absolute error."""
if self.native_result is not None and hasattr(self.native_result, "mae"):
return self.native_result.mae
return self._mae
@mae.setter
def mae(self, value: float) -> None:
"""Set mae (for backward compat initialization).
Note: This setter is a no-op when native_result is active (BUG-039).
"""
if self.native_result is not None:
logger.warning(
"NLSQResult.mae setter called while delegation to "
"native_result is active; this assignment is a no-op."
)
return
self._mae = value
# T024: aic property delegation
@property
def aic(self) -> float:
"""Akaike Information Criterion."""
if self.native_result is not None and hasattr(self.native_result, "aic"):
return self.native_result.aic
return self._aic
@aic.setter
def aic(self, value: float) -> None:
"""Set aic (for backward compat initialization).
Note: This setter is a no-op when native_result is active (BUG-039).
"""
if self.native_result is not None:
logger.warning(
"NLSQResult.aic setter called while delegation to "
"native_result is active; this assignment is a no-op."
)
return
self._aic = value
# T025: bic property delegation
@property
def bic(self) -> float:
"""Bayesian Information Criterion."""
if self.native_result is not None and hasattr(self.native_result, "bic"):
return self.native_result.bic
return self._bic
@bic.setter
def bic(self, value: float) -> None:
"""Set bic (for backward compat initialization).
Note: This setter is a no-op when native_result is active (BUG-039).
"""
if self.native_result is not None:
logger.warning(
"NLSQResult.bic setter called while delegation to "
"native_result is active; this assignment is a no-op."
)
return
self._bic = value
# T026: residuals property delegation
@property
def residuals(self) -> np.ndarray:
"""Fit residuals as numpy array."""
if self.native_result is not None and hasattr(self.native_result, "residuals"):
return np.asarray(self.native_result.residuals)
if self._residuals is not None:
return self._residuals
return np.array([])
@residuals.setter
def residuals(self, value: np.ndarray) -> None:
"""Set residuals (for backward compat initialization).
Note: This setter is a no-op when native_result is active (BUG-039).
"""
if self.native_result is not None:
logger.warning(
"NLSQResult.residuals setter called while delegation to "
"native_result is active; this assignment is a no-op."
)
return
self._residuals = value
# T027: predictions property delegation
@property
def predictions(self) -> np.ndarray | None:
"""Model predictions at input x values."""
if self.native_result is not None and hasattr(
self.native_result, "predictions"
):
return np.asarray(self.native_result.predictions)
return self._predictions
@predictions.setter
def predictions(self, value: np.ndarray | None) -> None:
"""Set predictions (for backward compat initialization).
Note: This setter is a no-op when native_result is active (BUG-039).
"""
if self.native_result is not None:
logger.warning(
"NLSQResult.predictions setter called while delegation to "
"native_result is active; this assignment is a no-op."
)
return
self._predictions = value
# T035: covariance property delegation
@property
def covariance(self) -> np.ndarray:
"""Parameter covariance matrix."""
if self.native_result is not None and hasattr(self.native_result, "pcov"):
return np.asarray(self.native_result.pcov)
if self._covariance is not None:
return self._covariance
n = len(self.params)
return np.zeros((n, n))
@covariance.setter
def covariance(self, value: np.ndarray) -> None:
"""Set covariance (for backward compat initialization).
Note: This setter is a no-op when native_result is active (BUG-039).
"""
if self.native_result is not None:
logger.warning(
"NLSQResult.covariance setter called while delegation to "
"native_result is active; this assignment is a no-op."
)
return
self._covariance = value
# T029: confidence_intervals property delegation
@property
def confidence_intervals(self) -> dict[str, tuple[float, float]]:
"""Parameter confidence intervals at 95% level."""
if self.native_result is not None and hasattr(
self.native_result, "confidence_intervals"
):
ci = self.native_result.confidence_intervals
# Handle both property (dict) and method (callable) on native_result
if callable(ci):
ci = ci()
if isinstance(ci, dict):
return ci
return self._confidence_intervals
@confidence_intervals.setter
def confidence_intervals(self, value: dict[str, tuple[float, float]]) -> None:
"""Set confidence_intervals (for backward compat initialization).
Note: This setter is a no-op when native_result is active (BUG-039).
"""
if self.native_result is not None:
logger.warning(
"NLSQResult.confidence_intervals setter called while delegation to "
"native_result is active; this assignment is a no-op."
)
return
self._confidence_intervals = value
# T31: diagnostics property delegation
@property
def diagnostics(self) -> ModelHealthReport | None:
"""NLSQ model health diagnostics."""
if self.native_result is not None and hasattr(
self.native_result, "diagnostics"
):
return self.native_result.diagnostics
return None
# T032: is_healthy property
@property
def is_healthy(self) -> bool:
"""Whether the fit passes all health checks.
Uses an attribute-based check against the NLSQ diagnostics API (BUG-040).
Avoids the fragile str(status) == "healthy" string comparison by:
1. Checking diagnostics.is_healthy attribute if available (NLSQ native), or
2. Using health_score >= 70 as a numeric fallback, or
3. Checking enum value via .value or .name attribute comparison.
"""
if self.diagnostics is None:
return False # Unknown health != healthy
# Primary: use native is_healthy attribute if available (NLSQ 0.6.0)
if hasattr(self.diagnostics, "is_healthy"):
return bool(self.diagnostics.is_healthy)
# Secondary: use numeric health_score (avoids string comparison)
if hasattr(self.diagnostics, "health_score"):
return int(self.diagnostics.health_score) >= 70
# Fallback: use enum value/name comparison (more robust than str())
status = self.diagnostics.status
if hasattr(status, "value"):
return status.value == "healthy" # type: ignore[comparison-overlap]
if hasattr(status, "name"):
return status.name.lower() == "healthy"
# Last resort: string comparison (kept for extreme backward compat)
return str(status).lower() in ("healthy", "modelstatus.healthy")
# T033: health_score property
@property
def health_score(self) -> int:
"""Health score (0-100)."""
if self.diagnostics is not None:
return int(self.diagnostics.health_score)
return 0 # Unknown health != perfect health
# T034: condition_number property
@property
def condition_number(self) -> float:
"""Condition number from identifiability diagnostics."""
if (
self.diagnostics is not None
and self.diagnostics.identifiability is not None
):
return self.diagnostics.identifiability.condition_number
return 1.0 # Default to well-conditioned
[docs]
def get_param_uncertainty(self, param: str) -> float:
"""Get standard error for a parameter from the covariance matrix.
Parameters
----------
param : str
Parameter name
Returns
-------
float
Standard error (sqrt of diagonal covariance element)
"""
param_names = list(self.params.keys())
if param not in param_names:
raise KeyError(f"Parameter '{param}' not found in params")
idx = param_names.index(param)
return float(np.sqrt(self.covariance[idx, idx]))
[docs]
def get_confidence_interval(
self, param: str, alpha: float = 0.05
) -> tuple[float, float]:
"""Get confidence interval for a parameter.
Parameters
----------
param : str
Parameter name
alpha : float
Significance level (default: 0.05 for 95% CI)
Returns
-------
tuple[float, float]
(lower, upper) bounds of confidence interval
"""
if param in self.confidence_intervals:
return self.confidence_intervals[param]
# Compute from covariance if not cached
from scipy import stats
param_value = self.params[param]
std_err = self.get_param_uncertainty(param)
z = stats.norm.ppf(1 - alpha / 2)
return (param_value - z * std_err, param_value + z * std_err)
# T030: get_prediction_interval method delegation
[docs]
def get_prediction_interval(
self, x: ArrayLike, alpha: float = 0.05
) -> tuple[np.ndarray, np.ndarray]:
"""Get prediction interval at new x values.
Delegates to native_result.prediction_interval() when available.
Parameters
----------
x : array_like
X values at which to compute prediction intervals
alpha : float
Significance level (default: 0.05 for 95% PI)
Returns
-------
tuple[ndarray, ndarray]
(lower, upper) bounds of prediction interval as numpy arrays
"""
x = np.asarray(x)
if self.native_result is not None:
pi = self.native_result.prediction_interval(x=x, alpha=alpha)
pi = np.asarray(pi)
# Handle both (n, 2) array and (2, n) array formats
if pi.ndim == 2 and pi.shape[-1] == 2:
return pi[:, 0], pi[:, 1]
if pi.ndim == 2 and pi.shape[0] == 2:
return pi[0], pi[1]
# Fallback: try direct unpack
return pi[0], pi[1]
# Fallback: return simple prediction +/- 2*rmse (rough approximation)
predictions = self.predictions
if predictions is not None and len(predictions) == len(x):
margin = 2 * self.rmse
return predictions - margin, predictions + margin
# No prediction interval available
return np.full_like(x, np.nan), np.full_like(x, np.nan)
[docs]
def summary(self) -> str:
"""Generate a formatted summary of the fit results.
Returns
-------
str
Formatted summary string
"""
lines = ["NLSQ Fit Results", "=" * 50]
# Convergence status
status = "Converged" if self.converged else "Did not converge"
lines.append(f"Status: {status}")
lines.append("")
# Fit quality metrics
lines.append("Fit Quality:")
lines.append(f" R²: {self.r_squared:.6f}")
lines.append(f" Adjusted R²: {self.adj_r_squared:.6f}")
lines.append(f" RMSE: {self.rmse:.6e}")
lines.append(f" MAE: {self.mae:.6e}")
lines.append(f" χ² (reduced): {self.chi_squared:.4f}")
lines.append("")
# Model selection criteria
lines.append("Model Selection:")
lines.append(f" AIC: {self.aic:.2f}")
lines.append(f" BIC: {self.bic:.2f}")
lines.append("")
# Parameters with uncertainties
lines.append("Parameters:")
for name, value in self.params.items():
std_err = self.get_param_uncertainty(name)
ci = self.get_confidence_interval(name)
lines.append(f" {name}: {value:.6e} ± {std_err:.6e}")
lines.append(f" 95% CI: [{ci[0]:.6e}, {ci[1]:.6e}]")
# Covariance validity
lines.append("")
if self.pcov_valid:
lines.append("Covariance: Valid")
else:
lines.append(f"Covariance: Invalid - {self.pcov_message}")
return "\n".join(lines)
[docs]
def plot(self, model, x_data, y_data, **kwargs):
"""Plot fit with uncertainty band.
Parameters
----------
model : callable
Model function
x_data, y_data : array-like
Original data
**kwargs
Additional arguments for visualization
Returns
-------
matplotlib.axes.Axes
"""
from .visualization import plot_nlsq_fit
return plot_nlsq_fit(self, model, x_data, y_data, **kwargs)
[docs]
@dataclass
class FitResult:
"""Container for Bayesian fitting results.
Attributes
----------
samples : dict[str, ndarray]
Posterior samples {param_name: (n_samples,)}
param_names : list[str]
Canonical parameter order matching the model function signature.
Used by visualization to ensure correct positional argument mapping.
summary : DataFrame
Summary statistics per parameter
diagnostics : FitDiagnostics
Convergence diagnostics
nlsq_init : dict[str, float]
NLSQ warm-start point estimates
arviz_data : DataTree
ArviZ-compatible data for plotting (FR-015)
config : SamplerConfig | None
Sampler configuration used for this fit (per Technical Guidelines)
x : ndarray | None
Input x data (for reproducibility metadata)
"""
samples: dict[str, np.ndarray]
param_names: list[str] = field(default_factory=list)
summary: pd.DataFrame | None = None
diagnostics: FitDiagnostics = field(default_factory=FitDiagnostics)
nlsq_init: dict[str, float] = field(default_factory=dict)
arviz_data: DataTree | None = None
config: SamplerConfig | None = None # Per Technical Guidelines
x: np.ndarray | None = None # For data_metadata
[docs]
def __post_init__(self) -> None:
"""Validate samples invariants at construction (BUG-024).
Enforces:
- samples dict is non-empty (no results without actual posterior draws)
- all sample arrays have consistent shapes (same number of draws)
"""
if not self.samples:
raise ValueError(
"FitResult.samples must be non-empty: "
"posterior samples dict cannot be empty"
)
shapes = {name: np.asarray(arr).shape for name, arr in self.samples.items()}
shape_values = list(shapes.values())
if len(shape_values) > 1 and not all(
s == shape_values[0] for s in shape_values
):
raise ValueError(
f"FitResult.samples arrays have inconsistent shapes: {shapes}. "
"All sample arrays must have the same shape."
)
[docs]
def get_mean(self, param: str) -> float:
"""Get posterior mean for parameter.
Parameters
----------
param : str
Parameter name
Returns
-------
float
Posterior mean
"""
if param not in self.samples:
raise KeyError(f"Parameter '{param}' not found in samples")
return float(np.mean(self.samples[param]))
[docs]
def get_std(self, param: str) -> float:
"""Get posterior standard deviation for parameter.
Parameters
----------
param : str
Parameter name
Returns
-------
float
Posterior standard deviation
"""
if param not in self.samples:
raise KeyError(f"Parameter '{param}' not found in samples")
return float(np.std(self.samples[param]))
[docs]
def get_hdi(self, param: str, prob: float = 0.94) -> tuple[float, float]:
"""Get highest density interval for parameter.
Uses ArviZ's ``az.hdi()`` for the true HDI (shortest interval
containing ``prob`` probability mass). Falls back to an equal-tailed
percentile interval when ArviZ is unavailable.
.. note::
For skewed posteriors (e.g. LogNormal tau), the true HDI can be
significantly narrower than the equal-tailed interval. The previous
implementation always used percentiles, which over-estimated
uncertainty for skewed parameters (BUG-C fix).
Parameters
----------
param : str
Parameter name
prob : float
Probability mass for HDI (default: 0.94)
Returns
-------
tuple[float, float]
(lower, upper) bounds of HDI
"""
if param not in self.samples:
raise KeyError(f"Parameter '{param}' not found in samples")
samples = self.samples[param]
# BUG-C fix: Use ArviZ for true HDI computation instead of
# equal-tailed percentile interval. For symmetric distributions
# they are identical; for skewed distributions (LogNormal tau),
# the HDI is up to ~35% narrower.
try:
import arviz as az
hdi_result = az.hdi(samples, hdi_prob=prob)
return (float(hdi_result[0]), float(hdi_result[1]))
except (ImportError, Exception):
# Fallback: equal-tailed interval (wider for skewed posteriors)
alpha = 1 - prob
lower_pct = 100 * (alpha / 2)
upper_pct = 100 * (1 - alpha / 2)
return (
float(np.percentile(samples, lower_pct)),
float(np.percentile(samples, upper_pct)),
)
[docs]
def get_samples(self, param: str) -> np.ndarray:
"""Get posterior samples for parameter.
Parameters
----------
param : str
Parameter name
Returns
-------
ndarray
Posterior samples array
"""
if param not in self.samples:
raise KeyError(f"Parameter '{param}' not found in samples")
return self.samples[param]
[docs]
def predict(self, x: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
"""Generate posterior predictive samples.
.. note::
This method requires a model function which is not stored in
``FitResult``. Use :meth:`plot_posterior_predictive` with an
explicit model function, or compute predictions manually from
:meth:`get_samples`.
Parameters
----------
x : array-like
X values for prediction
Returns
-------
tuple[ndarray, ndarray]
(mean prediction, std prediction)
Raises
------
NotImplementedError
Always raised. Use ``plot_posterior_predictive(model, x, y)``
or compute predictions from posterior samples directly.
"""
raise NotImplementedError(
"FitResult.predict() requires a model function. Use "
"plot_posterior_predictive() with an explicit model, or compute "
"predictions from get_samples() directly."
)
[docs]
def to_dict(self) -> dict:
"""Convert to serializable dictionary.
Per Technical Guidelines, exports include:
- versions: Package versions for reproducibility
- sampler_config: Sampler parameters used
- data_metadata: Data characteristics
- diagnostics: Including BFMI
Returns
-------
dict
Dictionary representation with full reproducibility metadata
"""
# T026: Build versions dictionary for reproducibility
versions = {
"xpcsviewer": safe_version("xpcs-toolkit"),
"numpyro": safe_version("numpyro"),
"jax": safe_version("jax"),
"arviz": safe_version("arviz"),
"nlsq": safe_version("nlsq"),
"python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
}
# T027: Build sampler_config dictionary
sampler_config: dict[str, Any] = {}
if self.config is not None:
sampler_config = {
"num_warmup": self.config.num_warmup,
"num_samples": self.config.num_samples,
"num_chains": self.config.num_chains,
"target_accept_prob": self.config.target_accept_prob,
"max_tree_depth": self.config.max_tree_depth,
"random_seed": self.config.random_seed,
}
# T028: Build data_metadata dictionary
data_metadata: dict[str, Any] = {}
if self.x is not None:
x_arr = np.asarray(self.x)
data_metadata = {
"n_points": len(x_arr),
"x_range": [float(x_arr.min()), float(x_arr.max())],
}
return {
"samples": {k: v.tolist() for k, v in self.samples.items()},
"nlsq_init": self.nlsq_init,
"diagnostics": {
"r_hat": self.diagnostics.r_hat,
"ess_bulk": self.diagnostics.ess_bulk,
"ess_tail": self.diagnostics.ess_tail,
"divergences": self.diagnostics.divergences,
"max_treedepth_reached": self.diagnostics.max_treedepth_reached,
"bfmi": self.diagnostics.bfmi, # Per Technical Guidelines
"converged": self.diagnostics.converged,
},
"versions": versions,
"sampler_config": sampler_config,
"data_metadata": data_metadata,
}
[docs]
def plot_posterior_predictive(self, model, x_data, y_data, **kwargs):
"""Plot posterior predictive with credible interval.
Parameters
----------
model : callable
Model function
x_data, y_data : array-like
Original data
**kwargs
Additional arguments for visualization
Returns
-------
matplotlib.figure.Figure
"""
from .visualization import plot_posterior_predictive
return plot_posterior_predictive(self, model, x_data, y_data, **kwargs)
[docs]
def generate_diagnostics(self, output_dir=None, formats=("pdf", "png")) -> dict:
"""Generate ArviZ diagnostic plots.
Parameters
----------
output_dir : str or Path, optional
Directory for output files
formats : tuple
Output formats
Returns
-------
dict
Mapping of plot type to figure or file path
"""
from .visualization import generate_arviz_diagnostics
return generate_arviz_diagnostics(
self.arviz_data, output_dir=output_dir, formats=formats
)