Source code for xpcsviewer.utils.streaming_processor

"""
Streaming Data Processing for XPCS Viewer

This module provides memory-efficient streaming processing for large XPCS datasets,
particularly for operations like logarithmic transformations of SAXS data.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any

import numpy as np

from .logging_config import get_logger
from .memory_manager import MemoryPressure, get_memory_manager

logger = get_logger(__name__)


[docs] @dataclass class ChunkInfo: """Information about a data chunk.""" index: int slice_obj: tuple[slice, ...] shape: tuple[int, ...] estimated_size_mb: float total_chunks: int
[docs] def calculate_chunk_slices( shape: tuple[int, ...], dtype: np.dtype, chunk_size_mb: float ) -> list[tuple[slice, ...]]: """Calculate optimal chunk slices for streaming processing.""" itemsize = dtype.itemsize total_size_mb = np.prod(shape) * itemsize / (1024 * 1024) # If data is small enough, process as single chunk if total_size_mb <= chunk_size_mb: return [tuple(slice(0, dim_size) for dim_size in shape)] # For large data, calculate chunks along the first dimension elements_per_chunk = int(chunk_size_mb * 1024 * 1024 / itemsize) # Calculate elements per row for multi-dimensional arrays if len(shape) > 1: elements_per_row = np.prod(shape[1:]) rows_per_chunk = max(1, int(elements_per_chunk // elements_per_row)) else: rows_per_chunk = elements_per_chunk # Ensure rows_per_chunk is a standard python int for range() rows_per_chunk = int(rows_per_chunk) chunk_slices = [] for start_row in range(0, shape[0], rows_per_chunk): end_row = min(start_row + rows_per_chunk, shape[0]) # Create slice tuple slice_obj = tuple( [ slice(start_row, end_row) if dim == 0 else slice(0, shape[dim]) for dim in range(len(shape)) ] ) chunk_slices.append(slice_obj) return chunk_slices
[docs] class StreamingProcessor(ABC): """Abstract base class for streaming data processors."""
[docs] def __init__(self, chunk_size_mb: float = 50.0): self.chunk_size_mb = chunk_size_mb self.memory_manager = get_memory_manager()
[docs] @abstractmethod def process_chunk(self, chunk: np.ndarray, chunk_info: ChunkInfo) -> Any: """Process a single data chunk."""
[docs] @abstractmethod def combine_chunks( self, processed_chunks: list, original_shape: tuple[int, ...] ) -> Any: """Combine processed chunks into final result."""
[docs] def process_array_streaming( self, data: np.ndarray, output_dtype: np.dtype | None = None ) -> np.ndarray: """ Process array data using streaming to reduce memory footprint. Parameters ---------- data : np.ndarray Input array to process output_dtype : np.dtype, optional Output data type for memory optimization Returns ------- np.ndarray Processed array """ if output_dtype is None: output_dtype = data.dtype # Calculate optimal chunk size based on memory constraints chunk_slices = calculate_chunk_slices( data.shape, data.dtype, self.chunk_size_mb ) total_chunks = len(chunk_slices) logger.info( f"Processing array with streaming: {total_chunks} chunks of ~{self.chunk_size_mb:.1f}MB each" ) processed_chunks = [] memory_before = self.memory_manager.get_cache_stats()["current_memory_mb"] for i, slice_obj in enumerate(chunk_slices): # Monitor memory pressure before each chunk pressure = self.memory_manager.get_memory_pressure() if pressure == MemoryPressure.CRITICAL: logger.warning( "Critical memory pressure during streaming, triggering cleanup" ) self.memory_manager._emergency_cleanup() # Create chunk info chunk_shape = tuple( slice_obj[dim].stop - slice_obj[dim].start for dim in range(len(slice_obj)) ) chunk_info = ChunkInfo( index=i, slice_obj=slice_obj, shape=chunk_shape, estimated_size_mb=np.prod(chunk_shape) * np.dtype(data.dtype).itemsize / (1024 * 1024), total_chunks=total_chunks, ) # Extract and process chunk chunk = data[slice_obj] processed_chunk = self.process_chunk(chunk, chunk_info) processed_chunks.append((slice_obj, processed_chunk)) # Progress reporting if i % max(1, total_chunks // 10) == 0: progress = (i + 1) / total_chunks * 100 logger.debug( f"Streaming progress: {progress:.1f}% ({i + 1}/{total_chunks})" ) # Combine chunks into final result result = self.combine_chunks(processed_chunks, data.shape) memory_after = self.memory_manager.get_cache_stats()["current_memory_mb"] memory_used = memory_after - memory_before logger.info( f"Streaming processing completed, memory used: {memory_used:+.1f}MB" ) return result
[docs] class SAXSLogProcessor(StreamingProcessor): """Streaming processor for SAXS logarithmic transformations."""
[docs] def __init__(self, chunk_size_mb: float = 50.0, epsilon: float = 1e-10): super().__init__(chunk_size_mb) self.epsilon = epsilon # Set by process_array_streaming before process_chunk is called. self._global_min: float = epsilon
[docs] def process_array_streaming( self, data: np.ndarray, output_dtype: np.dtype | None = None ) -> np.ndarray: """ Two-pass streaming log transform that uses the global minimum positive value (across all chunks) as the replacement for non-positive pixels. Pass 1 — find global_min across all chunks. Pass 2 — apply log10 with non-positive pixels replaced by global_min. This overrides the single-pass base implementation to fix a correctness bug where using a per-chunk local minimum would produce wrong values for non-positive pixels in chunks whose local minimum differs from the global minimum (e.g. detector beamstop surrounded by high-intensity regions). """ # Pass 1: find global minimum positive value across all chunks chunk_slices = calculate_chunk_slices( data.shape, data.dtype, self.chunk_size_mb ) global_min = np.inf for slice_obj in chunk_slices: chunk = data[slice_obj] pos_vals = chunk[chunk > 0] if pos_vals.size > 0: chunk_min = float(pos_vals.min()) if chunk_min < global_min: global_min = chunk_min self._global_min = max( global_min if np.isfinite(global_min) else self.epsilon, self.epsilon ) logger.debug(f"SAXSLogProcessor global_min = {self._global_min:.6g}") # Pass 2: apply log transform using global_min (delegates to base class loop) return super().process_array_streaming(data, output_dtype)
[docs] def process_chunk(self, chunk: np.ndarray, chunk_info: ChunkInfo) -> np.ndarray: """ Process SAXS data chunk with logarithmic transformation. Uses the global minimum positive value (set by ``process_array_streaming`` pass 1) as the replacement for non-positive pixels so that all chunks use a consistent floor value. Parameters ---------- chunk : np.ndarray Input SAXS data chunk chunk_info : ChunkInfo Information about the chunk Returns ------- np.ndarray Log-transformed chunk """ saxs_chunk = chunk.astype(np.float32) positive_mask = saxs_chunk > 0 saxs_chunk[~positive_mask] = self._global_min log_chunk = np.log10(saxs_chunk) logger.debug( f"Processed chunk {chunk_info.index + 1}/{chunk_info.total_chunks}: " f"{chunk_info.shape} -> log transform" ) return log_chunk
[docs] def combine_chunks( self, processed_chunks: list, original_shape: tuple[int, ...] ) -> np.ndarray: """ Combine processed log chunks into final result array. Parameters ---------- processed_chunks : list List of (slice_obj, processed_chunk) tuples original_shape : tuple[int, ...] Shape of original array Returns ------- np.ndarray Combined log-transformed array """ # Create output array result = np.zeros(original_shape, dtype=np.float32) # Fill in processed chunks for slice_obj, processed_chunk in processed_chunks: result[slice_obj] = processed_chunk return result
[docs] class ROIProcessor(StreamingProcessor): """Streaming processor for ROI (Region of Interest) calculations."""
[docs] def __init__(self, roi_mask: np.ndarray, chunk_size_mb: float = 50.0): super().__init__(chunk_size_mb) self.roi_mask = roi_mask
[docs] def process_chunk(self, chunk: np.ndarray, chunk_info: ChunkInfo) -> dict: """ Process chunk for ROI calculations. Parameters ---------- chunk : np.ndarray Input data chunk chunk_info : ChunkInfo Information about the chunk Returns ------- dict ROI statistics for the chunk """ # Apply ROI mask to chunk roi_slice = self.roi_mask[chunk_info.slice_obj[1:]] # Skip time dimension roi_data = chunk * roi_slice[np.newaxis, ...] # Broadcast for time dimension # Calculate ROI statistics roi_sum = np.sum(roi_data, axis=(1, 2)) # Sum over spatial dimensions roi_mean = np.mean(roi_data, axis=(1, 2)) roi_std = np.std(roi_data, axis=(1, 2)) return { "sum": roi_sum, "mean": roi_mean, "std": roi_std, "slice_obj": chunk_info.slice_obj, }
[docs] def combine_chunks( self, processed_chunks: list, original_shape: tuple[int, ...] ) -> dict: """ Combine ROI chunk results into final statistics. Parameters ---------- processed_chunks : list List of (slice_obj, roi_stats) tuples original_shape : tuple[int, ...] Shape of original array Returns ------- dict Combined ROI statistics """ # Initialize result arrays total_time_points = original_shape[0] roi_sum = np.zeros(total_time_points) roi_mean = np.zeros(total_time_points) roi_std = np.zeros(total_time_points) # Combine chunk results for slice_obj, roi_stats in processed_chunks: time_slice = slice_obj[0] roi_sum[time_slice] = roi_stats["sum"] roi_mean[time_slice] = roi_stats["mean"] roi_std[time_slice] = roi_stats["std"] return {"roi_sum": roi_sum, "roi_mean": roi_mean, "roi_std": roi_std}
[docs] class AdaptiveChunkSizer: """ Adaptive chunk sizing based on memory pressure and system performance. """
[docs] def __init__(self, base_chunk_size_mb: float = 50.0): self.base_chunk_size_mb = base_chunk_size_mb self.memory_manager = get_memory_manager() self.performance_history: list[float] = []
[docs] def get_optimal_chunk_size( self, data_shape: tuple[int, ...], data_dtype: np.dtype ) -> float: """ Calculate optimal chunk size based on current conditions. Parameters ---------- data_shape : tuple[int, ...] Shape of data to be processed data_dtype : np.dtype Data type of array Returns ------- float Optimal chunk size in MB """ # Base chunk size chunk_size = self.base_chunk_size_mb # Adjust based on memory pressure pressure = self.memory_manager.get_memory_pressure() if pressure == MemoryPressure.CRITICAL: chunk_size *= 0.25 # Very small chunks elif pressure == MemoryPressure.HIGH: chunk_size *= 0.5 # Small chunks elif pressure == MemoryPressure.MODERATE: chunk_size *= 0.75 # Moderately small chunks # LOW pressure uses base size # Adjust based on total data size total_size_mb = np.prod(data_shape) * data_dtype.itemsize / (1024 * 1024) if total_size_mb < 100: # Small dataset chunk_size = min(chunk_size, total_size_mb) # Don't over-chunk elif total_size_mb > 1000: # Very large dataset chunk_size = min(chunk_size, 100.0) # Cap chunk size # Ensure minimum chunk size chunk_size = max(chunk_size, 10.0) logger.debug( f"Adaptive chunk size: {chunk_size:.1f}MB " f"(pressure: {pressure.value}, data: {total_size_mb:.1f}MB)" ) return chunk_size
[docs] def process_saxs_log_streaming( data: np.ndarray, chunk_size_mb: float | None = None, epsilon: float = 1e-10 ) -> np.ndarray: """ Convenience function for streaming SAXS log processing. Parameters ---------- data : np.ndarray Input SAXS data chunk_size_mb : float, optional Chunk size in MB (auto-calculated if None) epsilon : float Small value for handling non-positive data Returns ------- np.ndarray Log-transformed SAXS data """ if chunk_size_mb is None: # Use adaptive chunk sizing sizer = AdaptiveChunkSizer() chunk_size_mb = sizer.get_optimal_chunk_size(data.shape, data.dtype) processor = SAXSLogProcessor(chunk_size_mb=chunk_size_mb, epsilon=epsilon) return processor.process_array_streaming(data)
[docs] def process_roi_streaming( data: np.ndarray, roi_mask: np.ndarray, chunk_size_mb: float | None = None ) -> dict: """ Convenience function for streaming ROI processing. Parameters ---------- data : np.ndarray Input data array roi_mask : np.ndarray ROI mask array chunk_size_mb : float, optional Chunk size in MB (auto-calculated if None) Returns ------- dict ROI statistics """ if chunk_size_mb is None: # Use adaptive chunk sizing sizer = AdaptiveChunkSizer() chunk_size_mb = sizer.get_optimal_chunk_size(data.shape, data.dtype) processor = ROIProcessor(roi_mask=roi_mask, chunk_size_mb=chunk_size_mb) # Calculate optimal chunk slices chunk_slices = calculate_chunk_slices(data.shape, data.dtype, chunk_size_mb) processed_chunks = [] for i, slice_obj in enumerate(chunk_slices): chunk_shape = tuple( slice_obj[dim].stop - slice_obj[dim].start for dim in range(len(slice_obj)) ) chunk_info = ChunkInfo( index=i, slice_obj=slice_obj, shape=chunk_shape, estimated_size_mb=np.prod(chunk_shape) * data.dtype.itemsize / (1024 * 1024), total_chunks=len(chunk_slices), ) chunk = data[slice_obj] roi_stats = processor.process_chunk(chunk, chunk_info) processed_chunks.append((slice_obj, roi_stats)) return processor.combine_chunks(processed_chunks, data.shape)
[docs] class MemoryEfficientIterator: """ Memory-efficient iterator for large arrays with automatic cleanup. """
[docs] def __init__(self, data: np.ndarray, chunk_size_mb: float = 50.0): self.data = data self.chunk_size_mb = chunk_size_mb self.memory_manager = get_memory_manager() # Calculate chunk slices # Calculate chunk slices self.chunk_slices = calculate_chunk_slices( data.shape, data.dtype, chunk_size_mb ) self.current_index = 0
def __iter__(self): return self def __next__(self): if self.current_index >= len(self.chunk_slices): raise StopIteration # Check memory pressure before yielding chunk pressure = self.memory_manager.get_memory_pressure() if pressure in [MemoryPressure.HIGH, MemoryPressure.CRITICAL]: logger.warning("Memory pressure during iteration, triggering cleanup") self.memory_manager._aggressive_cleanup() slice_obj = self.chunk_slices[self.current_index] chunk = self.data[slice_obj] chunk_info = ChunkInfo( index=self.current_index, slice_obj=slice_obj, shape=chunk.shape, estimated_size_mb=chunk.nbytes / (1024 * 1024), total_chunks=len(self.chunk_slices), ) self.current_index += 1 return chunk, chunk_info def __len__(self): return len(self.chunk_slices)