Spaces:
Sleeping
Sleeping
File size: 6,587 Bytes
6510698 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
"""
Model configuration and environment detection module.
This module handles:
- Environment detection (CPU/GPU)
- Model configuration based on environment
- Dependency validation
- Safe defaults for different environments
"""
import os
import torch
from typing import Dict, Any, Optional
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
@dataclass
class ModelConfig:
"""Configuration for model loading and inference."""
model_id: str
revision: Optional[str]
dtype: torch.dtype
device_map: str
attn_implementation: str
low_cpu_mem_usage: bool
trust_remote_code: bool
@property
def is_gpu_available(self) -> bool:
"""Check if GPU is available."""
return torch.cuda.is_available()
@property
def device_info(self) -> Dict[str, Any]:
"""Get device information."""
info = {
"cuda_available": torch.cuda.is_available(),
"device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
"current_device": torch.cuda.current_device() if torch.cuda.is_available() else None,
}
if torch.cuda.is_available():
info["device_name"] = torch.cuda.get_device_name()
info["memory_allocated"] = torch.cuda.memory_allocated()
info["memory_reserved"] = torch.cuda.memory_reserved()
return info
class EnvironmentDetector:
"""Detects and configures environment-specific settings."""
@staticmethod
def detect_environment() -> Dict[str, Any]:
"""Detect current environment capabilities."""
env_info = {
"cuda_available": torch.cuda.is_available(),
"cuda_version": torch.version.cuda if torch.cuda.is_available() else None,
"torch_version": torch.__version__,
"platform": os.name,
"python_version": os.sys.version,
}
# Check for flash_attn availability
try:
import importlib
flash_attn = importlib.import_module("flash_attn")
env_info["flash_attn_available"] = True
env_info["flash_attn_version"] = getattr(flash_attn, "__version__", "unknown")
except ImportError:
env_info["flash_attn_available"] = False
env_info["flash_attn_version"] = None
# Check for einops availability
try:
import importlib
einops = importlib.import_module("einops")
env_info["einops_available"] = True
env_info["einops_version"] = getattr(einops, "__version__", "unknown")
except ImportError:
env_info["einops_available"] = False
env_info["einops_version"] = None
logger.info(f"Environment detected: {env_info}")
return env_info
@staticmethod
def create_model_config(
model_id: Optional[str] = None,
revision: Optional[str] = None
) -> ModelConfig:
"""Create model configuration based on environment."""
# Default model
if model_id is None:
model_id = os.getenv("HF_MODEL_ID", "microsoft/Phi-3.5-MoE-instruct")
# Get revision from environment if not provided
if revision is None:
revision = os.getenv("HF_REVISION")
# Detect environment
is_gpu = torch.cuda.is_available()
# Configure based on environment
if is_gpu:
# GPU configuration - optimized for performance
config = ModelConfig(
model_id=model_id,
revision=revision,
dtype=torch.bfloat16, # Use bfloat16 for better GPU performance
device_map="auto",
attn_implementation="sdpa", # Use scaled dot-product attention
low_cpu_mem_usage=False,
trust_remote_code=True
)
logger.info("Created GPU-optimized model configuration")
else:
# CPU configuration - optimized for compatibility
config = ModelConfig(
model_id=model_id,
revision=revision,
dtype=torch.float32, # Use float32 for CPU compatibility
device_map="cpu",
attn_implementation="eager", # Use eager attention for CPU
low_cpu_mem_usage=True,
trust_remote_code=True
)
logger.info("Created CPU-optimized model configuration")
return config
class DependencyValidator:
"""Validates required dependencies are available."""
REQUIRED_PACKAGES = [
"transformers",
"accelerate",
"einops",
"huggingface_hub",
"gradio",
"torch"
]
OPTIONAL_PACKAGES = [
"flash_attn" # Only required for GPU with certain model revisions
]
@classmethod
def validate_dependencies(cls) -> Dict[str, bool]:
"""Validate all dependencies."""
results = {}
# Check required packages
for package in cls.REQUIRED_PACKAGES:
try:
import importlib
importlib.import_module(package)
results[package] = True
logger.debug(f"✅ {package} is available")
except ImportError:
results[package] = False
logger.error(f"❌ {package} is missing")
# Check optional packages
for package in cls.OPTIONAL_PACKAGES:
try:
import importlib
importlib.import_module(package)
results[package] = True
logger.debug(f"✅ {package} (optional) is available")
except ImportError:
results[package] = False
logger.debug(f"⚠️ {package} (optional) is missing")
return results
@classmethod
def get_missing_required_packages(cls) -> list:
"""Get list of missing required packages."""
validation = cls.validate_dependencies()
return [pkg for pkg in cls.REQUIRED_PACKAGES if not validation.get(pkg, False)]
@classmethod
def is_environment_ready(cls) -> bool:
"""Check if environment has all required dependencies."""
missing = cls.get_missing_required_packages()
if missing:
logger.error(f"Missing required packages: {missing}")
return False
return True
|