Spaces:
Sleeping
Sleeping
File size: 5,694 Bytes
6510698 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
"""
CPU-safe model revision selector.
This script finds a model revision that doesn't hard-require flash_attn
for CPU-only environments.
"""
import os
import re
import sys
import logging
from pathlib import Path
from typing import Optional, List
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration
MODEL_ID = os.getenv("HF_MODEL_ID", "microsoft/Phi-3.5-MoE-instruct")
TARGET_FILE = "modeling_phimoe.py"
ENV_FILE = Path(".env")
MAX_COMMITS_TO_CHECK = 50
class RevisionSelector:
"""Selects CPU-safe model revisions."""
def __init__(self, model_id: str = MODEL_ID):
self.model_id = model_id
self.api = HfApi()
def is_cpu_safe_revision(self, revision: str) -> bool:
"""Check if a revision is safe for CPU use (no hard flash_attn import)."""
try:
# Download the modeling file for this revision
file_path = hf_hub_download(
repo_id=self.model_id,
filename=TARGET_FILE,
revision=revision,
repo_type="model",
cache_dir=".cache"
)
# Read and analyze the file
with open(file_path, "r", encoding="utf-8") as f:
code = f.read()
# Check for hard flash_attn imports at module level
flash_attn_patterns = [
r'^\s*import\s+flash_attn',
r'^\s*from\s+flash_attn',
r'^\s*import\s+.*flash_attn',
r'^\s*from\s+.*flash_attn'
]
for pattern in flash_attn_patterns:
if re.search(pattern, code, flags=re.MULTILINE):
logger.debug(f"Revision {revision} has hard flash_attn import")
return False
logger.debug(f"Revision {revision} appears CPU-safe")
return True
except Exception as e:
logger.warning(f"Could not check revision {revision}: {e}")
return False
def get_recent_commits(self, max_commits: int = MAX_COMMITS_TO_CHECK) -> List[str]:
"""Get list of recent commit SHAs."""
try:
commits = list(self.api.list_repo_commits(
repo_id=self.model_id,
repo_type="model"
))
# Limit to max_commits and extract SHAs
commit_shas = [c.commit_id for c in commits[:max_commits]]
logger.info(f"Found {len(commit_shas)} recent commits to check")
return commit_shas
except Exception as e:
logger.error(f"Failed to get commits: {e}")
return []
def find_cpu_safe_revision(self) -> Optional[str]:
"""Find the most recent CPU-safe revision."""
logger.info(f"Searching for CPU-safe revision of {self.model_id}")
commits = self.get_recent_commits()
if not commits:
logger.error("No commits found")
return None
for i, commit_sha in enumerate(commits):
logger.info(f"Checking commit {i+1}/{len(commits)}: {commit_sha[:8]}...")
if self.is_cpu_safe_revision(commit_sha):
logger.info(f"β
Found CPU-safe revision: {commit_sha}")
return commit_sha
logger.error("β No CPU-safe revision found in recent commits")
return None
def save_revision_to_env(self, revision: str) -> None:
"""Save the selected revision to .env file."""
try:
# Read existing .env content
env_content = ""
if ENV_FILE.exists():
env_content = ENV_FILE.read_text()
# Remove any existing HF_REVISION line
lines = env_content.split('\n')
lines = [line for line in lines if not line.startswith('HF_REVISION=')]
# Add new revision
lines.append(f'HF_REVISION={revision}')
# Write back to file
ENV_FILE.write_text('\n'.join(lines))
logger.info(f"β
Saved revision {revision} to {ENV_FILE}")
except Exception as e:
logger.error(f"Failed to save revision to .env: {e}")
raise
def main():
"""Main function to select and save CPU-safe revision."""
# Check if we're on CPU and don't already have a revision set
import torch
if torch.cuda.is_available():
logger.info("GPU detected - no need to select CPU-safe revision")
return 0
existing_revision = os.getenv("HF_REVISION")
if existing_revision:
logger.info(f"HF_REVISION already set to: {existing_revision}")
return 0
logger.info("CPU-only environment detected - selecting CPU-safe revision")
try:
selector = RevisionSelector()
revision = selector.find_cpu_safe_revision()
if revision:
selector.save_revision_to_env(revision)
logger.info(f"π Successfully selected CPU-safe revision: {revision}")
return 0
else:
logger.error("β Could not find CPU-safe revision")
logger.error("Consider using a different model or enabling GPU")
return 1
except Exception as e:
logger.error(f"β Error selecting revision: {e}")
return 1
if __name__ == "__main__":
sys.exit(main())
|