File size: 2,920 Bytes
920b7a4 0ac10e4 920b7a4 0ac10e4 920b7a4 0ac10e4 920b7a4 0ac10e4 920b7a4 0ac10e4 920b7a4 0ac10e4 920b7a4 0ac10e4 920b7a4 0ac10e4 920b7a4 0ac10e4 920b7a4 0ac10e4 920b7a4 0ac10e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import torch
def show_device_list(backend: str) -> int:
"""
Displays a list of all detected devices for a given PyTorch backend.
Args:
backend: The name of the device backend module (e.g., "cuda", "xpu").
Returns:
The number of devices found if the backend is usable, otherwise 0.
"""
backend_upper = backend.upper()
try:
# Get the backend module from PyTorch, e.g., `torch.cuda`.
# NOTE: Backends always exist even if the user has no devices.
backend_module = getattr(torch, backend)
# Determine which vendor brand name to display.
brand_name = backend_upper
if backend == "cuda":
# NOTE: This also checks for PyTorch's official AMD ROCm support,
# since that's implemented inside the PyTorch CUDA APIs.
# SEE: https://docs.pytorch.org/docs/stable/cuda.html
brand_name = "NVIDIA CUDA / AMD ROCm"
elif backend == "xpu":
brand_name = "Intel XPU"
elif backend == "mps":
brand_name = "Apple MPS"
if not backend_module.is_available():
print(f"PyTorch: No devices found for {brand_name} backend.")
return 0
print(f"PyTorch: {brand_name} is available!")
# Show all available hardware acceleration devices.
device_count = backend_module.device_count()
print(f" * Number of {backend_upper} devices found: {device_count}")
# NOTE: Apple Silicon devices don't have `get_device_name()` at the
# moment, so we'll skip those since we can't get their device names.
# SEE: https://docs.pytorch.org/docs/stable/mps.html
if backend != "mps":
for i in range(device_count):
device_name = backend_module.get_device_name(i)
print(f' * Device {i}: "{device_name}"')
return device_count
except AttributeError:
print(
f'Error: The PyTorch backend "{backend}" does not exist, or is missing the necessary APIs (is_available, device_count, get_device_name).'
)
except Exception as e:
print(f"Error: {e}")
return 0
def check_torch_devices() -> None:
"""
Checks for the availability of various PyTorch hardware acceleration
platforms and prints information about the discovered devices.
"""
print("Scanning for PyTorch hardware acceleration devices...\n")
device_count = 0
device_count += show_device_list("cuda") # NVIDIA CUDA / AMD ROCm.
device_count += show_device_list("xpu") # Intel XPU.
device_count += show_device_list("mps") # Apple Metal Performance Shaders (MPS).
if device_count > 0:
print("\nHardware acceleration detected. Your system is ready!")
else:
print("\nNo hardware acceleration detected. Running in CPU mode.")
if __name__ == "__main__":
check_torch_devices()
|