Source code for xpcsviewer.backends._conversions
"""Array conversion utilities for I/O boundaries.
This module provides functions for converting between NumPy arrays and
backend-specific arrays at I/O boundaries (file I/O, visualization, etc.).
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import numpy as np
if TYPE_CHECKING:
from numpy.typing import ArrayLike
from ._base import BackendProtocol
[docs]
def ensure_numpy(array: ArrayLike) -> np.ndarray:
"""Convert any array type to NumPy for I/O boundaries.
Use this function at boundaries where NumPy arrays are required:
- HDF5 file I/O (h5py)
- PyQtGraph visualization
- Matplotlib plotting
- Pandas DataFrame operations
Parameters
----------
array : array-like
JAX array, NumPy array, list, or any array-like object
Returns
-------
np.ndarray
NumPy array
Examples
--------
>>> import jax.numpy as jnp
>>> x = jnp.array([1, 2, 3])
>>> np_x = ensure_numpy(x)
>>> isinstance(np_x, np.ndarray)
True
>>> # Works with lists too
>>> ensure_numpy([1, 2, 3])
array([1, 2, 3])
"""
# Fast path for NumPy arrays - ensure it's writable
if isinstance(array, np.ndarray):
if not array.flags.writeable:
return np.array(array) # Force copy for read-only arrays
return array
# Check for JAX arrays - must copy to ensure writeable numpy array
try:
import jax.numpy as jnp
if isinstance(array, jnp.ndarray):
return np.array(array) # Use np.array to ensure copy
except ImportError:
pass
# Check for arrays with __array__ method (covers most array-like objects)
if hasattr(array, "__array__"):
result = np.asarray(array)
# Ensure result is writable
if not result.flags.writeable:
return np.array(result)
return result
# Final fallback: convert via np.array
return np.array(array)
[docs]
def ensure_backend_array(
array: ArrayLike, backend: BackendProtocol | None = None
) -> Any:
"""Convert array to the backend's array type.
Use this function when receiving external data that needs to be
converted to the current backend's array format.
Parameters
----------
array : array-like
NumPy array, JAX array, list, or any array-like object
backend : BackendProtocol, optional
Backend to convert to. If None, uses the current backend.
Returns
-------
ArrayType
Array in the backend's native format
Examples
--------
>>> from xpcsviewer.backends import get_backend, ensure_backend_array
>>> import numpy as np
>>> x = np.array([1, 2, 3])
>>> backend = get_backend()
>>> bx = ensure_backend_array(x, backend)
"""
if backend is None:
from . import get_backend
backend = get_backend()
return backend.from_numpy(ensure_numpy(array))
[docs]
def is_jax_array(array: Any) -> bool:
"""Check if array is a JAX array.
Parameters
----------
array : Any
Object to check
Returns
-------
bool
True if array is a JAX array
"""
try:
import jax.numpy as jnp
return isinstance(array, jnp.ndarray)
except ImportError:
return False
[docs]
def is_numpy_array(array: Any) -> bool:
"""Check if array is a NumPy array.
Parameters
----------
array : Any
Object to check
Returns
-------
bool
True if array is a NumPy array
"""
return isinstance(array, np.ndarray)
[docs]
def get_array_backend(array: Any) -> str:
"""Determine which backend an array belongs to.
Parameters
----------
array : Any
Array to check
Returns
-------
str
'numpy', 'jax', or 'unknown'
"""
if is_numpy_array(array):
return "numpy"
if is_jax_array(array):
return "jax"
return "unknown"
[docs]
def arrays_compatible(a: Any, b: Any) -> bool:
"""Check if two arrays are from the same backend.
Parameters
----------
a, b : Any
Arrays to compare
Returns
-------
bool
True if both arrays are from the same backend
"""
return get_array_backend(a) == get_array_backend(b)