Spaces:
Runtime error
Runtime error
File size: 6,993 Bytes
ec8f374 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
"""
Model Registry - Centralized model configuration and management
Provides pre-configured models with their specs, LoRA settings, and quantization recommendations.
"""
from dataclasses import dataclass
from typing import Dict, List, Optional
@dataclass
class ModelInfo:
"""Information about a model in the registry"""
name: str
model_id: str # HuggingFace model ID
description: str
vram_required_gb: int
recommended_lora_rank: int
recommended_quantization: str
model_type: str # "local" or "cloud"
ollama_equivalent: Optional[str] = None
class ModelRegistry:
"""Registry of pre-configured models for fine-tuning"""
def __init__(self):
self.models: Dict[str, ModelInfo] = {}
self._register_default_models()
def _register_default_models(self):
"""Register default models with their configurations"""
# Qwen Models
self.models["qwen2.5-7b"] = ModelInfo(
name="Qwen 2.5 7B Instruct",
model_id="Qwen/Qwen2.5-7B-Instruct",
description="Fast 7B parameter model, good for quick testing",
vram_required_gb=6,
recommended_lora_rank=16,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="qwen2.5:7b"
)
self.models["qwen2.5-32b"] = ModelInfo(
name="Qwen 2.5 32B Instruct",
model_id="Qwen/Qwen2.5-32B-Instruct",
description="High-quality 32B model for production use",
vram_required_gb=24,
recommended_lora_rank=32,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="qwen2.5:32b"
)
self.models["qwen2.5-72b"] = ModelInfo(
name="Qwen 2.5 72B Instruct",
model_id="Qwen/Qwen2.5-72B-Instruct",
description="Largest Qwen 2.5 model for maximum performance",
vram_required_gb=48,
recommended_lora_rank=64,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="qwen2.5:72b"
)
# Llama Models
self.models["llama-3.1-8b"] = ModelInfo(
name="Llama 3.1 8B Instruct",
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
description="Meta's Llama 3.1 8B model",
vram_required_gb=8,
recommended_lora_rank=16,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="llama3.1:8b"
)
self.models["llama-3.1-70b"] = ModelInfo(
name="Llama 3.1 70B Instruct",
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
description="Large Llama model for maximum performance",
vram_required_gb=48,
recommended_lora_rank=64,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="llama3.1:70b"
)
# Mistral Models
self.models["mistral-7b"] = ModelInfo(
name="Mistral 7B Instruct v0.3",
model_id="mistralai/Mistral-7B-Instruct-v0.3",
description="Efficient 7B model from Mistral AI",
vram_required_gb=6,
recommended_lora_rank=16,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="mistral:7b"
)
# Mixtral (MoE)
self.models["mixtral-8x7b"] = ModelInfo(
name="Mixtral 8x7B Instruct",
model_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
description="Mixture-of-Experts model with 8x7B parameters",
vram_required_gb=40,
recommended_lora_rank=32,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="mixtral:8x7b"
)
self.models["mixtral-8x22b"] = ModelInfo(
name="Mixtral 8x22B Instruct",
model_id="mistralai/Mixtral-8x22B-Instruct-v0.1",
description="Large MoE model for highest quality",
vram_required_gb=80,
recommended_lora_rank=64,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent=None
)
# Phi Models (Microsoft)
self.models["phi-3-mini"] = ModelInfo(
name="Phi-3 Mini 3.8B",
model_id="microsoft/Phi-3-mini-4k-instruct",
description="Small efficient model for quick testing",
vram_required_gb=4,
recommended_lora_rank=8,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="phi3:mini"
)
# Gemma Models (Google)
self.models["gemma-7b"] = ModelInfo(
name="Gemma 7B Instruct",
model_id="google/gemma-7b-it",
description="Google's Gemma 7B instruction-tuned model",
vram_required_gb=6,
recommended_lora_rank=16,
recommended_quantization="4bit",
model_type="cloud",
ollama_equivalent="gemma:7b"
)
def get_model(self, model_id: str) -> Optional[ModelInfo]:
"""Get model info by ID"""
return self.models.get(model_id)
def get_all_models(self) -> Dict[str, ModelInfo]:
"""Get all registered models"""
return self.models
def register_custom_model(self, model_id: str, info: ModelInfo):
"""Register a custom model"""
self.models[model_id] = info
def get_models_by_vram(self, max_vram_gb: int) -> List[ModelInfo]:
"""Get models that fit within VRAM budget"""
return [
info for info in self.models.values()
if info.vram_required_gb <= max_vram_gb
]
def get_model_choices_for_gui(self) -> List[str]:
"""Get list of model choices formatted for GUI dropdown"""
choices = []
for model_id, info in self.models.items():
label = f"{info.name} ({info.vram_required_gb}GB VRAM)"
choices.append((label, model_id))
return choices
def get_model_names(self) -> List[str]:
"""Get list of model names"""
return [info.name for info in self.models.values()]
def get_model_ids(self) -> List[str]:
"""Get list of model IDs"""
return list(self.models.keys())
def list_models(self) -> List[str]:
"""Alias for get_model_ids() - returns list of model IDs"""
return self.get_model_ids()
# Global registry instance
_registry = None
def get_registry() -> ModelRegistry:
"""Get the global model registry instance"""
global _registry
if _registry is None:
_registry = ModelRegistry()
return _registry
# Convenience function
def get_model_info(model_id: str) -> Optional[ModelInfo]:
"""Get model info by ID from global registry"""
return get_registry().get_model(model_id)
|