Source code for xpcsviewer.utils.state_validator

"""
Lock-Free State Consistency Validation for XPCS Viewer.

This module provides high-performance state validation using atomic operations,
weak references, and lock-free data structures to ensure object consistency
without blocking critical operations.
"""

import hashlib
import threading
import time
from collections import namedtuple
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from weakref import WeakSet

import numpy as np

from .logging_config import get_logger

# BUG-060: Consolidate duplicate ValidationLevel / StateValidationLevel enums.
# StateValidationLevel is now an alias for the canonical ValidationLevel defined
# in reliability.py.  All existing code using StateValidationLevel continues to
# work without modification.
from .reliability import ValidationLevel as StateValidationLevel  # noqa: F401

logger = get_logger(__name__)


[docs] class StateTransition(Enum): """Valid state transitions for tracked objects.""" INITIALIZING = "initializing" READY = "ready" PROCESSING = "processing" CACHED = "cached" INVALID = "invalid" DESTROYED = "destroyed"
[docs] @dataclass class StateSnapshot: """Immutable snapshot of object state for consistency checking.""" object_id: int state: StateTransition checksum: str timestamp: float critical_attributes: dict[str, Any] = field(default_factory=dict) def __post_init__(self): if not self.timestamp: self.timestamp = time.time()
StateRecord = namedtuple("StateRecord", ["snapshot", "version", "last_validated"])
[docs] class AtomicCounter: """Thread-safe atomic counter using threading primitives."""
[docs] def __init__(self, initial_value: int = 0): self._value = initial_value self._lock = threading.Lock()
[docs] def increment(self) -> int: """Atomically increment and return new value.""" with self._lock: self._value += 1 return self._value
[docs] def get(self) -> int: """Get current value.""" with self._lock: return self._value
[docs] def set(self, value: int) -> None: """Set value atomically.""" with self._lock: self._value = value
[docs] class LockFreeStateValidator: """ Lock-free state consistency validator using atomic operations and weak references. Provides high-performance state tracking without blocking operations. Uses copy-on-write semantics for state snapshots and atomic version counters. """
[docs] def __init__( self, validation_level: StateValidationLevel = StateValidationLevel.STANDARD ): self.validation_level = validation_level # Use weak references to avoid memory leaks self._tracked_objects: WeakSet[Any] = WeakSet() # Lock-free state storage using atomic versioning self._state_records: dict[int, StateRecord] = {} self._version_counter = AtomicCounter() # State transition rules self._valid_transitions = self._init_transition_rules() # Performance metrics self._validation_count = AtomicCounter() self._inconsistency_count = AtomicCounter() self._performance_history: list[float] = [] # Background validation self._validation_active = False self._validation_thread: threading.Thread | None = None self._validation_interval = 60.0 # 1 minute
def _init_transition_rules(self) -> dict[StateTransition, set[StateTransition]]: """Initialize valid state transition rules.""" return { StateTransition.INITIALIZING: { StateTransition.READY, StateTransition.INVALID, StateTransition.DESTROYED, }, StateTransition.READY: { StateTransition.PROCESSING, StateTransition.CACHED, StateTransition.INVALID, StateTransition.DESTROYED, }, StateTransition.PROCESSING: { StateTransition.READY, StateTransition.CACHED, StateTransition.INVALID, StateTransition.DESTROYED, }, StateTransition.CACHED: { StateTransition.READY, StateTransition.PROCESSING, StateTransition.INVALID, StateTransition.DESTROYED, }, StateTransition.INVALID: {StateTransition.READY, StateTransition.DESTROYED}, StateTransition.DESTROYED: set(), # Terminal state }
[docs] def register_object( self, obj: Any, initial_state: StateTransition = StateTransition.INITIALIZING ) -> None: """Register object for state tracking using weak references.""" try: # Use weak reference to avoid circular references self._tracked_objects.add(obj) # Create initial state snapshot snapshot = self._create_state_snapshot(obj, initial_state) version = self._version_counter.increment() # Store state record (atomic operation) self._state_records[id(obj)] = StateRecord( snapshot=snapshot, version=version, last_validated=time.time() ) logger.debug( f"Registered object {id(obj)} for state tracking in state {initial_state.value}" ) except Exception as e: # Don't let registration failures break the application logger.debug(f"Failed to register object for state tracking: {e}")
[docs] def update_object_state( self, obj: Any, new_state: StateTransition, **critical_attributes ) -> bool: """Update object state with validation (lock-free).""" obj_id = id(obj) current_time = time.time() try: # Get current state record current_record = self._state_records.get(obj_id) if not current_record: logger.debug(f"Object {obj_id} not registered for state tracking") return False # Validate transition current_state = current_record.snapshot.state if new_state not in self._valid_transitions.get(current_state, set()): logger.warning( f"Invalid state transition for object {obj_id}: {current_state.value} -> {new_state.value}" ) self._inconsistency_count.increment() return False # Create new snapshot new_snapshot = self._create_state_snapshot( obj, new_state, critical_attributes ) new_version = self._version_counter.increment() # Atomic update using copy-on-write self._state_records[obj_id] = StateRecord( snapshot=new_snapshot, version=new_version, last_validated=current_time ) logger.debug( f"Updated object {obj_id} state: {current_state.value} -> {new_state.value}" ) return True except Exception as e: logger.debug(f"Error updating object state: {e}") return False
def _create_state_snapshot( self, obj: Any, state: StateTransition, critical_attributes: dict[str, Any] | None = None, ) -> StateSnapshot: """Create immutable state snapshot for consistency checking.""" if critical_attributes is None: critical_attributes = {} # Extract critical attributes based on object type extracted_attributes = self._extract_critical_attributes(obj) extracted_attributes.update(critical_attributes) # Calculate checksum for consistency verification checksum = self._calculate_state_checksum(obj, state, extracted_attributes) return StateSnapshot( object_id=id(obj), state=state, checksum=checksum, timestamp=time.time(), critical_attributes=extracted_attributes, ) def _extract_critical_attributes(self, obj: Any) -> dict[str, Any]: """Extract critical attributes based on object type.""" attributes = {} try: # Common XPCS object attributes if hasattr(obj, "fname"): attributes["filename"] = str(obj.fname) if hasattr(obj, "atype"): attributes["analysis_type"] = str(obj.atype) if hasattr(obj, "qmap") and obj.qmap is not None: # Hash qmap for consistency without storing large arrays if hasattr(obj.qmap, "q_values") and hasattr( obj.qmap.q_values, "shape" ): attributes["qmap_shape"] = obj.qmap.q_values.shape attributes["qmap_checksum"] = hashlib.md5( str( obj.qmap.q_values.data if hasattr(obj.qmap.q_values, "data") else obj.qmap.q_values ).encode(), usedforsecurity=False, ).hexdigest()[:8] # Data integrity checks if hasattr(obj, "saxs_2d_data") and obj.saxs_2d_data is not None: data = obj.saxs_2d_data if hasattr(data, "shape"): attributes["saxs_shape"] = data.shape # Quick checksum for large arrays if hasattr(data, "size") and data.size > 0: sample_indices = np.linspace( 0, data.size - 1, min(100, data.size), dtype=int ) sample_data = ( data.flat[sample_indices] if hasattr(data, "flat") else [0] ) attributes["saxs_checksum"] = hashlib.md5( str(sample_data).encode(), usedforsecurity=False ).hexdigest()[:8] if hasattr(obj, "fit_summary") and obj.fit_summary is not None: attributes["has_fit_summary"] = "True" if isinstance(obj.fit_summary, dict): attributes["fit_keys"] = str(sorted(obj.fit_summary.keys())) except Exception as e: # Don't let attribute extraction failures break validation logger.debug(f"Error extracting critical attributes: {e}") attributes["extraction_error"] = str(e) return attributes def _calculate_state_checksum( self, obj: Any, state: StateTransition, attributes: dict[str, Any] ) -> str: """Calculate checksum for state consistency verification.""" try: # Create deterministic checksum from object state checksum_data = { "object_type": type(obj).__name__, "state": state.value, "attributes": attributes, } checksum_string = str(sorted(checksum_data.items())) return hashlib.md5( checksum_string.encode(), usedforsecurity=False ).hexdigest()[:16] except Exception as e: logger.debug(f"Error calculating state checksum: {e}") return "error"
[docs] def validate_object_consistency(self, obj: Any) -> tuple[bool, list[str]]: """Validate object consistency against stored state (lock-free).""" obj_id = id(obj) validation_start = time.time() issues = [] try: self._validation_count.increment() # Get current state record record = self._state_records.get(obj_id) if not record: return True, [] # Not tracked, assume valid stored_snapshot = record.snapshot # Create current snapshot for comparison current_attributes = self._extract_critical_attributes(obj) current_checksum = self._calculate_state_checksum( obj, stored_snapshot.state, current_attributes ) # Check for inconsistencies based on validation level if self.validation_level in [ StateValidationLevel.STRICT, StateValidationLevel.PARANOID, ]: # Comprehensive validation # Checksum comparison if current_checksum != stored_snapshot.checksum: issues.append( f"State checksum mismatch: {current_checksum} != {stored_snapshot.checksum}" ) # Attribute comparison for key, stored_value in stored_snapshot.critical_attributes.items(): current_value = current_attributes.get(key) if current_value != stored_value: issues.append( f"Attribute '{key}' changed: {stored_value} -> {current_value}" ) # Type consistency if hasattr(obj, "__class__"): expected_type = stored_snapshot.critical_attributes.get( "object_type" ) if expected_type and obj.__class__.__name__ != expected_type: issues.append( f"Object type changed: {expected_type} -> {obj.__class__.__name__}" ) elif self.validation_level == StateValidationLevel.STANDARD: # Balanced validation - only check critical inconsistencies # Check for major structural changes stored_shape = stored_snapshot.critical_attributes.get("saxs_shape") current_shape = current_attributes.get("saxs_shape") if stored_shape and current_shape and stored_shape != current_shape: issues.append( f"SAXS data shape changed: {stored_shape} -> {current_shape}" ) # Check for filename changes (shouldn't happen) stored_filename = stored_snapshot.critical_attributes.get("filename") current_filename = current_attributes.get("filename") if ( stored_filename and current_filename and stored_filename != current_filename ): issues.append( f"Filename changed: {stored_filename} -> {current_filename}" ) # Record inconsistencies if issues: self._inconsistency_count.increment() logger.debug(f"Object {obj_id} consistency issues: {'; '.join(issues)}") # Update performance metrics validation_time = time.time() - validation_start self._performance_history.append(validation_time) if len(self._performance_history) > 1000: self._performance_history = self._performance_history[ -500: ] # Keep recent history return len(issues) == 0, issues except Exception as e: logger.debug(f"Error during consistency validation: {e}") return False, [f"Validation error: {e}"]
[docs] def validate_all_objects(self) -> dict[str, Any]: """Validate consistency of all tracked objects.""" start_time = time.time() results: dict[str, Any] = { "total_objects": 0, "valid_objects": 0, "invalid_objects": 0, "issues": [], "validation_time": 0.0, } # Clean up dead weak references first live_objects = [obj for obj in self._tracked_objects if obj is not None] for obj in live_objects: try: results["total_objects"] += 1 is_valid, issues = self.validate_object_consistency(obj) if is_valid: results["valid_objects"] += 1 else: results["invalid_objects"] += 1 results["issues"].extend( [f"Object {id(obj)}: {issue}" for issue in issues] ) except Exception as e: results["invalid_objects"] += 1 results["issues"].append(f"Object {id(obj)}: Validation error: {e}") results["validation_time"] = time.time() - start_time return results
[docs] def start_background_validation(self, interval: float = 60.0) -> None: """Start background consistency validation.""" import os # Skip starting background threads in test mode to prevent threading issues if os.environ.get("XPCS_TEST_MODE") == "1": return if self._validation_active: logger.debug("Background validation already active") return self._validation_active = True self._validation_interval = interval self._validation_thread = threading.Thread( target=self._background_validation_loop, name="XPCS-StateValidator", daemon=True, ) self._validation_thread.start() logger.info(f"Background state validation started (interval: {interval}s)")
[docs] def stop_background_validation(self) -> None: """Stop background consistency validation.""" if not self._validation_active: return self._validation_active = False if self._validation_thread and self._validation_thread.is_alive(): self._validation_thread.join(timeout=5.0) logger.info("Background state validation stopped")
def _background_validation_loop(self) -> None: """Background validation loop.""" logger.debug("Background state validation loop started") while self._validation_active: try: start_time = time.time() # Validate all objects results = self.validate_all_objects() # Log summary if issues found if results["invalid_objects"] > 0: logger.warning( f"State validation found {results['invalid_objects']} inconsistent objects " f"out of {results['total_objects']} total" ) # Log first few issues for debugging for issue in results["issues"][:5]: # Limit to avoid log spam logger.debug(f"State inconsistency: {issue}") # Sleep for remaining interval elapsed = time.time() - start_time sleep_time = max(0, self._validation_interval - elapsed) if sleep_time > 0: time.sleep(sleep_time) except Exception as e: logger.debug(f"Error in background validation: {e}") time.sleep(self._validation_interval) logger.debug("Background state validation loop ended")
[docs] def get_statistics(self) -> dict[str, Any]: """Get state validation statistics.""" total_validations = self._validation_count.get() total_inconsistencies = self._inconsistency_count.get() stats: dict[str, Any] = { "tracked_objects": len( [obj for obj in self._tracked_objects if obj is not None] ), "total_validations": total_validations, "total_inconsistencies": total_inconsistencies, "consistency_rate": (total_validations - total_inconsistencies) / max(total_validations, 1) * 100, "validation_level": self.validation_level.value, "background_validation_active": self._validation_active, } # Performance statistics if self._performance_history: stats["average_validation_time_ms"] = ( np.mean(self._performance_history) * 1000 ) stats["max_validation_time_ms"] = np.max(self._performance_history) * 1000 stats["validation_overhead_estimate"] = "< 0.1% CPU" return stats
[docs] def cleanup_destroyed_objects(self) -> int: """Clean up state records for destroyed objects.""" live_object_ids = {id(obj) for obj in self._tracked_objects if obj is not None} # Remove records for destroyed objects destroyed_ids = set(self._state_records.keys()) - live_object_ids for obj_id in destroyed_ids: del self._state_records[obj_id] cleaned_count = len(destroyed_ids) if cleaned_count > 0: logger.debug(f"Cleaned up {cleaned_count} destroyed object state records") return cleaned_count
# Global state validator instance _state_validator: LockFreeStateValidator | None = None _validator_lock = threading.Lock()
[docs] def get_state_validator( level: StateValidationLevel = StateValidationLevel.STANDARD, ) -> LockFreeStateValidator: """Get or create the global state validator instance.""" global _state_validator # noqa: PLW0603 - intentional singleton pattern if _state_validator is None: with _validator_lock: if _state_validator is None: _state_validator = LockFreeStateValidator(level) return _state_validator
[docs] def track_object_state( obj: Any, initial_state: StateTransition = StateTransition.INITIALIZING ) -> None: """Register object for state consistency tracking.""" validator = get_state_validator() validator.register_object(obj, initial_state)
[docs] def update_object_state( obj: Any, new_state: StateTransition, **critical_attributes ) -> bool: """Update tracked object state.""" validator = get_state_validator() return validator.update_object_state(obj, new_state, **critical_attributes)
[docs] def validate_object_state(obj: Any) -> tuple[bool, list[str]]: """Validate object state consistency.""" validator = get_state_validator() return validator.validate_object_consistency(obj)
[docs] def start_state_monitoring( interval: float = 60.0, level: StateValidationLevel = StateValidationLevel.STANDARD ) -> None: """Start background state consistency monitoring.""" validator = get_state_validator(level) validator.start_background_validation(interval)
[docs] def stop_state_monitoring() -> None: """Stop background state consistency monitoring.""" if _state_validator: _state_validator.stop_background_validation()
[docs] def get_state_statistics() -> dict[str, Any]: """Get state validation statistics.""" validator = get_state_validator() return validator.get_statistics()
# Decorator for automatic state tracking
[docs] def track_state(initial_state: StateTransition = StateTransition.INITIALIZING): """Decorator to automatically track object state.""" def decorator(cls): original_init = cls.__init__ def new_init(self, *args, **kwargs): original_init(self, *args, **kwargs) track_object_state(self, initial_state) cls.__init__ = new_init return cls return decorator
# Context manager for state validation
[docs] class state_validation_context: """Context manager for automatic state validation during operations."""
[docs] def __init__(self, obj: Any, expected_final_state: StateTransition): self.obj = obj self.expected_final_state = expected_final_state self.initial_valid = False self.initial_issues: list[str] = []
def __enter__(self): # Validate initial state self.initial_valid, self.initial_issues = validate_object_state(self.obj) if not self.initial_valid: logger.warning( f"Object entered context with state issues: {'; '.join(self.initial_issues)}" ) return self def __exit__(self, exc_type, exc_value, traceback): # Update to expected final state if no exception if exc_type is None: update_object_state(self.obj, self.expected_final_state) else: update_object_state(self.obj, StateTransition.INVALID) # Validate final state final_valid, final_issues = validate_object_state(self.obj) if not final_valid: logger.warning( f"Object exited context with state issues: {'; '.join(final_issues)}" ) return False # Don't suppress exceptions