Spaces:
Sleeping
Sleeping
| """ | |
| Model loading module with robust error handling and environment adaptation. | |
| """ | |
| import logging | |
| import torch | |
| from typing import Optional, Tuple, Any | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from .config.model_config import ModelConfig, EnvironmentDetector, DependencyValidator | |
| logger = logging.getLogger(__name__) | |
| class ModelLoader: | |
| """Handles model loading with environment-specific optimizations.""" | |
| def __init__(self): | |
| self.config: Optional[ModelConfig] = None | |
| self.model: Optional[Any] = None | |
| self.tokenizer: Optional[Any] = None | |
| self.pipeline: Optional[Any] = None | |
| self._is_loaded = False | |
| def validate_environment(self) -> bool: | |
| """Validate that the environment is ready for model loading.""" | |
| logger.info("π Validating environment...") | |
| # Check dependencies | |
| if not DependencyValidator.is_environment_ready(): | |
| logger.error("β Environment validation failed - missing dependencies") | |
| return False | |
| # Log environment info | |
| env_info = EnvironmentDetector.detect_environment() | |
| logger.info(f"π Environment info: {env_info}") | |
| return True | |
| def create_config( | |
| self, | |
| model_id: Optional[str] = None, | |
| revision: Optional[str] = None | |
| ) -> ModelConfig: | |
| """Create model configuration based on environment.""" | |
| logger.info("βοΈ Creating model configuration...") | |
| self.config = EnvironmentDetector.create_model_config(model_id, revision) | |
| logger.info(f"π Model config created:") | |
| logger.info(f" Model ID: {self.config.model_id}") | |
| logger.info(f" Revision: {self.config.revision or 'latest'}") | |
| logger.info(f" Device: {self.config.device_map}") | |
| logger.info(f" Dtype: {self.config.dtype}") | |
| logger.info(f" Attention: {self.config.attn_implementation}") | |
| return self.config | |
| def load_tokenizer(self) -> bool: | |
| """Load the tokenizer.""" | |
| if not self.config: | |
| logger.error("β No configuration available") | |
| return False | |
| try: | |
| logger.info("π Loading tokenizer...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.config.model_id, | |
| trust_remote_code=self.config.trust_remote_code, | |
| revision=self.config.revision | |
| ) | |
| logger.info("β Tokenizer loaded successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Failed to load tokenizer: {e}") | |
| return False | |
| def load_model(self) -> bool: | |
| """Load the model with environment-specific configuration.""" | |
| if not self.config: | |
| logger.error("β No configuration available") | |
| return False | |
| try: | |
| logger.info("π€ Loading model...") | |
| logger.info(f" This may take several minutes for {self.config.model_id}") | |
| # Load model with configuration | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.config.model_id, | |
| trust_remote_code=self.config.trust_remote_code, | |
| revision=self.config.revision, | |
| attn_implementation=self.config.attn_implementation, | |
| dtype=self.config.dtype, # Use dtype instead of deprecated torch_dtype | |
| device_map=self.config.device_map, | |
| low_cpu_mem_usage=self.config.low_cpu_mem_usage | |
| ).eval() | |
| logger.info("β Model loaded successfully") | |
| # Log model info | |
| if hasattr(self.model, 'config'): | |
| logger.info(f"π Model info:") | |
| logger.info(f" Architecture: {getattr(self.model.config, 'architectures', 'unknown')}") | |
| logger.info(f" Parameters: ~{self.model.num_parameters() / 1e9:.1f}B") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Failed to load model: {e}") | |
| return False | |
| def create_pipeline(self) -> bool: | |
| """Create inference pipeline.""" | |
| if not self.model or not self.tokenizer: | |
| logger.error("β Model or tokenizer not loaded") | |
| return False | |
| try: | |
| logger.info("π§ Creating inference pipeline...") | |
| self.pipeline = pipeline( | |
| "text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| dtype=self.config.dtype, # Use dtype instead of deprecated torch_dtype | |
| device_map=self.config.device_map, | |
| trust_remote_code=self.config.trust_remote_code | |
| ) | |
| logger.info("β Pipeline created successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Failed to create pipeline: {e}") | |
| return False | |
| def load_complete_model( | |
| self, | |
| model_id: Optional[str] = None, | |
| revision: Optional[str] = None | |
| ) -> bool: | |
| """Load complete model (tokenizer + model + pipeline).""" | |
| logger.info("π Starting complete model loading process...") | |
| try: | |
| # Validate environment | |
| if not self.validate_environment(): | |
| return False | |
| # Create configuration | |
| self.create_config(model_id, revision) | |
| # Load components in order | |
| if not self.load_tokenizer(): | |
| return False | |
| if not self.load_model(): | |
| return False | |
| if not self.create_pipeline(): | |
| return False | |
| # Run smoke test | |
| if not self.smoke_test(): | |
| logger.warning("β οΈ Smoke test failed, but model appears loaded") | |
| self._is_loaded = True | |
| logger.info("π Model loading completed successfully!") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Complete model loading failed: {e}") | |
| return False | |
| def smoke_test(self) -> bool: | |
| """Run a quick smoke test to verify model works.""" | |
| if not self.pipeline: | |
| return False | |
| try: | |
| logger.info("π§ͺ Running smoke test...") | |
| # Simple test generation | |
| test_input = "Hello" | |
| result = self.pipeline( | |
| test_input, | |
| max_new_tokens=4, | |
| do_sample=False, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| if result and len(result) > 0: | |
| logger.info("β Smoke test passed") | |
| return True | |
| else: | |
| logger.warning("β οΈ Smoke test returned empty result") | |
| return False | |
| except Exception as e: | |
| logger.warning(f"β οΈ Smoke test failed: {e}") | |
| return False | |
| def is_loaded(self) -> bool: | |
| """Check if model is fully loaded and ready.""" | |
| return self._is_loaded and self.pipeline is not None | |
| def get_model_info(self) -> dict: | |
| """Get information about the loaded model.""" | |
| if not self.is_loaded: | |
| return {"status": "not_loaded"} | |
| info = { | |
| "status": "loaded", | |
| "model_id": self.config.model_id, | |
| "revision": self.config.revision, | |
| "device": self.config.device_map, | |
| "dtype": str(self.config.dtype), | |
| "attention": self.config.attn_implementation, | |
| "device_info": self.config.device_info | |
| } | |
| if hasattr(self.model, 'config'): | |
| info["architecture"] = getattr(self.model.config, 'architectures', 'unknown') | |
| info["parameters"] = f"~{self.model.num_parameters() / 1e9:.1f}B" | |
| return info | |