phi35-moe-demo / scripts /select_revision.py
ianshank's picture
πŸš€ Deploy robust modular solution with comprehensive testing and CPU/GPU support
6510698 verified
"""
CPU-safe model revision selector.
This script finds a model revision that doesn't hard-require flash_attn
for CPU-only environments.
"""
import os
import re
import sys
import logging
from pathlib import Path
from typing import Optional, List
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration
MODEL_ID = os.getenv("HF_MODEL_ID", "microsoft/Phi-3.5-MoE-instruct")
TARGET_FILE = "modeling_phimoe.py"
ENV_FILE = Path(".env")
MAX_COMMITS_TO_CHECK = 50
class RevisionSelector:
"""Selects CPU-safe model revisions."""
def __init__(self, model_id: str = MODEL_ID):
self.model_id = model_id
self.api = HfApi()
def is_cpu_safe_revision(self, revision: str) -> bool:
"""Check if a revision is safe for CPU use (no hard flash_attn import)."""
try:
# Download the modeling file for this revision
file_path = hf_hub_download(
repo_id=self.model_id,
filename=TARGET_FILE,
revision=revision,
repo_type="model",
cache_dir=".cache"
)
# Read and analyze the file
with open(file_path, "r", encoding="utf-8") as f:
code = f.read()
# Check for hard flash_attn imports at module level
flash_attn_patterns = [
r'^\s*import\s+flash_attn',
r'^\s*from\s+flash_attn',
r'^\s*import\s+.*flash_attn',
r'^\s*from\s+.*flash_attn'
]
for pattern in flash_attn_patterns:
if re.search(pattern, code, flags=re.MULTILINE):
logger.debug(f"Revision {revision} has hard flash_attn import")
return False
logger.debug(f"Revision {revision} appears CPU-safe")
return True
except Exception as e:
logger.warning(f"Could not check revision {revision}: {e}")
return False
def get_recent_commits(self, max_commits: int = MAX_COMMITS_TO_CHECK) -> List[str]:
"""Get list of recent commit SHAs."""
try:
commits = list(self.api.list_repo_commits(
repo_id=self.model_id,
repo_type="model"
))
# Limit to max_commits and extract SHAs
commit_shas = [c.commit_id for c in commits[:max_commits]]
logger.info(f"Found {len(commit_shas)} recent commits to check")
return commit_shas
except Exception as e:
logger.error(f"Failed to get commits: {e}")
return []
def find_cpu_safe_revision(self) -> Optional[str]:
"""Find the most recent CPU-safe revision."""
logger.info(f"Searching for CPU-safe revision of {self.model_id}")
commits = self.get_recent_commits()
if not commits:
logger.error("No commits found")
return None
for i, commit_sha in enumerate(commits):
logger.info(f"Checking commit {i+1}/{len(commits)}: {commit_sha[:8]}...")
if self.is_cpu_safe_revision(commit_sha):
logger.info(f"βœ… Found CPU-safe revision: {commit_sha}")
return commit_sha
logger.error("❌ No CPU-safe revision found in recent commits")
return None
def save_revision_to_env(self, revision: str) -> None:
"""Save the selected revision to .env file."""
try:
# Read existing .env content
env_content = ""
if ENV_FILE.exists():
env_content = ENV_FILE.read_text()
# Remove any existing HF_REVISION line
lines = env_content.split('\n')
lines = [line for line in lines if not line.startswith('HF_REVISION=')]
# Add new revision
lines.append(f'HF_REVISION={revision}')
# Write back to file
ENV_FILE.write_text('\n'.join(lines))
logger.info(f"βœ… Saved revision {revision} to {ENV_FILE}")
except Exception as e:
logger.error(f"Failed to save revision to .env: {e}")
raise
def main():
"""Main function to select and save CPU-safe revision."""
# Check if we're on CPU and don't already have a revision set
import torch
if torch.cuda.is_available():
logger.info("GPU detected - no need to select CPU-safe revision")
return 0
existing_revision = os.getenv("HF_REVISION")
if existing_revision:
logger.info(f"HF_REVISION already set to: {existing_revision}")
return 0
logger.info("CPU-only environment detected - selecting CPU-safe revision")
try:
selector = RevisionSelector()
revision = selector.find_cpu_safe_revision()
if revision:
selector.save_revision_to_env(revision)
logger.info(f"πŸŽ‰ Successfully selected CPU-safe revision: {revision}")
return 0
else:
logger.error("❌ Could not find CPU-safe revision")
logger.error("Consider using a different model or enabling GPU")
return 1
except Exception as e:
logger.error(f"❌ Error selecting revision: {e}")
return 1
if __name__ == "__main__":
sys.exit(main())