Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Pre-installation script for Phi-3.5-MoE Space | |
| Installs required dependencies and selects CPU-safe model revision if needed | |
| """ | |
| import os | |
| import sys | |
| import subprocess | |
| import torch | |
| import re | |
| from pathlib import Path | |
| from huggingface_hub import HfApi | |
| def install_dependencies(): | |
| """Install required dependencies based on environment.""" | |
| print("π§ Installing required dependencies...") | |
| # Always install einops | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "einops>=0.7.0"]) | |
| print("β Installed einops") | |
| # Install flash-attn only if CUDA is available | |
| if torch.cuda.is_available(): | |
| try: | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "flash-attn>=2.6.0", "--no-build-isolation"]) | |
| print("β Installed flash-attn for GPU runtime") | |
| except subprocess.CalledProcessError: | |
| print("β οΈ Failed to install flash-attn, continuing without it") | |
| else: | |
| print("βΉοΈ CPU runtime detected: skipping flash-attn installation") | |
| def select_cpu_safe_revision(): | |
| """Select a CPU-safe model revision by checking commit history.""" | |
| if torch.cuda.is_available() or os.getenv("HF_REVISION"): | |
| return | |
| MODEL_ID = os.getenv("HF_MODEL_ID", "microsoft/Phi-3.5-MoE-instruct") | |
| TARGET_FILE = "modeling_phimoe.py" | |
| ENV_FILE = ".env" | |
| print(f"π Selecting CPU-safe revision for {MODEL_ID}...") | |
| try: | |
| api = HfApi() | |
| for commit in api.list_repo_commits(MODEL_ID, repo_type="model"): | |
| sha = commit.commit_id | |
| try: | |
| file_path = api.hf_hub_download(MODEL_ID, TARGET_FILE, revision=sha, repo_type="model") | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| code = f.read() | |
| # Check if this version doesn't have flash_attn as a top-level import | |
| if not re.search(r'^\s*import\s+flash_attn|^\s*from\s+flash_attn', code, flags=re.M): | |
| # Write to .env file | |
| with open(ENV_FILE, "a", encoding="utf-8") as env_file: | |
| env_file.write(f"HF_REVISION={sha}\n") | |
| # Also set it in the current environment | |
| os.environ["HF_REVISION"] = sha | |
| print(f"β Selected CPU-safe revision: {sha}") | |
| return | |
| except Exception: | |
| continue | |
| print("β οΈ No CPU-safe revision found") | |
| except Exception as e: | |
| print(f"β οΈ Error selecting CPU-safe revision: {e}") | |
| def create_model_patch(): | |
| """Create a patch file to fix the model loading code.""" | |
| PATCH_FILE = "model_patch.py" | |
| patch_content = """ | |
| # Monkey patch for transformers.dynamic_module_utils | |
| import sys | |
| import importlib | |
| from importlib.abc import Loader | |
| from importlib.machinery import ModuleSpec | |
| from transformers.dynamic_module_utils import check_imports | |
| # Create mock modules for missing dependencies | |
| class MockModule: | |
| def __init__(self, name): | |
| self.__name__ = name | |
| self.__spec__ = ModuleSpec(name, None) | |
| def __getattr__(self, key): | |
| return MockModule(f"{self.__name__}.{key}") | |
| # Override check_imports to handle missing dependencies | |
| original_check_imports = check_imports | |
| def patched_check_imports(resolved_module_file): | |
| try: | |
| return original_check_imports(resolved_module_file) | |
| except ImportError as e: | |
| # Extract missing modules | |
| import re | |
| missing = re.findall(r'packages that were not found in your environment: ([^.]+)', str(e)) | |
| if missing: | |
| missing_modules = [m.strip() for m in missing[0].split(',')] | |
| print(f"β οΈ Missing dependencies: {', '.join(missing_modules)}") | |
| print("π§ Creating mock modules to continue loading...") | |
| # Create mock modules | |
| for module_name in missing_modules: | |
| if module_name not in sys.modules: | |
| mock_module = MockModule(module_name) | |
| sys.modules[module_name] = mock_module | |
| print(f"β Created mock for {module_name}") | |
| # Try again | |
| return original_check_imports(resolved_module_file) | |
| else: | |
| raise | |
| # Apply the patch | |
| from transformers import dynamic_module_utils | |
| dynamic_module_utils.check_imports = patched_check_imports | |
| print("β Applied transformers patch for handling missing dependencies") | |
| """ | |
| with open(PATCH_FILE, "w", encoding="utf-8") as f: | |
| f.write(patch_content) | |
| print(f"β Created model patch file: {PATCH_FILE}") | |
| if __name__ == "__main__": | |
| print("π Running pre-installation script...") | |
| install_dependencies() | |
| select_cpu_safe_revision() | |
| create_model_patch() | |
| print("β Pre-installation complete!") | |