phi35-moe-demo / app /config /model_config.py
ianshank's picture
πŸš€ Deploy robust modular solution with comprehensive testing and CPU/GPU support
6510698 verified
"""
Model configuration and environment detection module.
This module handles:
- Environment detection (CPU/GPU)
- Model configuration based on environment
- Dependency validation
- Safe defaults for different environments
"""
import os
import torch
from typing import Dict, Any, Optional
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
@dataclass
class ModelConfig:
"""Configuration for model loading and inference."""
model_id: str
revision: Optional[str]
dtype: torch.dtype
device_map: str
attn_implementation: str
low_cpu_mem_usage: bool
trust_remote_code: bool
@property
def is_gpu_available(self) -> bool:
"""Check if GPU is available."""
return torch.cuda.is_available()
@property
def device_info(self) -> Dict[str, Any]:
"""Get device information."""
info = {
"cuda_available": torch.cuda.is_available(),
"device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
"current_device": torch.cuda.current_device() if torch.cuda.is_available() else None,
}
if torch.cuda.is_available():
info["device_name"] = torch.cuda.get_device_name()
info["memory_allocated"] = torch.cuda.memory_allocated()
info["memory_reserved"] = torch.cuda.memory_reserved()
return info
class EnvironmentDetector:
"""Detects and configures environment-specific settings."""
@staticmethod
def detect_environment() -> Dict[str, Any]:
"""Detect current environment capabilities."""
env_info = {
"cuda_available": torch.cuda.is_available(),
"cuda_version": torch.version.cuda if torch.cuda.is_available() else None,
"torch_version": torch.__version__,
"platform": os.name,
"python_version": os.sys.version,
}
# Check for flash_attn availability
try:
import importlib
flash_attn = importlib.import_module("flash_attn")
env_info["flash_attn_available"] = True
env_info["flash_attn_version"] = getattr(flash_attn, "__version__", "unknown")
except ImportError:
env_info["flash_attn_available"] = False
env_info["flash_attn_version"] = None
# Check for einops availability
try:
import importlib
einops = importlib.import_module("einops")
env_info["einops_available"] = True
env_info["einops_version"] = getattr(einops, "__version__", "unknown")
except ImportError:
env_info["einops_available"] = False
env_info["einops_version"] = None
logger.info(f"Environment detected: {env_info}")
return env_info
@staticmethod
def create_model_config(
model_id: Optional[str] = None,
revision: Optional[str] = None
) -> ModelConfig:
"""Create model configuration based on environment."""
# Default model
if model_id is None:
model_id = os.getenv("HF_MODEL_ID", "microsoft/Phi-3.5-MoE-instruct")
# Get revision from environment if not provided
if revision is None:
revision = os.getenv("HF_REVISION")
# Detect environment
is_gpu = torch.cuda.is_available()
# Configure based on environment
if is_gpu:
# GPU configuration - optimized for performance
config = ModelConfig(
model_id=model_id,
revision=revision,
dtype=torch.bfloat16, # Use bfloat16 for better GPU performance
device_map="auto",
attn_implementation="sdpa", # Use scaled dot-product attention
low_cpu_mem_usage=False,
trust_remote_code=True
)
logger.info("Created GPU-optimized model configuration")
else:
# CPU configuration - optimized for compatibility
config = ModelConfig(
model_id=model_id,
revision=revision,
dtype=torch.float32, # Use float32 for CPU compatibility
device_map="cpu",
attn_implementation="eager", # Use eager attention for CPU
low_cpu_mem_usage=True,
trust_remote_code=True
)
logger.info("Created CPU-optimized model configuration")
return config
class DependencyValidator:
"""Validates required dependencies are available."""
REQUIRED_PACKAGES = [
"transformers",
"accelerate",
"einops",
"huggingface_hub",
"gradio",
"torch"
]
OPTIONAL_PACKAGES = [
"flash_attn" # Only required for GPU with certain model revisions
]
@classmethod
def validate_dependencies(cls) -> Dict[str, bool]:
"""Validate all dependencies."""
results = {}
# Check required packages
for package in cls.REQUIRED_PACKAGES:
try:
import importlib
importlib.import_module(package)
results[package] = True
logger.debug(f"βœ… {package} is available")
except ImportError:
results[package] = False
logger.error(f"❌ {package} is missing")
# Check optional packages
for package in cls.OPTIONAL_PACKAGES:
try:
import importlib
importlib.import_module(package)
results[package] = True
logger.debug(f"βœ… {package} (optional) is available")
except ImportError:
results[package] = False
logger.debug(f"⚠️ {package} (optional) is missing")
return results
@classmethod
def get_missing_required_packages(cls) -> list:
"""Get list of missing required packages."""
validation = cls.validate_dependencies()
return [pkg for pkg in cls.REQUIRED_PACKAGES if not validation.get(pkg, False)]
@classmethod
def is_environment_ready(cls) -> bool:
"""Check if environment has all required dependencies."""
missing = cls.get_missing_required_packages()
if missing:
logger.error(f"Missing required packages: {missing}")
return False
return True