Spaces:
Runtime error
Runtime error
| """ | |
| Model Registry - Centralized model configuration and management | |
| Provides pre-configured models with their specs, LoRA settings, and quantization recommendations. | |
| """ | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional | |
| class ModelInfo: | |
| """Information about a model in the registry""" | |
| name: str | |
| model_id: str # HuggingFace model ID | |
| description: str | |
| vram_required_gb: int | |
| recommended_lora_rank: int | |
| recommended_quantization: str | |
| model_type: str # "local" or "cloud" | |
| ollama_equivalent: Optional[str] = None | |
| class ModelRegistry: | |
| """Registry of pre-configured models for fine-tuning""" | |
| def __init__(self): | |
| self.models: Dict[str, ModelInfo] = {} | |
| self._register_default_models() | |
| def _register_default_models(self): | |
| """Register default models with their configurations""" | |
| # Qwen Models | |
| self.models["qwen2.5-7b"] = ModelInfo( | |
| name="Qwen 2.5 7B Instruct", | |
| model_id="Qwen/Qwen2.5-7B-Instruct", | |
| description="Fast 7B parameter model, good for quick testing", | |
| vram_required_gb=6, | |
| recommended_lora_rank=16, | |
| recommended_quantization="4bit", | |
| model_type="cloud", | |
| ollama_equivalent="qwen2.5:7b" | |
| ) | |
| self.models["qwen2.5-32b"] = ModelInfo( | |
| name="Qwen 2.5 32B Instruct", | |
| model_id="Qwen/Qwen2.5-32B-Instruct", | |
| description="High-quality 32B model for production use", | |
| vram_required_gb=24, | |
| recommended_lora_rank=32, | |
| recommended_quantization="4bit", | |
| model_type="cloud", | |
| ollama_equivalent="qwen2.5:32b" | |
| ) | |
| self.models["qwen2.5-72b"] = ModelInfo( | |
| name="Qwen 2.5 72B Instruct", | |
| model_id="Qwen/Qwen2.5-72B-Instruct", | |
| description="Largest Qwen 2.5 model for maximum performance", | |
| vram_required_gb=48, | |
| recommended_lora_rank=64, | |
| recommended_quantization="4bit", | |
| model_type="cloud", | |
| ollama_equivalent="qwen2.5:72b" | |
| ) | |
| # Llama Models | |
| self.models["llama-3.1-8b"] = ModelInfo( | |
| name="Llama 3.1 8B Instruct", | |
| model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", | |
| description="Meta's Llama 3.1 8B model", | |
| vram_required_gb=8, | |
| recommended_lora_rank=16, | |
| recommended_quantization="4bit", | |
| model_type="cloud", | |
| ollama_equivalent="llama3.1:8b" | |
| ) | |
| self.models["llama-3.1-70b"] = ModelInfo( | |
| name="Llama 3.1 70B Instruct", | |
| model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", | |
| description="Large Llama model for maximum performance", | |
| vram_required_gb=48, | |
| recommended_lora_rank=64, | |
| recommended_quantization="4bit", | |
| model_type="cloud", | |
| ollama_equivalent="llama3.1:70b" | |
| ) | |
| # Mistral Models | |
| self.models["mistral-7b"] = ModelInfo( | |
| name="Mistral 7B Instruct v0.3", | |
| model_id="mistralai/Mistral-7B-Instruct-v0.3", | |
| description="Efficient 7B model from Mistral AI", | |
| vram_required_gb=6, | |
| recommended_lora_rank=16, | |
| recommended_quantization="4bit", | |
| model_type="cloud", | |
| ollama_equivalent="mistral:7b" | |
| ) | |
| # Mixtral (MoE) | |
| self.models["mixtral-8x7b"] = ModelInfo( | |
| name="Mixtral 8x7B Instruct", | |
| model_id="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| description="Mixture-of-Experts model with 8x7B parameters", | |
| vram_required_gb=40, | |
| recommended_lora_rank=32, | |
| recommended_quantization="4bit", | |
| model_type="cloud", | |
| ollama_equivalent="mixtral:8x7b" | |
| ) | |
| self.models["mixtral-8x22b"] = ModelInfo( | |
| name="Mixtral 8x22B Instruct", | |
| model_id="mistralai/Mixtral-8x22B-Instruct-v0.1", | |
| description="Large MoE model for highest quality", | |
| vram_required_gb=80, | |
| recommended_lora_rank=64, | |
| recommended_quantization="4bit", | |
| model_type="cloud", | |
| ollama_equivalent=None | |
| ) | |
| # Phi Models (Microsoft) | |
| self.models["phi-3-mini"] = ModelInfo( | |
| name="Phi-3 Mini 3.8B", | |
| model_id="microsoft/Phi-3-mini-4k-instruct", | |
| description="Small efficient model for quick testing", | |
| vram_required_gb=4, | |
| recommended_lora_rank=8, | |
| recommended_quantization="4bit", | |
| model_type="cloud", | |
| ollama_equivalent="phi3:mini" | |
| ) | |
| # Gemma Models (Google) | |
| self.models["gemma-7b"] = ModelInfo( | |
| name="Gemma 7B Instruct", | |
| model_id="google/gemma-7b-it", | |
| description="Google's Gemma 7B instruction-tuned model", | |
| vram_required_gb=6, | |
| recommended_lora_rank=16, | |
| recommended_quantization="4bit", | |
| model_type="cloud", | |
| ollama_equivalent="gemma:7b" | |
| ) | |
| def get_model(self, model_id: str) -> Optional[ModelInfo]: | |
| """Get model info by ID""" | |
| return self.models.get(model_id) | |
| def get_all_models(self) -> Dict[str, ModelInfo]: | |
| """Get all registered models""" | |
| return self.models | |
| def register_custom_model(self, model_id: str, info: ModelInfo): | |
| """Register a custom model""" | |
| self.models[model_id] = info | |
| def get_models_by_vram(self, max_vram_gb: int) -> List[ModelInfo]: | |
| """Get models that fit within VRAM budget""" | |
| return [ | |
| info for info in self.models.values() | |
| if info.vram_required_gb <= max_vram_gb | |
| ] | |
| def get_model_choices_for_gui(self) -> List[str]: | |
| """Get list of model choices formatted for GUI dropdown""" | |
| choices = [] | |
| for model_id, info in self.models.items(): | |
| label = f"{info.name} ({info.vram_required_gb}GB VRAM)" | |
| choices.append((label, model_id)) | |
| return choices | |
| def get_model_names(self) -> List[str]: | |
| """Get list of model names""" | |
| return [info.name for info in self.models.values()] | |
| def get_model_ids(self) -> List[str]: | |
| """Get list of model IDs""" | |
| return list(self.models.keys()) | |
| def list_models(self) -> List[str]: | |
| """Alias for get_model_ids() - returns list of model IDs""" | |
| return self.get_model_ids() | |
| # Global registry instance | |
| _registry = None | |
| def get_registry() -> ModelRegistry: | |
| """Get the global model registry instance""" | |
| global _registry | |
| if _registry is None: | |
| _registry = ModelRegistry() | |
| return _registry | |
| # Convenience function | |
| def get_model_info(model_id: str) -> Optional[ModelInfo]: | |
| """Get model info by ID from global registry""" | |
| return get_registry().get_model(model_id) | |