LaunchLLM / model_registry.py
Bmccloud22's picture
Deploy LaunchLLM - Production AI Training Platform
ec8f374 verified
"""
Model Registry - Centralized model configuration and management
Provides pre-configured models with their specs, LoRA settings, and quantization recommendations.
"""
from dataclasses import dataclass
from typing import Dict, List, Optional
@dataclass
class ModelInfo:
"""Information about a model in the registry"""
name: str
model_id: str # HuggingFace model ID
description: str
vram_required_gb: int
recommended_lora_rank: int
recommended_quantization: str
model_type: str # "local" or "cloud"
ollama_equivalent: Optional[str] = None
class ModelRegistry:
"""Registry of pre-configured models for fine-tuning"""
def __init__(self):
self.models: Dict[str, ModelInfo] = {}
self._register_default_models()
def _register_default_models(self):
"""Register default models with their configurations"""
# Qwen Models
self.models["qwen2.5-7b"] = ModelInfo(
name="Qwen 2.5 7B Instruct",
model_id="Qwen/Qwen2.5-7B-Instruct",
description="Fast 7B parameter model, good for quick testing",
vram_required_gb=6,
recommended_lora_rank=16,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="qwen2.5:7b"
)
self.models["qwen2.5-32b"] = ModelInfo(
name="Qwen 2.5 32B Instruct",
model_id="Qwen/Qwen2.5-32B-Instruct",
description="High-quality 32B model for production use",
vram_required_gb=24,
recommended_lora_rank=32,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="qwen2.5:32b"
)
self.models["qwen2.5-72b"] = ModelInfo(
name="Qwen 2.5 72B Instruct",
model_id="Qwen/Qwen2.5-72B-Instruct",
description="Largest Qwen 2.5 model for maximum performance",
vram_required_gb=48,
recommended_lora_rank=64,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="qwen2.5:72b"
)
# Llama Models
self.models["llama-3.1-8b"] = ModelInfo(
name="Llama 3.1 8B Instruct",
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
description="Meta's Llama 3.1 8B model",
vram_required_gb=8,
recommended_lora_rank=16,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="llama3.1:8b"
)
self.models["llama-3.1-70b"] = ModelInfo(
name="Llama 3.1 70B Instruct",
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
description="Large Llama model for maximum performance",
vram_required_gb=48,
recommended_lora_rank=64,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="llama3.1:70b"
)
# Mistral Models
self.models["mistral-7b"] = ModelInfo(
name="Mistral 7B Instruct v0.3",
model_id="mistralai/Mistral-7B-Instruct-v0.3",
description="Efficient 7B model from Mistral AI",
vram_required_gb=6,
recommended_lora_rank=16,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="mistral:7b"
)
# Mixtral (MoE)
self.models["mixtral-8x7b"] = ModelInfo(
name="Mixtral 8x7B Instruct",
model_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
description="Mixture-of-Experts model with 8x7B parameters",
vram_required_gb=40,
recommended_lora_rank=32,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="mixtral:8x7b"
)
self.models["mixtral-8x22b"] = ModelInfo(
name="Mixtral 8x22B Instruct",
model_id="mistralai/Mixtral-8x22B-Instruct-v0.1",
description="Large MoE model for highest quality",
vram_required_gb=80,
recommended_lora_rank=64,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent=None
)
# Phi Models (Microsoft)
self.models["phi-3-mini"] = ModelInfo(
name="Phi-3 Mini 3.8B",
model_id="microsoft/Phi-3-mini-4k-instruct",
description="Small efficient model for quick testing",
vram_required_gb=4,
recommended_lora_rank=8,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="phi3:mini"
)
# Gemma Models (Google)
self.models["gemma-7b"] = ModelInfo(
name="Gemma 7B Instruct",
model_id="google/gemma-7b-it",
description="Google's Gemma 7B instruction-tuned model",
vram_required_gb=6,
recommended_lora_rank=16,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="gemma:7b"
)
def get_model(self, model_id: str) -> Optional[ModelInfo]:
"""Get model info by ID"""
return self.models.get(model_id)
def get_all_models(self) -> Dict[str, ModelInfo]:
"""Get all registered models"""
return self.models
def register_custom_model(self, model_id: str, info: ModelInfo):
"""Register a custom model"""
self.models[model_id] = info
def get_models_by_vram(self, max_vram_gb: int) -> List[ModelInfo]:
"""Get models that fit within VRAM budget"""
return [
info for info in self.models.values()
if info.vram_required_gb <= max_vram_gb
]
def get_model_choices_for_gui(self) -> List[str]:
"""Get list of model choices formatted for GUI dropdown"""
choices = []
for model_id, info in self.models.items():
label = f"{info.name} ({info.vram_required_gb}GB VRAM)"
choices.append((label, model_id))
return choices
def get_model_names(self) -> List[str]:
"""Get list of model names"""
return [info.name for info in self.models.values()]
def get_model_ids(self) -> List[str]:
"""Get list of model IDs"""
return list(self.models.keys())
def list_models(self) -> List[str]:
"""Alias for get_model_ids() - returns list of model IDs"""
return self.get_model_ids()
# Global registry instance
_registry = None
def get_registry() -> ModelRegistry:
"""Get the global model registry instance"""
global _registry
if _registry is None:
_registry = ModelRegistry()
return _registry
# Convenience function
def get_model_info(model_id: str) -> Optional[ModelInfo]:
"""Get model info by ID from global registry"""
return get_registry().get_model(model_id)