Source code for xpcsviewer.fileIO.hdf_reader

"""HDF5 file reading utilities with connection pooling.

Provides optimized HDF5 file reading with connection pooling, batch operations,
and automatic resource management for efficient data access.

Classes:
    HDF5ConnectionPool: Thread-safe connection pool for HDF5 files
    HDF5Reader: High-level reader with caching and batch operations
"""

# Standard library imports
import os
import threading
import time
from collections import OrderedDict
from contextlib import contextmanager, suppress
from typing import Any

# Third-party imports
import h5py
import numpy as np

from xpcsviewer.constants import (
    CACHE_ENTRY_TIMEOUT,
    DIRECT_READ_LIMIT_MB,
    LARGE_DATASET_THRESHOLD_MB,
    NDIM_2D,
    NDIM_3D,
)
from xpcsviewer.utils.log_utils import log_timing
from xpcsviewer.utils.logging_config import get_logger
from xpcsviewer.xpcs_file.memory import _get_virtual_memory

# Local imports
from .aps_8idi import key as hdf_key

logger = get_logger(__name__)

# Performance monitoring stub for testing
_perf_monitor = None

# Memory pressure thresholds for connection pool management
MEMORY_PRESSURE_CRITICAL = 0.90  # 90% - Critical memory pressure
MEMORY_PRESSURE_HIGH = 0.80  # 80% - High memory pressure
MEMORY_PRESSURE_MODERATE = 0.70  # 70% - Moderate memory pressure


