Source code for xpcsviewer.fitting.visualization

"""Fitting visualization module (FR-013 to FR-021).

This module provides visualization functions for NLSQ and Bayesian
fitting results, including uncertainty bands, diagnostic plots,
and publication-quality output.

NLSQ 0.6.0 Enhanced Features:
- Prediction interval visualization
- R², RMSE, AIC/BIC display on plots
- Confidence interval annotations
- Model comparison with information criteria
"""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    import matplotlib.pyplot as plt
    from matplotlib.figure import Figure

    from .results import FitResult, NLSQResult


# Publication style preset (FR-019)
PUBLICATION_STYLE = {
    "font.family": "serif",
    "font.size": 10,
    "axes.grid": True,
    "axes.linewidth": 0.8,
    "grid.alpha": 0.3,
    "figure.dpi": 300,
    "savefig.dpi": 300,
    "savefig.bbox": "tight",
}


[docs] def apply_publication_style(): """Apply publication-quality matplotlib style (FR-019).""" import matplotlib.pyplot as plt plt.rcParams.update(PUBLICATION_STYLE)
[docs] def validate_pcov(pcov, param_names=None) -> tuple[bool, str]: """Validate covariance matrix before computing uncertainty bands (FR-021). Checks: - pcov is not None - All values are finite (no inf/nan) - Matrix is positive semi-definite Parameters ---------- pcov : ndarray or None Covariance matrix to validate param_names : list, optional Parameter names for error messages Returns ------- is_valid : bool True if covariance is valid message : str Validation message (error description if invalid) """ import numpy as np if pcov is None: return False, "Covariance matrix is None" pcov = np.asarray(pcov) if not np.all(np.isfinite(pcov)): return False, "Covariance matrix contains inf or nan values" # Check positive semi-definite (eigenvalues >= 0) try: eigenvalues = np.linalg.eigvalsh(pcov) if np.any(eigenvalues < -1e-10): # Small tolerance for numerical issues return False, "Covariance matrix is not positive semi-definite" except np.linalg.LinAlgError: return False, "Failed to compute eigenvalues of covariance matrix" return True, "Covariance matrix is valid"
[docs] def compute_uncertainty_band(model, x_pred, popt, pcov, confidence=0.95): """Compute prediction uncertainty band via error propagation (FR-016). Uses Jacobian-based error propagation to compute uncertainty bands. Parameters ---------- model : callable Model function taking x and params x_pred : ndarray X values for prediction popt : ndarray Fitted parameters pcov : ndarray Parameter covariance matrix (n_params x n_params) confidence : float Confidence level (default: 0.95 for 95% CI) Returns ------- y_fit : ndarray Fitted curve values y_lower : ndarray Lower bound of confidence band y_upper : ndarray Upper bound of confidence band """ import numpy as np from scipy import stats x_pred = np.asarray(x_pred) popt = np.asarray(popt) pcov = np.asarray(pcov) # Compute fit curve y_fit = model(x_pred, *popt) # Compute Jacobian via finite differences eps = 1e-8 n_params = len(popt) n_points = len(x_pred) jacobian = np.zeros((n_points, n_params)) for i in range(n_params): popt_plus = popt.copy() popt_plus[i] += eps jacobian[:, i] = (model(x_pred, *popt_plus) - y_fit) / eps # Variance: diag(J @ pcov @ J.T) # Efficient computation: sum((J @ pcov) * J, axis=1) variance = np.sum((jacobian @ pcov) * jacobian, axis=1) variance = np.maximum(variance, 0) # Ensure non-negative sigma = np.sqrt(variance) # Confidence interval z = stats.norm.ppf((1 + confidence) / 2) y_lower = y_fit - z * sigma y_upper = y_fit + z * sigma return y_fit, y_lower, y_upper
[docs] def compute_prediction_interval(model, x_pred, popt, pcov, residuals, confidence=0.95): """Compute prediction interval including residual variance (NLSQ 0.6.0). Prediction intervals are wider than confidence intervals because they account for both parameter uncertainty AND observation noise. Adds residual standard deviation term to confidence interval width. Parameters ---------- model : callable Model function taking x and params x_pred : ndarray X values for prediction popt : ndarray Fitted parameters pcov : ndarray Parameter covariance matrix residuals : ndarray Fit residuals (y - y_pred) from training data confidence : float Confidence level (default: 0.95 for 95% PI) Returns ------- y_fit : ndarray Fitted curve values pi_lower : ndarray Lower bound of prediction interval pi_upper : ndarray Upper bound of prediction interval """ import numpy as np from scipy import stats # Get confidence interval y_fit, ci_lower, ci_upper = compute_uncertainty_band( model, x_pred, popt, pcov, confidence ) # Estimate residual standard deviation n_data = len(residuals) n_params = len(popt) dof = max(1, n_data - n_params) sigma_residuals = np.sqrt(np.sum(residuals**2) / dof) # t-value for prediction interval t_value = stats.t.ppf((1 + confidence) / 2, dof) # Prediction interval = confidence interval + residual variance ci_half_width = (ci_upper - ci_lower) / 2 pi_half_width = np.sqrt(ci_half_width**2 + (t_value * sigma_residuals) ** 2) pi_lower = y_fit - pi_half_width pi_upper = y_fit + pi_half_width return y_fit, pi_lower, pi_upper
[docs] def plot_nlsq_fit( result: NLSQResult, model, x_data, y_data, x_pred=None, confidence=0.95, ax=None, show_metrics: bool = True, show_prediction_interval: bool = False, xlabel: str = "x", ylabel: str = "y", title: str | None = None, ) -> plt.Axes: """Plot NLSQ fit with uncertainty band (FR-017, FR-021, NLSQ 0.6.0). If covariance is invalid, logs warning and displays "Uncertainty unavailable" in legend. Parameters ---------- result : NLSQResult Output from nlsq_fit() model : callable Model function x_data, y_data : ndarray Original data x_pred : ndarray, optional X values for prediction curve confidence : float Confidence level for band (default: 0.95) ax : matplotlib.axes.Axes, optional Axes to plot on (creates new if None) show_metrics : bool, optional Display R², RMSE, and χ² on the plot (default: True). Uses NLSQ 0.6.0 enhanced metrics from NLSQResult. show_prediction_interval : bool, optional Show prediction interval in addition to confidence interval. Prediction intervals account for observation noise (default: False). xlabel : str, optional X-axis label (default: "x") ylabel : str, optional Y-axis label (default: "y") title : str, optional Plot title (default: None) Returns ------- ax : matplotlib.axes.Axes """ import logging import matplotlib.pyplot as plt import numpy as np logger = logging.getLogger(__name__) if ax is None: fig, ax = plt.subplots() x_data = np.asarray(x_data) y_data = np.asarray(y_data) if x_pred is None: x_pred = np.linspace(x_data.min(), x_data.max(), 200) # Plot data ax.scatter(x_data, y_data, c="k", s=20, alpha=0.7, label="Data", zorder=3) # Get parameters as array popt = np.array(list(result.params.values())) # Compute fit curve y_fit = model(x_pred, *popt) # Check covariance validity if result.pcov_valid: try: # Prediction interval (wider, includes observation noise) if show_prediction_interval: _, pi_lower, pi_upper = compute_prediction_interval( model, x_pred, popt, result.covariance, result.residuals, confidence ) ax.fill_between( x_pred, pi_lower, pi_upper, alpha=0.15, color="C0", label=f"{int(confidence * 100)}% PI", ) # Confidence interval (parameter uncertainty only) _, y_lower, y_upper = compute_uncertainty_band( model, x_pred, popt, result.covariance, confidence ) ax.fill_between( x_pred, y_lower, y_upper, alpha=0.3, color="C0", label=f"{int(confidence * 100)}% CI", ) except Exception as e: logger.warning(f"Failed to compute uncertainty band: {e}") ax.plot(x_pred, y_fit, "C0-", lw=2, label="Fit (uncertainty unavailable)") else: logger.warning(f"Covariance invalid: {result.pcov_message}") ax.plot(x_pred, y_fit, "C0-", lw=2, label="Fit (uncertainty unavailable)") # Plot fit curve on top if result.pcov_valid: ax.plot(x_pred, y_fit, "C0-", lw=2, label="Fit") ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) if title: ax.set_title(title) # Add metrics annotation (NLSQ 0.6.0 enhanced) if show_metrics: metrics_text = ( f"R² = {result.r_squared:.4f}\n" f"RMSE = {result.rmse:.2e}\n" f"χ²ᵣ = {result.chi_squared:.3f}" ) # Position in upper right corner with some padding ax.text( 0.97, 0.97, metrics_text, transform=ax.transAxes, fontsize=9, verticalalignment="top", horizontalalignment="right", bbox={"boxstyle": "round", "facecolor": "white", "alpha": 0.8}, ) ax.legend(loc="best") return ax
[docs] def plot_posterior_predictive( result: FitResult, model, x_data, y_data, x_pred=None, credible_level=0.95, n_draws=100, ax=None, *, max_draws: int | None = None, subsample_seed: int | None = None, ) -> plt.Axes: """Plot Bayesian fit with posterior credible interval (FR-014). Parameters ---------- result : FitResult Output from Bayesian fitting model : callable Model function x_data, y_data : ndarray Original data x_pred : ndarray, optional X values for prediction (default: smooth range over x_data) credible_level : float Credible interval level (default: 0.95) n_draws : int Number of posterior samples for band calculation (legacy, prefer max_draws) ax : matplotlib.axes.Axes, optional max_draws : int | None Maximum posterior draws to use. If None (default), uses all samples. When specified and n_samples > max_draws, subsamples with logging. Per Technical Guidelines, subsampling must be explicit and logged. subsample_seed : int | None Random seed for reproducible subsampling. Only used when max_draws triggers subsampling. Returns ------- ax : matplotlib.axes.Axes Notes ----- Per Technical Guidelines, posterior subsampling is explicit and logged. When max_draws is None (default), all posterior samples are used. """ import logging import matplotlib.pyplot as plt import numpy as np logger = logging.getLogger(__name__) if ax is None: fig, ax = plt.subplots() x_data = np.asarray(x_data) y_data = np.asarray(y_data) if x_pred is None: x_pred = np.linspace(x_data.min(), x_data.max(), 200) # Plot data ax.scatter(x_data, y_data, c="k", s=20, alpha=0.7, label="Data", zorder=3) # Generate posterior predictive samples — use canonical param_names # (matches model function signature) to avoid positional arg mismatch. param_names = result.param_names or list(result.samples.keys()) n_samples = len(result.samples[param_names[0]]) # Determine effective draw count per Technical Guidelines # max_draws takes precedence over legacy n_draws parameter if max_draws is not None: effective_draws = min(max_draws, n_samples) else: effective_draws = min(n_draws, n_samples) # Log subsampling if occurring per Technical Guidelines if effective_draws < n_samples: logger.info( f"Subsampling posterior: {n_samples} -> {effective_draws} draws " f"(seed={subsample_seed})" ) # Use seed for reproducibility if subsampling rng = np.random.default_rng(subsample_seed) indices = rng.choice(n_samples, effective_draws, replace=False) predictions = [] for idx in indices: params = [result.samples[name][idx] for name in param_names] predictions.append(model(x_pred, *params)) predictions_arr = np.array(predictions) # Compute credible interval alpha = 1 - credible_level lower = np.percentile(predictions_arr, 100 * alpha / 2, axis=0) upper = np.percentile(predictions_arr, 100 * (1 - alpha / 2), axis=0) median = np.median(predictions_arr, axis=0) # Plot credible interval ax.fill_between( x_pred, lower, upper, alpha=0.3, color="C1", label=f"{int(credible_level * 100)}% CI", ) ax.plot(x_pred, median, "C1-", lw=2, label="Median fit") ax.set_xlabel("x") ax.set_ylabel("y") ax.legend() return ax
def _extract_figure(plot_obj): """Extract matplotlib Figure from ArviZ 1.0 PlotCollection/PlotMatrix or legacy axes.""" # ArviZ 1.0: PlotCollection/PlotMatrix have .viz DataTree with figure if hasattr(plot_obj, "viz"): viz = plot_obj.viz if hasattr(viz, "ds") and "figure" in viz.ds: return viz.ds["figure"].item() if hasattr(viz, "__getitem__"): try: return viz["figure"].item() except (KeyError, AttributeError, ValueError): pass # Legacy ArviZ: axes with .figure attribute if hasattr(plot_obj, "figure"): return plot_obj.figure # Array of axes if hasattr(plot_obj, "__iter__"): import numpy as np axes_flat = np.asarray(plot_obj).flatten() if len(axes_flat) > 0 and hasattr(axes_flat[0], "figure"): return axes_flat[0].figure return None
[docs] def generate_arviz_diagnostics( trace, var_names=None, output_dir=None, formats=("pdf", "png"), dpi=300, prefix="mcmc", ) -> dict: """Generate complete ArviZ diagnostic suite (FR-013). Plots generated: 1. Pair plot (parameter correlations, divergences) 2. Forest plot (HDI intervals) 3. Energy plot (NUTS E-BFMI diagnostics) 4. Autocorrelation plot (chain mixing) 5. Rank plot (convergence check) 6. ESS plot (effective sample size) Parameters ---------- trace : DataTree ArviZ DataTree object from MCMC var_names : list, optional Parameter names to plot (default: all) output_dir : str or Path, optional Directory for output files (None = return figures only) formats : tuple Output formats (default: ("pdf", "png") per FR-018) dpi : int Resolution for raster formats (default: 300 per FR-018) prefix : str Filename prefix Returns ------- dict Mapping plot_type → figure (if output_dir is None) Mapping plot_type_format → file path (if output_dir provided) """ import arviz as az if trace is None: return {} results = {} # Define plot functions — ArviZ 1.0 returns PlotCollection/PlotMatrix plots = [ ("pair", lambda: az.plot_pair(trace, var_names=var_names)), ("forest", lambda: az.plot_forest(trace, var_names=var_names, combined=True)), ("energy", lambda: az.plot_energy(trace)), ("autocorr", lambda: az.plot_autocorr(trace, var_names=var_names)), ("rank", lambda: az.plot_rank(trace, var_names=var_names)), ("ess", lambda: az.plot_ess(trace, var_names=var_names, kind="local")), ] for plot_name, plot_func in plots: try: plot_obj = plot_func() # ArviZ 1.0 returns PlotCollection/PlotMatrix with savefig() # Extract matplotlib figure for return or saving fig = _extract_figure(plot_obj) if fig is not None: if output_dir is not None: from pathlib import Path output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) for fmt in formats: filepath = output_dir / f"{prefix}_{plot_name}.{fmt}" fig.savefig(filepath, dpi=dpi if fmt != "pdf" else None) results[f"{plot_name}_{fmt}"] = str(filepath) else: results[plot_name] = fig except Exception as e: import logging logging.getLogger(__name__).warning( f"Failed to generate {plot_name} plot: {e}" ) return results
[docs] def plot_comparison( nlsq_result: NLSQResult, bayesian_result: FitResult, model, x_data, y_data, x_pred=None, confidence_level=0.95, band_alpha=0.25, ax=None, ) -> plt.Axes: """Overlay NLSQ confidence band and Bayesian credible interval (FR-020). Parameters ---------- nlsq_result : NLSQResult NLSQ fitting result bayesian_result : FitResult Bayesian fitting result model : callable Model function x_data, y_data : ndarray Original data x_pred : ndarray, optional X values for prediction curves confidence_level : float Confidence/credible level (default: 0.95) band_alpha : float Transparency for bands (default: 0.25) ax : matplotlib.axes.Axes, optional Returns ------- ax : matplotlib.axes.Axes """ import matplotlib.pyplot as plt import numpy as np if ax is None: fig, ax = plt.subplots() x_data = np.asarray(x_data) y_data = np.asarray(y_data) if x_pred is None: x_pred = np.linspace(x_data.min(), x_data.max(), 200) # Plot data ax.scatter(x_data, y_data, c="k", s=20, alpha=0.7, label="Data", zorder=5) # NLSQ fit popt = np.array(list(nlsq_result.params.values())) y_nlsq = model(x_pred, *popt) if nlsq_result.pcov_valid: _, y_nlsq_lower, y_nlsq_upper = compute_uncertainty_band( model, x_pred, popt, nlsq_result.covariance, confidence_level ) ax.fill_between( x_pred, y_nlsq_lower, y_nlsq_upper, alpha=band_alpha, color="C0", label=f"NLSQ {int(confidence_level * 100)}% CI", ) ax.plot(x_pred, y_nlsq, "C0-", lw=2, label="NLSQ fit") # Bayesian fit — use canonical param_names to avoid positional arg mismatch. param_names = bayesian_result.param_names or list(bayesian_result.samples.keys()) n_samples = len(bayesian_result.samples[param_names[0]]) indices = np.random.choice(n_samples, min(100, n_samples), replace=False) predictions = [] for idx in indices: params = [bayesian_result.samples[name][idx] for name in param_names] predictions.append(model(x_pred, *params)) predictions_arr = np.array(predictions) alpha = 1 - confidence_level lower = np.percentile(predictions_arr, 100 * alpha / 2, axis=0) upper = np.percentile(predictions_arr, 100 * (1 - alpha / 2), axis=0) median = np.median(predictions_arr, axis=0) ax.fill_between( x_pred, lower, upper, alpha=band_alpha, color="C1", label=f"Bayesian {int(confidence_level * 100)}% CI", ) ax.plot(x_pred, median, "C1--", lw=2, label="Bayesian median") ax.set_xlabel("x") ax.set_ylabel("y") ax.legend() return ax
[docs] def plot_diagnostics( result: NLSQResult, model, x_data, y_data, figsize=(10, 8), ) -> Figure: """Create 2x2 diagnostic plot for NLSQ fit (T081-T085). Layout: - Top-left: Residuals plot (vs x) - Top-right: Parameter confidence intervals - Bottom-left: Diagnostic issues / health score - Bottom-right: Summary metrics (R², adj_R², RMSE, AIC, BIC) Parameters ---------- result : NLSQResult Output from NLSQ fitting with native_result model : callable Model function x_data, y_data : ndarray Original data figsize : tuple Figure size (default: (10, 8)) Returns ------- fig : matplotlib.figure.Figure """ import matplotlib.pyplot as plt import numpy as np x_data = np.asarray(x_data) y_data = np.asarray(y_data) fig, axes = plt.subplots(2, 2, figsize=figsize) ax_residuals, ax_ci, ax_health, ax_metrics = axes.flatten() # Get parameters popt = np.array(list(result.params.values())) param_names = list(result.params.keys()) # Top-left: Residuals plot (T082) y_pred = model(x_data, *popt) residuals = y_data - y_pred ax_residuals.scatter(x_data, residuals, c="C0", alpha=0.6, s=20) ax_residuals.axhline(0, color="k", linestyle="--", alpha=0.5) ax_residuals.set_xlabel("x") ax_residuals.set_ylabel("Residuals") ax_residuals.set_title("Residuals vs X") # Top-right: Parameter confidence intervals (T083) if result.pcov_valid and result.covariance is not None: perr = np.sqrt(np.diag(result.covariance)) y_pos = np.arange(len(param_names)) ax_ci.barh(y_pos, 2 * perr, left=popt - perr, height=0.4, color="C0", alpha=0.6) ax_ci.scatter(popt, y_pos, c="C0", s=50, zorder=5) ax_ci.set_yticks(y_pos) ax_ci.set_yticklabels(param_names) ax_ci.set_xlabel("Parameter Value ± 1σ") ax_ci.set_title("Parameter Estimates with 68% CI") else: ax_ci.text( 0.5, 0.5, "Covariance invalid\n(no CI available)", ha="center", va="center", transform=ax_ci.transAxes, ) ax_ci.set_title("Parameter Confidence Intervals") # Bottom-left: Diagnostic issues / health score (T084) health_info = [] health_score = getattr(result, "health_score", None) is_healthy = getattr(result, "is_healthy", None) if health_score is not None: health_info.append(f"Health Score: {health_score:.2f}") if is_healthy is not None: health_info.append(f"Status: {'Healthy' if is_healthy else 'Issues detected'}") if result.converged: health_info.append("✓ Converged") else: health_info.append("✗ Did not converge") if result.pcov_valid: health_info.append("✓ Covariance valid") else: health_info.append(f"✗ {result.pcov_message}") cond_num = getattr(result, "condition_number", None) if cond_num is not None: health_info.append(f"Condition Number: {cond_num:.1f}") ax_health.text( 0.1, 0.8, "\n".join(health_info), transform=ax_health.transAxes, fontsize=11, verticalalignment="top", family="monospace", ) ax_health.set_xlim(0, 1) ax_health.set_ylim(0, 1) ax_health.axis("off") ax_health.set_title("Fit Diagnostics") # Bottom-right: Summary metrics (T085) metrics_lines = [ f"R²: {result.r_squared:.6f}", f"Adj R²: {result.adj_r_squared:.6f}", f"RMSE: {result.rmse:.2e}", f"MAE: {result.mae:.2e}", f"χ²ᵣ: {result.chi_squared:.4f}", f"AIC: {result.aic:.2f}", f"BIC: {result.bic:.2f}", ] ax_metrics.text( 0.1, 0.8, "\n".join(metrics_lines), transform=ax_metrics.transAxes, fontsize=11, verticalalignment="top", family="monospace", ) ax_metrics.set_xlim(0, 1) ax_metrics.set_ylim(0, 1) ax_metrics.axis("off") ax_metrics.set_title("Fit Metrics") fig.tight_layout() return fig
[docs] def save_figure(fig, filepath, formats=("pdf", "png"), dpi=300) -> dict: """Save figure in multiple formats (FR-018). Parameters ---------- fig : matplotlib.figure.Figure filepath : str or Path Base path (extension will be replaced) formats : tuple Output formats (default: ("pdf", "png")) dpi : int Resolution for raster formats Returns ------- dict Mapping format → saved file path """ from pathlib import Path filepath = Path(filepath) results = {} for fmt in formats: output_path = filepath.with_suffix(f".{fmt}") fig.savefig(output_path, dpi=dpi if fmt != "pdf" else None) results[fmt] = str(output_path) return results