File size: 10,573 Bytes
bc9c638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
#!/usr/bin/env python3
"""
GPU Detection Diagnostic Tool for ACE-Step

This script helps diagnose GPU detection issues by checking:
- PyTorch installation and build type (CUDA/ROCm/CPU)
- GPU availability and properties
- Environment variables
- Common configuration issues

Usage:
    python scripts/check_gpu.py
"""

import os
import sys
import subprocess

# Constants
HEADER_WIDTH = 80
PYTORCH_CUDA_INSTALL_URL = "https://download.pytorch.org/whl/cu121"
PYTORCH_ROCM_INSTALL_URL = "https://download.pytorch.org/whl/rocm6.0"


def print_section(title):
    """Print a section header."""
    print(f"\n{'=' * HEADER_WIDTH}")
    print(f"  {title}")
    print('=' * HEADER_WIDTH)


def check_pytorch():
    """Check PyTorch installation and build type."""
    print_section("PyTorch Installation")
    
    try:
        import torch
        print(f"βœ“ PyTorch installed: {torch.__version__}")
        
        # Check build type
        is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None
        is_cuda = hasattr(torch.version, 'cuda') and torch.version.cuda is not None
        
        if is_rocm:
            print(f"βœ“ Build type: ROCm (HIP {torch.version.hip})")
        elif is_cuda:
            print(f"βœ“ Build type: CUDA {torch.version.cuda}")
        else:
            print("⚠️ Build type: CPU-only")
            print("\n❌ You have installed a CPU-only version of PyTorch!")
            print("\nTo enable GPU support:")
            print("  For NVIDIA GPUs:")
            print(f"    pip install torch --index-url {PYTORCH_CUDA_INSTALL_URL}")
            print("\n  For AMD GPUs with ROCm:")
            print("    Windows: See requirements-rocm.txt")
            print(f"    Linux: pip install torch --index-url {PYTORCH_ROCM_INSTALL_URL}")
            return False
        
        return True
    except ImportError:
        print("❌ PyTorch not installed")
        print("\nPlease install PyTorch first:")
        print("  pip install torch")
        return False


def check_cuda_availability():
    """Check CUDA/ROCm availability."""
    print_section("GPU Availability Check")
    
    try:
        import torch
        
        is_available = torch.cuda.is_available()
        print(f"torch.cuda.is_available(): {is_available}")
        
        if is_available:
            print(f"βœ“ GPU detected!")
            device_count = torch.cuda.device_count()
            print(f"  Number of GPUs: {device_count}")
            
            for i in range(device_count):
                device_name = torch.cuda.get_device_name(i)
                props = torch.cuda.get_device_properties(i)
                memory_gb = props.total_memory / (1024**3)
                print(f"\n  GPU {i}: {device_name}")
                print(f"    Total memory: {memory_gb:.2f} GB")
                print(f"    Compute capability: {props.major}.{props.minor}")
            
            return True
        else:
            print("❌ No GPU detected by PyTorch")
            return False
            
    except Exception as e:
        print(f"❌ Error checking GPU availability: {e}")
        return False


def check_rocm_setup():
    """Check ROCm-specific setup for AMD GPUs."""
    print_section("ROCm Configuration (AMD GPUs)")
    
    try:
        import torch
        is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None
        
        if not is_rocm:
            print("Skipping - not a ROCm build")
            return
        
        print("Checking ROCm environment variables:")
        
        # Check HSA_OVERRIDE_GFX_VERSION
        hsa_override = os.environ.get('HSA_OVERRIDE_GFX_VERSION')
        if hsa_override:
            print(f"  βœ“ HSA_OVERRIDE_GFX_VERSION = {hsa_override}")
        else:
            print("  ⚠️ HSA_OVERRIDE_GFX_VERSION not set")
            print("\n  This variable is required for many AMD GPUs!")
            print("  Set it according to your GPU:")
            print("    RX 7900 XT/XTX, RX 9070 XT: HSA_OVERRIDE_GFX_VERSION=11.0.0")
            print("    RX 7800 XT, RX 7700 XT: HSA_OVERRIDE_GFX_VERSION=11.0.1")
            print("    RX 7600: HSA_OVERRIDE_GFX_VERSION=11.0.2")
            print("    RX 6000 series: HSA_OVERRIDE_GFX_VERSION=10.3.0")
        
        # Check MIOPEN_FIND_MODE
        miopen_mode = os.environ.get('MIOPEN_FIND_MODE')
        if miopen_mode:
            print(f"  βœ“ MIOPEN_FIND_MODE = {miopen_mode}")
        else:
            print("  ℹ️ MIOPEN_FIND_MODE not set (recommended: FAST)")
        
        # Try to run rocm-smi
        print("\nChecking ROCm system management interface:")
        try:
            result = subprocess.run(['rocm-smi'], capture_output=True, text=True, timeout=5)
            if result.returncode == 0:
                print("  βœ“ rocm-smi found and working")
                print("\n  Output (first 10 lines):")
                lines = result.stdout.split('\n')[:10]
                for line in lines:
                    if line.strip():
                        print(f"    {line}")
            else:
                print("  ⚠️ rocm-smi found but returned error")
        except FileNotFoundError:
            print("  ❌ rocm-smi not found in PATH")
            print("     This suggests ROCm is not properly installed")
        except Exception as e:
            print(f"  ⚠️ Error running rocm-smi: {e}")
            
    except ImportError:
        print("❌ PyTorch not installed")


