""" Model Registry - Centralized model configuration and management Provides pre-configured models with their specs, LoRA settings, and quantization recommendations. """ from dataclasses import dataclass from typing import Dict, List, Optional @dataclass class ModelInfo: """Information about a model in the registry""" name: str model_id: str # HuggingFace model ID description: str vram_required_gb: int recommended_lora_rank: int recommended_quantization: str model_type: str # "local" or "cloud" ollama_equivalent: Optional[str] = None class ModelRegistry: """Registry of pre-configured models for fine-tuning""" def __init__(self): self.models: Dict[str, ModelInfo] = {} self._register_default_models() def _register_default_models(self): """Register default models with their configurations""" # Qwen Models self.models["qwen2.5-7b"] = ModelInfo( name="Qwen 2.5 7B Instruct", model_id="Qwen/Qwen2.5-7B-Instruct", description="Fast 7B parameter model, good for quick testing", vram_required_gb=6, recommended_lora_rank=16, recommended_quantization="4bit", model_type="cloud", ollama_equivalent="qwen2.5:7b" ) self.models["qwen2.5-32b"] = ModelInfo( name="Qwen 2.5 32B Instruct", model_id="Qwen/Qwen2.5-32B-Instruct", description="High-quality 32B model for production use", vram_required_gb=24, recommended_lora_rank=32, recommended_quantization="4bit", model_type="cloud", ollama_equivalent="qwen2.5:32b" ) self.models["qwen2.5-72b"] = ModelInfo( name="Qwen 2.5 72B Instruct", model_id="Qwen/Qwen2.5-72B-Instruct", description="Largest Qwen 2.5 model for maximum performance", vram_required_gb=48, recommended_lora_rank=64, recommended_quantization="4bit", model_type="cloud", ollama_equivalent="qwen2.5:72b" ) # Llama Models self.models["llama-3.1-8b"] = ModelInfo( name="Llama 3.1 8B Instruct", model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", description="Meta's Llama 3.1 8B model", vram_required_gb=8, recommended_lora_rank=16, recommended_quantization="4bit", model_type="cloud", ollama_equivalent="llama3.1:8b" ) self.models["llama-3.1-70b"] = ModelInfo( name="Llama 3.1 70B Instruct", model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", description="Large Llama model for maximum performance", vram_required_gb=48, recommended_lora_rank=64, recommended_quantization="4bit", model_type="cloud", ollama_equivalent="llama3.1:70b" ) # Mistral Models self.models["mistral-7b"] = ModelInfo( name="Mistral 7B Instruct v0.3", model_id="mistralai/Mistral-7B-Instruct-v0.3", description="Efficient 7B model from Mistral AI", vram_required_gb=6, recommended_lora_rank=16, recommended_quantization="4bit", model_type="cloud", ollama_equivalent="mistral:7b" ) # Mixtral (MoE) self.models["mixtral-8x7b"] = ModelInfo( name="Mixtral 8x7B Instruct", model_id="mistralai/Mixtral-8x7B-Instruct-v0.1", description="Mixture-of-Experts model with 8x7B parameters", vram_required_gb=40, recommended_lora_rank=32, recommended_quantization="4bit", model_type="cloud", ollama_equivalent="mixtral:8x7b" ) self.models["mixtral-8x22b"] = ModelInfo( name="Mixtral 8x22B Instruct", model_id="mistralai/Mixtral-8x22B-Instruct-v0.1", description="Large MoE model for highest quality", vram_required_gb=80, recommended_lora_rank=64, recommended_quantization="4bit", model_type="cloud", ollama_equivalent=None ) # Phi Models (Microsoft) self.models["phi-3-mini"] = ModelInfo( name="Phi-3 Mini 3.8B", model_id="microsoft/Phi-3-mini-4k-instruct", description="Small efficient model for quick testing", vram_required_gb=4, recommended_lora_rank=8, recommended_quantization="4bit", model_type="cloud", ollama_equivalent="phi3:mini" ) # Gemma Models (Google) self.models["gemma-7b"] = ModelInfo( name="Gemma 7B Instruct", model_id="google/gemma-7b-it", description="Google's Gemma 7B instruction-tuned model", vram_required_gb=6, recommended_lora_rank=16, recommended_quantization="4bit", model_type="cloud", ollama_equivalent="gemma:7b" ) def get_model(self, model_id: str) -> Optional[ModelInfo]: """Get model info by ID""" return self.models.get(model_id) def get_all_models(self) -> Dict[str, ModelInfo]: """Get all registered models""" return self.models def register_custom_model(self, model_id: str, info: ModelInfo): """Register a custom model""" self.models[model_id] = info def get_models_by_vram(self, max_vram_gb: int) -> List[ModelInfo]: """Get models that fit within VRAM budget""" return [ info for info in self.models.values() if info.vram_required_gb <= max_vram_gb ] def get_model_choices_for_gui(self) -> List[str]: """Get list of model choices formatted for GUI dropdown""" choices = [] for model_id, info in self.models.items(): label = f"{info.name} ({info.vram_required_gb}GB VRAM)" choices.append((label, model_id)) return choices def get_model_names(self) -> List[str]: """Get list of model names""" return [info.name for info in self.models.values()] def get_model_ids(self) -> List[str]: """Get list of model IDs""" return list(self.models.keys()) def list_models(self) -> List[str]: """Alias for get_model_ids() - returns list of model IDs""" return self.get_model_ids() # Global registry instance _registry = None def get_registry() -> ModelRegistry: """Get the global model registry instance""" global _registry if _registry is None: _registry = ModelRegistry() return _registry # Convenience function def get_model_info(model_id: str) -> Optional[ModelInfo]: """Get model info by ID from global registry""" return get_registry().get_model(model_id)