Spaces:
Sleeping
Sleeping
| """ | |
| Model configuration and environment detection module. | |
| This module handles: | |
| - Environment detection (CPU/GPU) | |
| - Model configuration based on environment | |
| - Dependency validation | |
| - Safe defaults for different environments | |
| """ | |
| import os | |
| import torch | |
| from typing import Dict, Any, Optional | |
| from dataclasses import dataclass | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class ModelConfig: | |
| """Configuration for model loading and inference.""" | |
| model_id: str | |
| revision: Optional[str] | |
| dtype: torch.dtype | |
| device_map: str | |
| attn_implementation: str | |
| low_cpu_mem_usage: bool | |
| trust_remote_code: bool | |
| def is_gpu_available(self) -> bool: | |
| """Check if GPU is available.""" | |
| return torch.cuda.is_available() | |
| def device_info(self) -> Dict[str, Any]: | |
| """Get device information.""" | |
| info = { | |
| "cuda_available": torch.cuda.is_available(), | |
| "device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, | |
| "current_device": torch.cuda.current_device() if torch.cuda.is_available() else None, | |
| } | |
| if torch.cuda.is_available(): | |
| info["device_name"] = torch.cuda.get_device_name() | |
| info["memory_allocated"] = torch.cuda.memory_allocated() | |
| info["memory_reserved"] = torch.cuda.memory_reserved() | |
| return info | |
| class EnvironmentDetector: | |
| """Detects and configures environment-specific settings.""" | |
| def detect_environment() -> Dict[str, Any]: | |
| """Detect current environment capabilities.""" | |
| env_info = { | |
| "cuda_available": torch.cuda.is_available(), | |
| "cuda_version": torch.version.cuda if torch.cuda.is_available() else None, | |
| "torch_version": torch.__version__, | |
| "platform": os.name, | |
| "python_version": os.sys.version, | |
| } | |
| # Check for flash_attn availability | |
| try: | |
| import importlib | |
| flash_attn = importlib.import_module("flash_attn") | |
| env_info["flash_attn_available"] = True | |
| env_info["flash_attn_version"] = getattr(flash_attn, "__version__", "unknown") | |
| except ImportError: | |
| env_info["flash_attn_available"] = False | |
| env_info["flash_attn_version"] = None | |
| # Check for einops availability | |
| try: | |
| import importlib | |
| einops = importlib.import_module("einops") | |
| env_info["einops_available"] = True | |
| env_info["einops_version"] = getattr(einops, "__version__", "unknown") | |
| except ImportError: | |
| env_info["einops_available"] = False | |
| env_info["einops_version"] = None | |
| logger.info(f"Environment detected: {env_info}") | |
| return env_info | |
| def create_model_config( | |
| model_id: Optional[str] = None, | |
| revision: Optional[str] = None | |
| ) -> ModelConfig: | |
| """Create model configuration based on environment.""" | |
| # Default model | |
| if model_id is None: | |
| model_id = os.getenv("HF_MODEL_ID", "microsoft/Phi-3.5-MoE-instruct") | |
| # Get revision from environment if not provided | |
| if revision is None: | |
| revision = os.getenv("HF_REVISION") | |
| # Detect environment | |
| is_gpu = torch.cuda.is_available() | |
| # Configure based on environment | |
| if is_gpu: | |
| # GPU configuration - optimized for performance | |
| config = ModelConfig( | |
| model_id=model_id, | |
| revision=revision, | |
| dtype=torch.bfloat16, # Use bfloat16 for better GPU performance | |
| device_map="auto", | |
| attn_implementation="sdpa", # Use scaled dot-product attention | |
| low_cpu_mem_usage=False, | |
| trust_remote_code=True | |
| ) | |
| logger.info("Created GPU-optimized model configuration") | |
| else: | |
| # CPU configuration - optimized for compatibility | |
| config = ModelConfig( | |
| model_id=model_id, | |
| revision=revision, | |
| dtype=torch.float32, # Use float32 for CPU compatibility | |
| device_map="cpu", | |
| attn_implementation="eager", # Use eager attention for CPU | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| logger.info("Created CPU-optimized model configuration") | |
| return config | |
| class DependencyValidator: | |
| """Validates required dependencies are available.""" | |
| REQUIRED_PACKAGES = [ | |
| "transformers", | |
| "accelerate", | |
| "einops", | |
| "huggingface_hub", | |
| "gradio", | |
| "torch" | |
| ] | |
| OPTIONAL_PACKAGES = [ | |
| "flash_attn" # Only required for GPU with certain model revisions | |
| ] | |
| def validate_dependencies(cls) -> Dict[str, bool]: | |
| """Validate all dependencies.""" | |
| results = {} | |
| # Check required packages | |
| for package in cls.REQUIRED_PACKAGES: | |
| try: | |
| import importlib | |
| importlib.import_module(package) | |
| results[package] = True | |
| logger.debug(f"β {package} is available") | |
| except ImportError: | |
| results[package] = False | |
| logger.error(f"β {package} is missing") | |
| # Check optional packages | |
| for package in cls.OPTIONAL_PACKAGES: | |
| try: | |
| import importlib | |
| importlib.import_module(package) | |
| results[package] = True | |
| logger.debug(f"β {package} (optional) is available") | |
| except ImportError: | |
| results[package] = False | |
| logger.debug(f"β οΈ {package} (optional) is missing") | |
| return results | |
| def get_missing_required_packages(cls) -> list: | |
| """Get list of missing required packages.""" | |
| validation = cls.validate_dependencies() | |
| return [pkg for pkg in cls.REQUIRED_PACKAGES if not validation.get(pkg, False)] | |
| def is_environment_ready(cls) -> bool: | |
| """Check if environment has all required dependencies.""" | |
| missing = cls.get_missing_required_packages() | |
| if missing: | |
| logger.error(f"Missing required packages: {missing}") | |
| return False | |
| return True | |