Source code for xpcsviewer.module.g2mod

"""G2 correlation analysis module.

Provides multi-tau correlation analysis with single and double exponential
fitting for XPCS time correlation functions.

Functions:
    get_data: Extract G2 correlation data from XpcsFile objects
    fit_g2: Fit G2 data with exponential models
"""

# Standard library imports
import logging

# Third-party imports
import numpy as np
import pyqtgraph as pg

from xpcsviewer.backends._conversions import ensure_numpy
from xpcsviewer.plothandler.plot_constants import MATPLOTLIB_COLORS_RGB as colors

# Local imports
from xpcsviewer.utils.logging_config import get_logger

pg.setConfigOption("foreground", pg.mkColor(80, 80, 80))
# pg.setConfigOption("background", 'w')
logger = get_logger(__name__)


def _log_array_dims(prefix: str, arr) -> None:
    """Log array shape/dtype for data flow tracing."""
    if logger.isEnabledFor(logging.DEBUG):
        if arr is None:
            logger.debug(f"{prefix}: None")
        elif hasattr(arr, "shape"):
            logger.debug(f"{prefix}: shape={arr.shape}, dtype={arr.dtype}")
        elif isinstance(arr, (list, tuple)):
            logger.debug(f"{prefix}: list of {len(arr)} items")


# https://www.geeksforgeeks.org/pyqtgraph-symbols/
symbols = ["o", "t", "t1", "t2", "t3", "s", "p", "h", "star", "+", "d", "x"]


