File size: 6,587 Bytes
6510698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
"""
Model configuration and environment detection module.

This module handles:
- Environment detection (CPU/GPU)
- Model configuration based on environment
- Dependency validation
- Safe defaults for different environments
"""

import os
import torch
from typing import Dict, Any, Optional
from dataclasses import dataclass
import logging

logger = logging.getLogger(__name__)


@dataclass
class ModelConfig:
    """Configuration for model loading and inference."""
    
    model_id: str
    revision: Optional[str]
    dtype: torch.dtype
    device_map: str
    attn_implementation: str
    low_cpu_mem_usage: bool
    trust_remote_code: bool
    
    @property
    def is_gpu_available(self) -> bool:
        """Check if GPU is available."""
        return torch.cuda.is_available()
    
    @property
    def device_info(self) -> Dict[str, Any]:
        """Get device information."""
        info = {
            "cuda_available": torch.cuda.is_available(),
            "device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
            "current_device": torch.cuda.current_device() if torch.cuda.is_available() else None,
        }
        
        if torch.cuda.is_available():
            info["device_name"] = torch.cuda.get_device_name()
            info["memory_allocated"] = torch.cuda.memory_allocated()
            info["memory_reserved"] = torch.cuda.memory_reserved()
        
        return info


class EnvironmentDetector:
    """Detects and configures environment-specific settings."""
    
    @staticmethod
    def detect_environment() -> Dict[str, Any]:
        """Detect current environment capabilities."""
        env_info = {
            "cuda_available": torch.cuda.is_available(),
            "cuda_version": torch.version.cuda if torch.cuda.is_available() else None,
            "torch_version": torch.__version__,
            "platform": os.name,
            "python_version": os.sys.version,
        }
        
        # Check for flash_attn availability
        try:
            import importlib
            flash_attn = importlib.import_module("flash_attn")
            env_info["flash_attn_available"] = True
            env_info["flash_attn_version"] = getattr(flash_attn, "__version__", "unknown")
        except ImportError:
            env_info["flash_attn_available"] = False
            env_info["flash_attn_version"] = None
        
        # Check for einops availability
        try:
            import importlib
            einops = importlib.import_module("einops")
            env_info["einops_available"] = True
            env_info["einops_version"] = getattr(einops, "__version__", "unknown")
        except ImportError:
            env_info["einops_available"] = False
            env_info["einops_version"] = None
        
        logger.info(f"Environment detected: {env_info}")
        return env_info
    
    @staticmethod
    def create_model_config(
        model_id: Optional[str] = None,
        revision: Optional[str] = None
    ) -> ModelConfig:
        """Create model configuration based on environment."""
        
        # Default model
        if model_id is None:
            model_id = os.getenv("HF_MODEL_ID", "microsoft/Phi-3.5-MoE-instruct")
        
        # Get revision from environment if not provided
        if revision is None:
            revision = os.getenv("HF_REVISION")
        
        # Detect environment
        is_gpu = torch.cuda.is_available()
        
        # Configure based on environment
        if is_gpu:
            # GPU configuration - optimized for performance
            config = ModelConfig(
                model_id=model_id,
                revision=revision,
                dtype=torch.bfloat16,  # Use bfloat16 for better GPU performance
                device_map="auto",
                attn_implementation="sdpa",  # Use scaled dot-product attention
                low_cpu_mem_usage=False,
                trust_remote_code=True
            )
            logger.info("Created GPU-optimized model configuration")
        else:
            # CPU configuration - optimized for compatibility
            config = ModelConfig(
                model_id=model_id,
                revision=revision,
                dtype=torch.float32,  # Use float32 for CPU compatibility
                device_map="cpu",
                attn_implementation="eager",  # Use eager attention for CPU
                low_cpu_mem_usage=True,
                trust_remote_code=True
            )
            logger.info("Created CPU-optimized model configuration")
        
        return config


class DependencyValidator:
    """Validates required dependencies are available."""
    
    REQUIRED_PACKAGES = [
        "transformers",
        "accelerate", 
        "einops",
        "huggingface_hub",
        "gradio",
        "torch"
    ]
    
    OPTIONAL_PACKAGES = [
        "flash_attn"  # Only required for GPU with certain model revisions
    ]
    
    @classmethod
    def validate_dependencies(cls) -> Dict[str, bool]:
        """Validate all dependencies."""
        results = {}
        
        # Check required packages
        for package in cls.REQUIRED_PACKAGES:
            try:
                import importlib
                importlib.import_module(package)
                results[package] = True
                logger.debug(f"✅ {package} is available")
            except ImportError:
                results[package] = False
                logger.error(f"❌ {package} is missing")
        
        # Check optional packages
        for package in cls.OPTIONAL_PACKAGES:
            try:
                import importlib
                importlib.import_module(package)
                results[package] = True
                logger.debug(f"✅ {package} (optional) is available")
            except ImportError:
                results[package] = False
                logger.debug(f"⚠️ {package} (optional) is missing")
        
        return results
    
    @classmethod
    def get_missing_required_packages(cls) -> list:
        """Get list of missing required packages."""
        validation = cls.validate_dependencies()
        return [pkg for pkg in cls.REQUIRED_PACKAGES if not validation.get(pkg, False)]
    
    @classmethod
    def is_environment_ready(cls) -> bool:
        """Check if environment has all required dependencies."""
        missing = cls.get_missing_required_packages()
        if missing:
            logger.error(f"Missing required packages: {missing}")
            return False
        return True