Source code for xpcsviewer.utils.vectorized_roi

"""
Vectorized ROI (Region of Interest) Calculations for XPCS Viewer

This module provides highly optimized vectorized ROI calculations with advanced
memory management and parallel processing capabilities for XPCS data analysis.
Uses JAX backend for GPU acceleration when available.
"""

import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from enum import Enum
from typing import Any

import numpy as np

from xpcsviewer.backends import get_backend
from xpcsviewer.backends._conversions import ensure_numpy

from .logging_config import get_logger
from .memory_manager import MemoryPressure, get_memory_manager

logger = get_logger(__name__)


[docs] class ROIType(Enum): """Types of ROI calculations.""" PIE = "pie" RING = "ring" PHI_RING = "phi_ring" RECTANGLE = "rectangle" POLYGON = "polygon"
[docs] @dataclass class ROIParameters: """Parameters for ROI calculation.""" roi_type: ROIType parameters: dict[str, Any] label: str = ""
[docs] @dataclass class ROIResult: """Result of ROI calculation.""" x_values: np.ndarray roi_data: np.ndarray roi_type: ROIType parameters: dict[str, Any] metadata: dict[str, Any] processing_time: float memory_used_mb: float
[docs] class VectorizedROICalculator(ABC): """Abstract base class for vectorized ROI calculators. Uses JAX backend for GPU acceleration when available, with automatic vmap for batch processing. """
[docs] def __init__(self, chunk_size_mb: float = 100.0): self.chunk_size_mb = chunk_size_mb self.memory_manager = get_memory_manager() self._backend = get_backend() self._vmap_enabled = self._backend.name == "jax"
[docs] @abstractmethod def calculate_roi_mask( self, geometry_data: dict[str, np.ndarray], roi_params: ROIParameters ) -> np.ndarray: """Calculate the ROI mask for the given parameters."""
[docs] @abstractmethod def process_roi_data( self, saxs_data: np.ndarray, roi_mask: np.ndarray, roi_params: ROIParameters ) -> ROIResult: """Process SAXS data with the ROI mask."""
[docs] def calculate_roi( self, saxs_data: np.ndarray, geometry_data: dict[str, np.ndarray], roi_params: ROIParameters, use_streaming: bool | None = None, ) -> ROIResult: """ Main ROI calculation method with automatic streaming detection. Parameters ---------- saxs_data : np.ndarray SAXS data array geometry_data : Dict[str, np.ndarray] Geometry arrays (qmap, rmap, pmap, mask, etc.) roi_params : ROIParameters ROI calculation parameters use_streaming : bool, optional Whether to use streaming processing (auto-detected if None) Returns ------- ROIResult ROI calculation result """ start_time = time.time() memory_before_mb = self.memory_manager.get_cache_stats()["current_memory_mb"] # Calculate ROI mask roi_mask = self.calculate_roi_mask(geometry_data, roi_params) # Determine if streaming is needed if use_streaming is None: data_size_mb = saxs_data.nbytes / (1024 * 1024) use_streaming = data_size_mb > self.chunk_size_mb * 2 if use_streaming: logger.info( f"Using streaming ROI calculation for {roi_params.roi_type.value}" ) result = self._process_roi_streaming(saxs_data, roi_mask, roi_params) else: logger.debug( f"Using vectorized ROI calculation for {roi_params.roi_type.value}" ) result = self.process_roi_data(saxs_data, roi_mask, roi_params) # Add timing and memory usage processing_time = time.time() - start_time memory_after_mb = self.memory_manager.get_cache_stats()["current_memory_mb"] result.processing_time = processing_time result.memory_used_mb = memory_after_mb - memory_before_mb logger.debug( f"ROI calculation completed: {processing_time:.3f}s, " f"{result.memory_used_mb:+.1f}MB memory" ) return result
def _process_roi_streaming( self, saxs_data: np.ndarray, roi_mask: np.ndarray, roi_params: ROIParameters ) -> ROIResult: """Process ROI calculation using streaming for memory efficiency.""" from .streaming_processor import MemoryEfficientIterator # Create streaming iterator for SAXS data iterator = MemoryEfficientIterator(saxs_data, chunk_size_mb=self.chunk_size_mb) # Initialize accumulation arrays if roi_params.roi_type == ROIType.PIE: # For pie ROI, we accumulate binned q-values qsize = roi_params.parameters.get("qsize", 1000) accumulated_data: Any = np.zeros(qsize) normalization: Any = np.zeros(qsize) elif roi_params.roi_type == ROIType.RING: # For ring ROI, we accumulate angular values phi_num = roi_params.parameters.get("phi_num", 180) accumulated_data = np.zeros(phi_num) normalization = np.zeros(phi_num) else: # Generic accumulation accumulated_data = 0.0 normalization = 0.0 # Process chunks for chunk, chunk_info in iterator: # Extract roi mask for this chunk chunk_mask_slice = tuple(slice(0, dim_size) for dim_size in chunk.shape[1:]) chunk_roi_mask = roi_mask[chunk_mask_slice] # Process this chunk chunk_result = self._process_chunk_vectorized( chunk, chunk_roi_mask, roi_params, chunk_info ) # Accumulate results if isinstance(accumulated_data, np.ndarray): accumulated_data += chunk_result["data"] normalization += chunk_result["norm"] else: accumulated_data += chunk_result["data"] normalization += chunk_result["norm"] # Finalize results if isinstance(accumulated_data, np.ndarray) and normalization.sum() > 0: final_data = accumulated_data / np.maximum(normalization, 1e-10) else: final_data = accumulated_data / max(normalization, 1e-10) # Generate x-values based on ROI type x_values = self._generate_x_values(roi_params) return ROIResult( x_values=x_values, roi_data=final_data, roi_type=roi_params.roi_type, parameters=roi_params.parameters, metadata={"streaming_used": True, "chunks_processed": len(iterator)}, processing_time=0.0, # Will be set by caller memory_used_mb=0.0, # Will be set by caller ) def _process_chunk_vectorized( self, chunk: np.ndarray, roi_mask: np.ndarray, roi_params: ROIParameters, chunk_info, ) -> dict[str, np.ndarray]: """Process a single chunk with vectorized operations.""" # Apply ROI mask using broadcasting masked_chunk = chunk * roi_mask[np.newaxis, ...] if roi_params.roi_type == ROIType.PIE: return self._process_pie_chunk(masked_chunk, roi_mask, roi_params) if roi_params.roi_type == ROIType.RING: return self._process_ring_chunk(masked_chunk, roi_mask, roi_params) # Generic processing data = np.sum(masked_chunk) norm = np.sum(roi_mask) return {"data": data, "norm": norm} def _process_pie_chunk( self, masked_chunk: np.ndarray, roi_mask: np.ndarray, roi_params: ROIParameters ) -> dict[str, np.ndarray]: """Process pie ROI chunk with q-value binning.""" qmap_idx = roi_params.parameters.get("qmap_idx") qsize = roi_params.parameters.get("qsize", 1000) if qmap_idx is None: # Fallback to simple accumulation return {"data": np.sum(masked_chunk), "norm": np.sum(roi_mask)} # Vectorized binning using np.bincount flat_qmap = np.where(roi_mask, qmap_idx, 0).ravel() masked_chunk.ravel() # Process each time frame chunk_data = np.zeros(qsize) chunk_norm = np.zeros(qsize) for t in range(masked_chunk.shape[0]): frame_data = masked_chunk[t].ravel() binned_data = np.bincount(flat_qmap, frame_data, minlength=qsize + 1)[1:] binned_norm = np.bincount( flat_qmap, roi_mask.ravel().astype(float), minlength=qsize + 1 )[1:] chunk_data += binned_data chunk_norm += binned_norm return {"data": chunk_data, "norm": chunk_norm} def _process_ring_chunk( self, masked_chunk: np.ndarray, roi_mask: np.ndarray, roi_params: ROIParameters ) -> dict[str, np.ndarray]: """Process ring ROI chunk with angular binning.""" phi_idx = roi_params.parameters.get("phi_idx") phi_num = roi_params.parameters.get("phi_num", 180) if phi_idx is None: return {"data": np.sum(masked_chunk), "norm": np.sum(roi_mask)} # Vectorized angular binning flat_phi = np.where(roi_mask, phi_idx, 0).ravel() chunk_data = np.zeros(phi_num) chunk_norm = np.zeros(phi_num) for t in range(masked_chunk.shape[0]): frame_data = masked_chunk[t].ravel() binned_data = np.bincount(flat_phi, frame_data, minlength=phi_num + 1)[1:] binned_norm = np.bincount( flat_phi, roi_mask.ravel().astype(float), minlength=phi_num + 1 )[1:] chunk_data += binned_data chunk_norm += binned_norm return {"data": chunk_data, "norm": chunk_norm} def _generate_x_values(self, roi_params: ROIParameters) -> np.ndarray: """Generate x-axis values based on ROI type.""" if roi_params.roi_type == ROIType.PIE: # Q-values for pie ROI qmin = roi_params.parameters.get("qmin", 0.0) qmax = roi_params.parameters.get("qmax", 1.0) qsize = roi_params.parameters.get("qsize", 1000) return np.linspace(qmin, qmax, qsize) if roi_params.roi_type == ROIType.RING: # Angular values for ring ROI phi_min = roi_params.parameters.get("phi_min", 0.0) phi_max = roi_params.parameters.get("phi_max", 360.0) phi_num = roi_params.parameters.get("phi_num", 180) return np.linspace(phi_min, phi_max, phi_num) # Default x-values return np.arange(len(roi_params.parameters.get("default_size", 100)))
[docs] def process_batch_vmap( self, frames: np.ndarray, roi_mask: np.ndarray ) -> np.ndarray: """Process batch of frames using JAX vmap for GPU acceleration. Parameters ---------- frames : np.ndarray Stack of frames [N, H, W] roi_mask : np.ndarray ROI mask [H, W] Returns ------- np.ndarray Summed intensities for each frame [N] """ if self._vmap_enabled: import jax import jax.numpy as jnp # Convert to JAX arrays frames_jax = jnp.asarray(frames) mask_jax = jnp.asarray(roi_mask) # Define single-frame processing function def process_single_frame(frame): return jnp.sum(jnp.where(mask_jax, frame, 0.0)) # Apply vmap for batch processing batch_fn = jax.vmap(process_single_frame) result = batch_fn(frames_jax) # Return as NumPy array return ensure_numpy(result) # NumPy fallback for non-JAX backend return np.sum(frames * roi_mask[np.newaxis, ...], axis=(1, 2))
[docs] class PieROICalculator(VectorizedROICalculator): """Optimized calculator for pie-shaped ROI."""
[docs] def calculate_roi_mask( self, geometry_data: dict[str, np.ndarray], roi_params: ROIParameters ) -> np.ndarray: """Calculate pie ROI mask using vectorized operations.""" pmap = geometry_data["pmap"] mask = geometry_data.get("mask", np.ones_like(pmap)) pmin, pmax = roi_params.parameters["angle_range"] # Handle angle wraparound for pie ROI if pmin < pmax: angle_mask = (pmap >= pmin) & (pmap < pmax) else: # Wraparound case (e.g., pmin=350, pmax=10) angle_mask = (pmap >= pmin) | (pmap < pmax) # Combine with detector mask roi_mask = angle_mask & (mask > 0) logger.debug( f"Pie ROI mask: {np.sum(roi_mask)} pixels selected " f"({np.sum(roi_mask) / roi_mask.size * 100:.1f}%)" ) return roi_mask
[docs] def process_roi_data( self, saxs_data: np.ndarray, roi_mask: np.ndarray, roi_params: ROIParameters ) -> ROIResult: """Process pie ROI with optimized q-value binning.""" qmap_idx = roi_params.parameters["qmap_idx"] qsize = roi_params.parameters["qsize"] qspan = roi_params.parameters.get("qspan", np.arange(qsize + 1)) # Vectorized binning for all time frames at once if saxs_data.ndim == 3: # Multiple time frames roi_data = np.zeros((saxs_data.shape[0], qsize)) # Prepare masked indices flat_qmap = np.where(roi_mask, qmap_idx, 0).ravel() for t in range(saxs_data.shape[0]): frame_data = saxs_data[t].ravel() binned = np.bincount(flat_qmap, frame_data, minlength=qsize + 1) roi_data[t] = binned[1:] # Remove 0th bin # Average over time if multiple frames final_roi_data = np.mean(roi_data, axis=0) else: # Single frame flat_qmap = np.where(roi_mask, qmap_idx, 0).ravel() flat_data = saxs_data.ravel() binned = np.bincount(flat_qmap, flat_data, minlength=qsize + 1) final_roi_data = binned[1:] # Generate q-values x_values = qspan[:-1] if len(qspan) == qsize + 1 else np.arange(qsize) # Apply distance cutoff if specified if "dist" in roi_params.parameters: roi_params.parameters["dist"] qmax = roi_params.parameters.get("qmax") if qmax is not None: qmax_idx = np.searchsorted(x_values, qmax) final_roi_data[qmax_idx:] = np.nan return ROIResult( x_values=ensure_numpy(x_values), roi_data=ensure_numpy(final_roi_data), roi_type=roi_params.roi_type, parameters=roi_params.parameters, metadata={"vectorized": True, "qsize": qsize}, processing_time=0.0, memory_used_mb=0.0, )
[docs] class RingROICalculator(VectorizedROICalculator): """Optimized calculator for ring-shaped ROI."""
[docs] def calculate_roi_mask( self, geometry_data: dict[str, np.ndarray], roi_params: ROIParameters ) -> np.ndarray: """Calculate ring ROI mask using vectorized operations.""" rmap = geometry_data["rmap"] mask = geometry_data.get("mask", np.ones_like(rmap)) rmin, rmax = roi_params.parameters["radius_range"] # Vectorized ring mask calculation ring_mask = (rmap >= rmin) & (rmap < rmax) & (mask > 0) logger.debug( f"Ring ROI mask: {np.sum(ring_mask)} pixels selected " f"({np.sum(ring_mask) / ring_mask.size * 100:.1f}%)" ) return ring_mask
[docs] def process_roi_data( self, saxs_data: np.ndarray, roi_mask: np.ndarray, roi_params: ROIParameters ) -> ROIResult: """Process ring ROI with optimized angular binning.""" pmap = roi_params.parameters["pmap"] phi_num = roi_params.parameters.get("phi_num", 180) # Calculate angular indices for binning pmap_roi = pmap[roi_mask] phi_min, phi_max = np.min(pmap_roi), np.max(pmap_roi) phi_indices = np.floor((pmap - phi_min) / (phi_max - phi_min) * phi_num).astype( int ) phi_indices = np.clip(phi_indices, 0, phi_num - 1) # Create angular binning mask angular_idx = np.where(roi_mask, phi_indices, 0).ravel() # Process SAXS data if saxs_data.ndim == 3: roi_data = np.zeros((saxs_data.shape[0], phi_num)) for t in range(saxs_data.shape[0]): frame_data = saxs_data[t].ravel() binned = np.bincount(angular_idx, frame_data, minlength=phi_num + 1) roi_data[t] = binned[1:] final_roi_data = np.mean(roi_data, axis=0) else: flat_data = saxs_data.ravel() binned = np.bincount(angular_idx, flat_data, minlength=phi_num + 1) final_roi_data = binned[1:] # Generate angular x-values x_values = np.linspace(phi_min, phi_max, phi_num) return ROIResult( x_values=ensure_numpy(x_values), roi_data=ensure_numpy(final_roi_data), roi_type=roi_params.roi_type, parameters=roi_params.parameters, metadata={ "vectorized": True, "phi_num": phi_num, "phi_range": (phi_min, phi_max), }, processing_time=0.0, memory_used_mb=0.0, )
[docs] class ParallelROIProcessor: """Process multiple ROIs in parallel for enhanced performance."""
[docs] def __init__(self, max_workers: int | None = None): self.max_workers = max_workers self.memory_manager = get_memory_manager()
[docs] def calculate_multiple_rois( self, saxs_data: np.ndarray, geometry_data: dict[str, np.ndarray], roi_list: list[ROIParameters], use_parallel: bool = True, ) -> list[ROIResult]: """ Calculate multiple ROIs with optional parallel processing. Parameters ---------- saxs_data : np.ndarray SAXS data array geometry_data : Dict[str, np.ndarray] Geometry arrays roi_list : List[ROIParameters] List of ROI parameters to calculate use_parallel : bool Whether to use parallel processing Returns ------- List[ROIResult] List of ROI calculation results """ if not roi_list: return [] # Check memory pressure pressure = self.memory_manager.get_memory_pressure() if pressure in [MemoryPressure.HIGH, MemoryPressure.CRITICAL]: logger.warning( "High memory pressure detected, using sequential ROI processing" ) use_parallel = False if not use_parallel or len(roi_list) < 2: # Sequential processing return self._calculate_sequential(saxs_data, geometry_data, roi_list) # Parallel processing return self._calculate_parallel(saxs_data, geometry_data, roi_list)
def _calculate_sequential( self, saxs_data: np.ndarray, geometry_data: dict[str, np.ndarray], roi_list: list[ROIParameters], ) -> list[ROIResult]: """Calculate ROIs sequentially.""" results = [] for roi_params in roi_list: calculator = self._get_calculator(roi_params.roi_type) result = calculator.calculate_roi(saxs_data, geometry_data, roi_params) results.append(result) return results def _calculate_parallel( self, saxs_data: np.ndarray, geometry_data: dict[str, np.ndarray], roi_list: list[ROIParameters], ) -> list[ROIResult]: """Calculate ROIs in parallel.""" import os max_workers = self.max_workers or min(len(roi_list), os.cpu_count() or 1) logger.info( f"Processing {len(roi_list)} ROIs in parallel with {max_workers} workers" ) with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all ROI calculations future_to_roi = {} for i, roi_params in enumerate(roi_list): calculator = self._get_calculator(roi_params.roi_type) future = executor.submit( calculator.calculate_roi, saxs_data, geometry_data, roi_params ) future_to_roi[future] = i # Collect results in order results: list[ROIResult | None] = [None] * len(roi_list) for future in as_completed(future_to_roi): index = future_to_roi[future] try: results[index] = future.result() except Exception as e: logger.error(f"ROI calculation {index} failed: {e}") # Create empty result for failed calculation results[index] = ROIResult( x_values=np.array([]), roi_data=np.array([]), roi_type=roi_list[index].roi_type, parameters=roi_list[index].parameters, metadata={"error": str(e)}, processing_time=0.0, memory_used_mb=0.0, ) from typing import cast return cast(list[ROIResult], results) def _get_calculator(self, roi_type: ROIType) -> VectorizedROICalculator: """Get appropriate calculator for ROI type.""" if roi_type == ROIType.PIE: return PieROICalculator() if roi_type == ROIType.RING: return RingROICalculator() # Default to pie calculator for unknown types logger.warning(f"Unknown ROI type {roi_type}, using pie calculator") return PieROICalculator()
# Convenience functions for easy usage
[docs] def calculate_pie_roi( saxs_data: np.ndarray, geometry_data: dict[str, np.ndarray], angle_range: tuple[float, float], **kwargs, ) -> ROIResult: """Convenience function for pie ROI calculation.""" roi_params = ROIParameters( roi_type=ROIType.PIE, parameters={"angle_range": angle_range, **kwargs} ) calculator = PieROICalculator() return calculator.calculate_roi(saxs_data, geometry_data, roi_params)
[docs] def calculate_ring_roi( saxs_data: np.ndarray, geometry_data: dict[str, np.ndarray], radius_range: tuple[float, float], **kwargs, ) -> ROIResult: """Convenience function for ring ROI calculation.""" roi_params = ROIParameters( roi_type=ROIType.RING, parameters={"radius_range": radius_range, **kwargs} ) calculator = RingROICalculator() return calculator.calculate_roi(saxs_data, geometry_data, roi_params)
[docs] def calculate_multiple_rois_parallel( saxs_data: np.ndarray, geometry_data: dict[str, np.ndarray], roi_list: list[ROIParameters], max_workers: int | None = None, ) -> list[ROIResult]: """Convenience function for parallel multiple ROI calculation.""" processor = ParallelROIProcessor(max_workers=max_workers) return processor.calculate_multiple_rois( saxs_data, geometry_data, roi_list, use_parallel=True )