"""SAXS 1D analysis module.
Provides radial averaging and intensity profile analysis for small-angle
X-ray scattering data.
Functions:
get_data: Extract 1D SAXS intensity profiles
plot_saxs1d: Generate SAXS 1D plots
"""
# Standard library imports
import logging
# Third-party imports
import numpy as np
import pyqtgraph as pg
from xpcsviewer.backends._conversions import ensure_numpy
from xpcsviewer.utils.logging_config import get_logger
from ..plothandler.plot_constants import get_color_marker
logger = get_logger(__name__)
def _log_array_dims(prefix: str, arr) -> None:
"""Log array shape/dtype for data flow tracing."""
if logger.isEnabledFor(logging.DEBUG):
if arr is None:
logger.debug(f"{prefix}: None")
elif hasattr(arr, "shape"):
logger.debug(f"{prefix}: shape={arr.shape}, dtype={arr.dtype}")
pg.setConfigOption("background", "w")
# Mapping from integer codes to string codes (based on Matplotlib docs)
_MPL_LOC_INT_TO_STR = {
1: "upper right",
2: "upper left",
3: "lower left",
4: "lower right",
5: "right", # Often equivalent to center right in placement
6: "center left",
7: "center right",
8: "lower center",
9: "upper center",
10: "center",
}
[docs]
def get_pyqtgraph_anchor_params(loc, padding=10):
"""
Converts a Matplotlib loc string or code to pyqtgraph anchor parameters.
Calculates the 'itemPos', 'parentPos', and 'offset' needed to position
a pyqtgraph LegendItem similarly to how Matplotlib places legends
using the 'loc' parameter.
Args:
loc (str or int): Matplotlib location code. Accepts standard strings
('upper left', 'center', etc.) or integer codes (0-10).
padding (int): Pixel padding to use for the offset from the anchor point.
Positive values generally push the legend inwards from
the edge/corner. Defaults to 10.
Returns:
dict or None: A dictionary with keys 'itemPos', 'parentPos', and 'offset'
suitable for unpacking into LegendItem.anchor(\\*\\*params),
or None if loc='best' (code 0) as it is not directly
supported by pyqtgraph deterministic anchoring.
Raises:
ValueError: If the loc code or type is invalid.
Example Usage::
plot_item = pg.PlotItem()
legend = plot_item.addLegend()
# ... plot data ...
try:
anchor_params = get_pyqtgraph_anchor_params('lower left', padding=15)
if anchor_params:
legend.anchor(\\*\\*anchor_params)
else:
logger.info("Using default legend position for best location")
except ValueError as e:
logger.warning(f"Error setting legend position: {e}")
"""
if isinstance(loc, int):
if loc in _MPL_LOC_INT_TO_STR:
loc_str = _MPL_LOC_INT_TO_STR[loc]
else:
raise ValueError(f"Invalid Matplotlib integer location code: {loc}")
elif isinstance(loc, str):
loc_str = (
loc.lower().replace(" ", "").replace("_", "")
) # Normalize input string
else:
raise ValueError(f"Invalid loc type: {type(loc)}. Must be str or int.")
# --- Define anchor points and offset multipliers ---
# Map: loc_string -> (itemPos, parentPos, offset_multipliers)
# Offset multipliers (mult_x, mult_y) determine offset direction based on padding
_ANCHOR_MAP = {
# Corners
"upperleft": ((0.0, 0.0), (0.0, 0.0), (1, 1)), # Offset moves down-right
"upperright": ((1.0, 0.0), (1.0, 0.0), (-1, 1)), # Offset moves down-left
"lowerleft": ((0.0, 1.0), (0.0, 1.0), (1, -1)), # Offset moves up-right
"lowerright": ((1.0, 1.0), (1.0, 1.0), (-1, -1)), # Offset moves up-left
# Centers
"center": ((0.5, 0.5), (0.5, 0.5), (0, 0)), # No offset needed usually
"lowercenter": ((0.5, 1.0), (0.5, 1.0), (0, -1)), # Offset moves up
"uppercenter": ((0.5, 0.0), (0.5, 0.0), (0, 1)), # Offset moves down
# Sides (center align on edge)
"centerleft": ((0.0, 0.5), (0.0, 0.5), (1, 0)), # Offset moves right
"centerright": ((1.0, 0.5), (1.0, 0.5), (-1, 0)), # Offset moves left
"right": (
(1.0, 0.5),
(1.0, 0.5),
(-1, 0),
), # Treat 'right' same as 'centerright'
}
if loc_str in _ANCHOR_MAP:
itemPos, parentPos, offset_mult = _ANCHOR_MAP[loc_str]
offset = (padding * offset_mult[0], padding * offset_mult[1])
return {"itemPos": itemPos, "parentPos": parentPos, "offset": offset}
raise ValueError(f"Invalid or unsupported Matplotlib location string: '{loc}'")
[docs]
def offset_intensity(Iq, n, plot_offset=None, yscale=None):
"""Apply a vertical offset to intensity curves for visual separation.
For linear y-scale, subtracts a fraction of the peak intensity per
file index. For log y-scale, divides by a power-of-ten factor.
Args:
Iq: Intensity array. Can be 1-D ``(n_q,)`` or 2-D
``(n_phi, n_q)``.
n: File index (0-based) used to scale the offset.
plot_offset: Offset magnitude. ``None`` or 0 means no offset.
yscale: ``"linear"`` or ``"log"`` determines the offset
strategy.
Returns:
numpy.ndarray: Offset-adjusted intensity array (same shape as
*Iq*).
"""
if plot_offset is None or plot_offset == 0:
return Iq
if yscale == "linear":
# Vectorized linear offset with single max computation
max_Iq = np.max(Iq, axis=-1, keepdims=True) if Iq.ndim > 1 else np.max(Iq)
offset = -plot_offset * n * max_Iq
return Iq + offset
if yscale == "log":
# Vectorized logarithmic offset using optimized power calculation
offset = np.power(10.0, plot_offset * n)
return Iq / offset
return Iq # Return original if no scaling applied
[docs]
def switch_line_builder(hdl, lb_type=None):
hdl.link_line_builder(lb_type)
[docs]
def plot_line_with_marker(
plot_item, x, y, index, label, alpha_val, marker_size=6, log_x=False, log_y=False
):
"""
Vectorized plotting function with advanced data filtering and memory optimization.
"""
color_hex, marker = get_color_marker(index, backend="pyqtgraph")
rgba = (*pg.mkColor(color_hex).getRgb()[:3], int(alpha_val * 255))
# Ensure arrays for consistent processing
x = np.atleast_1d(np.asarray(x))
y = np.atleast_1d(np.asarray(y))
# Handle dimension mismatch
if len(x) != len(y):
logger.warning(
f"Dimension mismatch: x has {len(x)} elements, y has {len(y)} elements"
)
return # Skip plotting if dimensions don't match
# Vectorized data validation - single pass filtering
if log_y:
valid_mask = np.isfinite(x) & np.isfinite(y) & (x > 0) & (y > 0)
else:
valid_mask = np.isfinite(x) & np.isfinite(y) & (x > 0)
# Ensure valid_mask is always a 1D array to handle scalar inputs
valid_mask = np.atleast_1d(valid_mask)
if not np.any(valid_mask):
return # Skip if no valid data
# Apply filtering in single operation
x_clean, y_clean = x[valid_mask], y[valid_mask]
# Intelligent downsampling based on data characteristics
n_points = len(x_clean)
if n_points > 2000:
# Use logarithmic sampling for better feature preservation
if log_x:
# Sample more points in log-space for log-scale data
np.log10(x_clean)
indices = np.linspace(0, n_points - 1, 1000).astype(int)
else:
# Use linear sampling for linear data
indices = np.linspace(0, n_points - 1, 1000).astype(int)
indices = np.unique(indices) # Remove potential duplicates
x_plot = x_clean[indices]
y_plot = y_clean[indices]
else:
x_plot = x_clean
y_plot = y_clean
# Optimized pen creation with caching
pen_line = pg.mkPen(color=rgba, width=1.5)
# Plot line with cleaned data - ensure NumPy at PyQtGraph boundary
plot_item.plot(ensure_numpy(x_plot), ensure_numpy(y_plot), pen=pen_line, name=label)
# Optimized scatter plot for markers
if len(x_plot) <= 1000: # Only show markers for manageable datasets
# Apply log transforms efficiently
if log_x or log_y:
x_scatter = np.log10(x_plot) if log_x else x_plot
y_scatter = np.log10(y_plot) if log_y else y_plot
else:
x_scatter = x_plot
y_scatter = y_plot
# Create scatter plot with optimized parameters - ensure NumPy at PyQtGraph boundary
scatter = pg.ScatterPlotItem(
x=ensure_numpy(x_scatter),
y=ensure_numpy(y_scatter),
symbol=marker,
size=marker_size,
pen=pg.mkPen(color=rgba, width=1),
brush=None, # No fill for better performance
)
plot_item.addItem(scatter)
[docs]
def pg_plot(
xf_list,
pg_hdl,
plot_type=2,
plot_norm=0,
plot_offset=0,
title=None,
rows=None,
qmax=10.0,
qmin=0,
loc="best",
marker_size=3,
sampling=1,
all_phi=False,
absolute_crosssection=False,
subtract_background=False,
bkg_file=None,
weight=1.0,
roi_list=None,
show_roi=True,
show_phi_roi=True,
):
"""Plot one-dimensional SAXS intensity profiles using PyQtGraph.
Renders I(q) curves for one or more XPCS files with support for
log/linear axis scaling, intensity normalization, background
subtraction, and multi-phi overlays.
Args:
xf_list: List of XpcsFile objects containing SAXS 1-D data.
pg_hdl: PyQtGraph PlotWidget handle for rendering.
plot_type: Axis scale encoding. 0 = lin-lin, 1 = log-lin,
2 = lin-log, 3 = log-log. Default is 2 (lin-log).
plot_norm: Normalization mode index. 0 = none, 1 = q^2,
2 = q^4, 3 = I(0).
plot_offset: Vertical offset exponent applied per file for
visual separation. 0 means no offset.
title: Plot title string, or None for no title.
rows: List of file indices to highlight at full opacity.
Remaining files are drawn at alpha 0.35.
qmax: Upper bound of the q-range to display (1/A).
qmin: Lower bound of the q-range to display (1/A).
loc: Legend anchor position string (e.g. ``"best"``,
``"upper right"``).
marker_size: Scatter marker diameter in pixels.
sampling: Down-sampling factor for the q-points. 1 = no
down-sampling.
all_phi: If True, plot all phi-sector curves. Otherwise
plot only the first (azimuthally averaged) curve.
absolute_crosssection: If True, convert intensity to
absolute cross-section units (1/cm).
subtract_background: Enable background subtraction using
*bkg_file*.
bkg_file: XpcsFile used as the background reference when
*subtract_background* is True.
weight: Multiplicative weight applied to the background
before subtraction.
roi_list: Optional list of ROI descriptors to overlay.
show_roi: Show q-ROI indicators on the plot.
show_phi_roi: Show phi-ROI indicators on the plot.
Example:
>>> pg_plot(
... xf_list=[xf],
... pg_hdl=widget,
... plot_type=3,
... qmin=0.001,
... qmax=1.0,
... )
"""
logger.debug(f"pg_plot: entry with {len(xf_list)} files, q_range=[{qmin}, {qmax}]")
pg_hdl.clear()
plot_item = pg_hdl.getPlotItem()
plot_item.setTitle(title)
legend = plot_item.addLegend()
# Handle invalid loc parameter gracefully
try:
anchor_param = get_pyqtgraph_anchor_params(loc, padding=15)
if anchor_param:
legend.anchor(**anchor_param)
except (ValueError, TypeError):
pass # Use default position if loc is invalid
# Vectorized alpha computation with early termination
num_files = len(xf_list)
if rows:
alpha = np.full(num_files, 0.35, dtype=np.float32)
alpha[rows] = 1.0
else:
alpha = np.ones(num_files, dtype=np.float32)
# Optimize background handling
if not subtract_background:
bkg_file = None
# Pre-compute plot parameters using vectorized operations
norm_methods = [None, "q2", "q4", "I0"]
norm_method = norm_methods[plot_norm] if plot_norm < len(norm_methods) else None
log_x = bool(plot_type % 2)
log_y = bool(plot_type // 2)
plot_item.setLogMode(x=log_x, y=log_y)
plot_id = 0
xlabel, ylabel = None, None
# Batch data extraction for parallel processing potential
plot_data = []
for n, fi in enumerate(xf_list):
try:
q, Iq, xlabel, ylabel = fi.get_saxs1d_data(
bkg_xf=bkg_file,
bkg_weight=weight,
qrange=(qmin, qmax),
sampling=sampling,
norm_method=norm_method,
use_absolute_crosssection=absolute_crosssection,
)
# Vectorized intensity offset application
if plot_offset > 0:
Iq = offset_intensity(
Iq, n, plot_offset, yscale="log" if log_y else "linear"
)
plot_data.append((q, Iq, fi, n))
except Exception as e:
logger.warning(f"Failed to extract data for file {fi.label}: {e}")
continue
# Vectorized plotting with optimized loops
for q, Iq, fi, n in plot_data:
num_lines = Iq.shape[0] if all_phi else 1
# Vectorized validity check for all lines at once
valid_lines = (
np.any(np.isfinite(Iq), axis=1)
if Iq.ndim > 1
else [np.any(np.isfinite(Iq))]
)
for m in range(num_lines):
if valid_lines[m] if len(valid_lines) > m else valid_lines[0]:
plot_line_with_marker(
plot_item,
q,
Iq[m] if Iq.ndim > 1 else Iq,
plot_id,
fi.saxs_1d["labels"][m],
alpha[n],
marker_size=marker_size,
log_x=log_x,
log_y=log_y,
)
plot_id += 1
# Optimized ylabel generation with caching
if plot_norm == 0: # no normalization
ylabel = (
"Intensity (1/cm)"
if absolute_crosssection
else "Intensity (photon/pixel/frame)"
)
# Set labels efficiently
if xlabel:
plot_item.setLabel("bottom", xlabel)
if ylabel:
plot_item.setLabel("left", ylabel)
plot_item.showGrid(x=True, y=True, alpha=0.3)
[docs]
def vectorized_q_binning(q_values, intensities, q_min, q_max, num_bins):
"""
Vectorized q-space binning for SAXS data with optimized memory usage.
Uses single-pass bincount accumulation instead of per-bin masking loops.
Args:
q_values: Q-values array
intensities: Intensity array [q_points] or [phi_slices, q_points]
q_min, q_max: Q-range for binning
num_bins: Number of bins
Returns:
Tuple of (binned_q, binned_intensity, bin_counts)
"""
# Create bin edges using vectorized operations
bin_edges = np.linspace(q_min, q_max, num_bins + 1)
bin_centers = 0.5 * (bin_edges[1:] + bin_edges[:-1])
# Vectorized binning using digitize
bin_indices = np.digitize(q_values, bin_edges) - 1
# Filter to valid range: indices < 0 (below q_min) or >= num_bins (at/above q_max) are excluded
valid = (bin_indices >= 0) & (bin_indices < num_bins)
valid_indices = bin_indices[valid]
# Single-pass bin counts via bincount
bin_counts = np.bincount(valid_indices, minlength=num_bins).astype(np.float64)[
:num_bins
]
# Compute per-bin means via scatter-accumulate (no Python loops).
# np.add.at performs unbuffered in-place accumulation — equivalent to a
# segment-reduce — in a single O(n) vectorized pass over valid_indices.
valid_intensities = (
intensities[..., valid] if intensities.ndim > 1 else intensities[valid]
)
if intensities.ndim == 1:
binned_intensity = np.zeros(num_bins, dtype=np.float64)
np.add.at(binned_intensity, valid_indices, valid_intensities)
# Divide by counts where non-zero (same as np.mean per bin)
nonzero = bin_counts > 0
binned_intensity[nonzero] /= bin_counts[nonzero]
else:
num_phi = intensities.shape[0]
binned_intensity = np.zeros((num_phi, num_bins), dtype=np.float64)
# valid_intensities: [num_phi, n_valid]; accumulate along bin axis for all phi at once
np.add.at(binned_intensity, (slice(None), valid_indices), valid_intensities)
nonzero = bin_counts > 0
binned_intensity[:, nonzero] /= bin_counts[nonzero]
return bin_centers, binned_intensity, bin_counts
[docs]
def vectorized_background_subtraction(foreground_data, background_data, weight=1.0):
"""
Vectorized background subtraction with error propagation.
Args:
foreground_data: Tuple of (q, I_fg, I_err_fg)
background_data: Tuple of (q_bg, I_bg, I_err_bg)
weight: Background scaling weight
Returns:
Tuple of (q, I_subtracted, I_err_propagated)
"""
q, I_fg, I_err_fg = foreground_data
q_bg, I_bg, I_err_bg = background_data
# Interpolate background to foreground q-values if needed
if not np.array_equal(q, q_bg):
from xpcsviewer.backends.scipy_replacements import interp1d
interp_func = interp1d(
q_bg, I_bg, kind="linear", bounds_error=False, fill_value=0
)
I_bg_interp = interp_func(q)
# Interpolate errors as well
if I_err_bg is not None:
interp_err_func = interp1d(
q_bg, I_err_bg, kind="linear", bounds_error=False, fill_value=0
)
I_err_bg_interp = interp_err_func(q)
else:
I_err_bg_interp = np.zeros_like(q)
else:
I_bg_interp = I_bg
I_err_bg_interp = I_err_bg if I_err_bg is not None else np.zeros_like(q)
# Vectorized background subtraction
I_subtracted = I_fg - weight * I_bg_interp
# Vectorized error propagation
if I_err_fg is not None:
I_err_propagated = np.sqrt(I_err_fg**2 + (weight * I_err_bg_interp) ** 2)
else:
I_err_propagated = weight * I_err_bg_interp
return q, I_subtracted, I_err_propagated
[docs]
def vectorized_intensity_normalization(
q_values, intensities, method="none", q_ref=None
):
"""
Vectorized intensity normalization using various methods.
Args:
q_values: Q-values array
intensities: Intensity array
method: Normalization method ('none', 'q2', 'q4', 'max', 'area')
q_ref: Reference q-value for certain normalizations
Returns:
Normalized intensity array
"""
if method == "none":
return intensities
if intensities.ndim == 1:
# 1D case
if method == "q2":
return intensities * q_values**2
if method == "q4":
return intensities * q_values**4
if method == "max":
return intensities / np.max(intensities)
if method == "area":
# Numerical integration for area normalization
area = np.trapezoid(intensities, q_values)
return intensities / area
# 2D case: multiple curves
elif method == "q2":
return intensities * q_values[np.newaxis, :] ** 2
elif method == "q4":
return intensities * q_values[np.newaxis, :] ** 4
elif method == "max":
max_vals = np.max(intensities, axis=1, keepdims=True)
return intensities / max_vals
elif method == "area":
# Vectorized area normalization: np.trapezoid on the last axis handles
# the full 2D array in one call, avoiding a Python loop over phi slices.
areas = np.trapezoid(intensities, q_values, axis=1) # [num_phi]
return intensities / areas[:, np.newaxis]
return intensities
[docs]
def batch_saxs_analysis(data_list, operations):
"""
Batch processing of multiple SAXS datasets with vectorized operations.
Args:
data_list: List of (q, I) tuples
operations: List of operations to apply
Returns:
List of processed (q, I) tuples
"""
processed_data = []
# Pre-allocate arrays for batch operations where possible
max(len(q) for q, _ in data_list)
for q, intensity in data_list:
processed_q = q
processed_I = intensity.copy()
for op in operations:
if op["type"] == "normalize":
processed_I = vectorized_intensity_normalization(
processed_q, processed_I, method=op["method"]
)
elif op["type"] == "smooth":
# Vectorized smoothing using JAX-compatible backend
from xpcsviewer.backends.scipy_replacements import gaussian_filter1d
sigma = op.get("sigma", 1.0)
processed_I = gaussian_filter1d(processed_I, sigma=sigma)
elif op["type"] == "trim":
# Vectorized trimming
q_min, q_max = op["q_range"]
mask = (processed_q >= q_min) & (processed_q <= q_max)
processed_q = processed_q[mask]
processed_I = processed_I[mask]
processed_data.append((processed_q, processed_I))
return processed_data