Source code for xpcsviewer.simplemask.area_mask

"""Mask assembly classes for SimpleMask.

This module provides classes for creating, combining, and managing detector masks.
Supports various mask types including file-based, threshold-based, parameter-based,
and drawing-based masks.

Ported from pySimpleMask with modifications:
- Removed TIFF support for initial release (HDF5 only)
- Removed skimage.io dependency
- Uses backend abstraction for GPU acceleration when available.
"""

import logging
import os
from typing import Any

import h5py
import numpy as np

from xpcsviewer.backends import get_backend
from xpcsviewer.backends._conversions import ensure_numpy

logger = logging.getLogger(__name__)

# Angular periodicity constants for degree-based coordinates
ANGLE_MIN_DEG = -180
ANGLE_MAX_DEG = 180


[docs] def create_qring( qmin: float, qmax: float, pmin: float, pmax: float, qnum: int = 1, flag_const_width: bool = True, ) -> list[tuple[float, float, float, float]]: """Create Q-ring definitions for mask parameter constraints. Args: qmin: Minimum Q value qmax: Maximum Q value pmin: Minimum phi value pmax: Maximum phi value qnum: Number of Q rings to generate flag_const_width: If True, rings have constant width; if False, width scales Returns: List of (qlow, qhigh, pmin, pmax) tuples defining each ring """ qrings = [] if qmin > qmax: qmin, qmax = qmax, qmin qcen = (qmin + qmax) / 2.0 qhalf = (qmax - qmin) / 2.0 for n in range(1, qnum + 1): if flag_const_width: low = qcen * n - qhalf high = qcen * n + qhalf else: low = (qcen - qhalf) * n high = (qcen + qhalf) * n qrings.append((low, high, pmin, pmax)) return qrings
[docs] class MaskBase: """Base class for all mask types."""
[docs] def __init__(self, shape: tuple[int, int] = (512, 1024)) -> None: """Initialize mask with given detector shape. Args: shape: Detector dimensions as (height, width) """ self.shape = shape self.zero_loc: np.ndarray | None = None self.mtype = "base" self.qrings: list = []
[docs] def evaluate(self, **kwargs: Any) -> None: """Evaluate mask (base method)."""
[docs] def describe(self) -> str: """Return description of mask statistics.""" if self.zero_loc is None: return "Mask is not initialized" bad_num = len(self.zero_loc[0]) total_num = self.shape[0] * self.shape[1] ratio = bad_num / total_num * 100.0 return f"{self.mtype}: bad_pixel: {bad_num}/{ratio:0.3f}%"
[docs] def get_mask(self) -> np.ndarray: """Return the mask as a 2D boolean array. Returns: Boolean array where True = valid pixel, False = masked pixel """ mask = np.ones(self.shape, dtype=bool) if self.zero_loc is not None: mask[tuple(self.zero_loc)] = 0 return mask
[docs] def combine_mask(self, mask: np.ndarray | None) -> np.ndarray: """Combine this mask with another mask. Args: mask: Existing mask to combine with, or None Returns: Combined mask array """ if logger.isEnabledFor(logging.DEBUG): zero_count = 0 if self.zero_loc is None else self.zero_loc.shape[1] logger.debug(f"combine_mask: {self.mtype} with {zero_count} masked pixels") if self.zero_loc is not None: if mask is None: mask = self.get_mask() else: mask[tuple(self.zero_loc)] = 0 return mask if mask is not None else self.get_mask()
[docs] class MaskList(MaskBase): """Mask defined by a list of pixel coordinates."""
[docs] def __init__(self, shape: tuple[int, int] = (512, 1024)) -> None: super().__init__(shape=shape) self.mtype = "list" self.xylist: np.ndarray | None = None
[docs] def append_zero_pt(self, row: int, col: int) -> None: """Append a single pixel to the mask. Args: row: Row index col: Column index """ if self.zero_loc is None: self.zero_loc = np.array([[row], [col]]) else: self.zero_loc = np.append( self.zero_loc, np.array([row, col]).reshape(2, 1), axis=1 )
[docs] def evaluate( self, zero_loc: np.ndarray | None = None, **kwargs: Any, ) -> None: """Set the mask from a coordinate array. Args: zero_loc: Array of shape (2, N) with [rows, cols] of masked pixels """ self.zero_loc = zero_loc
[docs] class MaskFile(MaskBase): """Mask loaded from an HDF5 file."""
[docs] def __init__( self, shape: tuple[int, int] = (512, 1024), fname: str | None = None, **kwargs, ) -> None: super().__init__(shape=shape) self.mtype = "file"
[docs] def evaluate( self, fname: str | None = None, key: str | None = None, **kwargs: Any, ) -> None: """Load mask from an HDF5 file. Args: fname: Path to HDF5 file key: Dataset key within the HDF5 file containing the mask """ if fname is None or not os.path.isfile(fname): self.zero_loc = None return _, ext = os.path.splitext(fname) mask = None if ext in [".hdf", ".h5", ".hdf5"]: try: with h5py.File(fname, "r") as f: mask = f[key][()] except Exception: logger.error(f"Cannot read HDF file: {fname}, key: {key}") else: logger.error(f"MaskFile only supports HDF5 files. Found: {fname}") if mask is None: self.zero_loc = None return # Handle transposed masks if mask.shape != self.shape: mask = np.swapaxes(mask, 0, 1) if mask.shape != self.shape: logger.error( f"Mask shape {mask.shape} doesn't match detector shape {self.shape}" ) self.zero_loc = None return # Masked pixels have value <= 0 mask = mask <= 0 self.zero_loc = np.array(np.nonzero(mask))
[docs] class MaskThreshold(MaskBase): """Mask based on intensity thresholds."""
[docs] def __init__(self, shape: tuple[int, int] = (512, 1024)) -> None: super().__init__(shape=shape) self.mtype = "threshold"
[docs] def evaluate( self, saxs_lin: np.ndarray | None = None, low: float = 0, high: float = 1e8, low_enable: bool = True, high_enable: bool = True, **kwargs: Any, ) -> None: """Create mask based on intensity thresholds. Uses backend abstraction for GPU acceleration when available. Args: saxs_lin: 2D intensity array low: Lower threshold value high: Upper threshold value low_enable: Whether to apply lower threshold high_enable: Whether to apply upper threshold """ if saxs_lin is None: self.zero_loc = None return backend = get_backend() data: Any = backend.array(saxs_lin) mask: Any = backend.ones_like(data, dtype=bool) if low_enable: mask = mask * (data >= low) if high_enable: mask = mask * (data < high) mask = backend.logical_not(mask) # Convert to NumPy at I/O boundary mask_np = ensure_numpy(mask) self.zero_loc = np.array(np.nonzero(mask_np))
[docs] class MaskParameter(MaskBase): """Mask based on Q-map parameter constraints (e.g., Q-rings)."""
[docs] def __init__(self, shape: tuple[int, int] = (512, 1024)) -> None: super().__init__(shape=shape) self.mtype = "parameter" self.constraints: list = []
[docs] def evaluate( self, qmap: dict[str, np.ndarray] | None = None, constraints: list[tuple[str, str, str, float, float]] | None = None, **kwargs: Any, ) -> None: """Create mask based on Q-map constraints. Uses backend abstraction for GPU acceleration when available. Args: qmap: Dictionary of Q-map arrays constraints: List of (map_name, logic, unit, vbeg, vend) tuples - map_name: Key in qmap dict (e.g., "q", "phi") - logic: "AND" or "OR" - unit: Unit string (e.g., "deg") - vbeg, vend: Value range bounds """ if qmap is None or constraints is None: self.zero_loc = None return backend = get_backend() mask: Any = backend.ones(self.shape, dtype=bool) for xmap_name, logic, unit, vbeg, vend in constraints: xmap: Any = backend.array(qmap[xmap_name]) # Handle periodicity of angular coordinates if xmap_name in ["phi", "chi", "alpha"] and unit == "deg": xmap = backend.copy(xmap) if vbeg <= ANGLE_MIN_DEG <= vend <= ANGLE_MAX_DEG: # Replace in-place with where pattern xmap = backend.where(xmap > vend, xmap - 360.0, xmap) if vend > ANGLE_MAX_DEG and ANGLE_MIN_DEG <= vbeg <= ANGLE_MAX_DEG: xmap = backend.where(xmap < vbeg, xmap + 360.0, xmap) mask_t = (xmap >= vbeg) * (xmap <= vend) if logic == "AND": mask = backend.logical_and(mask, mask_t) elif logic == "OR": mask = backend.logical_or(mask, mask_t) # Convert to NumPy at I/O boundary mask_np = ensure_numpy(mask) self.zero_loc = np.array(np.nonzero(~mask_np))
[docs] class MaskArray(MaskBase): """Mask from a direct boolean/integer array."""
[docs] def __init__(self, shape: tuple[int, int] = (512, 1024)) -> None: super().__init__(shape=shape) self.mtype = "array"
[docs] def evaluate(self, arr: np.ndarray | None = None, **kwargs: Any) -> None: """Set mask from an array. Args: arr: Boolean or integer array where nonzero = masked """ if arr is not None: self.zero_loc = np.array(np.nonzero(arr))
[docs] class MaskAssemble: """Manager for combining multiple mask types with undo/redo support. Maintains a history of mask states enabling undo, redo, and reset operations. Individual mask workers (threshold, file, draw, parameter, etc.) are evaluated independently and combined via logical AND. Attributes: shape: Detector dimensions as ``(height, width)``. workers: Dictionary of named mask worker instances. mask_record: List of historical mask states. mask_ptr: Index into *mask_record* for the current state. """
[docs] def __init__( self, shape: tuple[int, int] = (128, 128), saxs_lin: np.ndarray | None = None, qmap: dict[str, np.ndarray] | None = None, max_history: int = 50, ) -> None: """Initialize mask assembly with detector shape and optional data. Args: shape: Detector dimensions as (height, width) saxs_lin: 2D intensity array for threshold masking qmap: Q-map dictionary for parameter masking max_history: Maximum number of undo steps to retain. Older entries are dropped when the limit is reached to prevent unbounded memory growth (a 2048×2048 boolean mask costs ~4 MB per step). """ self.workers: dict[str, Any] = { "mask_blemish": MaskFile(shape), "mask_file": MaskFile(shape), "mask_threshold": MaskThreshold(shape), "mask_list": MaskList(shape), "mask_draw": MaskArray(shape), "mask_outlier": MaskList(shape), "mask_parameter": MaskParameter(shape), } self.shape = shape self.saxs_lin = saxs_lin self.qmap = qmap self.max_history = max_history # Initialize mask history for undo/redo initial_mask = ( np.ones_like(saxs_lin, dtype=bool) if saxs_lin is not None else np.ones(shape, dtype=bool) ) self.mask_record: list[np.ndarray] = [initial_mask] self.mask_ptr = 0 # 0: no mask; 1: apply the default mask self.mask_ptr_min = 0
[docs] def update_qmap(self, qmap_all: dict[str, np.ndarray]) -> None: """Update the Q-map used for parameter masking. Args: qmap_all: Dictionary of Q-map arrays """ self.qmap = qmap_all
[docs] def apply(self, target: str | None) -> np.ndarray: """Apply a mask type and add to history. Args: target: Mask type key (e.g., "mask_draw", "mask_threshold") or None to return current mask Returns: Current combined mask array """ if target is None: return self.get_mask() new_layer = self.get_one_mask(target) # Use read-only view for the current mask — avoids a 4MB memcpy. # np.logical_and always allocates a fresh output array, so `combined` # is already an independent copy that is safe to push to history. current_ref = self._get_mask_ref() combined = np.logical_and(current_ref, new_layer) # Short-circuit change detection: compare nonzero counts first (integer # comparison) before falling back to the full O(H*W) element scan. last = self.mask_record[self.mask_ptr] changed = (combined.sum() != last.sum()) or not np.array_equal(last, combined) if changed: # Remove any redo states beyond current pointer while len(self.mask_record) > self.mask_ptr + 1: self.mask_record.pop() # combined is freshly allocated — push directly, no extra copy needed. self.mask_record.append(combined) self.mask_ptr += 1 # Enforce bounded history: drop the oldest entry (beyond mask_ptr_min # anchor) when the limit is exceeded to prevent unbounded memory growth. if len(self.mask_record) > self.max_history + self.mask_ptr_min: self.mask_record.pop(self.mask_ptr_min) self.mask_ptr -= 1 return combined
[docs] def evaluate(self, target: str, **kwargs: Any) -> str: """Evaluate a mask type with given parameters. Args: target: Mask type key **kwargs: Parameters for the mask evaluation Returns: Description string from the mask worker """ if target == "mask_threshold": self.workers[target].evaluate(self.saxs_lin, **kwargs) elif target == "mask_parameter": self.workers[target].evaluate(qmap=self.qmap, **kwargs) else: self.workers[target].evaluate(**kwargs) return self.workers[target].describe()
[docs] def redo_undo(self, action: str = "redo") -> None: """Navigate mask history. Args: action: One of "undo", "redo", or "reset" """ if action == "undo": if self.mask_ptr > self.mask_ptr_min: self.mask_ptr -= 1 elif action == "redo": if self.mask_ptr < len(self.mask_record) - 1: self.mask_ptr += 1 elif action == "reset": # Keep the default mask if one exists while len(self.mask_record) > 1 + self.mask_ptr_min: self.mask_record.pop() self.mask_ptr = self.mask_ptr_min
[docs] def get_one_mask(self, target: str) -> np.ndarray: """Get mask from a single worker. Args: target: Mask type key Returns: Mask array from the specified worker """ return self.workers[target].get_mask()
def _get_mask_ref(self) -> np.ndarray: """Return a read-only view of the current mask without copying. For internal use only. Callers must not mutate the returned array. The view is ~700x faster than get_mask() for large detectors because no memory allocation occurs. """ arr = self.mask_record[self.mask_ptr] arr.flags.writeable = False return arr
[docs] def get_mask(self) -> np.ndarray: """Get the current combined mask. Returns: A copy of the current mask from history. Safe for external callers that may mutate the result. Use _get_mask_ref() internally where the result is only read. """ return self.mask_record[self.mask_ptr].copy()
@property def blemish(self) -> np.ndarray: """Get the blemish mask (pixels marked as bad in blemish file). Returns: Boolean array where True = valid pixel, False = blemish pixel """ return self.workers["mask_blemish"].get_mask()