Source code for xpcsviewer.module.twotime_utils

from __future__ import annotations

# Standard library imports
import multiprocessing as mp
from collections.abc import Callable
from concurrent.futures import ProcessPoolExecutor
from functools import lru_cache
from multiprocessing import Pool

# Third-party imports
import h5py
import numpy as np

from xpcsviewer.utils.logging_config import get_logger

# Local imports
from ..fileIO.aps_8idi import key as key_map
from ..fileIO.hdf_reader import _connection_pool

key_map_nexus = key_map["nexus"]
logger = get_logger(__name__)

# Global process pool for reuse
_process_pool: ProcessPoolExecutor | None = None
_pool_size: int | None = None
_shared_arrays = {}


[docs] def get_optimal_worker_count() -> int: """Get optimal number of worker processes based on system resources.""" cpu_count = mp.cpu_count() # Use 75% of available CPUs, minimum 1, maximum 16 optimal = max(1, min(16, int(cpu_count * 0.75))) logger.debug(f"Optimal worker count: {optimal} (CPU count: {cpu_count})") return optimal
[docs] def get_process_pool(num_workers: int | None = None) -> ProcessPoolExecutor: """ Get or create a reusable process pool. Args: num_workers: Number of worker processes. If None, uses optimal count. Returns: ProcessPoolExecutor instance """ global _process_pool, _pool_size # noqa: PLW0603 - intentional for process pool reuse if num_workers is None: num_workers = get_optimal_worker_count() # Create new pool if none exists or size changed if _process_pool is None or _pool_size != num_workers: if _process_pool is not None: _process_pool.shutdown(wait=False) _process_pool = ProcessPoolExecutor(max_workers=num_workers) _pool_size = num_workers logger.info(f"Created process pool with {num_workers} workers") return _process_pool
[docs] def shutdown_process_pool(): """Shutdown the global process pool.""" global _process_pool, _pool_size # noqa: PLW0603 - intentional for process pool cleanup if _process_pool is not None: _process_pool.shutdown(wait=True) _process_pool = None _pool_size = None logger.info("Process pool shutdown")
[docs] def create_shared_array( name: str, shape: tuple[int, ...], dtype=np.float32 ) -> np.ndarray: """ Create a shared memory array for multiprocessing. Args: name: Unique name for the array shape: Shape of the array dtype: Data type Returns: Numpy array backed by shared memory """ try: from multiprocessing import shared_memory # Calculate size size = int(np.prod(shape)) * np.dtype(dtype).itemsize # Create shared memory shm = shared_memory.SharedMemory(create=True, size=size, name=name) # Create numpy array array = np.ndarray(shape, dtype=dtype, buffer=shm.buf) # Store reference to prevent garbage collection _shared_arrays[name] = (shm, array) logger.debug(f"Created shared array '{name}' with shape {shape}") return array except ImportError: # Fallback to regular array if shared_memory not available logger.warning("shared_memory not available, using regular arrays") return np.empty(shape, dtype=dtype)
[docs] def get_shared_array(name: str) -> np.ndarray | None: """Get a previously created shared array by name.""" if name in _shared_arrays: return _shared_arrays[name][1] return None
[docs] def cleanup_shared_arrays(): """Clean up all shared memory arrays.""" for name, (shm, _array) in _shared_arrays.items(): try: shm.close() shm.unlink() logger.debug(f"Cleaned up shared array '{name}'") except Exception as e: logger.warning(f"Error cleaning up shared array '{name}': {e}") _shared_arrays.clear()
[docs] def read_single_c2_enhanced(args: tuple) -> tuple[np.ndarray | None, int, str]: """ Vectorized version of read_single_c2 with optimized matrix operations. Args: args: Tuple of (full_path, index_str, max_size, correct_diag, progress_callback) Returns: Tuple of (c2_matrix, sampling_rate, status_message) """ full_path, index_str, max_size, correct_diag = args[:4] progress_callback = args[4] if len(args) > 4 else None try: c2_prefix = key_map_nexus["c2_prefix"] if progress_callback: progress_callback(f"Reading {index_str}") with h5py.File(full_path, "r") as f: if f"{c2_prefix}/{index_str}" not in f: return None, 0, f"Dataset {index_str} not found" c2_half = f[f"{c2_prefix}/{index_str}"][()] # Vectorized C2 matrix reconstruction c2 = _reconstruct_c2_matrix_vectorized(c2_half) # Vectorized sampling if needed sampling_rate = 1 if max_size > 0 and max_size < c2.shape[0]: sampling_rate = (c2.shape[0] + max_size - 1) // max_size c2 = c2[::sampling_rate, ::sampling_rate] # Apply diagonal correction if requested if correct_diag: c2 = correct_diagonal_c2_vectorized(c2) if progress_callback: progress_callback(f"Completed {index_str}") return c2, sampling_rate, "success" except Exception as e: error_msg = f"Error reading {index_str}: {e!s}" logger.error(error_msg) return None, 0, error_msg
[docs] def process_c2_batch( batch_args: list[tuple], progress_callback: Callable | None = None ) -> list[tuple]: """ Optimized batch processing of C2 matrices with vectorized operations. Args: batch_args: List of arguments for read_single_c2_enhanced progress_callback: Optional progress callback Returns: List of (c2_matrix, sampling_rate, status) tuples """ results = [] # Dynamic threshold based on system resources optimal_threshold = max(4, get_optimal_worker_count() // 2) if len(batch_args) < optimal_threshold: # Use sequential processing for small batches with vectorized operations for i, args in enumerate(batch_args): if progress_callback: progress_callback(i, len(batch_args), f"Processing {args[1]}") result = read_single_c2_enhanced(args) results.append(result) else: # Use parallel processing for larger batches pool = get_process_pool() # Batch submit for better resource utilization futures = [] for args in batch_args: future = pool.submit(read_single_c2_enhanced, args) futures.append((future, args)) # Collect results as they complete with better error handling for completed_count, (future, args) in enumerate(futures, start=1): if progress_callback: progress_callback( completed_count, len(batch_args), f"Completed {args[1]}" ) try: result = future.result(timeout=300) # 5-minute timeout per operation results.append(result) except Exception as e: logger.error(f"Error in parallel C2 processing for {args[1]}: {e}") results.append((None, 0, str(e))) return results
[docs] def correct_diagonal_c2_vectorized(c2_mat): """ Vectorized diagonal correction for C2 matrices. Optimized to eliminate loops and use advanced NumPy operations. """ size = c2_mat.shape[0] if size < 2: return c2_mat # Extract side bands using advanced indexing - more cache-friendly upper_diag = np.diag(c2_mat, k=1) # Upper diagonal lower_diag = np.diag(c2_mat, k=-1) # Lower diagonal # Vectorized diagonal value computation diag_val = np.zeros(size, dtype=c2_mat.dtype) diag_val[:-1] += upper_diag diag_val[1:] += lower_diag # Vectorized normalization - avoid creating full norm array diag_val[0] /= 1.0 # Edge case diag_val[-1] /= 1.0 # Edge case if size > 2: diag_val[1:-1] /= 2.0 # Interior points # In-place diagonal assignment for better memory efficiency np.fill_diagonal(c2_mat, diag_val) return c2_mat
def _reconstruct_c2_matrix_vectorized(c2_half): """ Vectorized C2 matrix reconstruction from half matrix. Optimized for memory efficiency and cache performance. """ # Use optimized transpose and addition c2 = c2_half + c2_half.T # Vectorized diagonal correction - avoid creating index arrays diag_vals = np.diag(c2_half) np.fill_diagonal(c2, diag_vals) # Diagonal should not be doubled return c2
[docs] def read_single_c2(args): """ Optimized single C2 reading with vectorized matrix operations. """ if len(args) == 4: # Legacy mode: open file each time full_path, index_str, max_size, correct_diag = args c2_prefix = key_map_nexus["c2_prefix"] with _connection_pool.get_connection(full_path, "r") as f: c2_half = f[f"{c2_prefix}/{index_str}"][()] c2 = _reconstruct_c2_matrix_vectorized(c2_half) sampling_rate = 1 if max_size > 0 and max_size < c2.shape[0]: sampling_rate = (c2.shape[0] + max_size - 1) // max_size c2 = c2[::sampling_rate, ::sampling_rate] else: # Optimized mode: reuse file handle # Handle both 5+ args (new) and other arg counts f, index_str, max_size, correct_diag = args[:4] c2_prefix = key_map_nexus["c2_prefix"] c2_half = f[f"{c2_prefix}/{index_str}"][()] c2 = _reconstruct_c2_matrix_vectorized(c2_half) sampling_rate = 1 if max_size > 0 and max_size < c2.shape[0]: sampling_rate = (c2.shape[0] + max_size - 1) // max_size c2 = c2[::sampling_rate, ::sampling_rate] if correct_diag: c2 = correct_diagonal_c2_vectorized(c2) return c2, sampling_rate
[docs] @lru_cache(maxsize=16) def get_all_c2_from_hdf( full_path, dq_selection=None, max_c2_num=32, max_size=512, num_workers=12, correct_diag=True, ): # t0 = time.perf_counter() idx_toload = [] c2_prefix = key_map_nexus["c2_prefix"] # Read the index list and close the parent connection *before* spawning # Pool workers. HDF5 file handles are not fork-safe: keeping the parent # connection open while child processes try to open the same file causes # corruption or hangs on some platforms (SRE-7). with _connection_pool.get_connection(full_path, "r") as f: if c2_prefix not in f: return None idxlist = list(f[c2_prefix]) for idx in idxlist: if dq_selection is not None and int(idx[4:]) not in dq_selection: continue idx_toload.append(idx) if max_c2_num > 0 and len(idx_toload) > max_c2_num: break # Parent HDF5 connection is now closed (SRE-7). if len(idx_toload) >= 6: # Use multiprocessing with legacy approach (file paths). # The parent connection is already closed above so workers can # safely open their own handles to the same file. args_list = [(full_path, index, max_size, correct_diag) for index in idx_toload] with Pool(min(len(idx_toload), num_workers)) as p: result = p.map(read_single_c2, args_list) else: # Use single thread with a fresh connection for optimization with _connection_pool.get_connection(full_path, "r") as f: args_list = [(f, index, max_size, correct_diag) for index in idx_toload] result = [read_single_c2(args) for args in args_list] c2_all = np.array([res[0] for res in result]) sampling_rate_all = {res[1] for res in result} if len(sampling_rate_all) != 1: # Inconsistent rates across C2 matrices — use the minimum (finest # resolution) and warn rather than crashing with AssertionError. logger.warning( "Sampling rates are not consistent across C2 matrices: " f"{sampling_rate_all}. Using minimum rate as fallback." ) sampling_rate = min(sampling_rate_all) c2_result = { "c2_all": c2_all, "delta_t": 1.0 * sampling_rate, # put absolute time in xpcs_file "acquire_period": 1.0, "dq_selection": dq_selection, } return c2_result
[docs] def get_all_c2_from_hdf_enhanced( full_path, dq_selection=None, max_c2_num=32, max_size=512, num_workers=None, correct_diag=True, progress_callback=None, ): """ Enhanced version of get_all_c2_from_hdf with better multiprocessing and progress reporting. Args: full_path: Path to HDF5 file dq_selection: List of q-indices to load max_c2_num: Maximum number of C2 matrices to load max_size: Maximum size for C2 matrices (for downsampling) num_workers: Number of worker processes (auto-determined if None) correct_diag: Whether to apply diagonal correction progress_callback: Optional callback for progress updates Returns: Dictionary containing C2 data and metadata """ if progress_callback: progress_callback(0, 100, "Scanning HDF5 file structure") idx_toload = [] c2_prefix = key_map_nexus["c2_prefix"] try: with h5py.File(full_path, "r") as f: if c2_prefix not in f: logger.error(f"C2 prefix '{c2_prefix}' not found in {full_path}") return None idxlist = list(f[c2_prefix]) logger.info(f"Found {len(idxlist)} C2 datasets in {full_path}") for idx in idxlist: if dq_selection is not None and int(idx[4:]) not in dq_selection: continue idx_toload.append(idx) if max_c2_num > 0 and len(idx_toload) >= max_c2_num: break if progress_callback: progress_callback(10, 100, f"Selected {len(idx_toload)} datasets to load") except Exception as e: logger.error(f"Error scanning HDF5 file {full_path}: {e}") return None if not idx_toload: logger.warning(f"No C2 datasets selected for loading from {full_path}") return None # Prepare arguments for parallel processing args_list = [(full_path, index, max_size, correct_diag) for index in idx_toload] if progress_callback: progress_callback(20, 100, "Starting C2 matrix processing") # Process with enhanced parallel function try: def batch_progress(current, total, message): if progress_callback: # Scale progress from 20% to 90% progress_pct = 20 + int((current / total) * 70) progress_callback(progress_pct, 100, message) results = process_c2_batch(args_list, batch_progress) if progress_callback: progress_callback(90, 100, "Assembling results") # Filter out failed results and extract data successful_results = [ (c2, sr) for c2, sr, status in results if c2 is not None and status == "success" ] if not successful_results: logger.error(f"No C2 matrices successfully loaded from {full_path}") return None failed_count = len(results) - len(successful_results) if failed_count > 0: logger.warning(f"{failed_count} C2 matrices failed to load") # Extract C2 matrices and sampling rates c2_matrices = [res[0] for res in successful_results] sampling_rates = [res[1] for res in successful_results] # Verify sampling rate consistency unique_rates = set(sampling_rates) if len(unique_rates) != 1: logger.warning(f"Inconsistent sampling rates found: {unique_rates}") # Use the most common sampling rate from collections import Counter rate_counts = Counter(sampling_rates) sampling_rate = rate_counts.most_common(1)[0][0] logger.info(f"Using most common sampling rate: {sampling_rate}") else: sampling_rate = next(iter(unique_rates)) # Convert to numpy array c2_all = np.array(c2_matrices) if progress_callback: progress_callback(100, 100, "C2 loading complete") c2_result = { "c2_all": c2_all, "delta_t": 1.0 * sampling_rate, "acquire_period": 1.0, "dq_selection": dq_selection, "loaded_indices": [ idx_toload[i] for i, (_, _, status) in enumerate(results) if status == "success" ], "failed_count": failed_count, "total_requested": len(idx_toload), } logger.info( f"Successfully loaded {len(c2_matrices)} C2 matrices from {full_path}" ) return c2_result except Exception as e: logger.error(f"Error processing C2 matrices from {full_path}: {e}") return None
[docs] @lru_cache(maxsize=16) def get_single_c2_from_hdf( full_path, selection=0, max_size=512, t0=1, correct_diag=True ): c2_prefix = key_map_nexus["c2_prefix"] # Use connection pool and batch operations with _connection_pool.get_connection(full_path, "r") as f: if c2_prefix not in f: return None idxstr = list(f[c2_prefix])[selection] # Read C2 data using shared file handle (use 5 args to avoid legacy mode) c2_mat, sampling_rate = read_single_c2( (f, idxstr, max_size, correct_diag, True) ) # Read G2 partials in the same file operation g2_full_key = key_map_nexus["c2_g2"] # Dataset {5000, 25} g2_partial_key = key_map_nexus["c2_g2_segments"] # Dataset {1000, 5, 25} g2_full = f[g2_full_key][()] g2_partial = f[g2_partial_key][()] g2_full = np.swapaxes(g2_full, 0, 1) g2_partial = np.swapaxes(g2_partial, 0, 2) c2_result = { "c2_mat": c2_mat, "delta_t": t0 * sampling_rate, # put absolute time in xpcs_file "acquire_period": t0, "dq_selection": selection, "g2_full": g2_full[selection], "g2_partial": g2_partial[selection], } return c2_result
[docs] @lru_cache(maxsize=16) def get_c2_g2partials_from_hdf(full_path): # t0 = time.perf_counter() c2_prefix = key_map_nexus["c2_prefix"] g2_full_key = key_map_nexus["c2_g2"] # Dataset {5000, 25} g2_partial_key = key_map_nexus["c2_g2_segments"] # Dataset {1000, 5, 25} # Use connection pool for optimization with _connection_pool.get_connection(full_path, "r") as f: if c2_prefix not in f: return None g2_full = f[g2_full_key][()] g2_partial = f[g2_partial_key][()] g2_full = np.swapaxes(g2_full, 0, 1) g2_partial = np.swapaxes(g2_partial, 0, 2) g2_partials = { "g2_full": g2_full, "g2_partial": g2_partial, } return g2_partials
[docs] def get_c2_stream(full_path, max_size=-1): """Returns (idxlist, generator) where the generator yields C2 streams.""" c2_prefix = key_map_nexus["c2_prefix"] # Use connection pool for reading the index list with _connection_pool.get_connection(full_path, "r") as f: idxlist = list(f[c2_prefix]) if c2_prefix in f else [] def generator(): for idx in idxlist: # Use idxlist for iteration c2, _sampling_rate = read_single_c2((full_path, idx, max_size, False)) yield int(idx[4:]), c2 return idxlist, generator()
[docs] def batch_c2_matrix_operations(c2_matrices, operations=None): """ Vectorized batch operations on multiple C2 matrices. Args: c2_matrices: List or array of C2 matrices operations: List of operations to apply ('normalize', 'symmetrize', 'diagonal_correct') Returns: Processed C2 matrices """ if operations is None: operations = ["symmetrize", "diagonal_correct"] # Convert to numpy array for vectorized operations c2_array = np.array(c2_matrices) if "normalize" in operations: # Vectorized normalization across all matrices norms = np.linalg.norm(c2_array.reshape(c2_array.shape[0], -1), axis=1) c2_array = c2_array / norms.reshape(-1, 1, 1) if "symmetrize" in operations: # Vectorized symmetrization c2_array = 0.5 * (c2_array + np.swapaxes(c2_array, -2, -1)) if "diagonal_correct" in operations: # Fully vectorized batch diagonal correction without Python loop size = c2_array.shape[-1] if size >= 2: # Extract upper and lower diagonals for all matrices at once # np.diagonal with axis1=-2, axis2=-1 works on batches upper_diag = np.diagonal(c2_array, offset=1, axis1=-2, axis2=-1) lower_diag = np.diagonal(c2_array, offset=-1, axis1=-2, axis2=-1) # Compute corrected diagonal values: [batch, size] diag_val = np.zeros((c2_array.shape[0], size), dtype=c2_array.dtype) diag_val[:, :-1] += upper_diag diag_val[:, 1:] += lower_diag # Interior points are average of upper and lower neighbors if size > 2: diag_val[:, 1:-1] /= 2.0 # Batch diagonal assignment using advanced indexing idx = np.arange(size) c2_array[:, idx, idx] = diag_val return c2_array
[docs] def compute_c2_statistics_vectorized(c2_matrices): """ Compute statistical measures for C2 matrices using vectorized operations. Uses a JAX-accelerated path when the matrix size is >= 1024, where the GPU transfer overhead is amortised by the larger workload (8.8× measured speedup at 2048×2048). Smaller matrices use the NumPy path to avoid the 0.7× slowdown seen at 512×512. Args: c2_matrices: Array of C2 matrices [batch, height, width] Returns: Dictionary with statistical measures """ c2_array = np.array(c2_matrices) n = c2_array.shape[-1] # Matrix size # JAX-accelerated path for large matrices (>= 1024×1024). # Gated on matrix size: JAX is ~0.7× slower at 512×512 due to # host↔device transfer overhead but 8.8× faster at 2048×2048. # Only safe on the main process — never called from Pool workers. if n >= 1024: try: import jax.numpy as jnp c2_jax = jnp.asarray(c2_array) diag_jax = jnp.diagonal(c2_jax, axis1=-2, axis2=-1) # [batch, n] trace_jax = jnp.sum(diag_jax, axis=-1) # [batch] total_sum_jax = jnp.sum(c2_jax, axis=(-2, -1)) off_diagonal_sum_jax = total_sum_jax - trace_jax off_diagonal_count = n * (n - 1) stats_jax = { "mean": np.asarray(jnp.mean(c2_jax, axis=0)), "std": np.asarray(jnp.std(c2_jax, axis=0)), "median": np.median(c2_array, axis=0), # no JAX median; use numpy "min": np.asarray(jnp.min(c2_jax, axis=0)), "max": np.asarray(jnp.max(c2_jax, axis=0)), "trace": np.asarray(trace_jax), "diagonal_mean": np.asarray(jnp.mean(diag_jax, axis=-1)), "off_diagonal_mean": ( np.asarray(off_diagonal_sum_jax / off_diagonal_count) if off_diagonal_count > 0 else np.full(c2_array.shape[0], np.nan) ), } return stats_jax except Exception: # Fall through to NumPy path on any JAX error (no GPU, import error, etc.) pass # NumPy path — always used for n < 1024, fallback for larger matrices. stats = { "mean": np.mean(c2_array, axis=0), "std": np.std(c2_array, axis=0), "median": np.median(c2_array, axis=0), "min": np.min(c2_array, axis=0), "max": np.max(c2_array, axis=0), "trace": np.trace(c2_array, axis1=-2, axis2=-1), # Batch trace computation "diagonal_mean": np.mean(np.diagonal(c2_array, axis1=-2, axis2=-1), axis=-1), } # Fully vectorized off-diagonal mean computation (OPT-010) # total_sum - diagonal_sum = off_diagonal_sum # off_diagonal_count = n*n - n = n*(n-1) total_sum = np.sum(c2_array, axis=(-2, -1)) # Sum over each matrix diagonal_sum = stats["trace"] # Already computed off_diagonal_sum = total_sum - diagonal_sum off_diagonal_count = n * (n - 1) # Guard against n==1 (1x1 matrix has no off-diagonal elements) if off_diagonal_count > 0: stats["off_diagonal_mean"] = off_diagonal_sum / off_diagonal_count else: stats["off_diagonal_mean"] = np.full_like(off_diagonal_sum, np.nan) return stats
[docs] def optimized_c2_sampling(c2_matrix, target_size, method="bilinear"): """ Optimized C2 matrix downsampling using vectorized operations. Args: c2_matrix: Input C2 matrix target_size: Target matrix size method: Sampling method ('uniform', 'bilinear', 'adaptive') Returns: Downsampled C2 matrix """ current_size = c2_matrix.shape[0] if current_size <= target_size: return c2_matrix if method == "uniform": # Simple uniform sampling step = current_size // target_size return c2_matrix[::step, ::step] if method == "bilinear": # Bilinear interpolation using JAX-compatible backend from xpcsviewer.backends.scipy_replacements import zoom zoom_factor = target_size / current_size return zoom(c2_matrix, zoom_factor, order=1) if method == "adaptive": # Adaptive sampling preserving important features # Use block averaging for better feature preservation block_size = current_size // target_size # Vectorized block averaging trimmed_size = target_size * block_size trimmed_matrix = c2_matrix[:trimmed_size, :trimmed_size] # Reshape and compute block means reshaped = trimmed_matrix.reshape( target_size, block_size, target_size, block_size ) return np.mean(reshaped, axis=(1, 3)) return c2_matrix