"""HDF5 file reading utilities with connection pooling.
Provides optimized HDF5 file reading with connection pooling, batch operations,
and automatic resource management for efficient data access.
Classes:
HDF5ConnectionPool: Thread-safe connection pool for HDF5 files
HDF5Reader: High-level reader with caching and batch operations
"""
# Standard library imports
import os
import threading
import time
from collections import OrderedDict
from contextlib import contextmanager, suppress
from typing import Any
# Third-party imports
import h5py
import numpy as np
from xpcsviewer.constants import (
CACHE_ENTRY_TIMEOUT,
DIRECT_READ_LIMIT_MB,
LARGE_DATASET_THRESHOLD_MB,
NDIM_2D,
NDIM_3D,
)
from xpcsviewer.utils.log_utils import log_timing
from xpcsviewer.utils.logging_config import get_logger
from xpcsviewer.xpcs_file.memory import _get_virtual_memory
# Local imports
from .aps_8idi import key as hdf_key
logger = get_logger(__name__)
# Performance monitoring stub for testing
_perf_monitor = None
# Memory pressure thresholds for connection pool management
MEMORY_PRESSURE_CRITICAL = 0.90 # 90% - Critical memory pressure
MEMORY_PRESSURE_HIGH = 0.80 # 80% - High memory pressure
MEMORY_PRESSURE_MODERATE = 0.70 # 70% - Moderate memory pressure
[docs]
class PooledConnection:
"""Wrapper for pooled HDF5 connections with metadata."""
[docs]
def __init__(self, file_handle: h5py.File, file_path: str):
self.file_handle = file_handle
self.file_path = file_path
self.created_at = time.time()
self.last_accessed = self.created_at
self.access_count = 0
self.is_healthy = True
self.lock = threading.RLock()
[docs]
def touch(self) -> None:
"""Update access time and count."""
self.last_accessed = time.time()
self.access_count += 1
[docs]
def check_health(self) -> bool:
"""Check if connection is still healthy."""
try:
with self.lock:
# Try to access a basic property
_ = self.file_handle.filename
# Check if file still exists
if not os.path.exists(self.file_path):
self.is_healthy = False
return False
self.is_healthy = True
return True
except (ValueError, OSError, AttributeError) as e:
logger.debug(f"Connection health check failed for {self.file_path}: {e}")
self.is_healthy = False
return False
[docs]
def close(self) -> None:
"""Close the connection safely."""
try:
with self.lock:
if hasattr(self.file_handle, "close"):
self.file_handle.close()
except Exception as e:
logger.debug(f"Error closing connection for {self.file_path}: {e}")
finally:
self.is_healthy = False
[docs]
class ConnectionStats:
"""Statistics tracker for HDF5 connections."""
[docs]
def __init__(self) -> None:
self.total_connections_created = 0
self.total_connections_reused = 0
self.total_connections_evicted = 0
self.total_health_checks = 0
self.failed_health_checks = 0
self.cache_hits = 0
self.cache_misses = 0
self.io_time_seconds = 0.0
self.start_time = time.time()
self._lock = threading.RLock()
[docs]
def record_connection_created(self) -> None:
with self._lock:
self.total_connections_created += 1
[docs]
def record_connection_reused(self) -> None:
with self._lock:
self.total_connections_reused += 1
self.cache_hits += 1
[docs]
def record_connection_evicted(self) -> None:
with self._lock:
self.total_connections_evicted += 1
[docs]
def record_health_check(self, success: bool):
with self._lock:
self.total_health_checks += 1
if not success:
self.failed_health_checks += 1
[docs]
def record_cache_miss(self) -> None:
with self._lock:
self.cache_misses += 1
[docs]
def record_io_time(self, duration: float):
with self._lock:
self.io_time_seconds += duration
[docs]
def get_stats(self) -> dict[str, Any]:
with self._lock:
uptime = time.time() - self.start_time
total_operations = self.cache_hits + self.cache_misses
return {
"uptime_seconds": uptime,
"total_connections_created": self.total_connections_created,
"total_connections_reused": self.total_connections_reused,
"total_connections_evicted": self.total_connections_evicted,
"cache_hit_ratio": self.cache_hits / total_operations
if total_operations > 0
else 0.0,
"cache_hits": self.cache_hits,
"cache_misses": self.cache_misses,
"health_checks_performed": self.total_health_checks,
"health_check_failure_rate": self.failed_health_checks
/ self.total_health_checks
if self.total_health_checks > 0
else 0.0,
"total_io_time_seconds": self.io_time_seconds,
"average_io_time_ms": (self.io_time_seconds / total_operations * 1000)
if total_operations > 0
else 0.0,
}
[docs]
class HDF5ConnectionPool:
"""
Enhanced connection pool for HDF5 files with LRU eviction, health monitoring,
and comprehensive I/O performance tracking.
Features:
- LRU eviction policy with configurable pool size
- Connection health monitoring with automatic cleanup
- Performance statistics and monitoring
- Thread-safe operations with fine-grained locking
- Memory pressure detection and adaptive sizing
- Batch read operations optimization
"""
[docs]
def __init__(
self,
max_pool_size: int = 20,
health_check_interval: float = 300.0,
enable_memory_pressure_adaptation: bool = True,
):
self.max_pool_size = max_pool_size
self.base_pool_size = max_pool_size
self.health_check_interval = health_check_interval
self.enable_memory_pressure_adaptation = enable_memory_pressure_adaptation
# LRU cache for connections
self._pool: OrderedDict[str, PooledConnection] = OrderedDict()
self._pool_lock = threading.RLock()
self._file_locks: dict[str, threading.RLock] = {}
self._file_locks_lock = threading.RLock()
# Track last-access time per file lock to prune stale entries (BUG-032)
self._file_lock_last_access: dict[str, float] = {}
self._file_lock_max_age_seconds: float = 3600.0 # 1 hour stale threshold
# Statistics tracking
self.stats = ConnectionStats()
# Health monitoring
self._last_health_check = 0.0
self._unhealthy_files: set[str] = set()
# Throttle for memory pressure adaptation (avoid psutil syscall on every request)
self._last_pressure_adaptation = 0.0
self._pressure_adaptation_interval = 30.0 # seconds
# Read-ahead cache for batch operations
self._read_cache: dict[str, dict[str, Any]] = {}
self._read_cache_lock = threading.RLock()
self._read_cache_max_size = 50 # Maximum cached read results per file
logger.info(
f"HDF5ConnectionPool initialized: max_size={max_pool_size}, health_check_interval={health_check_interval}s"
)
def _get_file_lock(self, fname: str) -> threading.RLock:
"""Get or create a lock for a specific file.
Tracks last-access time per entry and prunes stale locks on every
new creation to prevent the lock dict from growing without bound
under long-running contention. (BUG-032)
"""
_MAX_FILE_LOCKS = 256
now = time.time()
with self._file_locks_lock:
if fname not in self._file_locks:
self._file_locks[fname] = threading.RLock()
self._file_lock_last_access[fname] = now
# Prune stale entries first (age-based), then fall back to
# size-based pruning if needed.
stale_keys = [
k
for k, ts in list(self._file_lock_last_access.items())
if k != fname
and (now - ts) > self._file_lock_max_age_seconds
and self._file_locks.get(k) is not None
and self._file_locks[k].acquire(blocking=False)
and (self._file_locks[k].release() is None) # type: ignore[func-returns-value]
]
for key in stale_keys:
self._file_locks.pop(key, None)
self._file_lock_last_access.pop(key, None)
# Size-based fallback pruning — only remove locks that are
# not currently held to avoid corrupting active critical sections.
if len(self._file_locks) > _MAX_FILE_LOCKS:
keys_to_remove = []
for key, lock in list(self._file_locks.items()):
if key == fname:
continue
# Try non-blocking acquire: if it succeeds the lock
# is free and safe to remove.
if lock.acquire(blocking=False):
lock.release()
keys_to_remove.append(key)
if (
len(self._file_locks) - len(keys_to_remove)
<= _MAX_FILE_LOCKS // 2
):
break
for key in keys_to_remove:
self._file_locks.pop(key, None)
self._file_lock_last_access.pop(key, None)
else:
# Update last-access time on every hit
self._file_lock_last_access[fname] = now
return self._file_locks[fname]
def _adapt_pool_size_to_memory_pressure(self) -> None:
"""Dynamically adapt pool size based on system memory pressure.
Throttled to run at most once per ``_pressure_adaptation_interval``
seconds to avoid a psutil syscall on every ``get_connection()`` call.
"""
if not self.enable_memory_pressure_adaptation:
return
current_time = time.time()
if (
current_time - self._last_pressure_adaptation
< self._pressure_adaptation_interval
):
return
self._last_pressure_adaptation = current_time
try:
memory = _get_virtual_memory()
memory_pressure = memory.percent / 100.0
if memory_pressure > MEMORY_PRESSURE_CRITICAL:
# Very high memory pressure - reduce pool size significantly
new_size = max(3, self.base_pool_size // 4)
elif memory_pressure > MEMORY_PRESSURE_HIGH:
# High memory pressure - reduce pool size moderately
new_size = max(5, self.base_pool_size // 2)
elif memory_pressure > MEMORY_PRESSURE_MODERATE:
# Moderate memory pressure - slight reduction
new_size = max(8, int(self.base_pool_size * 0.75))
else:
# Normal memory pressure - use full pool size
new_size = self.base_pool_size
if new_size != self.max_pool_size:
old_size = self.max_pool_size
self.max_pool_size = new_size
logger.info(
f"Adapted pool size from {old_size} to {new_size} due to memory pressure: {memory_pressure:.1%}"
)
# If we reduced the pool size, evict excess connections
if new_size < len(self._pool):
self._evict_excess_connections()
except Exception as e:
logger.warning(f"Failed to adapt pool size to memory pressure: {e}")
def _evict_lru_connections(self, count: int = 1) -> int:
"""Evict least recently used connections.
The pool is an OrderedDict with move_to_end() on access,
so the front is always the LRU entry — O(1) eviction.
Must be called with ``_pool_lock`` held, or will acquire it.
"""
evicted = 0
with self._pool_lock:
for _ in range(min(count, len(self._pool))):
fname, connection = self._pool.popitem(last=False)
connection.close()
self.stats.record_connection_evicted()
logger.debug(f"Evicted LRU connection: {fname}")
evicted += 1
return evicted
def _evict_excess_connections(self) -> None:
"""Evict connections when pool exceeds maximum size."""
excess = len(self._pool) - self.max_pool_size
if excess > 0:
self._evict_lru_connections(excess)
def _perform_health_check(self) -> None:
"""Perform health checks on all connections."""
current_time = time.time()
if current_time - self._last_health_check < self.health_check_interval:
return
self._last_health_check = current_time
unhealthy_connections = []
aged_connections = []
with self._pool_lock:
logger.debug(f"Performing health check on {len(self._pool)} connections")
for fname, connection in list(self._pool.items()):
is_healthy = connection.check_health()
self.stats.record_health_check(is_healthy)
# Check if connection is too old (older than health check interval)
connection_age = current_time - connection.created_at
is_aged = connection_age > self.health_check_interval
if not is_healthy:
unhealthy_connections.append(fname)
self._unhealthy_files.add(fname)
elif is_aged:
aged_connections.append(fname)
# Remove unhealthy connections
for fname in unhealthy_connections:
if fname in self._pool:
self._pool[fname].close()
del self._pool[fname]
logger.info(f"Removed unhealthy connection: {fname}")
# Remove aged connections
for fname in aged_connections:
if fname in self._pool:
connection_age = current_time - self._pool[fname].created_at
self._pool[fname].close()
del self._pool[fname]
logger.info(
f"Removed aged connection: {fname} (age: {connection_age:.1f}s)"
)
total_removed = len(unhealthy_connections) + len(aged_connections)
if total_removed:
logger.info(
f"Health check completed: removed {len(unhealthy_connections)} unhealthy and {len(aged_connections)} aged connections"
)
[docs]
@contextmanager
def get_connection(self, fname: str, mode: str = "r"):
"""
Enhanced context manager to get an HDF5 file connection with comprehensive
health monitoring, LRU management, and performance tracking.
Lock ordering guarantee (BUG-008):
All lock acquisitions follow a strict total order —
``_pool_lock`` (pool-level operations) first.
File locks (``_get_file_lock``) are never held at the same time as
``_pool_lock``. Health checks and memory-pressure adaptation that
require ``_pool_lock`` are performed **before** acquiring any file
lock so the ordering is always: pool operations then yield (no lock
held). This eliminates the deadlock where Thread A holds
``_pool_lock`` (via a yield inside the old ``with self._pool_lock:``
block) while Thread B waits for ``_pool_lock`` from inside its own
file_lock critical section.
Parameters
----------
fname : str
Path to HDF5 file
mode : str
File access mode (default: 'r')
Yields
------
h5py.File
HDF5 file handle
"""
start_time = time.time()
# Normalize file path for consistency before any lock acquisition.
fname = os.path.abspath(fname)
# --- Phase 1: pool-level maintenance (only _pool_lock acquired) ---
# Perform memory-pressure adaptation and periodic health checks here,
# *outside* any file lock, so they can safely acquire _pool_lock
# without risk of inversion.
self._adapt_pool_size_to_memory_pressure()
self._perform_health_check()
# --- Phase 2: pool lookup / connection creation (only _pool_lock) ---
# Resolve the file handle we will use. The yield happens *after* this
# block exits so _pool_lock is never held across the yield.
file_handle_to_yield = None
is_direct_connection = False
connection = None
if fname in self._unhealthy_files:
# Open a fresh direct connection for previously unhealthy files.
logger.debug(
f"Using direct connection for previously unhealthy file: {fname}"
)
try:
file_handle_to_yield = h5py.File(fname, mode)
is_direct_connection = True
except Exception as e:
logger.error(f"Failed to open direct connection for {fname}: {e}")
raise
else:
with self._pool_lock:
if fname in self._pool:
connection = self._pool[fname]
self._pool.move_to_end(fname)
if connection.check_health():
connection.touch()
self.stats.record_connection_reused()
logger.debug(f"Reusing healthy connection for {fname}")
file_handle_to_yield = connection.file_handle
else:
logger.info(f"Removing unhealthy connection: {fname}")
connection.close()
del self._pool[fname]
self._unhealthy_files.add(fname)
connection = None
if file_handle_to_yield is None:
# Need a new connection.
self.stats.record_cache_miss()
if len(self._pool) >= self.max_pool_size:
self._evict_lru_connections(1)
try:
logger.debug(f"Creating new connection for {fname}")
fh = h5py.File(fname, mode)
connection = PooledConnection(fh, fname)
self._pool[fname] = connection
self.stats.record_connection_created()
file_handle_to_yield = fh
except Exception as e:
logger.error(f"Failed to open HDF5 file {fname}: {e}")
self._unhealthy_files.add(fname)
raise
# --- Phase 3: yield the file handle (no lock held) ---
if connection is not None:
connection.touch()
try:
yield file_handle_to_yield
# On success, clear the unhealthy flag for direct connections.
if is_direct_connection:
self._unhealthy_files.discard(fname)
except Exception:
# Mark pooled connection as potentially unhealthy on error.
if not is_direct_connection:
with self._pool_lock:
if fname in self._pool:
logger.warning(
f"Removing connection due to usage error: {fname}"
)
self._pool[fname].close()
del self._pool[fname]
raise
finally:
if is_direct_connection and file_handle_to_yield is not None:
file_handle_to_yield.close()
io_time = time.time() - start_time
self.stats.record_io_time(io_time)
if io_time > 1.0:
logger.debug(f"Slow I/O operation: {fname} took {io_time:.2f}s")
[docs]
def clear_pool(self, from_destructor=False):
"""Close all connections and clear the pool."""
with self._pool_lock:
# Avoid logging during Python shutdown to prevent logging errors
if not from_destructor:
logger.info(
f"Clearing connection pool with {len(self._pool)} connections"
)
for fname, connection in self._pool.items():
try:
connection.close()
except Exception as e:
if not from_destructor:
logger.debug(f"Error closing connection {fname}: {e}")
self._pool.clear()
with self._file_locks_lock:
self._file_locks.clear()
with self._read_cache_lock:
self._read_cache.clear()
self._unhealthy_files.clear()
if not from_destructor:
logger.info("Connection pool cleared")
[docs]
def get_pool_stats(self) -> dict[str, Any]:
"""Get comprehensive pool statistics."""
with self._pool_lock:
pool_info: dict[str, Any] = {
"active_connections": len(self._pool),
"unhealthy_files": len(self._unhealthy_files),
"health_check_interval": self.health_check_interval,
"memory_adaptation_enabled": self.enable_memory_pressure_adaptation,
"connections": [],
}
# Add per-connection statistics
for fname, connection in self._pool.items():
conn_info = {
"file": fname,
"age_seconds": time.time() - connection.created_at,
"last_accessed_seconds_ago": time.time() - connection.last_accessed,
"access_count": connection.access_count,
"is_healthy": connection.is_healthy,
}
pool_info["connections"].append(conn_info)
# Merge with I/O statistics
io_stats = self.stats.get_stats()
return {**pool_info, **io_stats}
[docs]
def force_health_check(self):
"""Force an immediate health check of all connections."""
with self._pool_lock:
self._last_health_check = 0 # Force immediate check
self._perform_health_check()
[docs]
def remove_unhealthy_file(self, fname: str):
"""Remove a file from the unhealthy files set."""
self._unhealthy_files.discard(fname)
logger.debug(f"Removed {fname} from unhealthy files set")
[docs]
def batch_read_datasets(
self, fname: str, dataset_paths: list[str], use_cache: bool = True
) -> dict[str, Any]:
"""
Optimized batch reading of multiple datasets from the same file.
Parameters
----------
fname : str
HDF5 file path
dataset_paths : List[str]
List of dataset paths to read
use_cache : bool
Whether to use read cache
Returns
-------
Dict[str, Any]
Dictionary mapping dataset paths to their values
"""
start_time = time.time()
results = {}
cache_key = f"batch_{hash(tuple(sorted(dataset_paths)))}"
# Check read cache first
if use_cache:
with self._read_cache_lock:
if fname in self._read_cache and cache_key in self._read_cache[fname]:
cached_result = self._read_cache[fname][cache_key]
# Check if cache entry is recent (within 5 minutes)
if time.time() - cached_result["timestamp"] < CACHE_ENTRY_TIMEOUT:
logger.debug(
f"Using cached batch read for {len(dataset_paths)} datasets from {fname}"
)
return cached_result["data"]
# Perform batch read
with self.get_connection(fname, "r") as f:
for path in dataset_paths:
try:
if path in f:
val = f[path][()]
# Handle common data type conversions
if isinstance(val, (np.bytes_, bytes)):
val = val.decode()
elif isinstance(val, np.ndarray) and val.shape == (1, 1):
val = val.item()
results[path] = val
else:
logger.warning(f"Dataset path not found: {path} in {fname}")
results[path] = None
except Exception as e:
logger.error(f"Error reading dataset {path} from {fname}: {e}")
results[path] = None
# Cache the results
if use_cache:
with self._read_cache_lock:
if fname not in self._read_cache:
self._read_cache[fname] = {}
# Limit cache size per file
if len(self._read_cache[fname]) >= self._read_cache_max_size:
# Remove oldest entries
sorted_entries = sorted(
self._read_cache[fname].items(), key=lambda x: x[1]["timestamp"]
)
for old_key, _ in sorted_entries[: len(sorted_entries) // 2]:
del self._read_cache[fname][old_key]
self._read_cache[fname][cache_key] = {
"data": results,
"timestamp": time.time(),
}
read_time = time.time() - start_time
logger.debug(
f"Batch read of {len(dataset_paths)} datasets from {fname} completed in {read_time:.3f}s"
)
return results
[docs]
def clear_read_cache(self, fname: str | None = None):
"""Clear read cache for specific file or all files."""
with self._read_cache_lock:
if fname:
if fname in self._read_cache:
del self._read_cache[fname]
logger.debug(f"Cleared read cache for {fname}")
else:
self._read_cache.clear()
logger.debug("Cleared all read cache")
def __del__(self):
# Pass True to indicate this is called from destructor during shutdown
with suppress(Exception):
self.clear_pool(from_destructor=True)
# Global connection pool instance with enhanced settings
_connection_pool = HDF5ConnectionPool(
max_pool_size=25, # Increased pool size for better performance
health_check_interval=180.0, # Health check every 3 minutes
enable_memory_pressure_adaptation=True,
)
[docs]
def put(save_path, result, ftype="nexus", mode="raw"):
"""
save the result to hdf5 file
Parameters
----------
save_path: str
path to save the result
result: dict
dictionary to save
ftype: str
file type, 'nexus' or 'aps_8idi'
mode: str
'raw' or 'alias'
"""
with h5py.File(save_path, "a") as f:
for key, value in result.items():
dest_key = hdf_key[ftype][key] if mode == "alias" else key
if dest_key in f:
del f[dest_key]
dest_value = (
np.reshape(value, (1, -1))
if isinstance(value, np.ndarray) and value.ndim == 1
else value
)
f[dest_key] = dest_value
return
[docs]
def get_abs_cs_scale(fname, ftype="nexus", use_pool=True):
key = hdf_key[ftype]["abs_cross_section_scale"]
context_manager = (
_connection_pool.get_connection(fname, "r")
if use_pool
else h5py.File(fname, "r")
)
with context_manager as f:
if key not in f:
return None
return float(f[key][()])
[docs]
@log_timing(threshold_ms=200)
def get(fname, fields, mode="raw", ret_type="dict", ftype="nexus", use_pool=True):
"""
get the values for the various keys listed in fields for a single
file;
:param fname:
:param fields_raw: list of keys [key1, key2, ..., ]
:param mode: ['raw' | 'alias']; alias is defined in .hdf_key
otherwise the raw hdf key will be used
:param ret_type: return dictonary if 'dict', list if it is 'list'
:param use_pool: whether to use connection pool for optimization
:return: dictionary or dictionary;
"""
assert mode in ["raw", "alias"], "mode not supported"
assert ret_type in ["dict", "list"], "ret_type not supported"
ret = {}
# Choose context manager based on use_pool parameter
if use_pool:
context_manager = _connection_pool.get_connection(fname, "r")
else:
context_manager = h5py.File(fname, "r")
with context_manager as hdf_result:
# Batch read all fields in a single file operation
for key in fields:
path = hdf_key[ftype][key] if mode == "alias" else key
if path not in hdf_result:
logger.error("path to field not found: %s", path)
raise ValueError(f"key not found: {key}:{path}")
val = hdf_result.get(path)[()]
if (
key in ["saxs_2d"] and val.ndim == NDIM_3D
): # saxs_2d is in [1xNxM] format
val = val[0]
# converts bytes to unicode;
if type(val) in [np.bytes_, bytes]:
val = val.decode()
if isinstance(val, np.ndarray) and val.shape == (1, 1):
val = val.item()
ret[key] = val
if ret_type == "dict":
return ret
if ret_type == "list":
return [ret[key] for key in fields]
return None
[docs]
@log_timing(threshold_ms=100)
def get_analysis_type(fname, ftype="nexus", use_pool=True):
"""
determine the analysis type of the file
Parameters
----------
fname: str
file name
ftype: str
file type, 'nexus' or 'legacy'
use_pool: bool
whether to use connection pool for optimization
Returns
-------
tuple
analysis type, 'Twotime' or 'Multitau', or both
"""
c2_prefix = hdf_key[ftype]["c2_prefix"]
g2_prefix = hdf_key[ftype]["g2"]
analysis_type = []
context_manager = (
_connection_pool.get_connection(fname, "r")
if use_pool
else h5py.File(fname, "r")
)
with context_manager as hdf_result:
# Primary detection: Check for folder presence (more reliable)
multitau_folder = "/xpcs/multitau"
twotime_folder = "/xpcs/twotime"
if multitau_folder in hdf_result:
analysis_type.append("Multitau")
logger.debug(
f"Detected Multitau format: folder {multitau_folder} found in {fname}"
)
if twotime_folder in hdf_result:
analysis_type.append("Twotime")
logger.debug(
f"Detected Twotime format: folder {twotime_folder} found in {fname}"
)
# Fallback detection: Check for specific files (backward compatibility)
if not analysis_type:
logger.debug(
f"Folder detection failed, falling back to file detection for {fname}"
)
if c2_prefix in hdf_result:
analysis_type.append("Twotime")
logger.debug(
f"Detected Twotime format via file: {c2_prefix} found in {fname}"
)
if g2_prefix in hdf_result:
analysis_type.append("Multitau")
logger.debug(
f"Detected Multitau format via file: {g2_prefix} found in {fname}"
)
if len(analysis_type) == 0:
raise ValueError(f"No analysis type found in {fname}")
logger.info(f"File {fname}: detected format(s) {analysis_type}")
return tuple(analysis_type)
[docs]
@log_timing(threshold_ms=300)
def batch_read_fields(
fname: str,
fields: list[str],
mode: str = "raw",
ftype: str = "nexus",
use_pool: bool = True,
) -> dict[str, Any]:
"""
Optimized batch reading of multiple fields from HDF5 file.
Parameters
----------
fname : str
HDF5 file path
fields : List[str]
List of field names to read
mode : str
'raw' or 'alias' mode
ftype : str
File type ('nexus' or 'legacy')
use_pool : bool
Whether to use connection pooling
Returns
-------
Dict[str, Any]
Dictionary of field values
"""
if not use_pool:
return get(fname, fields, mode, "dict", ftype, use_pool=False)
# Convert fields to dataset paths
dataset_paths = []
for field in fields:
path = hdf_key[ftype][field] if mode == "alias" else field
dataset_paths.append(path)
# Use batch read from connection pool
raw_results = _connection_pool.batch_read_datasets(fname, dataset_paths)
# Process results (same logic as get() function)
processed_results = {}
for i, field in enumerate(fields):
path = dataset_paths[i]
val = raw_results.get(path)
if val is None:
logger.error(f"Path to field not found: {path} for field {field}")
raise ValueError(f"Key not found: {field}:{path}")
# Apply same processing as original get() function
if field == "saxs_2d" and isinstance(val, np.ndarray) and val.ndim == NDIM_3D:
val = val[0] # saxs_2d is in [1xNxM] format
# converts bytes to unicode (matches get() behavior)
if type(val) in [np.bytes_, bytes]:
val = val.decode()
if isinstance(val, np.ndarray) and val.shape == (1, 1):
val = val.item()
processed_results[field] = val
return processed_results
[docs]
@log_timing(threshold_ms=500)
def get_file_info(fname: str, use_pool: bool = True) -> dict[str, Any]:
"""
Get basic file information and statistics.
Parameters
----------
fname : str
HDF5 file path
use_pool : bool
Whether to use connection pooling
Returns
-------
Dict[str, Any]
File information dictionary
"""
context_manager = (
_connection_pool.get_connection(fname, "r")
if use_pool
else h5py.File(fname, "r")
)
with context_manager as f:
info = {
"file_path": fname,
"file_size_mb": os.path.getsize(fname) / (1024 * 1024),
"hdf5_version": f.libver,
"root_groups": list(f.keys()),
"estimated_datasets": 0,
"large_datasets": [],
}
def count_datasets(name, obj):
if isinstance(obj, h5py.Dataset):
info["estimated_datasets"] += 1
# Track large datasets (>10MB estimated)
estimated_size = obj.size * obj.dtype.itemsize / (1024 * 1024)
if estimated_size > LARGE_DATASET_THRESHOLD_MB:
info["large_datasets"].append(
{
"path": name,
"shape": obj.shape,
"dtype": str(obj.dtype),
"estimated_size_mb": estimated_size,
}
)
f.visititems(count_datasets)
return info
[docs]
def get_connection_pool_stats() -> dict[str, Any]:
"""
Get comprehensive statistics about the global connection pool.
Returns
-------
Dict[str, Any]
Connection pool statistics
"""
return _connection_pool.get_pool_stats()
[docs]
def clear_connection_pool():
"""
Clear all connections in the global connection pool.
"""
_connection_pool.clear_pool()
[docs]
def force_connection_health_check():
"""
Force an immediate health check of all pooled connections.
"""
_connection_pool.force_health_check()
[docs]
@log_timing(threshold_ms=1000)
def get_chunked_dataset(
fname: str,
dataset_path: str,
chunk_size: tuple[int, ...] | None = None,
use_pool: bool = True,
) -> np.ndarray:
"""
Read a large dataset in chunks to manage memory usage.
Parameters
----------
fname : str
HDF5 file path
dataset_path : str
Path to dataset within HDF5 file
chunk_size : Tuple[int, ...], optional
Size of chunks to read. If None, will use dataset's native chunking
use_pool : bool
Whether to use connection pooling
Returns
-------
np.ndarray
The dataset array
"""
context_manager = (
_connection_pool.get_connection(fname, "r")
if use_pool
else h5py.File(fname, "r")
)
with context_manager as f:
if dataset_path not in f:
raise ValueError(f"Dataset {dataset_path} not found in {fname}")
dataset = f[dataset_path]
# For small datasets, read directly
estimated_size_mb = dataset.size * dataset.dtype.itemsize / (1024 * 1024)
if estimated_size_mb < DIRECT_READ_LIMIT_MB: # Less than 100MB, read directly
logger.debug(
f"Reading small dataset {dataset_path} directly ({estimated_size_mb:.1f}MB)"
)
return dataset[()]
logger.info(
f"Reading large dataset {dataset_path} in chunks ({estimated_size_mb:.1f}MB)"
)
# Use native chunking if available and no custom chunk_size specified
if chunk_size is None and dataset.chunks is not None:
chunk_size = dataset.chunks
logger.debug(f"Using native chunking: {chunk_size}")
elif chunk_size is None:
# Default chunking strategy for 2D data
if len(dataset.shape) == NDIM_2D:
# Aim for ~10MB chunks
target_elements = 10 * 1024 * 1024 // dataset.dtype.itemsize
chunk_rows = min(
dataset.shape[0], max(1, int(np.sqrt(target_elements)))
)
chunk_cols = min(dataset.shape[1], target_elements // chunk_rows)
chunk_size = (chunk_rows, chunk_cols)
else:
# For other dimensions, read in smaller slices along first dimension
chunk_size = (min(1000, dataset.shape[0]), *dataset.shape[1:])
logger.debug(f"Using computed chunking: {chunk_size}")
# BUG-056: Actually use the computed chunking strategy.
# Read the dataset in chunks along the first axis using the computed chunk_size.
if chunk_size is None:
# Fallback: no chunking strategy available; read directly
return dataset[()]
shape = dataset.shape
result = np.empty(shape, dtype=dataset.dtype)
row_chunk = chunk_size[0]
for start in range(0, shape[0], row_chunk):
end = min(start + row_chunk, shape[0])
result[start:end] = dataset[start:end]
return result
if __name__ == "__main__":
pass