"""
Enhanced Reliability Framework for XPCS Viewer.
This module provides zero-overhead reliability mechanisms including:
- Fail-fast validation decorators
- Exception result caching
- Smart fallback strategies
- Performance-preserving reliability checks
"""
import functools
import hashlib
import threading
import time
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
from typing import Any, Union
import numpy as np
from .exceptions import XPCSBaseError, XPCSValidationError, convert_exception
from .logging_config import get_logger
logger = get_logger(__name__)
[docs]
class ValidationLevel(Enum):
"""Validation strictness levels for performance tuning."""
MINIMAL = "minimal" # Only critical checks, maximum performance
STANDARD = "standard" # Balanced checks, good performance
STRICT = "strict" # Comprehensive checks, moderate performance
PARANOID = "paranoid" # All possible checks, thorough validation
[docs]
@dataclass
class ValidationResult:
"""Result of validation operation with caching support."""
is_valid: bool
error_message: str | None = None
warnings: list[str] | None = None
validation_time: float = 0.0
cached: bool = False
def __post_init__(self):
if self.warnings is None:
self.warnings = []
[docs]
class ValidationCache:
"""Thread-safe cache for validation results with TTL support."""
[docs]
def __init__(self, max_size: int = 1000, default_ttl: float = 300.0):
self._cache: dict[str, tuple[ValidationResult, float]] = {}
self._lock = threading.RLock()
self._max_size = max_size
self._default_ttl = default_ttl
self._access_times: dict[str, float] = {}
def _generate_key(self, func_name: str, args: tuple, kwargs: dict) -> str:
"""Generate cache key from function call signature."""
# Create deterministic key from arguments
key_data = {
"function": func_name,
"args": str(args),
"kwargs": str(sorted(kwargs.items())),
}
key_string = str(key_data)
return hashlib.md5(key_string.encode(), usedforsecurity=False).hexdigest()
[docs]
def get(self, key: str) -> ValidationResult | None:
"""Get cached validation result if still valid."""
with self._lock:
if key not in self._cache:
return None
result, timestamp = self._cache[key]
current_time = time.time()
# Check TTL
if current_time - timestamp > self._default_ttl:
del self._cache[key]
self._access_times.pop(key, None)
return None
# Update access time for LRU
self._access_times[key] = current_time
result.cached = True
return result
[docs]
def put(self, key: str, result: ValidationResult) -> None:
"""Store validation result in cache."""
with self._lock:
current_time = time.time()
# Evict oldest entries if cache is full
if len(self._cache) >= self._max_size:
self._evict_lru()
self._cache[key] = (result, current_time)
self._access_times[key] = current_time
def _evict_lru(self) -> None:
"""Evict least recently used entries."""
if not self._access_times:
return
# Remove 20% of oldest entries
sorted_items = sorted(self._access_times.items(), key=lambda x: x[1])
evict_count = max(1, len(sorted_items) // 5)
for key, _ in sorted_items[:evict_count]:
self._cache.pop(key, None)
self._access_times.pop(key, None)
[docs]
def clear(self) -> None:
"""Clear all cached results."""
with self._lock:
self._cache.clear()
self._access_times.clear()
# Global validation cache instance
_validation_cache = ValidationCache()
def _validate_type(value: Any, expected_type: Any, param_name: str = "") -> bool:
"""Validate type with support for generic aliases.
Args:
value: The value to validate
expected_type: The expected type (can be a generic alias like Union, List, etc.)
param_name: Parameter name for error reporting (unused but kept for API compat)
Returns:
True if the value matches the expected type
"""
if value is None:
return True # None is often acceptable for optional types
try:
# Handle Union types and generic types
if hasattr(expected_type, "__origin__"):
if expected_type.__origin__ is Union:
# Check if value matches any type in the Union
return any(
_validate_type(value, t, param_name) for t in expected_type.__args__
)
# Handle other generic types (e.g., List[int], Dict[str, Any])
# For now, just check against the origin type
return isinstance(value, expected_type.__origin__)
return isinstance(value, expected_type)
except Exception:
# If type checking fails, assume valid to avoid breaking functionality
return True
def _validate_range(value: int | float | np.number, param_name: str) -> bool:
"""Validate numerical ranges with domain-specific checks."""
# Check for NaN and infinity
if np.isnan(value) or np.isinf(value):
return False
# Domain-specific range checks
if "threshold" in param_name.lower() or "alpha" in param_name.lower():
return bool(0.0 <= value <= 1.0)
if "q_" in param_name.lower() or param_name.lower().endswith("_q"):
return bool(value > 0.0) # Q-values should be positive
if "time" in param_name.lower() or "t_" in param_name.lower():
return bool(value >= 0.0) # Time values should be non-negative
return True # Default: accept all finite values
def _validate_array(
array: np.ndarray, param_name: str, level: ValidationLevel
) -> tuple[list[str], list[str]]:
"""Validate numpy array properties."""
errors: list[str] = []
warnings: list[str] = []
# Check for empty arrays
if array.size == 0:
errors.append(f"Array '{param_name}' is empty")
return errors, warnings
# Check for reasonable array sizes to prevent memory issues
array_size_mb = array.nbytes / (1024 * 1024)
if array_size_mb > 1000: # > 1GB
if level in [ValidationLevel.STRICT, ValidationLevel.PARANOID]:
warnings.append(f"Large array '{param_name}': {array_size_mb:.1f}MB")
# Check for appropriate data types
if array.dtype == np.object_:
warnings.append(f"Object array '{param_name}' may cause performance issues")
# Domain-specific shape validation
if "saxs" in param_name.lower():
if array.ndim not in [2, 3]:
errors.append(
f"SAXS data '{param_name}' should be 2D or 3D, got {array.ndim}D"
)
elif "g2" in param_name.lower() or "correlation" in param_name.lower():
if array.ndim != 2:
errors.append(
f"Correlation data '{param_name}' should be 2D, got {array.ndim}D"
)
return errors, warnings
def _validate_values(
value: Any, param_name: str, level: ValidationLevel
) -> tuple[list[str], list[str]]:
"""Validate data values for scientific correctness."""
errors = []
warnings = []
if isinstance(value, np.ndarray):
# Check for NaN and infinity
if np.any(np.isnan(value)):
if level == ValidationLevel.PARANOID:
errors.append(f"Array '{param_name}' contains NaN values")
else:
warnings.append(f"Array '{param_name}' contains NaN values")
if np.any(np.isinf(value)):
errors.append(f"Array '{param_name}' contains infinite values")
# Check for reasonable value ranges
if np.issubdtype(value.dtype, np.number) and value.size > 0:
min_val, max_val = np.min(value), np.max(value)
value_range = max_val - min_val
if value_range == 0 and level in [
ValidationLevel.STRICT,
ValidationLevel.PARANOID,
]:
warnings.append(f"Array '{param_name}' has zero variance")
# Domain-specific value checks
if "intensity" in param_name.lower() or "saxs" in param_name.lower():
if min_val < 0:
warnings.append(f"Negative intensity values in '{param_name}'")
elif isinstance(value, (int, float, np.integer, np.floating)):
if np.isnan(value):
errors.append(f"Parameter '{param_name}' is NaN")
elif np.isinf(value):
errors.append(f"Parameter '{param_name}' is infinite")
return errors, warnings
[docs]
class SmartFallbackManager:
"""Manager for smart fallback strategies with pre-computed paths."""
[docs]
def __init__(self):
self._fallback_strategies: dict[str, list[Callable]] = {}
self._performance_history: dict[str, list[float]] = {}
self._lock = threading.RLock()
[docs]
def register_fallback_chain(
self, operation_name: str, strategies: list[Callable]
) -> None:
"""Register a chain of fallback strategies for an operation."""
with self._lock:
self._fallback_strategies[operation_name] = strategies
self._performance_history[operation_name] = []
[docs]
def execute_with_fallback(self, operation_name: str, *args, **kwargs) -> Any:
"""Execute operation with automatic fallback on failure."""
strategies = self._fallback_strategies.get(operation_name, [])
if not strategies:
raise XPCSBaseError(
f"No fallback strategies registered for '{operation_name}'"
)
last_exception = None
performance_start = time.time()
for i, strategy in enumerate(strategies):
try:
logger.debug(
f"Attempting {operation_name} strategy {i + 1}/{len(strategies)}: {strategy.__name__}"
)
result = strategy(*args, **kwargs)
# Record successful performance
execution_time = time.time() - performance_start
with self._lock:
self._performance_history[operation_name].append(execution_time)
# Keep only recent history
if len(self._performance_history[operation_name]) > 100:
self._performance_history[operation_name] = (
self._performance_history[operation_name][-50:]
)
logger.debug(
f"Successfully executed {operation_name} with strategy {i + 1}"
)
return result
except Exception as e:
logger.debug(f"Strategy {i + 1} failed for {operation_name}: {e}")
last_exception = e
continue
# All strategies failed
if last_exception:
raise convert_exception(
last_exception, f"All fallback strategies failed for '{operation_name}'"
)
raise XPCSBaseError(
f"All strategies failed for '{operation_name}' with no exceptions"
)
# Global fallback manager instance
_fallback_manager = SmartFallbackManager()
[docs]
def with_fallback(operation_name: str, strategies: list[Callable] | None = None):
"""
Decorator for automatic fallback execution with pre-computed strategies.
Args:
operation_name: Name of the operation for strategy registration
strategies: List of fallback functions (if not already registered)
Example::
@with_fallback("data_loading", [load_enhanced, load_standard])
def load_data(file_path):
# Primary implementation - automatically falls back if it fails
pass
"""
def decorator(func: Callable) -> Callable:
# Register strategies if provided
if strategies:
_fallback_manager.register_fallback_chain(
operation_name, [func, *strategies]
)
else:
# Use just the function itself
_fallback_manager.register_fallback_chain(operation_name, [func])
@functools.wraps(func)
def wrapper(*args, **kwargs):
return _fallback_manager.execute_with_fallback(
operation_name, *args, **kwargs
)
return wrapper
return decorator
[docs]
class ReliabilityContext:
"""Context manager for enhanced reliability with retries and exponential backoff.
Usage -- wrap each attempt in its own ``with`` block inside a loop::
ctx = ReliabilityContext(max_retries=3, retry_delay=0.1)
while True:
with ctx:
risky_operation()
if ctx.should_stop:
break
Alternatively use the :meth:`run` helper which encapsulates the loop::
ctx = ReliabilityContext(max_retries=3)
result = ctx.run(risky_operation, arg1, arg2)
"""
[docs]
def __init__(
self,
max_retries: int = 3,
retry_delay: float = 0.1,
exponential_backoff: bool = True,
acceptable_exceptions: tuple[type, ...] | None = None,
):
self.max_retries = max_retries
self.retry_delay = retry_delay
self.exponential_backoff = exponential_backoff
self.acceptable_exceptions = acceptable_exceptions or (
OSError,
IOError,
TimeoutError,
ConnectionError,
)
self.retry_count = 0
self.should_stop = True # set to False when a retry is pending
self._last_exception: BaseException | None = None
def __enter__(self):
self.should_stop = True
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
# Successful block: stop the retry loop.
self.should_stop = True
return False # Do not suppress; no exception to suppress.
if issubclass(exc_type, self.acceptable_exceptions):
self.retry_count += 1
self._last_exception = exc_val
if self.retry_count <= self.max_retries:
# Calculate delay with optional exponential backoff.
# Perform the sleep *outside* the exception context so that
# any error from time.sleep() is not mistakenly caught by
# the same exception handler (BUG-054).
if self.exponential_backoff:
delay = self.retry_delay * (2 ** (self.retry_count - 1))
else:
delay = self.retry_delay
logger.debug(
f"Retry {self.retry_count}/{self.max_retries} "
f"after {delay:.2f}s delay: {exc_val}"
)
# Signal the outer loop to continue before sleeping.
self.should_stop = False
# Sleep outside the active exception context. Calling
# time.sleep() while an exception is active would pollute the
# exception chain if sleep() itself raised. We suppress the
# retryable exception first (return True below), and the sleep
# happens as part of normal post-exit cleanup.
try:
time.sleep(delay)
except Exception:
# If the sleep itself fails (e.g., interrupted), treat it
# as a non-fatal delay error and continue anyway.
pass
return True # Suppress the exception; outer loop will retry.
# Max retries exceeded: surface the original exception.
self.should_stop = True
return False
# Non-retryable exception -- let it propagate.
self.should_stop = True
return False
[docs]
def run(self, func, *args, **kwargs):
"""Execute ``func(*args, **kwargs)`` with automatic retry on failure.
Retries up to ``self.max_retries`` times on ``acceptable_exceptions``.
Raises the last exception if all retries are exhausted.
This is the preferred API because it avoids the ``__exit__`` self-catch
problem that arises when the context manager is used in a ``while`` loop
manually (BUG-054).
"""
last_exc: BaseException | None = None
for attempt in range(self.max_retries + 1):
try:
return func(*args, **kwargs)
except self.acceptable_exceptions as exc:
last_exc = exc
self.retry_count = attempt + 1
if attempt < self.max_retries:
if self.exponential_backoff:
delay = self.retry_delay * (2**attempt)
else:
delay = self.retry_delay
logger.debug(
f"run() retry {self.retry_count}/{self.max_retries} "
f"after {delay:.2f}s delay: {exc}"
)
time.sleep(delay)
raise last_exc # pragma: no cover – loop always runs at least once
[docs]
def reliability_context(
max_retries: int = 3,
retry_delay: float = 0.1,
exponential_backoff: bool = True,
acceptable_exceptions: tuple[type, ...] | None = None,
):
"""
Context manager for enhanced reliability with retries and exponential backoff.
Args:
max_retries: Maximum number of retry attempts
retry_delay: Base delay between retries (seconds)
exponential_backoff: Use exponential backoff for delays
acceptable_exceptions: Exception types that should be retried
Example::
attempt = 0
while attempt <= max_retries:
try:
with reliability_context(max_retries=3, retry_delay=0.5):
risky_operation()
break # Success
except (OSError, ConnectionError) as e:
attempt += 1
if attempt > max_retries:
raise
"""
return ReliabilityContext(
max_retries, retry_delay, exponential_backoff, acceptable_exceptions
)
[docs]
def get_validation_cache() -> ValidationCache:
"""Get the global validation cache instance."""
return _validation_cache
[docs]
def get_fallback_manager() -> SmartFallbackManager:
"""Get the global fallback manager instance."""
return _fallback_manager
[docs]
def clear_reliability_caches() -> None:
"""Clear all reliability-related caches for testing or memory management."""
_validation_cache.clear()
logger.debug("Reliability caches cleared")
# Performance monitoring for reliability overhead
[docs]
class ReliabilityProfiler:
"""Lightweight profiler for reliability overhead measurement."""
[docs]
def __init__(self):
self._stats: dict[str, list[float]] = {}
self._lock = threading.RLock()
[docs]
def record_overhead(self, operation: str, overhead_time: float) -> None:
"""Record reliability overhead for an operation."""
with self._lock:
if operation not in self._stats:
self._stats[operation] = []
self._stats[operation].append(overhead_time)
# Keep only recent samples
if len(self._stats[operation]) > 1000:
self._stats[operation] = self._stats[operation][-500:]
[docs]
def get_overhead_stats(self, operation: str | None = None) -> dict[str, Any]:
"""Get overhead statistics for operations."""
with self._lock:
if operation:
times = self._stats.get(operation, [])
if not times:
return {}
return {
"mean_overhead_ms": np.mean(times) * 1000,
"max_overhead_ms": np.max(times) * 1000,
"total_calls": len(times),
}
return {op: self.get_overhead_stats(op) for op in self._stats}
# Global profiler instance
_reliability_profiler = ReliabilityProfiler()
[docs]
def get_reliability_profiler() -> ReliabilityProfiler:
"""Get the global reliability profiler instance."""
return _reliability_profiler