File size: 4,965 Bytes
3eeba36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python3
"""
Pre-installation script for Phi-3.5-MoE Space
Installs required dependencies and selects CPU-safe model revision if needed
"""

import os
import sys
import subprocess
import torch
import re
from pathlib import Path
from huggingface_hub import HfApi

def install_dependencies():
    """Install required dependencies based on environment."""
    print("πŸ”§ Installing required dependencies...")
    
    # Always install einops
    subprocess.check_call([sys.executable, "-m", "pip", "install", "einops>=0.7.0"])
    print("βœ… Installed einops")
    
    # Install flash-attn only if CUDA is available
    if torch.cuda.is_available():
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", "flash-attn>=2.6.0", "--no-build-isolation"])
            print("βœ… Installed flash-attn for GPU runtime")
        except subprocess.CalledProcessError:
            print("⚠️ Failed to install flash-attn, continuing without it")
    else:
        print("ℹ️ CPU runtime detected: skipping flash-attn installation")

def select_cpu_safe_revision():
    """Select a CPU-safe model revision by checking commit history."""
    if torch.cuda.is_available() or os.getenv("HF_REVISION"):
        return
    
    MODEL_ID = os.getenv("HF_MODEL_ID", "microsoft/Phi-3.5-MoE-instruct")
    TARGET_FILE = "modeling_phimoe.py"
    ENV_FILE = ".env"
    
    print(f"πŸ” Selecting CPU-safe revision for {MODEL_ID}...")
    
    try:
        api = HfApi()
        for commit in api.list_repo_commits(MODEL_ID, repo_type="model"):
            sha = commit.commit_id
            try:
                file_path = api.hf_hub_download(MODEL_ID, TARGET_FILE, revision=sha, repo_type="model")
                with open(file_path, "r", encoding="utf-8") as f:
                    code = f.read()
                
                # Check if this version doesn't have flash_attn as a top-level import
                if not re.search(r'^\s*import\s+flash_attn|^\s*from\s+flash_attn', code, flags=re.M):
                    # Write to .env file
                    with open(ENV_FILE, "a", encoding="utf-8") as env_file:
                        env_file.write(f"HF_REVISION={sha}\n")
                    
                    # Also set it in the current environment
                    os.environ["HF_REVISION"] = sha
                    
                    print(f"βœ… Selected CPU-safe revision: {sha}")
                    return
            except Exception:
                continue
        
        print("⚠️ No CPU-safe revision found")
    except Exception as e:
        print(f"⚠️ Error selecting CPU-safe revision: {e}")

def create_model_patch():
    """Create a patch file to fix the model loading code."""
    PATCH_FILE = "model_patch.py"
    
    patch_content = """
# Monkey patch for transformers.dynamic_module_utils
import sys
import importlib
from importlib.abc import Loader
from importlib.machinery import ModuleSpec
from transformers.dynamic_module_utils import check_imports

# Create mock modules for missing dependencies
class MockModule:
    def __init__(self, name):
        self.__name__ = name
        self.__spec__ = ModuleSpec(name, None)
    
    def __getattr__(self, key):
        return MockModule(f"{self.__name__}.{key}")

# Override check_imports to handle missing dependencies
original_check_imports = check_imports
def patched_check_imports(resolved_module_file):
    try:
        return original_check_imports(resolved_module_file)
    except ImportError as e:
        # Extract missing modules
        import re
        missing = re.findall(r'packages that were not found in your environment: ([^.]+)', str(e))
        if missing:
            missing_modules = [m.strip() for m in missing[0].split(',')]
            print(f"⚠️ Missing dependencies: {', '.join(missing_modules)}")
            print("πŸ”§ Creating mock modules to continue loading...")
            
            # Create mock modules
            for module_name in missing_modules:
                if module_name not in sys.modules:
                    mock_module = MockModule(module_name)
                    sys.modules[module_name] = mock_module
                    print(f"βœ… Created mock for {module_name}")
            
            # Try again
            return original_check_imports(resolved_module_file)
        else:
            raise

# Apply the patch
from transformers import dynamic_module_utils
dynamic_module_utils.check_imports = patched_check_imports
print("βœ… Applied transformers patch for handling missing dependencies")
"""
    
    with open(PATCH_FILE, "w", encoding="utf-8") as f:
        f.write(patch_content)
    
    print(f"βœ… Created model patch file: {PATCH_FILE}")

if __name__ == "__main__":
    print("πŸš€ Running pre-installation script...")
    install_dependencies()
    select_cpu_safe_revision()
    create_model_patch()
    print("βœ… Pre-installation complete!")