Spaces:
Sleeping
Sleeping
File size: 4,965 Bytes
3eeba36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
#!/usr/bin/env python3
"""
Pre-installation script for Phi-3.5-MoE Space
Installs required dependencies and selects CPU-safe model revision if needed
"""
import os
import sys
import subprocess
import torch
import re
from pathlib import Path
from huggingface_hub import HfApi
def install_dependencies():
"""Install required dependencies based on environment."""
print("π§ Installing required dependencies...")
# Always install einops
subprocess.check_call([sys.executable, "-m", "pip", "install", "einops>=0.7.0"])
print("β
Installed einops")
# Install flash-attn only if CUDA is available
if torch.cuda.is_available():
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "flash-attn>=2.6.0", "--no-build-isolation"])
print("β
Installed flash-attn for GPU runtime")
except subprocess.CalledProcessError:
print("β οΈ Failed to install flash-attn, continuing without it")
else:
print("βΉοΈ CPU runtime detected: skipping flash-attn installation")
def select_cpu_safe_revision():
"""Select a CPU-safe model revision by checking commit history."""
if torch.cuda.is_available() or os.getenv("HF_REVISION"):
return
MODEL_ID = os.getenv("HF_MODEL_ID", "microsoft/Phi-3.5-MoE-instruct")
TARGET_FILE = "modeling_phimoe.py"
ENV_FILE = ".env"
print(f"π Selecting CPU-safe revision for {MODEL_ID}...")
try:
api = HfApi()
for commit in api.list_repo_commits(MODEL_ID, repo_type="model"):
sha = commit.commit_id
try:
file_path = api.hf_hub_download(MODEL_ID, TARGET_FILE, revision=sha, repo_type="model")
with open(file_path, "r", encoding="utf-8") as f:
code = f.read()
# Check if this version doesn't have flash_attn as a top-level import
if not re.search(r'^\s*import\s+flash_attn|^\s*from\s+flash_attn', code, flags=re.M):
# Write to .env file
with open(ENV_FILE, "a", encoding="utf-8") as env_file:
env_file.write(f"HF_REVISION={sha}\n")
# Also set it in the current environment
os.environ["HF_REVISION"] = sha
print(f"β
Selected CPU-safe revision: {sha}")
return
except Exception:
continue
print("β οΈ No CPU-safe revision found")
except Exception as e:
print(f"β οΈ Error selecting CPU-safe revision: {e}")
def create_model_patch():
"""Create a patch file to fix the model loading code."""
PATCH_FILE = "model_patch.py"
patch_content = """
# Monkey patch for transformers.dynamic_module_utils
import sys
import importlib
from importlib.abc import Loader
from importlib.machinery import ModuleSpec
from transformers.dynamic_module_utils import check_imports
# Create mock modules for missing dependencies
class MockModule:
def __init__(self, name):
self.__name__ = name
self.__spec__ = ModuleSpec(name, None)
def __getattr__(self, key):
return MockModule(f"{self.__name__}.{key}")
# Override check_imports to handle missing dependencies
original_check_imports = check_imports
def patched_check_imports(resolved_module_file):
try:
return original_check_imports(resolved_module_file)
except ImportError as e:
# Extract missing modules
import re
missing = re.findall(r'packages that were not found in your environment: ([^.]+)', str(e))
if missing:
missing_modules = [m.strip() for m in missing[0].split(',')]
print(f"β οΈ Missing dependencies: {', '.join(missing_modules)}")
print("π§ Creating mock modules to continue loading...")
# Create mock modules
for module_name in missing_modules:
if module_name not in sys.modules:
mock_module = MockModule(module_name)
sys.modules[module_name] = mock_module
print(f"β
Created mock for {module_name}")
# Try again
return original_check_imports(resolved_module_file)
else:
raise
# Apply the patch
from transformers import dynamic_module_utils
dynamic_module_utils.check_imports = patched_check_imports
print("β
Applied transformers patch for handling missing dependencies")
"""
with open(PATCH_FILE, "w", encoding="utf-8") as f:
f.write(patch_content)
print(f"β
Created model patch file: {PATCH_FILE}")
if __name__ == "__main__":
print("π Running pre-installation script...")
install_dependencies()
select_cpu_safe_revision()
create_model_patch()
print("β
Pre-installation complete!")
|