Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,573 Bytes
bc9c638 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 |
#!/usr/bin/env python3
"""
GPU Detection Diagnostic Tool for ACE-Step
This script helps diagnose GPU detection issues by checking:
- PyTorch installation and build type (CUDA/ROCm/CPU)
- GPU availability and properties
- Environment variables
- Common configuration issues
Usage:
python scripts/check_gpu.py
"""
import os
import sys
import subprocess
# Constants
HEADER_WIDTH = 80
PYTORCH_CUDA_INSTALL_URL = "https://download.pytorch.org/whl/cu121"
PYTORCH_ROCM_INSTALL_URL = "https://download.pytorch.org/whl/rocm6.0"
def print_section(title):
"""Print a section header."""
print(f"\n{'=' * HEADER_WIDTH}")
print(f" {title}")
print('=' * HEADER_WIDTH)
def check_pytorch():
"""Check PyTorch installation and build type."""
print_section("PyTorch Installation")
try:
import torch
print(f"β PyTorch installed: {torch.__version__}")
# Check build type
is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None
is_cuda = hasattr(torch.version, 'cuda') and torch.version.cuda is not None
if is_rocm:
print(f"β Build type: ROCm (HIP {torch.version.hip})")
elif is_cuda:
print(f"β Build type: CUDA {torch.version.cuda}")
else:
print("β οΈ Build type: CPU-only")
print("\nβ You have installed a CPU-only version of PyTorch!")
print("\nTo enable GPU support:")
print(" For NVIDIA GPUs:")
print(f" pip install torch --index-url {PYTORCH_CUDA_INSTALL_URL}")
print("\n For AMD GPUs with ROCm:")
print(" Windows: See requirements-rocm.txt")
print(f" Linux: pip install torch --index-url {PYTORCH_ROCM_INSTALL_URL}")
return False
return True
except ImportError:
print("β PyTorch not installed")
print("\nPlease install PyTorch first:")
print(" pip install torch")
return False
def check_cuda_availability():
"""Check CUDA/ROCm availability."""
print_section("GPU Availability Check")
try:
import torch
is_available = torch.cuda.is_available()
print(f"torch.cuda.is_available(): {is_available}")
if is_available:
print(f"β GPU detected!")
device_count = torch.cuda.device_count()
print(f" Number of GPUs: {device_count}")
for i in range(device_count):
device_name = torch.cuda.get_device_name(i)
props = torch.cuda.get_device_properties(i)
memory_gb = props.total_memory / (1024**3)
print(f"\n GPU {i}: {device_name}")
print(f" Total memory: {memory_gb:.2f} GB")
print(f" Compute capability: {props.major}.{props.minor}")
return True
else:
print("β No GPU detected by PyTorch")
return False
except Exception as e:
print(f"β Error checking GPU availability: {e}")
return False
def check_rocm_setup():
"""Check ROCm-specific setup for AMD GPUs."""
print_section("ROCm Configuration (AMD GPUs)")
try:
import torch
is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None
if not is_rocm:
print("Skipping - not a ROCm build")
return
print("Checking ROCm environment variables:")
# Check HSA_OVERRIDE_GFX_VERSION
hsa_override = os.environ.get('HSA_OVERRIDE_GFX_VERSION')
if hsa_override:
print(f" β HSA_OVERRIDE_GFX_VERSION = {hsa_override}")
else:
print(" β οΈ HSA_OVERRIDE_GFX_VERSION not set")
print("\n This variable is required for many AMD GPUs!")
print(" Set it according to your GPU:")
print(" RX 7900 XT/XTX, RX 9070 XT: HSA_OVERRIDE_GFX_VERSION=11.0.0")
print(" RX 7800 XT, RX 7700 XT: HSA_OVERRIDE_GFX_VERSION=11.0.1")
print(" RX 7600: HSA_OVERRIDE_GFX_VERSION=11.0.2")
print(" RX 6000 series: HSA_OVERRIDE_GFX_VERSION=10.3.0")
# Check MIOPEN_FIND_MODE
miopen_mode = os.environ.get('MIOPEN_FIND_MODE')
if miopen_mode:
print(f" β MIOPEN_FIND_MODE = {miopen_mode}")
else:
print(" βΉοΈ MIOPEN_FIND_MODE not set (recommended: FAST)")
# Try to run rocm-smi
print("\nChecking ROCm system management interface:")
try:
result = subprocess.run(['rocm-smi'], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
print(" β rocm-smi found and working")
print("\n Output (first 10 lines):")
lines = result.stdout.split('\n')[:10]
for line in lines:
if line.strip():
print(f" {line}")
else:
print(" β οΈ rocm-smi found but returned error")
except FileNotFoundError:
print(" β rocm-smi not found in PATH")
print(" This suggests ROCm is not properly installed")
except Exception as e:
print(f" β οΈ Error running rocm-smi: {e}")
except ImportError:
print("β PyTorch not installed")
def check_nvidia_setup():
"""Check NVIDIA CUDA setup."""
print_section("NVIDIA CUDA Configuration")
try:
import torch
is_cuda = hasattr(torch.version, 'cuda') and torch.version.cuda is not None
if not is_cuda:
print("Skipping - not a CUDA build")
return
print(f"CUDA version in PyTorch: {torch.version.cuda}")
# Try to run nvidia-smi
print("\nChecking NVIDIA System Management Interface:")
try:
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
print(" β nvidia-smi found and working")
print("\n Output (first 15 lines):")
lines = result.stdout.split('\n')[:15]
for line in lines:
if line.strip():
print(f" {line}")
else:
print(" β οΈ nvidia-smi found but returned error")
except FileNotFoundError:
print(" β nvidia-smi not found in PATH")
print(" This suggests NVIDIA drivers are not properly installed")
except Exception as e:
print(f" β οΈ Error running nvidia-smi: {e}")
except ImportError:
print("β PyTorch not installed")
def check_ace_step_env():
"""Check ACE-Step specific environment variables."""
print_section("ACE-Step Environment Variables")
relevant_vars = [
'MAX_CUDA_VRAM',
'HSA_OVERRIDE_GFX_VERSION',
'MIOPEN_FIND_MODE',
'TORCH_COMPILE_BACKEND',
'ACESTEP_LM_BACKEND',
]
found_any = False
for var in relevant_vars:
value = os.environ.get(var)
if value:
print(f" {var} = {value}")
found_any = True
if not found_any:
print(" No ACE-Step specific environment variables set")
def print_recommendations():
"""Print recommendations based on detected issues."""
print_section("Recommendations")
try:
import torch
is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None
is_cuda = hasattr(torch.version, 'cuda') and torch.version.cuda is not None
is_available = torch.cuda.is_available()
if is_available:
print("β Your GPU setup appears to be working correctly!")
print("\nYou can now run ACE-Step with GPU acceleration.")
elif is_rocm:
print("β ROCm build detected but GPU not available")
print("\nTroubleshooting steps for AMD GPUs:")
print(" 1. Set HSA_OVERRIDE_GFX_VERSION for your GPU model (see above)")
print(" 2. Verify ROCm installation with: rocm-smi")
print(" 3. Check that your GPU is supported by your ROCm version")
print(" 4. On Windows: Use start_gradio_ui_rocm.bat which sets all required variables")
print(" 5. On Linux: See docs/en/ACE-Step1.5-Rocm-Manual-Linux.md")
print("\nFor RX 9070 XT specifically:")
print(" export HSA_OVERRIDE_GFX_VERSION=11.0.0")
print(" or on Windows: set HSA_OVERRIDE_GFX_VERSION=11.0.0")
elif is_cuda:
print("β CUDA build detected but GPU not available")
print("\nTroubleshooting steps for NVIDIA GPUs:")
print(" 1. Install NVIDIA drivers from https://www.nvidia.com/download/index.aspx")
print(" 2. Verify installation with: nvidia-smi")
print(" 3. Ensure CUDA version compatibility between driver and PyTorch")
else:
print("β CPU-only PyTorch build detected")
print("\nYou need to reinstall PyTorch with GPU support:")
print("\nFor NVIDIA GPUs:")
print(" pip uninstall torch torchvision torchaudio")
print(f" pip install torch torchvision torchaudio --index-url {PYTORCH_CUDA_INSTALL_URL}")
print("\nFor AMD GPUs:")
print(" Windows: Follow instructions in requirements-rocm.txt")
print(f" Linux: pip install torch --index-url {PYTORCH_ROCM_INSTALL_URL}")
except ImportError:
print("β PyTorch not installed")
print("\nPlease install PyTorch first. See README.md for instructions.")
def main():
"""Main diagnostic routine."""
print("=" * HEADER_WIDTH)
print(" ACE-Step GPU Detection Diagnostic Tool")
print("=" * HEADER_WIDTH)
print("\nThis tool will help diagnose GPU detection issues.")
print("Please share the output with support when reporting issues.")
# Run all checks
pytorch_ok = check_pytorch()
if pytorch_ok:
gpu_ok = check_cuda_availability()
check_rocm_setup()
check_nvidia_setup()
check_ace_step_env()
print_recommendations()
print("\n" + "=" * HEADER_WIDTH)
print(" Diagnostic Complete")
print("=" * HEADER_WIDTH)
if __name__ == "__main__":
main()
|