""" CPU-safe model revision selector. This script finds a model revision that doesn't hard-require flash_attn for CPU-only environments. """ import os import re import sys import logging from pathlib import Path from typing import Optional, List from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Configuration MODEL_ID = os.getenv("HF_MODEL_ID", "microsoft/Phi-3.5-MoE-instruct") TARGET_FILE = "modeling_phimoe.py" ENV_FILE = Path(".env") MAX_COMMITS_TO_CHECK = 50 class RevisionSelector: """Selects CPU-safe model revisions.""" def __init__(self, model_id: str = MODEL_ID): self.model_id = model_id self.api = HfApi() def is_cpu_safe_revision(self, revision: str) -> bool: """Check if a revision is safe for CPU use (no hard flash_attn import).""" try: # Download the modeling file for this revision file_path = hf_hub_download( repo_id=self.model_id, filename=TARGET_FILE, revision=revision, repo_type="model", cache_dir=".cache" ) # Read and analyze the file with open(file_path, "r", encoding="utf-8") as f: code = f.read() # Check for hard flash_attn imports at module level flash_attn_patterns = [ r'^\s*import\s+flash_attn', r'^\s*from\s+flash_attn', r'^\s*import\s+.*flash_attn', r'^\s*from\s+.*flash_attn' ] for pattern in flash_attn_patterns: if re.search(pattern, code, flags=re.MULTILINE): logger.debug(f"Revision {revision} has hard flash_attn import") return False logger.debug(f"Revision {revision} appears CPU-safe") return True except Exception as e: logger.warning(f"Could not check revision {revision}: {e}") return False def get_recent_commits(self, max_commits: int = MAX_COMMITS_TO_CHECK) -> List[str]: """Get list of recent commit SHAs.""" try: commits = list(self.api.list_repo_commits( repo_id=self.model_id, repo_type="model" )) # Limit to max_commits and extract SHAs commit_shas = [c.commit_id for c in commits[:max_commits]] logger.info(f"Found {len(commit_shas)} recent commits to check") return commit_shas except Exception as e: logger.error(f"Failed to get commits: {e}") return [] def find_cpu_safe_revision(self) -> Optional[str]: """Find the most recent CPU-safe revision.""" logger.info(f"Searching for CPU-safe revision of {self.model_id}") commits = self.get_recent_commits() if not commits: logger.error("No commits found") return None for i, commit_sha in enumerate(commits): logger.info(f"Checking commit {i+1}/{len(commits)}: {commit_sha[:8]}...") if self.is_cpu_safe_revision(commit_sha): logger.info(f"✅ Found CPU-safe revision: {commit_sha}") return commit_sha logger.error("❌ No CPU-safe revision found in recent commits") return None def save_revision_to_env(self, revision: str) -> None: """Save the selected revision to .env file.""" try: # Read existing .env content env_content = "" if ENV_FILE.exists(): env_content = ENV_FILE.read_text() # Remove any existing HF_REVISION line lines = env_content.split('\n') lines = [line for line in lines if not line.startswith('HF_REVISION=')] # Add new revision lines.append(f'HF_REVISION={revision}') # Write back to file ENV_FILE.write_text('\n'.join(lines)) logger.info(f"✅ Saved revision {revision} to {ENV_FILE}") except Exception as e: logger.error(f"Failed to save revision to .env: {e}") raise def main(): """Main function to select and save CPU-safe revision.""" # Check if we're on CPU and don't already have a revision set import torch if torch.cuda.is_available(): logger.info("GPU detected - no need to select CPU-safe revision") return 0 existing_revision = os.getenv("HF_REVISION") if existing_revision: logger.info(f"HF_REVISION already set to: {existing_revision}") return 0 logger.info("CPU-only environment detected - selecting CPU-safe revision") try: selector = RevisionSelector() revision = selector.find_cpu_safe_revision() if revision: selector.save_revision_to_env(revision) logger.info(f"🎉 Successfully selected CPU-safe revision: {revision}") return 0 else: logger.error("❌ Could not find CPU-safe revision") logger.error("Consider using a different model or enabling GPU") return 1 except Exception as e: logger.error(f"❌ Error selecting revision: {e}") return 1 if __name__ == "__main__": sys.exit(main())