[docs] class PooledConnection: """Wrapper for pooled HDF5 connections with metadata."""
[docs] def __init__(self, file_handle: h5py.File, file_path: str): self.file_handle = file_handle self.file_path = file_path self.created_at = time.time() self.last_accessed = self.created_at self.access_count = 0 self.is_healthy = True self.lock = threading.RLock()
[docs] def touch(self) -> None: """Update access time and count.""" self.last_accessed = time.time() self.access_count += 1
[docs] def check_health(self) -> bool: """Check if connection is still healthy.""" try: with self.lock: # Try to access a basic property _ = self.file_handle.filename # Check if file still exists if not os.path.exists(self.file_path): self.is_healthy = False return False self.is_healthy = True return True except (ValueError, OSError, AttributeError) as e: logger.debug(f"Connection health check failed for {self.file_path}: {e}") self.is_healthy = False return False
[docs] def close(self) -> None: """Close the connection safely.""" try: with self.lock: if hasattr(self.file_handle, "close"): self.file_handle.close() except Exception as e: logger.debug(f"Error closing connection for {self.file_path}: {e}") finally: self.is_healthy = False
[docs] class ConnectionStats: """Statistics tracker for HDF5 connections."""
[docs] def __init__(self) -> None: self.total_connections_created = 0 self.total_connections_reused = 0 self.total_connections_evicted = 0 self.total_health_checks = 0 self.failed_health_checks = 0 self.cache_hits = 0 self.cache_misses = 0 self.io_time_seconds = 0.0 self.start_time = time.time() self._lock = threading.RLock()
[docs] def record_connection_created(self) -> None: with self._lock: self.total_connections_created += 1
[docs] def record_connection_reused(self) -> None: with self._lock: self.total_connections_reused += 1 self.cache_hits += 1
[docs] def record_connection_evicted(self) -> None: with self._lock: self.total_connections_evicted += 1
[docs] def record_health_check(self, success: bool): with self._lock: self.total_health_checks += 1 if not success: self.failed_health_checks += 1
[docs] def record_cache_miss(self) -> None: with self._lock: self.cache_misses += 1
[docs] def record_io_time(self, duration: float): with self._lock: self.io_time_seconds += duration
[docs] def get_stats(self) -> dict[str, Any]: with self._lock: uptime = time.time() - self.start_time total_operations = self.cache_hits + self.cache_misses return { "uptime_seconds": uptime, "total_connections_created": self.total_connections_created, "total_connections_reused": self.total_connections_reused, "total_connections_evicted": self.total_connections_evicted, "cache_hit_ratio": self.cache_hits / total_operations if total_operations > 0 else 0.0, "cache_hits": self.cache_hits, "cache_misses": self.cache_misses, "health_checks_performed": self.total_health_checks, "health_check_failure_rate": self.failed_health_checks / self.total_health_checks if self.total_health_checks > 0 else 0.0, "total_io_time_seconds": self.io_time_seconds, "average_io_time_ms": (self.io_time_seconds / total_operations * 1000) if total_operations > 0 else 0.0, }
[docs] class HDF5ConnectionPool: """ Enhanced connection pool for HDF5 files with LRU eviction, health monitoring, and comprehensive I/O performance tracking. Features: - LRU eviction policy with configurable pool size - Connection health monitoring with automatic cleanup - Performance statistics and monitoring - Thread-safe operations with fine-grained locking - Memory pressure detection and adaptive sizing - Batch read operations optimization """
[docs] def __init__( self, max_pool_size: int = 20, health_check_interval: float = 300.0, enable_memory_pressure_adaptation: bool = True, ): self.max_pool_size = max_pool_size self.base_pool_size = max_pool_size self.health_check_interval = health_check_interval self.enable_memory_pressure_adaptation = enable_memory_pressure_adaptation # LRU cache for connections self._pool: OrderedDict[str, PooledConnection] = OrderedDict() self._pool_lock = threading.RLock() self._file_locks: dict[str, threading.RLock] = {} self._file_locks_lock = threading.RLock() # Track last-access time per file lock to prune stale entries (BUG-032) self._file_lock_last_access: dict[str, float] = {} self._file_lock_max_age_seconds: float = 3600.0 # 1 hour stale threshold # Statistics tracking self.stats = ConnectionStats() # Health monitoring self._last_health_check = 0.0 self._unhealthy_files: set[str] = set() # Throttle for memory pressure adaptation (avoid psutil syscall on every request) self._last_pressure_adaptation = 0.0 self._pressure_adaptation_interval = 30.0 # seconds # Read-ahead cache for batch operations self._read_cache: dict[str, dict[str, Any]] = {} self._read_cache_lock = threading.RLock() self._read_cache_max_size = 50 # Maximum cached read results per file logger.info( f"HDF5ConnectionPool initialized: max_size={max_pool_size}, health_check_interval={health_check_interval}s" )
def _get_file_lock(self, fname: str) -> threading.RLock: """Get or create a lock for a specific file. Tracks last-access time per entry and prunes stale locks on every new creation to prevent the lock dict from growing without bound under long-running contention. (BUG-032) """ _MAX_FILE_LOCKS = 256 now = time.time() with self._file_locks_lock: if fname not in self._file_locks: self._file_locks[fname] = threading.RLock() self._file_lock_last_access[fname] = now # Prune stale entries first (age-based), then fall back to # size-based pruning if needed. stale_keys = [ k for k, ts in list(self._file_lock_last_access.items()) if k != fname and (now - ts) > self._file_lock_max_age_seconds and self._file_locks.get(k) is not None and self._file_locks[k].acquire(blocking=False) and (self._file_locks[k].release() is None) # type: ignore[func-returns-value] ] for key in stale_keys: self._file_locks.pop(key, None) self._file_lock_last_access.pop(key, None) # Size-based fallback pruning — only remove locks that are # not currently held to avoid corrupting active critical sections. if len(self._file_locks) > _MAX_FILE_LOCKS: keys_to_remove = [] for key, lock in list(self._file_locks.items()): if key == fname: continue # Try non-blocking acquire: if it succeeds the lock # is free and safe to remove. if lock.acquire(blocking=False): lock.release() keys_to_remove.append(key) if ( len(self._file_locks) - len(keys_to_remove) <= _MAX_FILE_LOCKS // 2 ): break for key in keys_to_remove: self._file_locks.pop(key, None) self._file_lock_last_access.pop(key, None) else: # Update last-access time on every hit self._file_lock_last_access[fname] = now return self._file_locks[fname] def _adapt_pool_size_to_memory_pressure(self) -> None: """Dynamically adapt pool size based on system memory pressure. Throttled to run at most once per ``_pressure_adaptation_interval`` seconds to avoid a psutil syscall on every ``get_connection()`` call. """ if not self.enable_memory_pressure_adaptation: return current_time = time.time() if ( current_time - self._last_pressure_adaptation < self._pressure_adaptation_interval ): return self._last_pressure_adaptation = current_time try: memory = _get_virtual_memory() memory_pressure = memory.percent / 100.0 if memory_pressure > MEMORY_PRESSURE_CRITICAL: # Very high memory pressure - reduce pool size significantly new_size = max(3, self.base_pool_size // 4) elif memory_pressure > MEMORY_PRESSURE_HIGH: # High memory pressure - reduce pool size moderately new_size = max(5, self.base_pool_size // 2) elif memory_pressure > MEMORY_PRESSURE_MODERATE: # Moderate memory pressure - slight reduction new_size = max(8, int(self.base_pool_size * 0.75)) else: # Normal memory pressure - use full pool size new_size = self.base_pool_size if new_size != self.max_pool_size: old_size = self.max_pool_size self.max_pool_size = new_size logger.info( f"Adapted pool size from {old_size} to {new_size} due to memory pressure: {memory_pressure:.1%}" ) # If we reduced the pool size, evict excess connections if new_size < len(self._pool): self._evict_excess_connections() except Exception as e: logger.warning(f"Failed to adapt pool size to memory pressure: {e}") def _evict_lru_connections(self, count: int = 1) -> int: """Evict least recently used connections. The pool is an OrderedDict with move_to_end() on access, so the front is always the LRU entry — O(1) eviction. Must be called with ``_pool_lock`` held, or will acquire it. """ evicted = 0 with self._pool_lock: for _ in range(min(count, len(self._pool))): fname, connection = self._pool.popitem(last=False) connection.close() self.stats.record_connection_evicted() logger.debug(f"Evicted LRU connection: {fname}") evicted += 1 return evicted def _evict_excess_connections(self) -> None: """Evict connections when pool exceeds maximum size.""" excess = len(self._pool) - self.max_pool_size if excess > 0: self._evict_lru_connections(excess) def _perform_health_check(self) -> None: """Perform health checks on all connections.""" current_time = time.time() if current_time - self._last_health_check < self.health_check_interval: return self._last_health_check = current_time unhealthy_connections = [] aged_connections = [] with self._pool_lock: logger.debug(f"Performing health check on {len(self._pool)} connections") for fname, connection in list(self._pool.items()): is_healthy = connection.check_health() self.stats.record_health_check(is_healthy) # Check if connection is too old (older than health check interval) connection_age = current_time - connection.created_at is_aged = connection_age > self.health_check_interval if not is_healthy: unhealthy_connections.append(fname) self._unhealthy_files.add(fname) elif is_aged: aged_connections.append(fname) # Remove unhealthy connections for fname in unhealthy_connections: if fname in self._pool: self._pool[fname].close() del self._pool[fname] logger.info(f"Removed unhealthy connection: {fname}") # Remove aged connections for fname in aged_connections: if fname in self._pool: connection_age = current_time - self._pool[fname].created_at self._pool[fname].close() del self._pool[fname] logger.info( f"Removed aged connection: {fname} (age: {connection_age:.1f}s)" ) total_removed = len(unhealthy_connections) + len(aged_connections) if total_removed: logger.info( f"Health check completed: removed {len(unhealthy_connections)} unhealthy and {len(aged_connections)} aged connections" )
[docs] @contextmanager def get_connection(self, fname: str, mode: str = "r"): """ Enhanced context manager to get an HDF5 file connection with comprehensive health monitoring, LRU management, and performance tracking. Lock ordering guarantee (BUG-008): All lock acquisitions follow a strict total order — ``_pool_lock`` (pool-level operations) first. File locks (``_get_file_lock``) are never held at the same time as ``_pool_lock``. Health checks and memory-pressure adaptation that require ``_pool_lock`` are performed **before** acquiring any file lock so the ordering is always: pool operations then yield (no lock held). This eliminates the deadlock where Thread A holds ``_pool_lock`` (via a yield inside the old ``with self._pool_lock:`` block) while Thread B waits for ``_pool_lock`` from inside its own file_lock critical section. Parameters ---------- fname : str Path to HDF5 file mode : str File access mode (default: 'r') Yields ------ h5py.File HDF5 file handle """ start_time = time.time() # Normalize file path for consistency before any lock acquisition. fname = os.path.abspath(fname) # --- Phase 1: pool-level maintenance (only _pool_lock acquired) --- # Perform memory-pressure adaptation and periodic health checks here, # *outside* any file lock, so they can safely acquire _pool_lock # without risk of inversion. self._adapt_pool_size_to_memory_pressure() self._perform_health_check() # --- Phase 2: pool lookup / connection creation (only _pool_lock) --- # Resolve the file handle we will use. The yield happens *after* this # block exits so _pool_lock is never held across the yield. file_handle_to_yield = None is_direct_connection = False connection = None if fname in self._unhealthy_files: # Open a fresh direct connection for previously unhealthy files. logger.debug( f"Using direct connection for previously unhealthy file: {fname}" ) try: file_handle_to_yield = h5py.File(fname, mode) is_direct_connection = True except Exception as e: logger.error(f"Failed to open direct connection for {fname}: {e}") raise else: with self._pool_lock: if fname in self._pool: connection = self._pool[fname] self._pool.move_to_end(fname) if connection.check_health(): connection.touch() self.stats.record_connection_reused() logger.debug(f"Reusing healthy connection for {fname}") file_handle_to_yield = connection.file_handle else: logger.info(f"Removing unhealthy connection: {fname}") connection.close() del self._pool[fname] self._unhealthy_files.add(fname) connection = None if file_handle_to_yield is None: # Need a new connection. self.stats.record_cache_miss() if len(self._pool) >= self.max_pool_size: self._evict_lru_connections(1) try: logger.debug(f"Creating new connection for {fname}") fh = h5py.File(fname, mode) connection = PooledConnection(fh, fname) self._pool[fname] = connection self.stats.record_connection_created() file_handle_to_yield = fh except Exception as e: logger.error(f"Failed to open HDF5 file {fname}: {e}") self._unhealthy_files.add(fname) raise # --- Phase 3: yield the file handle (no lock held) --- if connection is not None: connection.touch() try: yield file_handle_to_yield # On success, clear the unhealthy flag for direct connections. if is_direct_connection: self._unhealthy_files.discard(fname) except Exception: # Mark pooled connection as potentially unhealthy on error. if not is_direct_connection: with self._pool_lock: if fname in self._pool: logger.warning( f"Removing connection due to usage error: {fname}" ) self._pool[fname].close() del self._pool[fname] raise finally: if is_direct_connection and file_handle_to_yield is not None: file_handle_to_yield.close() io_time = time.time() - start_time self.stats.record_io_time(io_time) if io_time > 1.0: logger.debug(f"Slow I/O operation: {fname} took {io_time:.2f}s")
[docs] def clear_pool(self, from_destructor=False): """Close all connections and clear the pool.""" with self._pool_lock: # Avoid logging during Python shutdown to prevent logging errors if not from_destructor: logger.info( f"Clearing connection pool with {len(self._pool)} connections" ) for fname, connection in self._pool.items(): try: connection.close() except Exception as e: if not from_destructor: logger.debug(f"Error closing connection {fname}: {e}") self._pool.clear() with self._file_locks_lock: self._file_locks.clear() with self._read_cache_lock: self._read_cache.clear() self._unhealthy_files.clear() if not from_destructor: logger.info("Connection pool cleared")
[docs] def get_pool_stats(self) -> dict[str, Any]: """Get comprehensive pool statistics.""" with self._pool_lock: pool_info: dict[str, Any] = { "active_connections": len(self._pool), "unhealthy_files": len(self._unhealthy_files), "health_check_interval": self.health_check_interval, "memory_adaptation_enabled": self.enable_memory_pressure_adaptation, "connections": [], } # Add per-connection statistics for fname, connection in self._pool.items(): conn_info = { "file": fname, "age_seconds": time.time() - connection.created_at, "last_accessed_seconds_ago": time.time() - connection.last_accessed, "access_count": connection.access_count, "is_healthy": connection.is_healthy, } pool_info["connections"].append(conn_info) # Merge with I/O statistics io_stats = self.stats.get_stats() return {**pool_info, **io_stats}
[docs] def force_health_check(self): """Force an immediate health check of all connections.""" with self._pool_lock: self._last_health_check = 0 # Force immediate check self._perform_health_check()
[docs] def remove_unhealthy_file(self, fname: str): """Remove a file from the unhealthy files set.""" self._unhealthy_files.discard(fname) logger.debug(f"Removed {fname} from unhealthy files set")
[docs] def batch_read_datasets( self, fname: str, dataset_paths: list[str], use_cache: bool = True ) -> dict[str, Any]: """ Optimized batch reading of multiple datasets from the same file. Parameters ---------- fname : str HDF5 file path dataset_paths : List[str] List of dataset paths to read use_cache : bool Whether to use read cache Returns ------- Dict[str, Any] Dictionary mapping dataset paths to their values """ start_time = time.time() results = {} cache_key = f"batch_{hash(tuple(sorted(dataset_paths)))}" # Check read cache first if use_cache: with self._read_cache_lock: if fname in self._read_cache and cache_key in self._read_cache[fname]: cached_result = self._read_cache[fname][cache_key] # Check if cache entry is recent (within 5 minutes) if time.time() - cached_result["timestamp"] < CACHE_ENTRY_TIMEOUT: logger.debug( f"Using cached batch read for {len(dataset_paths)} datasets from {fname}" ) return cached_result["data"] # Perform batch read with self.get_connection(fname, "r") as f: for path in dataset_paths: try: if path in f: val = f[path][()] # Handle common data type conversions if isinstance(val, (np.bytes_, bytes)): val = val.decode() elif isinstance(val, np.ndarray) and val.shape == (1, 1): val = val.item() results[path] = val else: logger.warning(f"Dataset path not found: {path} in {fname}") results[path] = None except Exception as e: logger.error(f"Error reading dataset {path} from {fname}: {e}") results[path] = None # Cache the results if use_cache: with self._read_cache_lock: if fname not in self._read_cache: self._read_cache[fname] = {} # Limit cache size per file if len(self._read_cache[fname]) >= self._read_cache_max_size: # Remove oldest entries sorted_entries = sorted( self._read_cache[fname].items(), key=lambda x: x[1]["timestamp"] ) for old_key, _ in sorted_entries[: len(sorted_entries) // 2]: del self._read_cache[fname][old_key] self._read_cache[fname][cache_key] = { "data": results, "timestamp": time.time(), } read_time = time.time() - start_time logger.debug( f"Batch read of {len(dataset_paths)} datasets from {fname} completed in {read_time:.3f}s" ) return results
[docs] def clear_read_cache(self, fname: str | None = None): """Clear read cache for specific file or all files.""" with self._read_cache_lock: if fname: if fname in self._read_cache: del self._read_cache[fname] logger.debug(f"Cleared read cache for {fname}") else: self._read_cache.clear() logger.debug("Cleared all read cache")
def __del__(self): # Pass True to indicate this is called from destructor during shutdown with suppress(Exception): self.clear_pool(from_destructor=True)
# Global connection pool instance with enhanced settings _connection_pool = HDF5ConnectionPool( max_pool_size=25, # Increased pool size for better performance health_check_interval=180.0, # Health check every 3 minutes enable_memory_pressure_adaptation=True, )
[docs] def put(save_path, result, ftype="nexus", mode="raw"): """ save the result to hdf5 file Parameters ---------- save_path: str path to save the result result: dict dictionary to save ftype: str file type, 'nexus' or 'aps_8idi' mode: str 'raw' or 'alias' """ with h5py.File(save_path, "a") as f: for key, value in result.items(): dest_key = hdf_key[ftype][key] if mode == "alias" else key if dest_key in f: del f[dest_key] dest_value = ( np.reshape(value, (1, -1)) if isinstance(value, np.ndarray) and value.ndim == 1 else value ) f[dest_key] = dest_value return
[docs] def get_abs_cs_scale(fname, ftype="nexus", use_pool=True): key = hdf_key[ftype]["abs_cross_section_scale"] context_manager = ( _connection_pool.get_connection(fname, "r") if use_pool else h5py.File(fname, "r") ) with context_manager as f: if key not in f: return None return float(f[key][()])
[docs] @log_timing(threshold_ms=500) def read_metadata_to_dict(file_path, use_pool=True): """ Reads an HDF5 file and loads its contents into a nested dictionary. Optimized to read all metadata groups in one file open operation. Parameters ---------- file_path : str Path to the HDF5 file. use_pool : bool Whether to use connection pool for optimization. Returns ------- dict A nested dictionary containing datasets as NumPy arrays. """ def recursive_read(h5_group, target_dict): """Recursively reads groups and datasets into the dictionary.""" for key in h5_group: obj = h5_group[key] if isinstance(obj, h5py.Dataset): val = obj[()] if type(val) in [np.bytes_, bytes]: val = val.decode() # Check for units attribute and append to key name units = obj.attrs.get("units", None) if units is not None: if isinstance(units, (np.bytes_, bytes)): units = units.decode() target_key = f"{key} ({units})" else: target_key = key target_dict[target_key] = val elif isinstance(obj, h5py.Group): target_dict[key] = {} recursive_read(obj, target_dict[key]) data_dict = {} groups = [ "/entry/instrument", "/xpcs/multitau/config", "/xpcs/twotime/config", "/entry/sample", "/entry/user", ] # Use connection pool for single file operation reading all metadata groups context_manager = ( _connection_pool.get_connection(file_path, "r") if use_pool else h5py.File(file_path, "r") ) with context_manager as hdf_file: # Batch read all groups in single file operation for group in groups: if group in hdf_file: data_dict[group] = {} recursive_read(hdf_file[group], data_dict[group]) return data_dict
[docs] @log_timing(threshold_ms=200) def get(fname, fields, mode="raw", ret_type="dict", ftype="nexus", use_pool=True): """ get the values for the various keys listed in fields for a single file; :param fname: :param fields_raw: list of keys [key1, key2, ..., ] :param mode: ['raw' | 'alias']; alias is defined in .hdf_key otherwise the raw hdf key will be used :param ret_type: return dictonary if 'dict', list if it is 'list' :param use_pool: whether to use connection pool for optimization :return: dictionary or dictionary; """ assert mode in ["raw", "alias"], "mode not supported" assert ret_type in ["dict", "list"], "ret_type not supported" ret = {} # Choose context manager based on use_pool parameter if use_pool: context_manager = _connection_pool.get_connection(fname, "r") else: context_manager = h5py.File(fname, "r") with context_manager as hdf_result: # Batch read all fields in a single file operation for key in fields: path = hdf_key[ftype][key] if mode == "alias" else key if path not in hdf_result: logger.error("path to field not found: %s", path) raise ValueError(f"key not found: {key}:{path}") val = hdf_result.get(path)[()] if ( key in ["saxs_2d"] and val.ndim == NDIM_3D ): # saxs_2d is in [1xNxM] format val = val[0] # converts bytes to unicode; if type(val) in [np.bytes_, bytes]: val = val.decode() if isinstance(val, np.ndarray) and val.shape == (1, 1): val = val.item() ret[key] = val if ret_type == "dict": return ret if ret_type == "list": return [ret[key] for key in fields] return None
[docs] @log_timing(threshold_ms=100) def get_analysis_type(fname, ftype="nexus", use_pool=True): """ determine the analysis type of the file Parameters ---------- fname: str file name ftype: str file type, 'nexus' or 'legacy' use_pool: bool whether to use connection pool for optimization Returns ------- tuple analysis type, 'Twotime' or 'Multitau', or both """ c2_prefix = hdf_key[ftype]["c2_prefix"] g2_prefix = hdf_key[ftype]["g2"] analysis_type = [] context_manager = ( _connection_pool.get_connection(fname, "r") if use_pool else h5py.File(fname, "r") ) with context_manager as hdf_result: # Primary detection: Check for folder presence (more reliable) multitau_folder = "/xpcs/multitau" twotime_folder = "/xpcs/twotime" if multitau_folder in hdf_result: analysis_type.append("Multitau") logger.debug( f"Detected Multitau format: folder {multitau_folder} found in {fname}" ) if twotime_folder in hdf_result: analysis_type.append("Twotime") logger.debug( f"Detected Twotime format: folder {twotime_folder} found in {fname}" ) # Fallback detection: Check for specific files (backward compatibility) if not analysis_type: logger.debug( f"Folder detection failed, falling back to file detection for {fname}" ) if c2_prefix in hdf_result: analysis_type.append("Twotime") logger.debug( f"Detected Twotime format via file: {c2_prefix} found in {fname}" ) if g2_prefix in hdf_result: analysis_type.append("Multitau") logger.debug( f"Detected Multitau format via file: {g2_prefix} found in {fname}" ) if len(analysis_type) == 0: raise ValueError(f"No analysis type found in {fname}") logger.info(f"File {fname}: detected format(s) {analysis_type}") return tuple(analysis_type)
[docs] @log_timing(threshold_ms=300) def batch_read_fields( fname: str, fields: list[str], mode: str = "raw", ftype: str = "nexus", use_pool: bool = True, ) -> dict[str, Any]: """ Optimized batch reading of multiple fields from HDF5 file. Parameters ---------- fname : str HDF5 file path fields : List[str] List of field names to read mode : str 'raw' or 'alias' mode ftype : str File type ('nexus' or 'legacy') use_pool : bool Whether to use connection pooling Returns ------- Dict[str, Any] Dictionary of field values """ if not use_pool: return get(fname, fields, mode, "dict", ftype, use_pool=False) # Convert fields to dataset paths dataset_paths = [] for field in fields: path = hdf_key[ftype][field] if mode == "alias" else field dataset_paths.append(path) # Use batch read from connection pool raw_results = _connection_pool.batch_read_datasets(fname, dataset_paths) # Process results (same logic as get() function) processed_results = {} for i, field in enumerate(fields): path = dataset_paths[i] val = raw_results.get(path) if val is None: logger.error(f"Path to field not found: {path} for field {field}") raise ValueError(f"Key not found: {field}:{path}") # Apply same processing as original get() function if field == "saxs_2d" and isinstance(val, np.ndarray) and val.ndim == NDIM_3D: val = val[0] # saxs_2d is in [1xNxM] format # converts bytes to unicode (matches get() behavior) if type(val) in [np.bytes_, bytes]: val = val.decode() if isinstance(val, np.ndarray) and val.shape == (1, 1): val = val.item() processed_results[field] = val return processed_results
[docs] @log_timing(threshold_ms=500) def get_file_info(fname: str, use_pool: bool = True) -> dict[str, Any]: """ Get basic file information and statistics. Parameters ---------- fname : str HDF5 file path use_pool : bool Whether to use connection pooling Returns ------- Dict[str, Any] File information dictionary """ context_manager = ( _connection_pool.get_connection(fname, "r") if use_pool else h5py.File(fname, "r") ) with context_manager as f: info = { "file_path": fname, "file_size_mb": os.path.getsize(fname) / (1024 * 1024), "hdf5_version": f.libver, "root_groups": list(f.keys()), "estimated_datasets": 0, "large_datasets": [], } def count_datasets(name, obj): if isinstance(obj, h5py.Dataset): info["estimated_datasets"] += 1 # Track large datasets (>10MB estimated) estimated_size = obj.size * obj.dtype.itemsize / (1024 * 1024) if estimated_size > LARGE_DATASET_THRESHOLD_MB: info["large_datasets"].append( { "path": name, "shape": obj.shape, "dtype": str(obj.dtype), "estimated_size_mb": estimated_size, } ) f.visititems(count_datasets) return info
[docs] def get_connection_pool_stats() -> dict[str, Any]: """ Get comprehensive statistics about the global connection pool. Returns ------- Dict[str, Any] Connection pool statistics """ return _connection_pool.get_pool_stats()
[docs] def clear_connection_pool(): """ Clear all connections in the global connection pool. """ _connection_pool.clear_pool()
[docs] def force_connection_health_check(): """ Force an immediate health check of all pooled connections. """ _connection_pool.force_health_check()
[docs] @log_timing(threshold_ms=1000) def get_chunked_dataset( fname: str, dataset_path: str, chunk_size: tuple[int, ...] | None = None, use_pool: bool = True, ) -> np.ndarray: """ Read a large dataset in chunks to manage memory usage. Parameters ---------- fname : str HDF5 file path dataset_path : str Path to dataset within HDF5 file chunk_size : Tuple[int, ...], optional Size of chunks to read. If None, will use dataset's native chunking use_pool : bool Whether to use connection pooling Returns ------- np.ndarray The dataset array """ context_manager = ( _connection_pool.get_connection(fname, "r") if use_pool else h5py.File(fname, "r") ) with context_manager as f: if dataset_path not in f: raise ValueError(f"Dataset {dataset_path} not found in {fname}") dataset = f[dataset_path] # For small datasets, read directly estimated_size_mb = dataset.size * dataset.dtype.itemsize / (1024 * 1024) if estimated_size_mb < DIRECT_READ_LIMIT_MB: # Less than 100MB, read directly logger.debug( f"Reading small dataset {dataset_path} directly ({estimated_size_mb:.1f}MB)" ) return dataset[()] logger.info( f"Reading large dataset {dataset_path} in chunks ({estimated_size_mb:.1f}MB)" ) # Use native chunking if available and no custom chunk_size specified if chunk_size is None and dataset.chunks is not None: chunk_size = dataset.chunks logger.debug(f"Using native chunking: {chunk_size}") elif chunk_size is None: # Default chunking strategy for 2D data if len(dataset.shape) == NDIM_2D: # Aim for ~10MB chunks target_elements = 10 * 1024 * 1024 // dataset.dtype.itemsize chunk_rows = min( dataset.shape[0], max(1, int(np.sqrt(target_elements))) ) chunk_cols = min(dataset.shape[1], target_elements // chunk_rows) chunk_size = (chunk_rows, chunk_cols) else: # For other dimensions, read in smaller slices along first dimension chunk_size = (min(1000, dataset.shape[0]), *dataset.shape[1:]) logger.debug(f"Using computed chunking: {chunk_size}") # BUG-056: Actually use the computed chunking strategy. # Read the dataset in chunks along the first axis using the computed chunk_size. if chunk_size is None: # Fallback: no chunking strategy available; read directly return dataset[()] shape = dataset.shape result = np.empty(shape, dtype=dataset.dtype) row_chunk = chunk_size[0] for start in range(0, shape[0], row_chunk): end = min(start + row_chunk, shape[0]) result[start:end] = dataset[start:end] return result
if __name__ == "__main__": pass