"""Mask assembly classes for SimpleMask.
This module provides classes for creating, combining, and managing detector masks.
Supports various mask types including file-based, threshold-based, parameter-based,
and drawing-based masks.
Ported from pySimpleMask with modifications:
- Removed TIFF support for initial release (HDF5 only)
- Removed skimage.io dependency
- Uses backend abstraction for GPU acceleration when available.
"""
import logging
import os
from typing import Any
import h5py
import numpy as np
from xpcsviewer.backends import get_backend
from xpcsviewer.backends._conversions import ensure_numpy
logger = logging.getLogger(__name__)
# Angular periodicity constants for degree-based coordinates
ANGLE_MIN_DEG = -180
ANGLE_MAX_DEG = 180
[docs]
def create_qring(
qmin: float,
qmax: float,
pmin: float,
pmax: float,
qnum: int = 1,
flag_const_width: bool = True,
) -> list[tuple[float, float, float, float]]:
"""Create Q-ring definitions for mask parameter constraints.
Args:
qmin: Minimum Q value
qmax: Maximum Q value
pmin: Minimum phi value
pmax: Maximum phi value
qnum: Number of Q rings to generate
flag_const_width: If True, rings have constant width; if False, width scales
Returns:
List of (qlow, qhigh, pmin, pmax) tuples defining each ring
"""
qrings = []
if qmin > qmax:
qmin, qmax = qmax, qmin
qcen = (qmin + qmax) / 2.0
qhalf = (qmax - qmin) / 2.0
for n in range(1, qnum + 1):
if flag_const_width:
low = qcen * n - qhalf
high = qcen * n + qhalf
else:
low = (qcen - qhalf) * n
high = (qcen + qhalf) * n
qrings.append((low, high, pmin, pmax))
return qrings
[docs]
class MaskBase:
"""Base class for all mask types."""
[docs]
def __init__(self, shape: tuple[int, int] = (512, 1024)) -> None:
"""Initialize mask with given detector shape.
Args:
shape: Detector dimensions as (height, width)
"""
self.shape = shape
self.zero_loc: np.ndarray | None = None
self.mtype = "base"
self.qrings: list = []
[docs]
def evaluate(self, **kwargs: Any) -> None:
"""Evaluate mask (base method)."""
[docs]
def describe(self) -> str:
"""Return description of mask statistics."""
if self.zero_loc is None:
return "Mask is not initialized"
bad_num = len(self.zero_loc[0])
total_num = self.shape[0] * self.shape[1]
ratio = bad_num / total_num * 100.0
return f"{self.mtype}: bad_pixel: {bad_num}/{ratio:0.3f}%"
[docs]
def get_mask(self) -> np.ndarray:
"""Return the mask as a 2D boolean array.
Returns:
Boolean array where True = valid pixel, False = masked pixel
"""
mask = np.ones(self.shape, dtype=bool)
if self.zero_loc is not None:
mask[tuple(self.zero_loc)] = 0
return mask
[docs]
def combine_mask(self, mask: np.ndarray | None) -> np.ndarray:
"""Combine this mask with another mask.
Args:
mask: Existing mask to combine with, or None
Returns:
Combined mask array
"""
if logger.isEnabledFor(logging.DEBUG):
zero_count = 0 if self.zero_loc is None else self.zero_loc.shape[1]
logger.debug(f"combine_mask: {self.mtype} with {zero_count} masked pixels")
if self.zero_loc is not None:
if mask is None:
mask = self.get_mask()
else:
mask[tuple(self.zero_loc)] = 0
return mask if mask is not None else self.get_mask()
[docs]
class MaskList(MaskBase):
"""Mask defined by a list of pixel coordinates."""
[docs]
def __init__(self, shape: tuple[int, int] = (512, 1024)) -> None:
super().__init__(shape=shape)
self.mtype = "list"
self.xylist: np.ndarray | None = None
[docs]
def append_zero_pt(self, row: int, col: int) -> None:
"""Append a single pixel to the mask.
Args:
row: Row index
col: Column index
"""
if self.zero_loc is None:
self.zero_loc = np.array([[row], [col]])
else:
self.zero_loc = np.append(
self.zero_loc, np.array([row, col]).reshape(2, 1), axis=1
)
[docs]
def evaluate(
self,
zero_loc: np.ndarray | None = None,
**kwargs: Any,
) -> None:
"""Set the mask from a coordinate array.
Args:
zero_loc: Array of shape (2, N) with [rows, cols] of masked pixels
"""
self.zero_loc = zero_loc
[docs]
class MaskFile(MaskBase):
"""Mask loaded from an HDF5 file."""
[docs]
def __init__(
self,
shape: tuple[int, int] = (512, 1024),
fname: str | None = None,
**kwargs,
) -> None:
super().__init__(shape=shape)
self.mtype = "file"
[docs]
def evaluate(
self,
fname: str | None = None,
key: str | None = None,
**kwargs: Any,
) -> None:
"""Load mask from an HDF5 file.
Args:
fname: Path to HDF5 file
key: Dataset key within the HDF5 file containing the mask
"""
if fname is None or not os.path.isfile(fname):
self.zero_loc = None
return
_, ext = os.path.splitext(fname)
mask = None
if ext in [".hdf", ".h5", ".hdf5"]:
try:
with h5py.File(fname, "r") as f:
mask = f[key][()]
except Exception:
logger.error(f"Cannot read HDF file: {fname}, key: {key}")
else:
logger.error(f"MaskFile only supports HDF5 files. Found: {fname}")
if mask is None:
self.zero_loc = None
return
# Handle transposed masks
if mask.shape != self.shape:
mask = np.swapaxes(mask, 0, 1)
if mask.shape != self.shape:
logger.error(
f"Mask shape {mask.shape} doesn't match detector shape {self.shape}"
)
self.zero_loc = None
return
# Masked pixels have value <= 0
mask = mask <= 0
self.zero_loc = np.array(np.nonzero(mask))
[docs]
class MaskThreshold(MaskBase):
"""Mask based on intensity thresholds."""
[docs]
def __init__(self, shape: tuple[int, int] = (512, 1024)) -> None:
super().__init__(shape=shape)
self.mtype = "threshold"
[docs]
def evaluate(
self,
saxs_lin: np.ndarray | None = None,
low: float = 0,
high: float = 1e8,
low_enable: bool = True,
high_enable: bool = True,
**kwargs: Any,
) -> None:
"""Create mask based on intensity thresholds.
Uses backend abstraction for GPU acceleration when available.
Args:
saxs_lin: 2D intensity array
low: Lower threshold value
high: Upper threshold value
low_enable: Whether to apply lower threshold
high_enable: Whether to apply upper threshold
"""
if saxs_lin is None:
self.zero_loc = None
return
backend = get_backend()
data: Any = backend.array(saxs_lin)
mask: Any = backend.ones_like(data, dtype=bool)
if low_enable:
mask = mask * (data >= low)
if high_enable:
mask = mask * (data < high)
mask = backend.logical_not(mask)
# Convert to NumPy at I/O boundary
mask_np = ensure_numpy(mask)
self.zero_loc = np.array(np.nonzero(mask_np))
[docs]
class MaskParameter(MaskBase):
"""Mask based on Q-map parameter constraints (e.g., Q-rings)."""
[docs]
def __init__(self, shape: tuple[int, int] = (512, 1024)) -> None:
super().__init__(shape=shape)
self.mtype = "parameter"
self.constraints: list = []
[docs]
def evaluate(
self,
qmap: dict[str, np.ndarray] | None = None,
constraints: list[tuple[str, str, str, float, float]] | None = None,
**kwargs: Any,
) -> None:
"""Create mask based on Q-map constraints.
Uses backend abstraction for GPU acceleration when available.
Args:
qmap: Dictionary of Q-map arrays
constraints: List of (map_name, logic, unit, vbeg, vend) tuples
- map_name: Key in qmap dict (e.g., "q", "phi")
- logic: "AND" or "OR"
- unit: Unit string (e.g., "deg")
- vbeg, vend: Value range bounds
"""
if qmap is None or constraints is None:
self.zero_loc = None
return
backend = get_backend()
mask: Any = backend.ones(self.shape, dtype=bool)
for xmap_name, logic, unit, vbeg, vend in constraints:
xmap: Any = backend.array(qmap[xmap_name])
# Handle periodicity of angular coordinates
if xmap_name in ["phi", "chi", "alpha"] and unit == "deg":
xmap = backend.copy(xmap)
if vbeg <= ANGLE_MIN_DEG <= vend <= ANGLE_MAX_DEG:
# Replace in-place with where pattern
xmap = backend.where(xmap > vend, xmap - 360.0, xmap)
if vend > ANGLE_MAX_DEG and ANGLE_MIN_DEG <= vbeg <= ANGLE_MAX_DEG:
xmap = backend.where(xmap < vbeg, xmap + 360.0, xmap)
mask_t = (xmap >= vbeg) * (xmap <= vend)
if logic == "AND":
mask = backend.logical_and(mask, mask_t)
elif logic == "OR":
mask = backend.logical_or(mask, mask_t)
# Convert to NumPy at I/O boundary
mask_np = ensure_numpy(mask)
self.zero_loc = np.array(np.nonzero(~mask_np))
[docs]
class MaskArray(MaskBase):
"""Mask from a direct boolean/integer array."""
[docs]
def __init__(self, shape: tuple[int, int] = (512, 1024)) -> None:
super().__init__(shape=shape)
self.mtype = "array"
[docs]
def evaluate(self, arr: np.ndarray | None = None, **kwargs: Any) -> None:
"""Set mask from an array.
Args:
arr: Boolean or integer array where nonzero = masked
"""
if arr is not None:
self.zero_loc = np.array(np.nonzero(arr))
[docs]
class MaskAssemble:
"""Manager for combining multiple mask types with undo/redo support.
Maintains a history of mask states enabling undo, redo, and reset
operations. Individual mask workers (threshold, file, draw,
parameter, etc.) are evaluated independently and combined via
logical AND.
Attributes:
shape: Detector dimensions as ``(height, width)``.
workers: Dictionary of named mask worker instances.
mask_record: List of historical mask states.
mask_ptr: Index into *mask_record* for the current state.
"""
[docs]
def __init__(
self,
shape: tuple[int, int] = (128, 128),
saxs_lin: np.ndarray | None = None,
qmap: dict[str, np.ndarray] | None = None,
max_history: int = 50,
) -> None:
"""Initialize mask assembly with detector shape and optional data.
Args:
shape: Detector dimensions as (height, width)
saxs_lin: 2D intensity array for threshold masking
qmap: Q-map dictionary for parameter masking
max_history: Maximum number of undo steps to retain. Older entries are
dropped when the limit is reached to prevent unbounded memory growth
(a 2048×2048 boolean mask costs ~4 MB per step).
"""
self.workers: dict[str, Any] = {
"mask_blemish": MaskFile(shape),
"mask_file": MaskFile(shape),
"mask_threshold": MaskThreshold(shape),
"mask_list": MaskList(shape),
"mask_draw": MaskArray(shape),
"mask_outlier": MaskList(shape),
"mask_parameter": MaskParameter(shape),
}
self.shape = shape
self.saxs_lin = saxs_lin
self.qmap = qmap
self.max_history = max_history
# Initialize mask history for undo/redo
initial_mask = (
np.ones_like(saxs_lin, dtype=bool)
if saxs_lin is not None
else np.ones(shape, dtype=bool)
)
self.mask_record: list[np.ndarray] = [initial_mask]
self.mask_ptr = 0
# 0: no mask; 1: apply the default mask
self.mask_ptr_min = 0
[docs]
def update_qmap(self, qmap_all: dict[str, np.ndarray]) -> None:
"""Update the Q-map used for parameter masking.
Args:
qmap_all: Dictionary of Q-map arrays
"""
self.qmap = qmap_all
[docs]
def apply(self, target: str | None) -> np.ndarray:
"""Apply a mask type and add to history.
Args:
target: Mask type key (e.g., "mask_draw", "mask_threshold")
or None to return current mask
Returns:
Current combined mask array
"""
if target is None:
return self.get_mask()
new_layer = self.get_one_mask(target)
# Use read-only view for the current mask — avoids a 4MB memcpy.
# np.logical_and always allocates a fresh output array, so `combined`
# is already an independent copy that is safe to push to history.
current_ref = self._get_mask_ref()
combined = np.logical_and(current_ref, new_layer)
# Short-circuit change detection: compare nonzero counts first (integer
# comparison) before falling back to the full O(H*W) element scan.
last = self.mask_record[self.mask_ptr]
changed = (combined.sum() != last.sum()) or not np.array_equal(last, combined)
if changed:
# Remove any redo states beyond current pointer
while len(self.mask_record) > self.mask_ptr + 1:
self.mask_record.pop()
# combined is freshly allocated — push directly, no extra copy needed.
self.mask_record.append(combined)
self.mask_ptr += 1
# Enforce bounded history: drop the oldest entry (beyond mask_ptr_min
# anchor) when the limit is exceeded to prevent unbounded memory growth.
if len(self.mask_record) > self.max_history + self.mask_ptr_min:
self.mask_record.pop(self.mask_ptr_min)
self.mask_ptr -= 1
return combined
[docs]
def evaluate(self, target: str, **kwargs: Any) -> str:
"""Evaluate a mask type with given parameters.
Args:
target: Mask type key
**kwargs: Parameters for the mask evaluation
Returns:
Description string from the mask worker
"""
if target == "mask_threshold":
self.workers[target].evaluate(self.saxs_lin, **kwargs)
elif target == "mask_parameter":
self.workers[target].evaluate(qmap=self.qmap, **kwargs)
else:
self.workers[target].evaluate(**kwargs)
return self.workers[target].describe()
[docs]
def redo_undo(self, action: str = "redo") -> None:
"""Navigate mask history.
Args:
action: One of "undo", "redo", or "reset"
"""
if action == "undo":
if self.mask_ptr > self.mask_ptr_min:
self.mask_ptr -= 1
elif action == "redo":
if self.mask_ptr < len(self.mask_record) - 1:
self.mask_ptr += 1
elif action == "reset":
# Keep the default mask if one exists
while len(self.mask_record) > 1 + self.mask_ptr_min:
self.mask_record.pop()
self.mask_ptr = self.mask_ptr_min
[docs]
def get_one_mask(self, target: str) -> np.ndarray:
"""Get mask from a single worker.
Args:
target: Mask type key
Returns:
Mask array from the specified worker
"""
return self.workers[target].get_mask()
def _get_mask_ref(self) -> np.ndarray:
"""Return a read-only view of the current mask without copying.
For internal use only. Callers must not mutate the returned array.
The view is ~700x faster than get_mask() for large detectors because
no memory allocation occurs.
"""
arr = self.mask_record[self.mask_ptr]
arr.flags.writeable = False
return arr
[docs]
def get_mask(self) -> np.ndarray:
"""Get the current combined mask.
Returns:
A copy of the current mask from history. Safe for external callers
that may mutate the result. Use _get_mask_ref() internally where
the result is only read.
"""
return self.mask_record[self.mask_ptr].copy()
@property
def blemish(self) -> np.ndarray:
"""Get the blemish mask (pixels marked as bad in blemish file).
Returns:
Boolean array where True = valid pixel, False = blemish pixel
"""
return self.workers["mask_blemish"].get_mask()