Source code for xpcsviewer.utils.device

"""GPU detection and warning utilities (System CUDA version).

This module provides utilities for detecting GPU hardware, system CUDA installation,
and JAX backend configuration. It helps users identify when GPU acceleration is
available but not being used.

Example usage:
    >>> from xpcsviewer.utils.device import check_gpu_availability, get_device_info
    >>> check_gpu_availability()  # Prints warning if GPU available but not used
    >>> info = get_device_info()  # Get comprehensive device information
"""

from __future__ import annotations

import logging
import subprocess
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    pass

_logger = logging.getLogger(__name__)


[docs] def get_system_cuda_version() -> tuple[str | None, int | None]: """Detect system CUDA version from nvcc. Returns ------- tuple[str | None, int | None] Tuple of (full_version, major_version) or (None, None) if not found. Example: ("12.6", 12) or ("13.0", 13) Examples -------- >>> version, major = get_system_cuda_version() >>> if major is not None: ... print(f"CUDA {version} detected") """ try: result = subprocess.run( ["nvcc", "--version"], capture_output=True, text=True, timeout=5, ) if result.returncode == 0: # Parse "release X.Y" from output for line in result.stdout.split("\n"): if "release" in line.lower(): # Extract version like "12.6" from "release 12.6, V12.6.77" parts = line.split("release")[-1].strip() version = parts.split(",")[0].strip() major = int(version.split(".")[0]) return version, major except subprocess.TimeoutExpired: _logger.debug("nvcc timed out") except FileNotFoundError: _logger.debug("nvcc not found") except (ValueError, IndexError) as e: _logger.debug(f"Failed to parse CUDA version: {e}") except Exception as e: _logger.debug(f"CUDA detection failed: {e}") return None, None
[docs] def get_gpu_info() -> tuple[str | None, float | None]: """Detect GPU name and SM version. Returns ------- tuple[str | None, float | None] Tuple of (gpu_name, sm_version) or (None, None) if not found. Example: ("NVIDIA GeForce RTX 4090", 8.9) Examples -------- >>> name, sm = get_gpu_info() >>> if name is not None: ... print(f"GPU: {name} (SM {sm})") """ try: result = subprocess.run( ["nvidia-smi", "--query-gpu=name,compute_cap", "--format=csv,noheader"], capture_output=True, text=True, timeout=5, ) if result.returncode == 0 and result.stdout.strip(): line = result.stdout.strip().split("\n")[0] parts = line.split(", ") if len(parts) >= 2: gpu_name = parts[0] sm_version = float(parts[1]) return gpu_name, sm_version except subprocess.TimeoutExpired: _logger.debug("nvidia-smi timed out") except FileNotFoundError: _logger.debug("nvidia-smi not found") except (ValueError, IndexError) as e: _logger.debug(f"Failed to parse GPU info: {e}") except Exception as e: _logger.debug(f"GPU detection failed: {e}") return None, None
[docs] def check_plugin_conflicts() -> list[str]: """Check for known JAX CUDA plugin conflicts. Returns ------- list[str] List of issue descriptions (empty = no issues). Examples -------- >>> issues = check_plugin_conflicts() >>> for issue in issues: ... print(f"WARNING: {issue}") """ issues = [] try: import importlib.metadata as md jaxlib_v = md.version("jaxlib") cuda12 = cuda13 = None try: cuda12 = md.version("jax-cuda12-plugin") except md.PackageNotFoundError: pass try: cuda13 = md.version("jax-cuda13-plugin") except md.PackageNotFoundError: pass # Check for dual plugin conflict if cuda12 and cuda13: issues.append( f"Both cuda12 ({cuda12}) and cuda13 ({cuda13}) plugins installed. " "Only ONE can be active — this causes PJRT registration conflicts." ) # Check for version mismatch for name, version in [("cuda12", cuda12), ("cuda13", cuda13)]: if version and version != jaxlib_v: issues.append( f"jax-{name}-plugin {version} != jaxlib {jaxlib_v}. " "Plugin version must exactly match jaxlib." ) except Exception as e: _logger.debug(f"Plugin conflict check failed: {e}") return issues
[docs] def check_gpu_availability(warn: bool = True) -> bool: """Check if GPU is available but not being used by JAX. Prints a helpful warning if GPU hardware and system CUDA are detected but JAX is running in CPU-only mode. Parameters ---------- warn : bool, optional If True, print warning when GPU available but not used, by default True. Returns ------- bool True if GPU is being used by JAX, False otherwise. Examples -------- >>> if not check_gpu_availability(): ... print("Consider enabling GPU acceleration") """ try: gpu_name, sm_version = get_gpu_info() cuda_version, cuda_major = get_system_cuda_version() if gpu_name is None: _logger.debug("No GPU hardware detected") return False # Check if JAX is using GPU import jax devices = jax.devices() using_gpu = any("cuda" in str(d).lower() for d in devices) if using_gpu: _logger.debug(f"JAX is using GPU: {devices}") # Check for plugin issues even when GPU works issues = check_plugin_conflicts() for issue in issues: _logger.warning(f"Plugin issue: {issue}") return True # GPU available but not being used if warn: _print_gpu_warning(gpu_name, sm_version, cuda_version, cuda_major) return False except ImportError: _logger.debug("JAX not installed") return False except Exception as e: _logger.debug(f"GPU check failed: {e}") return False
def _print_gpu_warning( gpu_name: str, sm_version: float | None, cuda_version: str | None, cuda_major: int | None, ) -> None: """Print warning about GPU acceleration availability.""" print("\nGPU ACCELERATION AVAILABLE") print("===========================") print(f"GPU: {gpu_name} (SM {sm_version})") print(f"System CUDA: {cuda_version or 'Not found'}") print("JAX backend: CPU-only") issues = check_plugin_conflicts() if issues: print("\nIssues detected:") for issue in issues: print(f" - {issue}") print() if cuda_major is None: print("To enable GPU acceleration:") print(" 1. Install CUDA toolkit (12.x or 13.x)") print(" 2. Ensure nvcc is in PATH") print(" 3. Run: make install-jax-gpu") else: pkg = f"jax[cuda{cuda_major}-local]" print("Enable 20-100x speedup:") print(" make install-jax-gpu") print() print("Or manually:") print( " pip uninstall -y jax jaxlib " "jax-cuda13-plugin jax-cuda13-pjrt " "jax-cuda12-plugin jax-cuda12-pjrt" ) print(f' pip install "{pkg}"') print("\nSee README.rst for details.\n")
[docs] def get_device_info() -> dict: """Get comprehensive device information. Returns ------- dict Dictionary containing: - jax_version: JAX version string - jax_backend: Current backend (cpu, gpu) - devices: List of device strings - gpu_count: Number of GPU devices - using_gpu: Boolean - gpu_hardware: GPU name - gpu_sm_version: SM version (float) - system_cuda_version: System CUDA version string - system_cuda_major: System CUDA major version (int) - recommended_package: Recommended JAX package - plugin_issues: List of detected plugin conflict descriptions Examples -------- >>> info = get_device_info() >>> print(f"JAX backend: {info['jax_backend']}") >>> if info['using_gpu']: ... print(f"Using {info['gpu_count']} GPU(s)") """ info: dict = { "jax_version": None, "jax_backend": None, "devices": [], "gpu_count": 0, "using_gpu": False, "gpu_hardware": None, "gpu_sm_version": None, "system_cuda_version": None, "system_cuda_major": None, "recommended_package": None, "plugin_issues": [], } # JAX info try: import jax info["jax_version"] = jax.__version__ info["jax_backend"] = jax.default_backend() devices = jax.devices() info["devices"] = [str(d) for d in devices] info["gpu_count"] = sum(1 for d in devices if "cuda" in str(d).lower()) info["using_gpu"] = info["gpu_count"] > 0 except ImportError: pass # GPU hardware info gpu_name, sm_version = get_gpu_info() info["gpu_hardware"] = gpu_name info["gpu_sm_version"] = sm_version # System CUDA info cuda_version, cuda_major = get_system_cuda_version() info["system_cuda_version"] = cuda_version info["system_cuda_major"] = cuda_major # Recommended package and plugin health info["recommended_package"] = get_recommended_package() info["plugin_issues"] = check_plugin_conflicts() return info