"""Core kernel for SimpleMask operations.
This module provides the SimpleMaskKernel class that handles all mask operations,
Q-map computation, partition generation, and ROI management.
Adapted from pySimpleMask for integration with XPCS Viewer.
"""
import logging
from typing import Any, Literal
import h5py
import numpy as np
import pyqtgraph as pg
from pyqtgraph.Qt import QtCore
from xpcsviewer.backends._conversions import ensure_numpy
from xpcsviewer.fileIO.qmap_utils import Q_UNIT_DISPLAY
from xpcsviewer.simplemask.area_mask import MaskAssemble
from xpcsviewer.simplemask.pyqtgraph_mod import ImageViewROI, LineROI
from xpcsviewer.simplemask.qmap import compute_qmap
from xpcsviewer.simplemask.utils import (
check_consistency,
combine_partitions,
generate_partition,
hash_numpy_dict,
optimize_integer_array,
)
pg.setConfigOptions(imageAxisOrder="row-major")
logger = logging.getLogger(__name__)
# Compression threshold for HDF5 datasets (elements)
HDF5_COMPRESSION_THRESHOLD = 1024
# Version for partition file metadata
__version__ = "1.0.0"
[docs]
class SimpleMaskKernel:
"""Core kernel for SimpleMask operations.
This class manages mask creation, editing, Q-map computation,
and partition generation for XPCS analysis.
Attributes:
shape: Detector shape as (height, width)
detector_image: 2D array of detector data
mask: Current combined mask array
qmap: Dictionary of Q-map arrays
qmap_unit: Dictionary of Q-map units
metadata: Geometry metadata dictionary
"""
[docs]
def __init__(
self,
pg_hdl: ImageViewROI | None = None,
infobar: Any | None = None,
):
"""Initialize SimpleMask kernel.
Args:
pg_hdl: PyQtGraph ImageView handle for display
infobar: Status bar widget for messages
"""
self.detector_image: np.ndarray | None = None
self.shape: tuple[int, int] | None = None
self.qmap: dict[str, np.ndarray] | None = None
self.qmap_unit: dict[str, str] | None = None
self.mask: np.ndarray | None = None
self.mask_kernel: MaskAssemble | None = None
self.new_partition: dict | None = None
self.metadata: dict[str, Any] = {}
self.hdl = pg_hdl
self.infobar = infobar
if self.hdl is not None:
self.hdl.scene.sigMouseMoved.connect(self.show_location)
[docs]
def is_ready(self) -> bool:
"""Check if kernel has data loaded.
Returns:
True if detector image is loaded
"""
return self.detector_image is not None
[docs]
def read_data(
self,
detector_image: np.ndarray,
metadata: dict[str, Any],
) -> bool:
"""Load detector data and initialize mask.
Args:
detector_image: 2D array of detector counts
metadata: Dictionary with geometry parameters:
- bcx: Beam center X (column)
- bcy: Beam center Y (row)
- det_dist: Detector distance (mm)
- pix_dim: Pixel size (mm)
- energy: X-ray energy (keV)
- shape: Detector shape (height, width)
Returns:
True if successful
"""
self.detector_image = detector_image.copy()
self.shape = detector_image.shape
self.metadata = metadata.copy()
self.metadata["shape"] = self.shape
# Fill in defaults for missing geometry parameters so Q-map
# computation can proceed even when HDF5 metadata is incomplete.
# These match the spinbox defaults in simplemask_ui.py.
_geometry_defaults = {
"pix_dim": 0.075,
"energy": 10.0,
"det_dist": 5000.0,
"bcx": self.shape[1] / 2.0,
"bcy": self.shape[0] / 2.0,
}
for key, default in _geometry_defaults.items():
if self.metadata.get(key) is None:
logger.info(
f"Geometry parameter '{key}' missing from HDF5 metadata, "
f"using default: {default}"
)
self.metadata[key] = default
# Initialize mask
self.mask = np.ones(self.shape, dtype=bool)
# Compute Q-map
stype = self.metadata.get("stype", "Transmission")
self.qmap, self.qmap_unit = compute_qmap(stype, self.metadata)
# Initialize mask assembly
self.mask_kernel = MaskAssemble(self.shape, self.detector_image)
self.mask_kernel.update_qmap(self.qmap)
return True
[docs]
def compute_qmap(self) -> tuple[dict[str, np.ndarray], dict[str, str]]:
"""Recompute Q-map from current metadata.
Returns:
Tuple of (qmap_dict, units_dict)
"""
stype = self.metadata.get("stype", "Transmission")
logger.debug(f"compute_qmap: stype={stype}, shape={self.shape}")
self.qmap, self.qmap_unit = compute_qmap(stype, self.metadata)
if self.mask_kernel is not None:
self.mask_kernel.update_qmap(self.qmap)
if logger.isEnabledFor(logging.DEBUG) and self.qmap:
sqmap = self.qmap.get("sqmap")
if sqmap is not None:
logger.debug(f"compute_qmap: result sqmap shape={sqmap.shape}")
return self.qmap, self.qmap_unit
[docs]
def mask_evaluate(self, target: str, **kwargs) -> str:
"""Evaluate a mask type without applying.
Args:
target: Mask type key
**kwargs: Parameters for mask evaluation
Returns:
Description string from the mask worker
"""
if self.mask_kernel is None:
return "Mask kernel not initialized"
return self.mask_kernel.evaluate(target, **kwargs)
[docs]
def mask_action(self, action: Literal["undo", "redo", "reset"] = "undo") -> None:
"""Execute undo/redo/reset on mask history.
Args:
action: One of "undo", "redo", or "reset"
"""
if self.mask_kernel is None:
return
self.mask_kernel.redo_undo(action=action)
self.mask_apply()
[docs]
def mask_apply(self, target: str | None = None) -> np.ndarray:
"""Apply a mask type and update current mask.
Args:
target: Mask type to apply, or None to get current mask
Returns:
Current combined mask array
"""
if self.mask_kernel is None or self.mask is None:
return self.mask if self.mask is not None else np.zeros((0, 0), dtype=bool)
self.mask = self.mask_kernel.apply(target)
return self.mask
[docs]
def save_mask(self, save_name: str) -> None:
"""Save mask to HDF5 file.
Args:
save_name: Output file path (.h5/.hdf5)
"""
if self.mask is None:
logger.warning("No mask to save")
return
mask = self.mask.astype(np.uint8)
try:
with h5py.File(save_name, "w") as hf:
hf.create_dataset("mask", data=mask, compression="lzf")
hf.attrs["shape"] = self.shape
hf.attrs["version"] = __version__
except OSError as e:
logger.error(f"Failed to save mask to {save_name}: {e}")
raise
[docs]
def load_mask(self, fname: str, key: str = "mask") -> bool:
"""Load mask from HDF5 file.
Args:
fname: Path to HDF5 file
key: Dataset key for mask data
Returns:
True if successful
"""
if self.mask_kernel is None:
return False
self.mask_kernel.evaluate("mask_file", fname=fname, key=key)
self.mask_apply("mask_file")
return True
[docs]
def save_partition(self, save_fname: str, root: str = "/qmap") -> None:
"""Save partition to HDF5 file.
Args:
save_fname: Output file path (.h5/.hdf5)
root: HDF5 group path for partition data
"""
if self.new_partition is None:
logger.warning("No partition to save")
return
# Optimize integer arrays
for key, val in self.new_partition.items():
self.new_partition[key] = optimize_integer_array(val)
hash_val = hash_numpy_dict(self.new_partition)
logger.info(f"Hash value of the partition: {hash_val}")
def optimize_save(group_handle, key, val):
compression = (
"lzf"
if isinstance(val, np.ndarray) and val.size > HDF5_COMPRESSION_THRESHOLD
else None
)
return group_handle.create_dataset(key, data=val, compression=compression)
with h5py.File(save_fname, "w") as hf:
if root in hf:
del hf[root]
group_handle = hf.create_group(root)
for key, val in self.new_partition.items():
dset = optimize_save(group_handle, key, val)
if "_v_list_dim" in key:
dim = int(key[-1])
dset.attrs["unit"] = self.new_partition["map_units"][dim]
dset.attrs["name"] = self.new_partition["map_names"][dim]
dset.attrs["size"] = val.size
group_handle.attrs["hash"] = hash_val
group_handle.attrs["version"] = __version__
[docs]
def show_location(self, pos) -> None:
"""Display pixel coordinates and Q values at mouse position.
Args:
pos: Mouse scene position
"""
if self.hdl is None or self.shape is None:
return
if not self.hdl.scene.itemsBoundingRect().contains(pos):
return
mouse_point = self.hdl.getView().mapSceneToView(pos)
col = int(mouse_point.x())
row = int(mouse_point.y())
if 0 <= row < self.shape[0] and 0 <= col < self.shape[1]:
msg = f"x={col}, y={row}"
q_val = (
self.qmap.get("q", self.qmap.get("sqmap"))
if self.qmap is not None
else None
)
if self.qmap is not None and q_val is not None:
msg += f", q={q_val[row, col]:.4f} {Q_UNIT_DISPLAY}"
if self.detector_image is not None:
intensity = self.detector_image[row, col]
msg += f", I={intensity:.1f}"
if self.infobar is not None:
self.infobar.clearMessage()
self.infobar.showMessage(msg)
[docs]
def show_saxs(
self,
cmap: str = "jet",
log: bool = True,
plot_center: bool = True,
) -> None:
"""Display detector image with optional beam center marker.
Args:
cmap: Matplotlib colormap name
log: Whether to use log scale (unused, kept for API compatibility)
plot_center: Whether to display beam center marker
"""
if self.detector_image is None or self.hdl is None:
return
self.hdl.clear()
center = self.get_center()
# Create display data (original + mask overlay)
# Convert to NumPy to guard against JAX arrays at PyQtGraph boundary (BUG-027)
data = self.detector_image.copy()
self.hdl.setImage(ensure_numpy(data))
self.hdl.adjust_viewbox()
self.hdl.set_colormap(cmap)
if plot_center and center[0] is not None:
t = pg.ScatterPlotItem()
t.addPoints(x=[center[0]], y=[center[1]], symbol="+", size=15)
self.hdl.add_item(t, label="center")
[docs]
def apply_drawing(self) -> np.ndarray:
"""Apply all pending drawing ROIs to create a mask.
Returns:
Mask array from drawing operations
"""
if self.detector_image is None or self.hdl is None or self.shape is None:
return np.ones(self.shape, dtype=bool) if self.shape else np.array([])
shape = self.shape
ones = np.ones((shape[0] + 1, shape[1] + 1), dtype=bool)
mask_e = np.zeros_like(ones, dtype=bool)
mask_i = np.zeros_like(mask_e)
for k, x in self.hdl.roi_items.items():
if not k.startswith("roi_"):
continue
mask_temp = np.zeros_like(ones, dtype=bool)
sl, _ = x.getArraySlice(self.detector_image, self.hdl.imageItem)
y = x.getArrayRegion(ones, self.hdl.imageItem)
nz_idx = np.nonzero(y)
if len(nz_idx[0]) == 0:
continue
h_beg = np.min(nz_idx[1])
h_end = np.max(nz_idx[1]) + 1
v_beg = np.min(nz_idx[0])
v_end = np.max(nz_idx[0]) + 1
sl_v = slice(sl[0].start, sl[0].start + v_end - v_beg)
sl_h = slice(sl[1].start, sl[1].start + h_end - h_beg)
mask_temp[sl_v, sl_h] = y[v_beg:v_end, h_beg:h_end]
if hasattr(x, "sl_mode"):
if x.sl_mode == "exclusive":
mask_e[mask_temp] = 1
elif x.sl_mode == "inclusive":
mask_i[mask_temp] = 1
self.hdl.remove_rois(filter_str="roi_")
if np.sum(mask_i) == 0:
mask_i = np.ones_like(mask_e, dtype=bool)
mask_p = np.logical_not(mask_e) * mask_i
mask_p = mask_p[:-1, :-1]
return mask_p
[docs]
def add_drawing(
self,
sl_type: Literal[
"Rectangle", "Circle", "Ellipse", "Polygon", "Line"
] = "Polygon",
sl_mode: Literal["exclusive", "inclusive"] = "exclusive",
num_edges: int | None = None,
radius: float = 60,
color: str = "r",
width: int = 3,
second_point: tuple[float, float] | None = None,
label: str | None = None,
movable: bool = True,
) -> pg.ROI | None:
"""Add a drawing ROI to the view.
Args:
sl_type: Shape type - Rectangle, Circle, Ellipse, Polygon, or Line
sl_mode: "exclusive" masks pixels, "inclusive" preserves pixels
num_edges: Number of edges for Polygon (random 6-10 if None)
radius: Default radius for Circle
color: Pen color
width: Pen width
second_point: End point for Line ROI
label: ROI label (auto-generated if None)
movable: Whether ROI can be moved
Returns:
Created ROI object, or None if creation failed
"""
if self.hdl is None or self.shape is None:
return None
if label is not None and label in self.hdl.roi_items:
self.hdl.remove_item(label)
cen = self.get_center()
if cen[0] is None or cen[1] is None or cen[0] < 0 or cen[1] < 0:
cen = (self.shape[1] // 2, self.shape[0] // 2)
elif cen[0] > self.shape[1] or cen[1] > self.shape[0]:
logger.warning("Beam center out of range, using image center")
cen_x: float = float(self.shape[1] // 2)
cen_y: float = float(self.shape[0] // 2)
cen = (cen_x, cen_y)
if sl_mode == "inclusive":
pen = pg.mkPen(color=color, width=width, style=QtCore.Qt.DotLine)
else:
pen = pg.mkPen(color=color, width=width)
handle_pen = pg.mkPen(color=color, width=width)
kwargs = {
"pen": pen,
"removable": True,
"hoverPen": pen,
"handlePen": handle_pen,
"movable": movable,
}
new_roi: pg.ROI | None = None
if sl_type == "Ellipse":
new_roi = pg.EllipseROI(cen, [60, 80], **kwargs)
# Midpoint handles (4)
new_roi.addScaleHandle([0.5, 0], [0.5, 1])
new_roi.addScaleHandle([0.5, 1], [0.5, 0])
new_roi.addScaleHandle([0, 0.5], [1, 0.5])
new_roi.addScaleHandle([1, 0.5], [0, 0.5])
# Corner handles (4) - positions adjusted for ellipse geometry
new_roi.addScaleHandle([0.1464, 0.1464], [1, 1])
new_roi.addScaleHandle([0.1464, 0.8536], [1, 0])
new_roi.addScaleHandle([0.8536, 0.1464], [0, 1])
new_roi.addScaleHandle([0.8536, 0.8536], [0, 0])
elif sl_type == "Circle":
cx, cy = cen
radius = 10.0
if second_point is not None and cx is not None and cy is not None:
radius = np.sqrt(
(second_point[1] - cy) ** 2 + (second_point[0] - cx) ** 2
)
if cx is not None and cy is not None:
new_roi = pg.CircleROI(
pos=[cx - radius, cy - radius], radius=radius, **kwargs
)
# Add 2 opposite-side handles for uniform scaling
new_roi.addScaleHandle([0.5, 0], [0.5, 0.5])
new_roi.addScaleHandle([0.5, 1], [0.5, 0.5])
else:
return None
elif sl_type == "Polygon":
if num_edges is None:
num_edges = np.random.randint(6, 11)
offset = np.random.randint(0, 360)
theta = np.linspace(0, np.pi * 2, num_edges + 1) + np.deg2rad(offset)
x = radius * np.cos(theta) + (cen[0] if cen[0] is not None else 0)
y = radius * np.sin(theta) + (cen[1] if cen[1] is not None else 0)
pts = np.vstack([x, y]).T
new_roi = pg.PolyLineROI(pts, closed=True, **kwargs)
elif sl_type == "Rectangle":
new_roi = pg.RectROI(cen, [200, 150], **kwargs)
# Corner handles (4)
new_roi.addScaleHandle([0, 0], [1, 1]) # bottom-left
new_roi.addScaleHandle([0, 1], [1, 0]) # top-left
new_roi.addScaleHandle([1, 0], [0, 1]) # bottom-right
new_roi.addScaleHandle([1, 1], [0, 0]) # top-right
# Midpoint handles (4)
new_roi.addScaleHandle([0, 0.5], [1, 0.5]) # left-mid
new_roi.addScaleHandle([1, 0.5], [0, 0.5]) # right-mid
new_roi.addScaleHandle([0.5, 0], [0.5, 1]) # bottom-mid
new_roi.addScaleHandle([0.5, 1], [0.5, 0]) # top-mid
elif sl_type == "Line":
if second_point is None or cen[0] is None or cen[1] is None:
return None
line_width = kwargs.pop("width", 1) if "width" in kwargs else 1
# cast to float tuple for LineROI
cen_f = (float(cen[0]), float(cen[1]))
new_roi = LineROI(cen_f, second_point, line_width, **kwargs)
else:
raise TypeError(f"ROI type not implemented: {sl_type}")
if new_roi is not None:
new_roi.sl_mode = sl_mode
roi_key = self.hdl.add_item(new_roi, label)
new_roi.sigRemoveRequested.connect(lambda: self.remove_roi(roi_key))
return new_roi
[docs]
def remove_roi(self, roi_key: str) -> None:
"""Remove an ROI by key.
Args:
roi_key: Label of the ROI to remove
"""
if self.hdl is not None:
self.hdl.remove_item(roi_key)
[docs]
def compute_partition(
self,
mode: str = "q-phi",
dq_num: int = 10,
sq_num: int = 100,
dp_num: int = 36,
sp_num: int = 360,
style: str = "linear",
phi_offset: float = 0.0,
symmetry_fold: int = 1,
) -> dict | None:
"""Compute Q-space partition for XPCS analysis.
Args:
mode: Partition mode (e.g., "q-phi", "x-y")
dq_num: Number of dynamic Q bins
sq_num: Number of static Q bins
dp_num: Number of dynamic phi bins
sp_num: Number of static phi bins
style: Binning style - "linear" or "logarithmic"
phi_offset: Phi angle offset
symmetry_fold: Symmetry fold for phi
Returns:
Partition dictionary or None if no data
"""
if self.detector_image is None or self.qmap is None or self.mask is None:
return None
map_names = mode.split("-")
logger.info(f"Computing partition with mode {mode}: map_names {map_names}")
name0, name1 = map_names
# Generate dynamic partition
pack_dq = generate_partition(
name0, self.mask, self.qmap[name0], dq_num, style=style
)
pack_dp = generate_partition(
name1,
self.mask,
self.qmap[name1],
dp_num,
style=style,
phi_offset=phi_offset,
symmetry_fold=symmetry_fold,
)
dynamic_map = combine_partitions(pack_dq, pack_dp, prefix="dynamic")
# Generate static partition
pack_sq = generate_partition(
name0, self.mask, self.qmap[name0], sq_num, style=style
)
pack_sp = generate_partition(
name1,
self.mask,
self.qmap[name1],
sp_num,
style=style,
phi_offset=phi_offset,
symmetry_fold=symmetry_fold,
)
static_map = combine_partitions(pack_sq, pack_sp, prefix="static")
# Check consistency
d_map = ensure_numpy(dynamic_map["dynamic_roi_map"])
s_map = ensure_numpy(static_map["static_roi_map"])
flag_consistency = check_consistency(d_map, s_map, self.mask)
logger.info(f"dqmap/sqmap consistency check: {flag_consistency}")
center = self.get_center()
# Warn when using default physical constants
_defaults = {"pix_dim": 0.075, "energy": 10.0, "det_dist": 5000.0}
for _key, _default in _defaults.items():
if _key not in self.metadata or self.metadata[_key] is None:
logger.warning(
f"Using default {_key}={_default} — no value found in metadata"
)
partition = {
"beam_center_x": center[0],
"beam_center_y": center[1],
"pixel_size": self.metadata.get("pix_dim", 0.075),
"mask": self.mask,
"blemish": self.mask_kernel.blemish if self.mask_kernel else self.mask,
"energy": self.metadata.get("energy", 10.0),
"detector_distance": self.metadata.get("det_dist", 5000.0),
"map_names": list(map_names),
"map_units": [
self.qmap_unit[name0] if self.qmap_unit else "",
self.qmap_unit[name1] if self.qmap_unit else "",
],
"source_file": self.metadata.get("source_file", ""),
}
partition.update(dynamic_map)
partition.update(static_map)
self.new_partition = partition
return partition
[docs]
def update_parameters(self, new_metadata: dict[str, Any]) -> None:
"""Update geometry parameters and recompute Q-map.
Args:
new_metadata: Dictionary with updated geometry values
"""
self.metadata.update(new_metadata)
self.compute_qmap()
[docs]
def get_center(
self, mode: Literal["xy", "vh"] = "xy"
) -> tuple[float | None, float | None]:
"""Get beam center coordinates.
Args:
mode: "xy" for (x, y) = (column, row), "vh" for (vertical, horizontal)
Returns:
Beam center coordinates
"""
if not self.metadata:
return (None, None)
bcx = self.metadata.get("bcx")
bcy = self.metadata.get("bcy")
if bcx is None or bcy is None:
return (None, None)
if mode == "vh":
return (bcy, bcx)
return (bcx, bcy)