""" Model loading module with robust error handling and environment adaptation. """ import logging import torch from typing import Optional, Tuple, Any from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from .config.model_config import ModelConfig, EnvironmentDetector, DependencyValidator logger = logging.getLogger(__name__) class ModelLoader: """Handles model loading with environment-specific optimizations.""" def __init__(self): self.config: Optional[ModelConfig] = None self.model: Optional[Any] = None self.tokenizer: Optional[Any] = None self.pipeline: Optional[Any] = None self._is_loaded = False def validate_environment(self) -> bool: """Validate that the environment is ready for model loading.""" logger.info("๐Ÿ” Validating environment...") # Check dependencies if not DependencyValidator.is_environment_ready(): logger.error("โŒ Environment validation failed - missing dependencies") return False # Log environment info env_info = EnvironmentDetector.detect_environment() logger.info(f"๐Ÿ“Š Environment info: {env_info}") return True def create_config( self, model_id: Optional[str] = None, revision: Optional[str] = None ) -> ModelConfig: """Create model configuration based on environment.""" logger.info("โš™๏ธ Creating model configuration...") self.config = EnvironmentDetector.create_model_config(model_id, revision) logger.info(f"๐Ÿ“‹ Model config created:") logger.info(f" Model ID: {self.config.model_id}") logger.info(f" Revision: {self.config.revision or 'latest'}") logger.info(f" Device: {self.config.device_map}") logger.info(f" Dtype: {self.config.dtype}") logger.info(f" Attention: {self.config.attn_implementation}") return self.config def load_tokenizer(self) -> bool: """Load the tokenizer.""" if not self.config: logger.error("โŒ No configuration available") return False try: logger.info("๐Ÿ“ Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained( self.config.model_id, trust_remote_code=self.config.trust_remote_code, revision=self.config.revision ) logger.info("โœ… Tokenizer loaded successfully") return True except Exception as e: logger.error(f"โŒ Failed to load tokenizer: {e}") return False def load_model(self) -> bool: """Load the model with environment-specific configuration.""" if not self.config: logger.error("โŒ No configuration available") return False try: logger.info("๐Ÿค– Loading model...") logger.info(f" This may take several minutes for {self.config.model_id}") # Load model with configuration self.model = AutoModelForCausalLM.from_pretrained( self.config.model_id, trust_remote_code=self.config.trust_remote_code, revision=self.config.revision, attn_implementation=self.config.attn_implementation, dtype=self.config.dtype, # Use dtype instead of deprecated torch_dtype device_map=self.config.device_map, low_cpu_mem_usage=self.config.low_cpu_mem_usage ).eval() logger.info("โœ… Model loaded successfully") # Log model info if hasattr(self.model, 'config'): logger.info(f"๐Ÿ“Š Model info:") logger.info(f" Architecture: {getattr(self.model.config, 'architectures', 'unknown')}") logger.info(f" Parameters: ~{self.model.num_parameters() / 1e9:.1f}B") return True except Exception as e: logger.error(f"โŒ Failed to load model: {e}") return False def create_pipeline(self) -> bool: """Create inference pipeline.""" if not self.model or not self.tokenizer: logger.error("โŒ Model or tokenizer not loaded") return False try: logger.info("๐Ÿ”ง Creating inference pipeline...") self.pipeline = pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer, dtype=self.config.dtype, # Use dtype instead of deprecated torch_dtype device_map=self.config.device_map, trust_remote_code=self.config.trust_remote_code ) logger.info("โœ… Pipeline created successfully") return True except Exception as e: logger.error(f"โŒ Failed to create pipeline: {e}") return False def load_complete_model( self, model_id: Optional[str] = None, revision: Optional[str] = None ) -> bool: """Load complete model (tokenizer + model + pipeline).""" logger.info("๐Ÿš€ Starting complete model loading process...") try: # Validate environment if not self.validate_environment(): return False # Create configuration self.create_config(model_id, revision) # Load components in order if not self.load_tokenizer(): return False if not self.load_model(): return False if not self.create_pipeline(): return False # Run smoke test if not self.smoke_test(): logger.warning("โš ๏ธ Smoke test failed, but model appears loaded") self._is_loaded = True logger.info("๐ŸŽ‰ Model loading completed successfully!") return True except Exception as e: logger.error(f"โŒ Complete model loading failed: {e}") return False def smoke_test(self) -> bool: """Run a quick smoke test to verify model works.""" if not self.pipeline: return False try: logger.info("๐Ÿงช Running smoke test...") # Simple test generation test_input = "Hello" result = self.pipeline( test_input, max_new_tokens=4, do_sample=False, pad_token_id=self.tokenizer.eos_token_id ) if result and len(result) > 0: logger.info("โœ… Smoke test passed") return True else: logger.warning("โš ๏ธ Smoke test returned empty result") return False except Exception as e: logger.warning(f"โš ๏ธ Smoke test failed: {e}") return False @property def is_loaded(self) -> bool: """Check if model is fully loaded and ready.""" return self._is_loaded and self.pipeline is not None def get_model_info(self) -> dict: """Get information about the loaded model.""" if not self.is_loaded: return {"status": "not_loaded"} info = { "status": "loaded", "model_id": self.config.model_id, "revision": self.config.revision, "device": self.config.device_map, "dtype": str(self.config.dtype), "attention": self.config.attn_implementation, "device_info": self.config.device_info } if hasattr(self.model, 'config'): info["architecture"] = getattr(self.model.config, 'architectures', 'unknown') info["parameters"] = f"~{self.model.num_parameters() / 1e9:.1f}B" return info