Source code for xpcsviewer.simplemask.utils

"""Utility functions for SimpleMask partitioning.

This module provides functions for generating Q-space partitions
and combining partition maps for XPCS analysis.

Uses backend abstraction for GPU acceleration when available.
Ported from pySimpleMask with backend abstraction for JAX support.
"""

from __future__ import annotations

import hashlib
import json
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

import numpy as np

from xpcsviewer.backends import get_backend
from xpcsviewer.backends._conversions import ensure_numpy

if TYPE_CHECKING:
    from numpy.typing import NDArray

logger = logging.getLogger(__name__)

# JIT cache for compiled functions (JAX arrays are not hashable for lru_cache)
_PARTITION_JIT_CACHE: dict[str, Callable] = {}


[docs] def hash_numpy_dict(input_dictionary: dict[str, Any]) -> str: """Compute a stable SHA256 hash for a dictionary containing NumPy arrays. Args: input_dictionary: Dictionary with NumPy arrays and other values Returns: SHA256 hash string of the dictionary contents """ hasher = hashlib.sha256() for key in sorted(input_dictionary.keys()): hasher.update(str(key).encode()) value = input_dictionary[key] if isinstance(value, np.ndarray): # Ensure consistent dtype & memory layout value = np.ascontiguousarray(value) hasher.update(value.astype(value.dtype.newbyteorder("=")).tobytes()) elif isinstance(value, list): hasher.update(json.dumps(value, sort_keys=True).encode()) else: hasher.update(json.dumps(value, sort_keys=True).encode()) return hasher.hexdigest()
[docs] def optimize_integer_array(arr: np.ndarray) -> np.ndarray: """Optimize the data type of an integer array to minimize memory. Args: arr: NumPy array of integers Returns: Array with optimized integer dtype, or original if not applicable """ if not isinstance(arr, np.ndarray) or arr.size == 0: return arr if not np.issubdtype(arr.dtype, np.integer): return arr min_val, max_val = arr.min(), arr.max() # Choose smallest dtype based on min/max new_dtype: Any = arr.dtype if min_val >= 0: if max_val <= np.iinfo(np.uint8).max: new_dtype = np.uint8 elif max_val <= np.iinfo(np.uint16).max: new_dtype = np.uint16 elif max_val <= np.iinfo(np.uint32).max: new_dtype = np.uint32 else: new_dtype = np.uint64 elif min_val >= np.iinfo(np.int8).min and max_val <= np.iinfo(np.int8).max: new_dtype = np.int8 elif min_val >= np.iinfo(np.int16).min and max_val <= np.iinfo(np.int16).max: new_dtype = np.int16 elif min_val >= np.iinfo(np.int32).min and max_val <= np.iinfo(np.int32).max: new_dtype = np.int32 else: new_dtype = np.int64 return arr.astype(new_dtype) if new_dtype != arr.dtype else arr
def _get_partition_linear_jit(): """Get or create JIT-compiled linear partition function. Returns a JIT-compiled function if JAX backend is active, otherwise returns None. Note: num_pts must be passed as a Python int (static argument) because JAX's linspace requires a concrete value for the number of points. """ global _PARTITION_JIT_CACHE backend = get_backend() if backend.name != "jax": return None cache_key = "partition_linear_jit" if cache_key not in _PARTITION_JIT_CACHE: import jax import jax.numpy as jnp @jax.jit def _partition_linear_core( mask_b: Any, xmap_b: Any, v_min: Any, v_max: Any, v_span: Any ) -> tuple[Any, Any]: """JIT-compiled linear partition core computation. Note: v_span is pre-computed outside JIT since linspace needs a concrete value for num_pts. """ num_pts = v_span.shape[0] - 1 v_list = (v_span[1:] + v_span[:-1]) / 2.0 # Round to 12 decimals to avoid IEEE 754 bin-edge misassignment xmap_b = jnp.round(xmap_b, decimals=12) v_span = jnp.round(v_span, decimals=12) v_max = jnp.round(v_max, decimals=12) # Digitize: find bin indices partition = jnp.digitize(xmap_b, v_span) * mask_b partition = jnp.where(partition > num_pts, 0, partition) partition = jnp.where((xmap_b == v_max) * mask_b, num_pts, partition) return partition, v_list _PARTITION_JIT_CACHE[cache_key] = _partition_linear_core logger.debug("Created JIT-compiled linear partition function") return _PARTITION_JIT_CACHE[cache_key] def _get_partition_log_jit(): """Get or create JIT-compiled logarithmic partition function. Returns a JIT-compiled function if JAX backend is active, otherwise returns None. Note: num_pts must be passed as a Python int (static argument) because JAX's logspace requires a concrete value for the number of points. """ global _PARTITION_JIT_CACHE backend = get_backend() if backend.name != "jax": return None cache_key = "partition_log_jit" if cache_key not in _PARTITION_JIT_CACHE: import jax import jax.numpy as jnp @jax.jit def _partition_log_core( mask_b: Any, xmap_b: Any, v_min: Any, v_max: Any, v_span: Any ) -> tuple[Any, Any]: """JIT-compiled logarithmic partition core computation. Note: v_span is pre-computed outside JIT since logspace needs a concrete value for num_pts. """ num_pts = v_span.shape[0] - 1 v_list = jnp.sqrt(v_span[1:] * v_span[:-1]) # Round to 12 decimals to avoid IEEE 754 bin-edge misassignment xmap_b = jnp.round(xmap_b, decimals=12) v_span = jnp.round(v_span, decimals=12) v_max = jnp.round(v_max, decimals=12) # Digitize: find bin indices partition = jnp.digitize(xmap_b, v_span) * mask_b partition = jnp.where(partition > num_pts, 0, partition) partition = jnp.where((xmap_b == v_max) * mask_b, num_pts, partition) return partition, v_list _PARTITION_JIT_CACHE[cache_key] = _partition_log_core logger.debug("Created JIT-compiled logarithmic partition function") return _PARTITION_JIT_CACHE[cache_key] def _get_phi_transform_jit(): """Get or create JIT-compiled phi angle transformation function. Returns a JIT-compiled function if JAX backend is active, otherwise returns None. """ global _PARTITION_JIT_CACHE backend = get_backend() if backend.name != "jax": return None cache_key = "phi_transform_jit" if cache_key not in _PARTITION_JIT_CACHE: import jax import jax.numpy as jnp @jax.jit def _phi_transform_core(xmap_b, phi_offset, symmetry_fold): """JIT-compiled phi angle transformation.""" # Apply phi offset angle_rad = jnp.deg2rad(xmap_b + phi_offset) xmap_transformed = jnp.rad2deg( jnp.arctan2(jnp.sin(angle_rad), jnp.cos(angle_rad)) ) # Apply symmetry folding unit_xmap = (xmap_transformed < (360.0 / symmetry_fold)) * ( xmap_transformed >= 0 ) xmap_folded = xmap_transformed + 180.0 xmap_folded = jnp.mod(xmap_folded, 360.0 / symmetry_fold) return xmap_folded, unit_xmap _PARTITION_JIT_CACHE[cache_key] = _phi_transform_core logger.debug("Created JIT-compiled phi transform function") return _PARTITION_JIT_CACHE[cache_key] def _generate_partition_backend( map_name: str, mask: np.ndarray, xmap: np.ndarray, num_pts: int, style: str = "linear", phi_offset: float | None = None, symmetry_fold: int = 1, ) -> dict[str, str | int | np.ndarray]: """Backend-accelerated partition generation. Uses the backend abstraction layer for GPU acceleration when available. JIT compilation is applied for repeated calls with JAX backend. """ backend = get_backend() # Convert inputs to backend arrays mask_b: Any = backend.array(mask) xmap_b: Any = backend.array(xmap) xmap_phi = None unit_xmap = None if map_name == "phi": xmap_phi = xmap_b if phi_offset is not None: # xmap = np.rad2deg(np.angle(np.exp(1j * np.deg2rad(xmap + phi_offset)))) angle_rad = backend.deg2rad(xmap_b + phi_offset) # Complex exponential angle extraction xmap_b = backend.rad2deg( backend.arctan2(backend.sin(angle_rad), backend.cos(angle_rad)) ) if symmetry_fold > 1: unit_xmap = (xmap_b < (360 / symmetry_fold)) * (xmap_b >= 0) xmap_b = xmap_b + 180.0 xmap_b = backend.mod(xmap_b, 360.0 / symmetry_fold) roi = mask_b > 0 # Use where to extract valid values for min/max computation valid_values = backend.where(roi, xmap_b, backend.array(float("nan"))) v_min = backend.nanmin(valid_values) v_max = backend.nanmax(valid_values) # Try to use JIT-compiled versions for JAX backend if map_name == "q" and style == "logarithmic": mask_b = mask_b * (xmap_b > 0) valid_xmap = backend.where(mask_b > 0, xmap_b, backend.array(float("nan"))) v_min_check = backend.nanmin(valid_xmap) if backend.isnan(v_min_check) or float(v_min_check) <= 0: raise ValueError( "Invalid xmap values for logarithmic binning. All values are non-positive." ) v_min = backend.nanmin(valid_xmap) xmap_b = backend.where(xmap_b > 0, xmap_b, backend.array(float("nan"))) # Pre-compute v_span (needs concrete num_pts value) v_span: Any = backend.logspace( backend.log10(backend.array(v_min)), backend.log10(backend.array(v_max)), num_pts + 1, ) # Try JIT-compiled logarithmic partition jit_fn = _get_partition_log_jit() if jit_fn is not None: import jax.numpy as jnp mask_jax = jnp.asarray(mask_b) xmap_jax = jnp.asarray(xmap_b) v_span_jax = jnp.asarray(v_span) partition, v_list = jit_fn( mask_jax, xmap_jax, jnp.asarray(v_min), jnp.asarray(v_max), v_span_jax ) else: # Non-JIT path v_list = backend.sqrt(v_span[1:] * v_span[:-1]) # Round to 12 decimals to avoid IEEE 754 bin-edge misassignment xmap_b = backend.round(xmap_b, decimals=12) v_span = backend.round(v_span, decimals=12) v_max_r: Any = backend.round(backend.array(v_max), decimals=12) partition = backend.digitize(xmap_b, v_span) * mask_b partition = backend.where(partition > num_pts, backend.array(0), partition) partition = backend.where( (xmap_b == v_max_r) * mask_b, backend.array(num_pts), partition ) else: # Pre-compute v_span (needs concrete num_pts value) v_span = backend.linspace(v_min, v_max, num_pts + 1) # Try JIT-compiled linear partition jit_fn = _get_partition_linear_jit() if jit_fn is not None: import jax.numpy as jnp mask_jax = jnp.asarray(mask_b) xmap_jax = jnp.asarray(xmap_b) v_span_jax = jnp.asarray(v_span) partition, v_list = jit_fn( mask_jax, xmap_jax, jnp.asarray(v_min), jnp.asarray(v_max), v_span_jax ) else: # Non-JIT path v_list = (v_span[1:] + v_span[:-1]) / 2.0 # Round to 12 decimals to avoid IEEE 754 bin-edge misassignment xmap_b = backend.round(xmap_b, decimals=12) v_span = backend.round(v_span, decimals=12) v_max_r = backend.round(backend.array(v_max), decimals=12) partition = backend.digitize(xmap_b, v_span) * mask_b partition = backend.where(partition > num_pts, backend.array(0), partition) partition = backend.where( (xmap_b == v_max_r) * mask_b, backend.array(num_pts), partition ) # Convert to NumPy for output (I/O boundary) # Cast to Any to satisfy static check if needed partition_np = ensure_numpy(partition).astype(np.uint32) v_list_np = ensure_numpy(v_list) if map_name == "phi" and symmetry_fold > 1 and unit_xmap is not None: # Use NumPy for bincount (complex operation) unit_xmap_np = ensure_numpy(unit_xmap) partition_np_i64 = partition_np.astype(np.int64) assert xmap_phi is not None xmap_phi_np = ensure_numpy(xmap_phi) idx_map = (unit_xmap_np * partition_np_i64).astype(np.int64) sum_value = np.bincount(idx_map.flatten(), weights=xmap_phi_np.flatten()) norm_factor = np.bincount(idx_map.flatten()) # Use a large float max instead of None for numpy < 1.17 compatibility if needed, # but clip accepts None for max in recent numpy. Mypy complains about None. v_list_np = sum_value / np.clip(norm_factor, 1, np.inf) v_list_np = v_list_np[1:] return { "map_name": map_name, "num_pts": num_pts, "partition": partition_np, "v_list": v_list_np, }
[docs] def generate_partition( map_name: str, mask: np.ndarray, xmap: np.ndarray, num_pts: int, style: str = "linear", phi_offset: float | None = None, symmetry_fold: int = 1, ) -> dict[str, str | int | np.ndarray]: """Generate a partition map for X-ray scattering analysis. Args: map_name: Name of the map ("q", "phi", "x", "y") mask: 2D boolean mask array (True = valid) xmap: 2D array of values to partition num_pts: Number of partition bins style: Binning style - "linear" or "logarithmic" phi_offset: Offset for phi angle (only for phi map) symmetry_fold: Symmetry fold for phi partitioning Returns: Dictionary with keys: - map_name: Name of the partition - num_pts: Number of bins - partition: 2D array of bin labels (1-indexed, 0=masked) - v_list: Array of bin center values """ return _generate_partition_backend( map_name, mask, xmap, num_pts, style, phi_offset, symmetry_fold )
def _combine_partitions_backend( pack1: dict[str, str | int | np.ndarray], pack2: dict[str, str | int | np.ndarray], prefix: str = "dynamic", ) -> dict[str, list | np.ndarray]: """Backend-accelerated partition combination. Uses the backend abstraction layer for GPU acceleration when available. """ backend = get_backend() p1: NDArray[Any] = backend.array(pack1["partition"].astype(np.int64)) # type: ignore p2: NDArray[Any] = backend.array(pack2["partition"].astype(np.int64)) # type: ignore num_pts2 = int(pack2["num_pts"]) # Convert to zero-based indexing, merge, convert back partition = (p1 - 1) * num_pts2 + (p2 - 1) + 1 # Clip to ensure non-negative partition = backend.clip(partition, 0, np.inf) # Convert to NumPy for unique operation (complex) partition_np = ensure_numpy(partition).astype(np.int64) start_index = np.min(partition_np) unique_idx, inverse = np.unique(partition_np, return_inverse=True) partition_natural_order = inverse.reshape(partition_np.shape).astype(np.uint32) # Shift if needed to preserve masked pixel indicator (0) if start_index > 0: partition_natural_order += 1 return { f"{prefix}_num_pts": [pack1["num_pts"], pack2["num_pts"]], f"{prefix}_roi_map": partition_natural_order, f"{prefix}_v_list_dim0": ensure_numpy(pack1["v_list"]), f"{prefix}_v_list_dim1": ensure_numpy(pack2["v_list"]), f"{prefix}_index_mapping": unique_idx[unique_idx >= 1] - 1, }
[docs] def combine_partitions( pack1: dict[str, str | int | np.ndarray], pack2: dict[str, str | int | np.ndarray], prefix: str = "dynamic", ) -> dict[str, list | np.ndarray]: """Combine two partition maps into a single partition space. Args: pack1: First partition dictionary (e.g., Q partition) pack2: Second partition dictionary (e.g., phi partition) prefix: Prefix for output keys ("dynamic" or "static") Returns: Dictionary with combined partition: - {prefix}_num_pts: [num_pts1, num_pts2] - {prefix}_roi_map: Combined 2D partition array - {prefix}_v_list_dim0: Bin centers for first dimension - {prefix}_v_list_dim1: Bin centers for second dimension - {prefix}_index_mapping: Unique partition indices """ return _combine_partitions_backend(pack1, pack2, prefix)
[docs] def check_consistency(dqmap: np.ndarray, sqmap: np.ndarray, mask: np.ndarray) -> bool: """Check consistency between dynamic and static Q-maps. Ensures each unique value in sqmap corresponds to only one unique value in dqmap. Args: dqmap: Dynamic Q-map (coarse bins) sqmap: Static Q-map (fine bins) mask: Boolean mask array Returns: True if maps are consistent, False otherwise Raises: ValueError: If array shapes don't match """ if dqmap.shape != sqmap.shape: raise ValueError("dqmap and sqmap must have the same shape") if dqmap.shape != mask.shape: raise ValueError("dqmap and mask must have the same shape") if not np.all((mask > 0) == (dqmap > 0)): return False if not np.all((mask > 0) == (sqmap > 0)): return False sq_flat = sqmap.ravel() dq_flat = dqmap.ravel() sq_to_dq: dict[int, int] = {} for sq_value, dq_value in zip(sq_flat, dq_flat, strict=False): if sq_value in sq_to_dq: if sq_to_dq[sq_value] != dq_value: return False else: sq_to_dq[sq_value] = dq_value return True
[docs] def create_partition( qmap: np.ndarray, mask: np.ndarray, n_bins: int = 36, spacing: str = "linear", ) -> dict[str, str | int | np.ndarray]: """Create a Q-space partition from a Q-map. Convenience wrapper around generate_partition for simple Q-binning. Args: qmap: 2D array of Q-values (momentum transfer) mask: 2D boolean mask array (True = valid) n_bins: Number of partition bins (default 36) spacing: Binning style - "linear" or "log" (default "linear") Returns: Dictionary with keys: - map_name: "q" - num_pts: Number of bins - partition: 2D array of bin labels (1-indexed, 0=masked) - v_list: Array of bin center Q-values """ # Map "log" to "logarithmic" for internal function style = "logarithmic" if spacing == "log" else spacing return generate_partition( map_name="q", mask=mask, xmap=qmap, num_pts=n_bins, style=style, )