[docs] def get_data(xf_list, q_range=None, t_range=None): """Extract G2 correlation data from a list of XpcsFile objects. Reads G2 correlation values, errors, delay times, and Q-values from each file, optionally filtered by Q-range and time-range. Args: xf_list: List of XpcsFile objects with correlation analysis data. q_range: Optional (min, max) Q-value filter in inverse Angstroms. t_range: Optional (min, max) delay time filter in seconds. Returns: If all files contain correlation data, returns a 5-tuple: ``(q, tel, g2, g2_err, labels)`` where each element is a list of per-file arrays. Returns ``(False, None, None, None, None)`` if any file lacks Multitau or Twotime analysis. Example: >>> q, tel, g2, g2_err, labels = get_data(xf_list, q_range=(0.01, 0.1)) >>> if q is not False: ... print(f"Loaded {len(q)} files, {g2[0].shape[1]} Q-bins") """ logger.debug( f"get_data: entry with {len(xf_list)} files, q_range={q_range}, t_range={t_range}" ) # Early validation - check all files have correlation analysis (Multitau or Twotime) analysis_types = [xf.atype for xf in xf_list] if not all( any(atype_part in ["Multitau", "Twotime"] for atype_part in atype) for atype in analysis_types ): logger.debug("get_data: exit early - no correlation analysis") return False, None, None, None, None # Pre-allocate lists with known size for better memory efficiency num_files = len(xf_list) q = [None] * num_files tel = [None] * num_files g2 = [None] * num_files g2_err = [None] * num_files labels = [None] * num_files # Process all files - can potentially be parallelized for i, fc in enumerate(xf_list): _q, _tel, _g2, _g2_err, _labels = fc.get_g2_data(qrange=q_range, trange=t_range) q[i] = _q tel[i] = _tel g2[i] = _g2 g2_err[i] = _g2_err labels[i] = _labels _log_array_dims("get_data g2", g2[0] if g2 else None) logger.debug(f"get_data: exit with {num_files} datasets") return q, tel, g2, g2_err, labels
[docs] def get_g2_stability_data(xf_obj, q_range=None, t_range=None): """ Extract G2 stability data from a single XpcsFile object. Parameters ---------- xf_obj : XpcsFile Single XpcsFile object with g2_partial data q_range : tuple or None Q-range filter (qmin, qmax) t_range : tuple or None Time range filter (tmin, tmax) Returns ------- tuple (q, tel, g2, g2_err, qbin_labels, labels) """ if "Multitau" not in xf_obj.atype: return False, None, None, None, None, None q, tel, g2, g2_err, qbin_labels, labels = xf_obj.get_g2_stability_data( qrange=q_range, trange=t_range ) return q, tel, g2, g2_err, qbin_labels, labels
[docs] def pg_plot_stability( hdl, xf_obj, q_range, t_range, y_range, y_auto=False, q_auto=False, t_auto=False, num_col=4, offset=0, show_label=False, plot_type="multiple", marker_size=5, **kwargs, ): """ Plot G2 stability data showing frame-by-frame correlation analysis. Parameters ---------- hdl : GraphicsLayoutWidget PyQtGraph layout widget for plotting xf_obj : XpcsFile XpcsFile object with g2_partial data q_range : tuple Q-range for filtering t_range : tuple Time range for filtering y_range : tuple Y-axis range y_auto, q_auto, t_auto : bool Auto-range flags num_col : int Number of columns in layout offset : float Y-offset between curves show_label : bool Show legend labels plot_type : str Plot type: 'multiple', 'single', 'single-combined' marker_size : int Symbol size **kwargs : dict Additional plotting parameters """ if q_auto: q_range = None if t_auto: t_range = None if y_auto: y_range = None q, tel, g2, g2_err, qbin_labels, labels = get_g2_stability_data( xf_obj, q_range=q_range, t_range=t_range ) # Handle case where data is not available if g2 is False or g2 is None: logger.warning("G2 stability data not available for this file") return num_figs, num_lines = compute_geometry(g2, plot_type) num_data, num_qval = len(g2), g2[0].shape[1] # col and row for the 2d layout col = min(num_figs, num_col) row = (num_figs + col - 1) // col # Frame indices for color/symbol cycling frame_indices = np.arange(num_data) hdl.adjust_canvas_size(num_col=col, num_row=row) hdl.clear() # Handle log scale for time range if t_range: with np.errstate(divide="ignore", invalid="ignore"): t_range_positive = np.asarray(t_range) t_range_positive = np.where( t_range_positive > 0, t_range_positive, np.finfo(float).eps ) t0_range = np.log10(t_range_positive) axes = [] for n in range(num_figs): i_col = n % col i_row = n // col t = hdl.addPlot(row=i_row, col=i_col) axes.append(t) if show_label: legend = t.addLegend(labelTextSize="6pt") legend.anchor(itemPos=(1, 0), parentPos=(1, 0), offset=(0, 0)) t.setMouseEnabled(x=False, y=y_auto) # Set axis labels once during setup (hoisted out of the m×n inner loop) t.setLabel("bottom", "tau (s)") t.setLabel("left", "g2") color_len = len(colors) symbol_len = len(symbols) for m in range(num_data): # default base line to be 1.0; used for non-fitting or fit error cases baseline_offset = np.ones(num_qval) color = colors[frame_indices[m] % color_len] symbol = symbols[frame_indices[m] % symbol_len] for n in range(num_qval): label = None if plot_type == "multiple": ax = axes[n] label = f"frame={int(labels[m])}" if m == 0: ax.setTitle(qbin_labels[n]) elif plot_type == "single": ax = axes[m] # overwrite color; use the same color for the same set; color = colors[n % color_len] if m == 0 or n == 0: ax.setTitle(str(labels[m])) elif plot_type == "single-combined": ax = axes[0] label = str(labels[m]) + qbin_labels[n] x = tel # normalize baseline y = g2[m][:, n] - baseline_offset[n] + 1.0 + m * offset y_err = g2_err[m][:, n] pg_plot_one_g2( ax, x, y, y_err, color, label=label, symbol=symbol, symbol_size=marker_size, ) if not y_auto and y_range is not None: ax.setRange(yRange=y_range) if not t_auto and t_range is not None: ax.setRange(xRange=t0_range) return
[docs] def compute_geometry(g2, plot_type): """Compute subplot grid dimensions for G2 plots. Determines how many figures (subplots) and how many overlay lines per figure are needed based on the layout mode. Args: g2: List of 2-D G2 arrays, each shaped ``(n_delay, n_q)``. plot_type: Layout mode. ``"multiple"`` creates one subplot per Q-bin, ``"single"`` one per file, and ``"single-combined"`` puts everything on one axes. Returns: tuple[int, int]: ``(num_figs, num_lines)`` where *num_figs* is the number of subplot panels and *num_lines* is the number of curves per panel. Raises: ValueError: If *plot_type* is not a recognised layout mode. """ if plot_type == "multiple": num_figs = g2[0].shape[1] num_lines = len(g2) elif plot_type == "single": num_figs = len(g2) num_lines = g2[0].shape[1] elif plot_type == "single-combined": num_figs = 1 num_lines = g2[0].shape[1] * len(g2) else: raise ValueError("plot_type not support.") return num_figs, num_lines
[docs] def pg_plot( hdl, xf_list, q_range, t_range, y_range, y_auto=False, q_auto=False, t_auto=False, num_col=4, rows=None, offset=0, show_fit=False, show_label=False, bounds=None, fit_flag=None, plot_type="multiple", subtract_baseline=True, marker_size=5, label_size=4, fit_func="single", robust_fitting=False, enable_diagnostics=False, **kwargs, ): """Plot G2 correlation data using PyQtGraph with optional curve fitting. Renders multi-panel G2 plots with configurable layout, fitting overlays, and baseline subtraction. Supports single, double, and stretched exponential fitting models. Args: hdl: PyQtGraph plot handler (GraphicsLayoutWidget). xf_list: List of XpcsFile objects containing G2 data. q_range: (min, max) Q-value range in inverse Angstroms, or None. t_range: (min, max) delay time range in seconds, or None. y_range: (min, max) y-axis range for G2 values, or None. y_auto: If True, auto-scale y-axis. q_auto: If True, ignore q_range and use all Q-bins. t_auto: If True, ignore t_range and use all delay times. num_col: Number of plot columns in the grid layout. rows: List of file indices to plot, or None for all. offset: Vertical offset between datasets for visibility. show_fit: If True, overlay fitted curves on the data. show_label: If True, display legend labels. bounds: Fitting bounds array, shape ``(2, n_params)``. fit_flag: Boolean array indicating which parameters to fit. plot_type: Layout mode: ``'multiple'`` (one panel per Q-bin), ``'single'`` (one panel per file), or ``'single-combined'``. subtract_baseline: If True, subtract fitted baseline from data. marker_size: Size of data point markers in pixels. label_size: Font size for legend labels in points. fit_func: Fitting model: ``'single'`` or ``'double'`` exponential. robust_fitting: If True, use NLSQ multistart robust fitting. enable_diagnostics: If True, compute model health diagnostics. **kwargs: Additional keyword arguments passed to fitting routines. Includes ``force_refit`` (bool) to force re-fitting. Example: >>> pg_plot(hdl, xf_list, q_range=(0.01, 0.1), ... t_range=(1e-4, 10), y_range=(0.9, 1.5), ... show_fit=True, fit_func='single') """ if q_auto: q_range = None if t_auto: t_range = None if y_auto: y_range = None _q, tel, g2, g2_err, labels = get_data(xf_list, q_range=q_range, t_range=t_range) num_figs, _num_lines = compute_geometry(g2, plot_type) num_data, num_qval = len(g2), g2[0].shape[1] # col and rows for the 2d layout col = min(num_figs, num_col) row = (num_figs + col - 1) // col if rows is None or len(rows) == 0: rows = list(range(len(xf_list))) hdl.adjust_canvas_size(num_col=col, num_row=row) hdl.clear() # a bug in pyqtgraph; the log scale in x-axis doesn't apply if t_range: # Handle log10 of zero or negative values with np.errstate(divide="ignore", invalid="ignore"): # Only take log10 of positive values t_range_positive = np.asarray(t_range) t_range_positive = np.where( t_range_positive > 0, t_range_positive, np.finfo(float).eps ) t0_range = np.log10(t_range_positive) axes = [] for n in range(num_figs): i_col = n % col i_row = n // col t = hdl.addPlot(row=i_row, col=i_col) axes.append(t) if show_label: t.addLegend(offset=(-1, 1), labelTextSize="9pt", verSpacing=-10) t.setMouseEnabled(x=False, y=y_auto) # Set axis labels once during setup (hoisted out of the m×n inner loop) t.setLabel("bottom", "tau (s)") t.setLabel("left", "g2") color_len = len(colors) symbol_len = len(symbols) for m in range(num_data): # default base line to be 1.0; used for non-fitting or fit error cases baseline_offset = np.ones(num_qval) fit_summary = None if show_fit: # Extract force_refit parameter from kwargs force_refit = kwargs.get("force_refit", False) if robust_fitting: # Use robust fitting with diagnostics fit_summary = xf_list[m].fit_g2_robust( q_range, t_range, bounds, fit_flag, fit_func, enable_diagnostics=enable_diagnostics, force_refit=force_refit, **{k: v for k, v in kwargs.items() if k != "force_refit"}, ) else: # Use traditional fitting fit_summary = xf_list[m].fit_g2( q_range, t_range, bounds, fit_flag, fit_func, force_refit=force_refit, ) if fit_summary is not None and subtract_baseline: # Check if fitting was successful by validating fit_val if ( fit_summary["fit_val"] is not None and len(fit_summary["fit_val"]) > 0 ): # Extract baseline: model is a*exp(-2x/b) + c + d, so baseline = c + d (params 2 and 3) try: baseline_offset = ( fit_summary["fit_val"][:, 0, 2] + fit_summary["fit_val"][:, 0, 3] ) except (IndexError, TypeError): # Fallback to default baseline if shape doesn't match baseline_offset = np.ones(num_qval) color = colors[rows[m] % color_len] symbol = symbols[rows[m] % symbol_len] for n in range(num_qval): label = None if plot_type == "multiple": ax = axes[n] label = xf_list[m].label if m == 0: ax.setTitle(labels[m][n]) elif plot_type == "single": ax = axes[m] # overwrite color; use the same color for the same set; color = colors[n % color_len] if n == 0: ax.setTitle(xf_list[m].label) elif plot_type == "single-combined": ax = axes[0] label = xf_list[m].label + labels[m][n] x = tel[m] # normalize baseline y = g2[m][:, n] - baseline_offset[n] + 1.0 + m * offset y_err = g2_err[m][:, n] pg_plot_one_g2( ax, x, y, y_err, color, label=label, symbol=symbol, symbol_size=marker_size, ) # if t_range is not None: if not y_auto: ax.setRange(yRange=y_range) if not t_auto: ax.setRange(xRange=t0_range) if show_fit and fit_summary is not None: # Check if we have valid fit_line data for this q-index if ( fit_summary["fit_line"] is not None and n < fit_summary["fit_line"].shape[0] and fit_summary.get("fit_x") is not None ): # Get fitted y-values from the numpy array y_fit = fit_summary["fit_line"][n] + m * offset # normalize baseline y_fit = y_fit - baseline_offset[n] + 1.0 # Use the correct x-values that were used for fitting fit_x = fit_summary["fit_x"] ax.plot( fit_x, y_fit, pen=pg.mkPen(color, width=2.5), ) # Add fitted parameter text annotation if ( fit_summary.get("fit_val") is not None and n < fit_summary["fit_val"].shape[0] ): _add_fit_param_annotation( ax, fit_summary["fit_val"][n], fit_func, color )
def _add_fit_param_annotation(ax, fit_val, fit_func, color): """Add fitted parameter text annotation to plot. Args: ax: PyQtGraph plot axis fit_val: Fitted values array [2, n_params] - values and errors fit_func: Fitting function type ('single' or 'double') color: Color for the text """ # Format parameters based on fit function if fit_func == "single": param_names = ["τ", "bkg", "cts", "d"] else: # double exponential param_names = ["τ1", "bkg", "cts1", "τ2", "cts2"] # Create parameter text param_text_lines = [] for p_idx in range(min(len(param_names), fit_val.shape[1])): value = fit_val[0, p_idx] # fitted value error = fit_val[1, p_idx] # error estimate # Format error gracefully if np.isfinite(error) and error > 0: if param_names[p_idx] == "τ" or param_names[p_idx] == "τ1": param_text_lines.append( f"{param_names[p_idx]} = {value:.3e} ± {error:.2e}" ) else: param_text_lines.append( f"{param_names[p_idx]} = {value:.3f} ± {error:.3f}" ) elif param_names[p_idx] == "τ" or param_names[p_idx] == "τ1": param_text_lines.append(f"{param_names[p_idx]} = {value:.3e} ± --") else: param_text_lines.append(f"{param_names[p_idx]} = {value:.3f} ± --") param_text = "\n".join(param_text_lines) # Add text item to plot text_item = pg.TextItem(param_text, anchor=(0, 1), color=color) # Position text in data coordinates (top-left of visible area) viewbox = ax.getViewBox() if viewbox is not None: # Get current view range [[xmin, xmax], [ymin, ymax]] = viewbox.viewRange() # Position at 5% from left edge, 95% from bottom (top area) text_x = xmin + 0.05 * (xmax - xmin) text_y = ymin + 0.95 * (ymax - ymin) text_item.setPos(text_x, text_y) else: # Fallback position text_item.setPos(-5, 1.5) ax.addItem(text_item)
[docs] def pg_plot_from_data( hdl, *, q, tel, g2, g2_err, labels, num_figs, fit_results=None, y_auto=False, q_auto=False, t_auto=False, num_col=4, rows=None, offset=0, show_fit=False, show_label=False, y_range=None, t_range=None, plot_type="multiple", subtract_baseline=True, marker_size=5, fit_func="single", **_ignored_kwargs, ): """Render pre-fetched G2 data without re-fetching from XpcsFile objects. This is the rendering-only counterpart of ``pg_plot``. It accepts the data structures already extracted by the async worker (``q``, ``tel``, ``g2``, ``g2_err``, ``labels``) and renders them directly, avoiding the redundant ``get_data()`` and ``get_xf_list()`` calls that ``vk.plot_g2`` would otherwise trigger on the main thread (BUG-014). Parameters mirror the ``pg_plot`` signature where applicable. """ if not q or g2 is None or len(g2) == 0: return num_data = len(g2) num_qval = g2[0].shape[1] if g2[0].ndim > 1 else 1 col = min(num_figs, num_col) row = (num_figs + col - 1) // col if rows is None or len(rows) == 0: rows = list(range(num_data)) hdl.adjust_canvas_size(num_col=col, num_row=row) hdl.clear() t0_range = None if t_range and not t_auto: with np.errstate(divide="ignore", invalid="ignore"): t_range_arr = np.asarray(t_range) t_range_arr = np.where(t_range_arr > 0, t_range_arr, np.finfo(float).eps) t0_range = np.log10(t_range_arr) axes = [] for n in range(num_figs): i_col = n % col i_row = n // col t = hdl.addPlot(row=i_row, col=i_col) axes.append(t) if show_label: t.addLegend(offset=(-1, 1), labelTextSize="9pt", verSpacing=-10) t.setMouseEnabled(x=False, y=y_auto) # Set axis labels once during setup (hoisted out of the m×n inner loop) t.setLabel("bottom", "tau (s)") t.setLabel("left", "g2") color_len = len(colors) symbol_len = len(symbols) for m in range(num_data): baseline_offset = np.ones(num_qval) # Use pre-computed fit_results from worker if available fit_summary = None if show_fit and fit_results is not None and m < len(fit_results): fit_summary = fit_results[m] if fit_summary is not None and subtract_baseline: try: if ( fit_summary.get("fit_val") is not None and len(fit_summary["fit_val"]) > 0 ): baseline_offset = ( fit_summary["fit_val"][:, 0, 2] + fit_summary["fit_val"][:, 0, 3] ) except (IndexError, TypeError, KeyError): baseline_offset = np.ones(num_qval) color = colors[rows[m] % color_len] symbol = symbols[rows[m] % symbol_len] for n in range(num_qval): label = None if plot_type == "multiple": ax = axes[n] if m == 0: ax.setTitle(labels[m][n] if labels and m < len(labels) else "") elif plot_type == "single": ax = axes[m] color = colors[n % color_len] if n == 0: ax.setTitle(labels[m][0] if labels and m < len(labels) else "") elif plot_type == "single-combined": ax = axes[0] label = labels[m][n] if labels and m < len(labels) else "" else: ax = axes[n % len(axes)] x = tel[m] y = g2[m][:, n] - baseline_offset[n] + 1.0 + m * offset y_err = g2_err[m][:, n] pg_plot_one_g2( ax, x, y, y_err, color, label=label, symbol=symbol, symbol_size=marker_size, ) if not y_auto and y_range is not None: ax.setRange(yRange=y_range) if not t_auto and t0_range is not None: ax.setRange(xRange=t0_range) if show_fit and fit_summary is not None: try: if ( fit_summary.get("fit_line") is not None and n < fit_summary["fit_line"].shape[0] and fit_summary.get("fit_x") is not None ): y_fit = fit_summary["fit_line"][n] + m * offset y_fit = y_fit - baseline_offset[n] + 1.0 fit_x = fit_summary["fit_x"] ax.plot(fit_x, y_fit, pen=pg.mkPen(color, width=2.5)) if ( fit_summary.get("fit_val") is not None and n < fit_summary["fit_val"].shape[0] ): _add_fit_param_annotation( ax, fit_summary["fit_val"][n], fit_func, color ) except (IndexError, KeyError, TypeError): pass
[docs] def pg_plot_one_g2(ax, x, y, dy, color, label, symbol, symbol_size=5): """Plot a single G2 correlation curve with error bars on a log-x axis. Filters NaN/inf values, applies log-scale x-axis, and downsamples dense data for rendering performance. Args: ax: PyQtGraph PlotItem to draw on. x: Delay times array (must be positive for log scale). y: G2 correlation values. dy: G2 error bars (standard deviation). color: RGB tuple for plot color, e.g. ``(255, 0, 0)``. label: Legend label string, or None. symbol: PyQtGraph symbol character (e.g. ``'o'``, ``'s'``). symbol_size: Marker size in pixels (default 5). """ # Ensure NumPy arrays at PyQtGraph I/O boundary (JAX arrays not supported) from xpcsviewer.backends._conversions import ensure_numpy x = ensure_numpy(x) y = ensure_numpy(y) dy = ensure_numpy(dy) # Validate input data if len(x) == 0 or len(y) == 0: return # Filter out invalid data points (NaN, inf, non-positive x for log scale) valid_mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(dy) & (x > 0) if not np.any(valid_mask): return # Skip if no valid data x_clean = x[valid_mask] y_clean = y[valid_mask] dy_clean = dy[valid_mask] # Optimize pen creation pen_line = pg.mkPen(color=color, width=2) pen_symbol = pg.mkPen(color=color, width=1) # Create error bars more efficiently try: log_x = np.log10(x_clean) line = pg.ErrorBarItem( x=log_x, y=y_clean, top=dy_clean, bottom=dy_clean, pen=pen_line ) except (ValueError, RuntimeWarning): # Handle edge cases in logarithm calculation return # Downsample data if too many points for better performance if len(x_clean) > 500: step = len(x_clean) // 250 x_plot = x_clean[::step] y_plot = y_clean[::step] else: x_plot = x_clean y_plot = y_clean # Plot symbols with optimized parameters ax.plot( x_plot, y_plot, pen=None, symbol=symbol, name=label, symbolSize=symbol_size, symbolPen=pen_symbol, symbolBrush=pg.mkBrush(color=(*color, 0)), ) ax.setLogMode(x=True, y=None) ax.addItem(line) return
[docs] def vectorized_g2_baseline_correction(g2_data, baseline_values): """ Vectorized baseline correction for G2 data. Args: g2_data: G2 data array [time, q_values] baseline_values: Baseline values [q_values] Returns: Baseline-corrected G2 data """ # Ensure host-resident NumPy arrays before raw np.* operations. g2_data = ensure_numpy(g2_data) baseline_values = ensure_numpy(baseline_values) # Broadcast baseline subtraction across all time points return g2_data - baseline_values[np.newaxis, :] + 1.0
[docs] def batch_g2_normalization(g2_data_list, method="max"): """ Batch normalization of multiple G2 datasets using vectorized operations. Args: g2_data_list: List of G2 data arrays method: Normalization method ('max', 'mean', 'std') Returns: List of normalized G2 data arrays """ if not g2_data_list: return [] # Ensure host-resident arrays; np.max/mean/std don't accept JAX arrays. arrays = [ensure_numpy(a) for a in g2_data_list] # Vectorized fast path for 'max' normalization with small-to-medium batch sizes # (B ≤ 100). For very large B with small T×Q the stack allocation itself becomes # the bottleneck, so fall back to the per-item loop in that case. # For 'std' and 'mean', the per-item loop is retained as it avoids materializing # a potentially large 3-D stack when the savings are smaller. n = len(arrays) use_vectorized = n <= 100 and method == "max" if use_vectorized: # Stack → single kernel for max + divide. The stack is unavoidable but # eliminates 2*n separate NumPy kernel dispatches (one max + one divide per item). g2_stack = np.stack(arrays, axis=0) # [B, T, Q] scale = np.max(g2_stack, axis=1, keepdims=True) # [B, 1, Q] scale = np.where(scale == 0, 1.0, scale) result = g2_stack / scale # [B, T, Q] # Return list of views (no extra copy) for backward compatibility. return list(result) # Per-item path: handles 'mean', 'std', very large B, or non-uniform shapes. normalized_data = [] for g2_data in arrays: if method == "max": max_vals = np.max(g2_data, axis=0, keepdims=True) max_vals = np.where(max_vals == 0, 1.0, max_vals) normalized = g2_data / max_vals elif method == "mean": mean_vals = np.mean(g2_data, axis=0, keepdims=True) mean_vals = np.where(mean_vals == 0, 1.0, mean_vals) normalized = g2_data / mean_vals elif method == "std": mean_vals = np.mean(g2_data, axis=0, keepdims=True) std_vals = np.std(g2_data, axis=0, keepdims=True) std_vals = np.where(std_vals == 0, 1.0, std_vals) normalized = (g2_data - mean_vals) / std_vals else: normalized = g2_data normalized_data.append(normalized) return normalized_data
[docs] def compute_g2_ensemble_statistics(g2_data_list, include_median: bool = False): """ Compute ensemble statistics for multiple G2 datasets using vectorized operations. Args: g2_data_list: List of G2 data arrays [time, q_values] include_median: If True, compute and include `ensemble_median` in the result. This adds an O(B·T·Q·log B) partial-sort cost on top of the O(B·T·Q) mean/std computation. Defaults to False for performance. Set True only when the caller genuinely needs the median statistic. Returns: Dictionary with ensemble statistics. Always contains: ensemble_mean, ensemble_std, ensemble_min, ensemble_max, ensemble_var, q_mean_values, temporal_correlation. Optionally contains: ensemble_median (when include_median=True). """ # Ensure host-resident NumPy arrays before stacking / raw np.* statistics. g2_data_list = [ensure_numpy(arr) for arr in g2_data_list] # Stack all data for vectorized operations g2_stack = np.stack(g2_data_list, axis=0) # [batch, time, q_values] # Vectorized statistical computations — O(B·T·Q) operations only. # np.median is intentionally excluded from the default path: it requires a # full O(B·T·Q·log B) partial sort and accounts for 85% of this function's # runtime (see tests/reports/baseline_profile.md). stats = { "ensemble_mean": np.mean(g2_stack, axis=0), "ensemble_std": np.std(g2_stack, axis=0), "ensemble_min": np.min(g2_stack, axis=0), "ensemble_max": np.max(g2_stack, axis=0), "ensemble_var": np.var(g2_stack, axis=0), "q_mean_values": np.mean(g2_stack, axis=(0, 1)), # Mean across time and batch } if include_median: stats["ensemble_median"] = np.median(g2_stack, axis=0) # Batched temporal correlation: transpose to (q, batch, time) and compute # correlation matrices for all q-values without a Python loop. # Each q-slice is (batch, time); corrcoef produces (batch, batch). q_transposed = np.transpose(g2_stack, (2, 0, 1)) # [q, batch, time] # Vectorized batched corrcoef: center, normalize, matmul # mean across time axis (last) q_mean = np.mean(q_transposed, axis=2, keepdims=True) # [q, batch, 1] q_centered = q_transposed - q_mean # [q, batch, time] # Standard deviation per (q, batch) q_std = np.sqrt(np.sum(q_centered**2, axis=2, keepdims=True)) # [q, batch, 1] # Avoid division by zero q_std = np.where(q_std == 0, 1.0, q_std) q_normed = q_centered / q_std # [q, batch, time] # Batched correlation: (q, batch, time) @ (q, time, batch) -> (q, batch, batch) n_time = q_transposed.shape[2] corr_batch = np.matmul(q_normed, np.transpose(q_normed, (0, 2, 1))) / n_time # Return as 3-D ndarray for efficiency; callers that need per-q slices # can index directly (corr_batch[q]) without rebuilding a Python list. stats["temporal_correlation"] = corr_batch # shape [num_q, batch, batch] return stats
[docs] def optimize_g2_error_propagation(g2_data, g2_errors, operations): """ Vectorized error propagation for G2 data operations. Args: g2_data: G2 data array [time, q_values] g2_errors: G2 error array [time, q_values] operations: List of operations applied to data Returns: Propagated errors """ # Ensure host-resident NumPy arrays; np.abs/power don't accept JAX arrays # without triggering an untraced host-device transfer. g2_data = ensure_numpy(g2_data) g2_errors = ensure_numpy(g2_errors) propagated_errors = g2_errors.copy() for op in operations: if op["type"] == "scale": # Error propagation for scaling: sigma_new = |scale| * sigma_old scale_factor = op["factor"] propagated_errors *= np.abs(scale_factor) elif op["type"] == "offset": # Error propagation for offset: sigma_new = sigma_old (additive operations don't change uncertainty) pass elif op["type"] == "power": # Error propagation for power: sigma_new = |n * x^(n-1)| * sigma_old power = op["power"] propagated_errors = ( np.abs(power * np.power(g2_data, power - 1)) * propagated_errors ) elif op["type"] == "log": # Error propagation for logarithm: sigma_new = sigma_old / |x| propagated_errors = propagated_errors / np.abs(g2_data) return propagated_errors
[docs] def vectorized_g2_interpolation(tel, g2_data, target_tel): """ Vectorized interpolation of G2 data to new time points. Uses JAX vmap + interpax for batch interpolation when available, falling back to a single scipy interp1d call with 2D y otherwise. Args: tel: Original time points g2_data: G2 data [time, q_values] target_tel: Target time points for interpolation Returns: Interpolated G2 data """ try: import interpax import jax import jax.numpy as jnp # Convert to JAX arrays once tel_jax = jnp.asarray(tel) target_jax = jnp.asarray(target_tel) g2_jax = jnp.asarray(g2_data) # [time, q_values] # Define single-column interpolation function def _interp_single_q(y_col): return interpax.interp1d( target_jax, tel_jax, y_col, method="cubic", extrap=True ) # vmap over columns (q-values): in_axes=1, out_axes=1 interpolated_jax = jax.vmap(_interp_single_q, in_axes=1, out_axes=1)(g2_jax) from xpcsviewer.backends._conversions import ensure_numpy return ensure_numpy(interpolated_jax) except ImportError: # NumPy fallback: use project's Interp1d wrapper (avoids direct scipy import) from xpcsviewer.backends.scipy_replacements.interpolate import Interp1d # Interpolate each q-column using the wrapper result = np.empty((len(target_tel), g2_data.shape[1])) for q_idx in range(g2_data.shape[1]): interp_func = Interp1d( tel, g2_data[:, q_idx], kind="cubic", bounds_error=False, fill_value="extrapolate", ) result[:, q_idx] = interp_func(target_tel) return result