def check_nvidia_setup():
    """Check NVIDIA CUDA setup."""
    print_section("NVIDIA CUDA Configuration")
    
    try:
        import torch
        is_cuda = hasattr(torch.version, 'cuda') and torch.version.cuda is not None
        
        if not is_cuda:
            print("Skipping - not a CUDA build")
            return
        
        print(f"CUDA version in PyTorch: {torch.version.cuda}")
        
        # Try to run nvidia-smi
        print("\nChecking NVIDIA System Management Interface:")
        try:
            result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=5)
            if result.returncode == 0:
                print("  βœ“ nvidia-smi found and working")
                print("\n  Output (first 15 lines):")
                lines = result.stdout.split('\n')[:15]
                for line in lines:
                    if line.strip():
                        print(f"    {line}")
            else:
                print("  ⚠️ nvidia-smi found but returned error")
        except FileNotFoundError:
            print("  ❌ nvidia-smi not found in PATH")
            print("     This suggests NVIDIA drivers are not properly installed")
        except Exception as e:
            print(f"  ⚠️ Error running nvidia-smi: {e}")
            
    except ImportError:
        print("❌ PyTorch not installed")


def check_ace_step_env():
    """Check ACE-Step specific environment variables."""
    print_section("ACE-Step Environment Variables")
    
    relevant_vars = [
        'MAX_CUDA_VRAM',
        'HSA_OVERRIDE_GFX_VERSION',
        'MIOPEN_FIND_MODE',
        'TORCH_COMPILE_BACKEND',
        'ACESTEP_LM_BACKEND',
    ]
    
    found_any = False
    for var in relevant_vars:
        value = os.environ.get(var)
        if value:
            print(f"  {var} = {value}")
            found_any = True
    
    if not found_any:
        print("  No ACE-Step specific environment variables set")


def print_recommendations():
    """Print recommendations based on detected issues."""
    print_section("Recommendations")
    
    try:
        import torch
        
        is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None
        is_cuda = hasattr(torch.version, 'cuda') and torch.version.cuda is not None
        is_available = torch.cuda.is_available()
        
        if is_available:
            print("βœ“ Your GPU setup appears to be working correctly!")
            print("\nYou can now run ACE-Step with GPU acceleration.")
        elif is_rocm:
            print("❌ ROCm build detected but GPU not available")
            print("\nTroubleshooting steps for AMD GPUs:")
            print("  1. Set HSA_OVERRIDE_GFX_VERSION for your GPU model (see above)")
            print("  2. Verify ROCm installation with: rocm-smi")
            print("  3. Check that your GPU is supported by your ROCm version")
            print("  4. On Windows: Use start_gradio_ui_rocm.bat which sets all required variables")
            print("  5. On Linux: See docs/en/ACE-Step1.5-Rocm-Manual-Linux.md")
            print("\nFor RX 9070 XT specifically:")
            print("  export HSA_OVERRIDE_GFX_VERSION=11.0.0")
            print("  or on Windows: set HSA_OVERRIDE_GFX_VERSION=11.0.0")
        elif is_cuda:
            print("❌ CUDA build detected but GPU not available")
            print("\nTroubleshooting steps for NVIDIA GPUs:")
            print("  1. Install NVIDIA drivers from https://www.nvidia.com/download/index.aspx")
            print("  2. Verify installation with: nvidia-smi")
            print("  3. Ensure CUDA version compatibility between driver and PyTorch")
        else:
            print("❌ CPU-only PyTorch build detected")
            print("\nYou need to reinstall PyTorch with GPU support:")
            print("\nFor NVIDIA GPUs:")
            print("  pip uninstall torch torchvision torchaudio")
            print(f"  pip install torch torchvision torchaudio --index-url {PYTORCH_CUDA_INSTALL_URL}")
            print("\nFor AMD GPUs:")
            print("  Windows: Follow instructions in requirements-rocm.txt")
            print(f"  Linux: pip install torch --index-url {PYTORCH_ROCM_INSTALL_URL}")
            
    except ImportError:
        print("❌ PyTorch not installed")
        print("\nPlease install PyTorch first. See README.md for instructions.")


def main():
    """Main diagnostic routine."""
    print("=" * HEADER_WIDTH)
    print("  ACE-Step GPU Detection Diagnostic Tool")
    print("=" * HEADER_WIDTH)
    print("\nThis tool will help diagnose GPU detection issues.")
    print("Please share the output with support when reporting issues.")
    
    # Run all checks
    pytorch_ok = check_pytorch()
    
    if pytorch_ok:
        gpu_ok = check_cuda_availability()
        check_rocm_setup()
        check_nvidia_setup()
        check_ace_step_env()
    
    print_recommendations()
    
    print("\n" + "=" * HEADER_WIDTH)
    print("  Diagnostic Complete")
    print("=" * HEADER_WIDTH)


if __name__ == "__main__":
    main()