"""
Streaming Data Processing for XPCS Viewer
This module provides memory-efficient streaming processing for large XPCS datasets,
particularly for operations like logarithmic transformations of SAXS data.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
import numpy as np
from .logging_config import get_logger
from .memory_manager import MemoryPressure, get_memory_manager
logger = get_logger(__name__)
[docs]
@dataclass
class ChunkInfo:
"""Information about a data chunk."""
index: int
slice_obj: tuple[slice, ...]
shape: tuple[int, ...]
estimated_size_mb: float
total_chunks: int
[docs]
def calculate_chunk_slices(
shape: tuple[int, ...], dtype: np.dtype, chunk_size_mb: float
) -> list[tuple[slice, ...]]:
"""Calculate optimal chunk slices for streaming processing."""
itemsize = dtype.itemsize
total_size_mb = np.prod(shape) * itemsize / (1024 * 1024)
# If data is small enough, process as single chunk
if total_size_mb <= chunk_size_mb:
return [tuple(slice(0, dim_size) for dim_size in shape)]
# For large data, calculate chunks along the first dimension
elements_per_chunk = int(chunk_size_mb * 1024 * 1024 / itemsize)
# Calculate elements per row for multi-dimensional arrays
if len(shape) > 1:
elements_per_row = np.prod(shape[1:])
rows_per_chunk = max(1, int(elements_per_chunk // elements_per_row))
else:
rows_per_chunk = elements_per_chunk
# Ensure rows_per_chunk is a standard python int for range()
rows_per_chunk = int(rows_per_chunk)
chunk_slices = []
for start_row in range(0, shape[0], rows_per_chunk):
end_row = min(start_row + rows_per_chunk, shape[0])
# Create slice tuple
slice_obj = tuple(
[
slice(start_row, end_row) if dim == 0 else slice(0, shape[dim])
for dim in range(len(shape))
]
)
chunk_slices.append(slice_obj)
return chunk_slices
[docs]
class StreamingProcessor(ABC):
"""Abstract base class for streaming data processors."""
[docs]
def __init__(self, chunk_size_mb: float = 50.0):
self.chunk_size_mb = chunk_size_mb
self.memory_manager = get_memory_manager()
[docs]
@abstractmethod
def process_chunk(self, chunk: np.ndarray, chunk_info: ChunkInfo) -> Any:
"""Process a single data chunk."""
[docs]
@abstractmethod
def combine_chunks(
self, processed_chunks: list, original_shape: tuple[int, ...]
) -> Any:
"""Combine processed chunks into final result."""
[docs]
def process_array_streaming(
self, data: np.ndarray, output_dtype: np.dtype | None = None
) -> np.ndarray:
"""
Process array data using streaming to reduce memory footprint.
Parameters
----------
data : np.ndarray
Input array to process
output_dtype : np.dtype, optional
Output data type for memory optimization
Returns
-------
np.ndarray
Processed array
"""
if output_dtype is None:
output_dtype = data.dtype
# Calculate optimal chunk size based on memory constraints
chunk_slices = calculate_chunk_slices(
data.shape, data.dtype, self.chunk_size_mb
)
total_chunks = len(chunk_slices)
logger.info(
f"Processing array with streaming: {total_chunks} chunks of ~{self.chunk_size_mb:.1f}MB each"
)
processed_chunks = []
memory_before = self.memory_manager.get_cache_stats()["current_memory_mb"]
for i, slice_obj in enumerate(chunk_slices):
# Monitor memory pressure before each chunk
pressure = self.memory_manager.get_memory_pressure()
if pressure == MemoryPressure.CRITICAL:
logger.warning(
"Critical memory pressure during streaming, triggering cleanup"
)
self.memory_manager._emergency_cleanup()
# Create chunk info
chunk_shape = tuple(
slice_obj[dim].stop - slice_obj[dim].start
for dim in range(len(slice_obj))
)
chunk_info = ChunkInfo(
index=i,
slice_obj=slice_obj,
shape=chunk_shape,
estimated_size_mb=np.prod(chunk_shape)
* np.dtype(data.dtype).itemsize
/ (1024 * 1024),
total_chunks=total_chunks,
)
# Extract and process chunk
chunk = data[slice_obj]
processed_chunk = self.process_chunk(chunk, chunk_info)
processed_chunks.append((slice_obj, processed_chunk))
# Progress reporting
if i % max(1, total_chunks // 10) == 0:
progress = (i + 1) / total_chunks * 100
logger.debug(
f"Streaming progress: {progress:.1f}% ({i + 1}/{total_chunks})"
)
# Combine chunks into final result
result = self.combine_chunks(processed_chunks, data.shape)
memory_after = self.memory_manager.get_cache_stats()["current_memory_mb"]
memory_used = memory_after - memory_before
logger.info(
f"Streaming processing completed, memory used: {memory_used:+.1f}MB"
)
return result
[docs]
class SAXSLogProcessor(StreamingProcessor):
"""Streaming processor for SAXS logarithmic transformations."""
[docs]
def __init__(self, chunk_size_mb: float = 50.0, epsilon: float = 1e-10):
super().__init__(chunk_size_mb)
self.epsilon = epsilon
# Set by process_array_streaming before process_chunk is called.
self._global_min: float = epsilon
[docs]
def process_array_streaming(
self, data: np.ndarray, output_dtype: np.dtype | None = None
) -> np.ndarray:
"""
Two-pass streaming log transform that uses the global minimum positive
value (across all chunks) as the replacement for non-positive pixels.
Pass 1 — find global_min across all chunks.
Pass 2 — apply log10 with non-positive pixels replaced by global_min.
This overrides the single-pass base implementation to fix a correctness
bug where using a per-chunk local minimum would produce wrong values for
non-positive pixels in chunks whose local minimum differs from the global
minimum (e.g. detector beamstop surrounded by high-intensity regions).
"""
# Pass 1: find global minimum positive value across all chunks
chunk_slices = calculate_chunk_slices(
data.shape, data.dtype, self.chunk_size_mb
)
global_min = np.inf
for slice_obj in chunk_slices:
chunk = data[slice_obj]
pos_vals = chunk[chunk > 0]
if pos_vals.size > 0:
chunk_min = float(pos_vals.min())
if chunk_min < global_min:
global_min = chunk_min
self._global_min = max(
global_min if np.isfinite(global_min) else self.epsilon, self.epsilon
)
logger.debug(f"SAXSLogProcessor global_min = {self._global_min:.6g}")
# Pass 2: apply log transform using global_min (delegates to base class loop)
return super().process_array_streaming(data, output_dtype)
[docs]
def process_chunk(self, chunk: np.ndarray, chunk_info: ChunkInfo) -> np.ndarray:
"""
Process SAXS data chunk with logarithmic transformation.
Uses the global minimum positive value (set by ``process_array_streaming``
pass 1) as the replacement for non-positive pixels so that all chunks
use a consistent floor value.
Parameters
----------
chunk : np.ndarray
Input SAXS data chunk
chunk_info : ChunkInfo
Information about the chunk
Returns
-------
np.ndarray
Log-transformed chunk
"""
saxs_chunk = chunk.astype(np.float32)
positive_mask = saxs_chunk > 0
saxs_chunk[~positive_mask] = self._global_min
log_chunk = np.log10(saxs_chunk)
logger.debug(
f"Processed chunk {chunk_info.index + 1}/{chunk_info.total_chunks}: "
f"{chunk_info.shape} -> log transform"
)
return log_chunk
[docs]
def combine_chunks(
self, processed_chunks: list, original_shape: tuple[int, ...]
) -> np.ndarray:
"""
Combine processed log chunks into final result array.
Parameters
----------
processed_chunks : list
List of (slice_obj, processed_chunk) tuples
original_shape : tuple[int, ...]
Shape of original array
Returns
-------
np.ndarray
Combined log-transformed array
"""
# Create output array
result = np.zeros(original_shape, dtype=np.float32)
# Fill in processed chunks
for slice_obj, processed_chunk in processed_chunks:
result[slice_obj] = processed_chunk
return result
[docs]
class ROIProcessor(StreamingProcessor):
"""Streaming processor for ROI (Region of Interest) calculations."""
[docs]
def __init__(self, roi_mask: np.ndarray, chunk_size_mb: float = 50.0):
super().__init__(chunk_size_mb)
self.roi_mask = roi_mask
[docs]
def process_chunk(self, chunk: np.ndarray, chunk_info: ChunkInfo) -> dict:
"""
Process chunk for ROI calculations.
Parameters
----------
chunk : np.ndarray
Input data chunk
chunk_info : ChunkInfo
Information about the chunk
Returns
-------
dict
ROI statistics for the chunk
"""
# Apply ROI mask to chunk
roi_slice = self.roi_mask[chunk_info.slice_obj[1:]] # Skip time dimension
roi_data = chunk * roi_slice[np.newaxis, ...] # Broadcast for time dimension
# Calculate ROI statistics
roi_sum = np.sum(roi_data, axis=(1, 2)) # Sum over spatial dimensions
roi_mean = np.mean(roi_data, axis=(1, 2))
roi_std = np.std(roi_data, axis=(1, 2))
return {
"sum": roi_sum,
"mean": roi_mean,
"std": roi_std,
"slice_obj": chunk_info.slice_obj,
}
[docs]
def combine_chunks(
self, processed_chunks: list, original_shape: tuple[int, ...]
) -> dict:
"""
Combine ROI chunk results into final statistics.
Parameters
----------
processed_chunks : list
List of (slice_obj, roi_stats) tuples
original_shape : tuple[int, ...]
Shape of original array
Returns
-------
dict
Combined ROI statistics
"""
# Initialize result arrays
total_time_points = original_shape[0]
roi_sum = np.zeros(total_time_points)
roi_mean = np.zeros(total_time_points)
roi_std = np.zeros(total_time_points)
# Combine chunk results
for slice_obj, roi_stats in processed_chunks:
time_slice = slice_obj[0]
roi_sum[time_slice] = roi_stats["sum"]
roi_mean[time_slice] = roi_stats["mean"]
roi_std[time_slice] = roi_stats["std"]
return {"roi_sum": roi_sum, "roi_mean": roi_mean, "roi_std": roi_std}
[docs]
class AdaptiveChunkSizer:
"""
Adaptive chunk sizing based on memory pressure and system performance.
"""
[docs]
def __init__(self, base_chunk_size_mb: float = 50.0):
self.base_chunk_size_mb = base_chunk_size_mb
self.memory_manager = get_memory_manager()
self.performance_history: list[float] = []
[docs]
def get_optimal_chunk_size(
self, data_shape: tuple[int, ...], data_dtype: np.dtype
) -> float:
"""
Calculate optimal chunk size based on current conditions.
Parameters
----------
data_shape : tuple[int, ...]
Shape of data to be processed
data_dtype : np.dtype
Data type of array
Returns
-------
float
Optimal chunk size in MB
"""
# Base chunk size
chunk_size = self.base_chunk_size_mb
# Adjust based on memory pressure
pressure = self.memory_manager.get_memory_pressure()
if pressure == MemoryPressure.CRITICAL:
chunk_size *= 0.25 # Very small chunks
elif pressure == MemoryPressure.HIGH:
chunk_size *= 0.5 # Small chunks
elif pressure == MemoryPressure.MODERATE:
chunk_size *= 0.75 # Moderately small chunks
# LOW pressure uses base size
# Adjust based on total data size
total_size_mb = np.prod(data_shape) * data_dtype.itemsize / (1024 * 1024)
if total_size_mb < 100: # Small dataset
chunk_size = min(chunk_size, total_size_mb) # Don't over-chunk
elif total_size_mb > 1000: # Very large dataset
chunk_size = min(chunk_size, 100.0) # Cap chunk size
# Ensure minimum chunk size
chunk_size = max(chunk_size, 10.0)
logger.debug(
f"Adaptive chunk size: {chunk_size:.1f}MB "
f"(pressure: {pressure.value}, data: {total_size_mb:.1f}MB)"
)
return chunk_size
[docs]
def process_saxs_log_streaming(
data: np.ndarray, chunk_size_mb: float | None = None, epsilon: float = 1e-10
) -> np.ndarray:
"""
Convenience function for streaming SAXS log processing.
Parameters
----------
data : np.ndarray
Input SAXS data
chunk_size_mb : float, optional
Chunk size in MB (auto-calculated if None)
epsilon : float
Small value for handling non-positive data
Returns
-------
np.ndarray
Log-transformed SAXS data
"""
if chunk_size_mb is None:
# Use adaptive chunk sizing
sizer = AdaptiveChunkSizer()
chunk_size_mb = sizer.get_optimal_chunk_size(data.shape, data.dtype)
processor = SAXSLogProcessor(chunk_size_mb=chunk_size_mb, epsilon=epsilon)
return processor.process_array_streaming(data)
[docs]
def process_roi_streaming(
data: np.ndarray, roi_mask: np.ndarray, chunk_size_mb: float | None = None
) -> dict:
"""
Convenience function for streaming ROI processing.
Parameters
----------
data : np.ndarray
Input data array
roi_mask : np.ndarray
ROI mask array
chunk_size_mb : float, optional
Chunk size in MB (auto-calculated if None)
Returns
-------
dict
ROI statistics
"""
if chunk_size_mb is None:
# Use adaptive chunk sizing
sizer = AdaptiveChunkSizer()
chunk_size_mb = sizer.get_optimal_chunk_size(data.shape, data.dtype)
processor = ROIProcessor(roi_mask=roi_mask, chunk_size_mb=chunk_size_mb)
# Calculate optimal chunk slices
chunk_slices = calculate_chunk_slices(data.shape, data.dtype, chunk_size_mb)
processed_chunks = []
for i, slice_obj in enumerate(chunk_slices):
chunk_shape = tuple(
slice_obj[dim].stop - slice_obj[dim].start for dim in range(len(slice_obj))
)
chunk_info = ChunkInfo(
index=i,
slice_obj=slice_obj,
shape=chunk_shape,
estimated_size_mb=np.prod(chunk_shape)
* data.dtype.itemsize
/ (1024 * 1024),
total_chunks=len(chunk_slices),
)
chunk = data[slice_obj]
roi_stats = processor.process_chunk(chunk, chunk_info)
processed_chunks.append((slice_obj, roi_stats))
return processor.combine_chunks(processed_chunks, data.shape)
[docs]
class MemoryEfficientIterator:
"""
Memory-efficient iterator for large arrays with automatic cleanup.
"""
[docs]
def __init__(self, data: np.ndarray, chunk_size_mb: float = 50.0):
self.data = data
self.chunk_size_mb = chunk_size_mb
self.memory_manager = get_memory_manager()
# Calculate chunk slices
# Calculate chunk slices
self.chunk_slices = calculate_chunk_slices(
data.shape, data.dtype, chunk_size_mb
)
self.current_index = 0
def __iter__(self):
return self
def __next__(self):
if self.current_index >= len(self.chunk_slices):
raise StopIteration
# Check memory pressure before yielding chunk
pressure = self.memory_manager.get_memory_pressure()
if pressure in [MemoryPressure.HIGH, MemoryPressure.CRITICAL]:
logger.warning("Memory pressure during iteration, triggering cleanup")
self.memory_manager._aggressive_cleanup()
slice_obj = self.chunk_slices[self.current_index]
chunk = self.data[slice_obj]
chunk_info = ChunkInfo(
index=self.current_index,
slice_obj=slice_obj,
shape=chunk.shape,
estimated_size_mb=chunk.nbytes / (1024 * 1024),
total_chunks=len(self.chunk_slices),
)
self.current_index += 1
return chunk, chunk_info
def __len__(self):
return len(self.chunk_slices)