File size: 5,694 Bytes
6510698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
CPU-safe model revision selector.

This script finds a model revision that doesn't hard-require flash_attn
for CPU-only environments.
"""

import os
import re
import sys
import logging
from pathlib import Path
from typing import Optional, List
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Configuration
MODEL_ID = os.getenv("HF_MODEL_ID", "microsoft/Phi-3.5-MoE-instruct")
TARGET_FILE = "modeling_phimoe.py"
ENV_FILE = Path(".env")
MAX_COMMITS_TO_CHECK = 50


class RevisionSelector:
    """Selects CPU-safe model revisions."""
    
    def __init__(self, model_id: str = MODEL_ID):
        self.model_id = model_id
        self.api = HfApi()
    
    def is_cpu_safe_revision(self, revision: str) -> bool:
        """Check if a revision is safe for CPU use (no hard flash_attn import)."""
        try:
            # Download the modeling file for this revision
            file_path = hf_hub_download(
                repo_id=self.model_id,
                filename=TARGET_FILE,
                revision=revision,
                repo_type="model",
                cache_dir=".cache"
            )
            
            # Read and analyze the file
            with open(file_path, "r", encoding="utf-8") as f:
                code = f.read()
            
            # Check for hard flash_attn imports at module level
            flash_attn_patterns = [
                r'^\s*import\s+flash_attn',
                r'^\s*from\s+flash_attn',
                r'^\s*import\s+.*flash_attn',
                r'^\s*from\s+.*flash_attn'
            ]
            
            for pattern in flash_attn_patterns:
                if re.search(pattern, code, flags=re.MULTILINE):
                    logger.debug(f"Revision {revision} has hard flash_attn import")
                    return False
            
            logger.debug(f"Revision {revision} appears CPU-safe")
            return True
            
        except Exception as e:
            logger.warning(f"Could not check revision {revision}: {e}")
            return False
    
    def get_recent_commits(self, max_commits: int = MAX_COMMITS_TO_CHECK) -> List[str]:
        """Get list of recent commit SHAs."""
        try:
            commits = list(self.api.list_repo_commits(
                repo_id=self.model_id,
                repo_type="model"
            ))
            
            # Limit to max_commits and extract SHAs
            commit_shas = [c.commit_id for c in commits[:max_commits]]
            logger.info(f"Found {len(commit_shas)} recent commits to check")
            return commit_shas
            
        except Exception as e:
            logger.error(f"Failed to get commits: {e}")
            return []
    
    def find_cpu_safe_revision(self) -> Optional[str]:
        """Find the most recent CPU-safe revision."""
        logger.info(f"Searching for CPU-safe revision of {self.model_id}")
        
        commits = self.get_recent_commits()
        if not commits:
            logger.error("No commits found")
            return None
        
        for i, commit_sha in enumerate(commits):
            logger.info(f"Checking commit {i+1}/{len(commits)}: {commit_sha[:8]}...")
            
            if self.is_cpu_safe_revision(commit_sha):
                logger.info(f"βœ… Found CPU-safe revision: {commit_sha}")
                return commit_sha
        
        logger.error("❌ No CPU-safe revision found in recent commits")
        return None
    
    def save_revision_to_env(self, revision: str) -> None:
        """Save the selected revision to .env file."""
        try:
            # Read existing .env content
            env_content = ""
            if ENV_FILE.exists():
                env_content = ENV_FILE.read_text()
            
            # Remove any existing HF_REVISION line
            lines = env_content.split('\n')
            lines = [line for line in lines if not line.startswith('HF_REVISION=')]
            
            # Add new revision
            lines.append(f'HF_REVISION={revision}')
            
            # Write back to file
            ENV_FILE.write_text('\n'.join(lines))
            logger.info(f"βœ… Saved revision {revision} to {ENV_FILE}")
            
        except Exception as e:
            logger.error(f"Failed to save revision to .env: {e}")
            raise


def main():
    """Main function to select and save CPU-safe revision."""
    
    # Check if we're on CPU and don't already have a revision set
    import torch
    
    if torch.cuda.is_available():
        logger.info("GPU detected - no need to select CPU-safe revision")
        return 0
    
    existing_revision = os.getenv("HF_REVISION")
    if existing_revision:
        logger.info(f"HF_REVISION already set to: {existing_revision}")
        return 0
    
    logger.info("CPU-only environment detected - selecting CPU-safe revision")
    
    try:
        selector = RevisionSelector()
        revision = selector.find_cpu_safe_revision()
        
        if revision:
            selector.save_revision_to_env(revision)
            logger.info(f"πŸŽ‰ Successfully selected CPU-safe revision: {revision}")
            return 0
        else:
            logger.error("❌ Could not find CPU-safe revision")
            logger.error("Consider using a different model or enabling GPU")
            return 1
            
    except Exception as e:
        logger.error(f"❌ Error selecting revision: {e}")
        return 1


if __name__ == "__main__":
    sys.exit(main())