File size: 6,993 Bytes
ec8f374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
"""
Model Registry - Centralized model configuration and management

Provides pre-configured models with their specs, LoRA settings, and quantization recommendations.
"""

from dataclasses import dataclass
from typing import Dict, List, Optional


@dataclass
class ModelInfo:
    """Information about a model in the registry"""
    name: str
    model_id: str  # HuggingFace model ID
    description: str
    vram_required_gb: int
    recommended_lora_rank: int
    recommended_quantization: str
    model_type: str  # "local" or "cloud"
    ollama_equivalent: Optional[str] = None


class ModelRegistry:
    """Registry of pre-configured models for fine-tuning"""

    def __init__(self):
        self.models: Dict[str, ModelInfo] = {}
        self._register_default_models()

    def _register_default_models(self):
        """Register default models with their configurations"""

        # Qwen Models
        self.models["qwen2.5-7b"] = ModelInfo(
            name="Qwen 2.5 7B Instruct",
            model_id="Qwen/Qwen2.5-7B-Instruct",
            description="Fast 7B parameter model, good for quick testing",
            vram_required_gb=6,
            recommended_lora_rank=16,
            recommended_quantization="4bit",
            model_type="cloud",
            ollama_equivalent="qwen2.5:7b"
        )

        self.models["qwen2.5-32b"] = ModelInfo(
            name="Qwen 2.5 32B Instruct",
            model_id="Qwen/Qwen2.5-32B-Instruct",
            description="High-quality 32B model for production use",
            vram_required_gb=24,
            recommended_lora_rank=32,
            recommended_quantization="4bit",
            model_type="cloud",
            ollama_equivalent="qwen2.5:32b"
        )

        self.models["qwen2.5-72b"] = ModelInfo(
            name="Qwen 2.5 72B Instruct",
            model_id="Qwen/Qwen2.5-72B-Instruct",
            description="Largest Qwen 2.5 model for maximum performance",
            vram_required_gb=48,
            recommended_lora_rank=64,
            recommended_quantization="4bit",
            model_type="cloud",
            ollama_equivalent="qwen2.5:72b"
        )

        # Llama Models
        self.models["llama-3.1-8b"] = ModelInfo(
            name="Llama 3.1 8B Instruct",
            model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
            description="Meta's Llama 3.1 8B model",
            vram_required_gb=8,
            recommended_lora_rank=16,
            recommended_quantization="4bit",
            model_type="cloud",
            ollama_equivalent="llama3.1:8b"
        )

        self.models["llama-3.1-70b"] = ModelInfo(
            name="Llama 3.1 70B Instruct",
            model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
            description="Large Llama model for maximum performance",
            vram_required_gb=48,
            recommended_lora_rank=64,
            recommended_quantization="4bit",
            model_type="cloud",
            ollama_equivalent="llama3.1:70b"
        )

        # Mistral Models
        self.models["mistral-7b"] = ModelInfo(
            name="Mistral 7B Instruct v0.3",
            model_id="mistralai/Mistral-7B-Instruct-v0.3",
            description="Efficient 7B model from Mistral AI",
            vram_required_gb=6,
            recommended_lora_rank=16,
            recommended_quantization="4bit",
            model_type="cloud",
            ollama_equivalent="mistral:7b"
        )

        # Mixtral (MoE)
        self.models["mixtral-8x7b"] = ModelInfo(
            name="Mixtral 8x7B Instruct",
            model_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
            description="Mixture-of-Experts model with 8x7B parameters",
            vram_required_gb=40,
            recommended_lora_rank=32,
            recommended_quantization="4bit",
            model_type="cloud",
            ollama_equivalent="mixtral:8x7b"
        )

        self.models["mixtral-8x22b"] = ModelInfo(
            name="Mixtral 8x22B Instruct",
            model_id="mistralai/Mixtral-8x22B-Instruct-v0.1",
            description="Large MoE model for highest quality",
            vram_required_gb=80,
            recommended_lora_rank=64,
            recommended_quantization="4bit",
            model_type="cloud",
            ollama_equivalent=None
        )

        # Phi Models (Microsoft)
        self.models["phi-3-mini"] = ModelInfo(
            name="Phi-3 Mini 3.8B",
            model_id="microsoft/Phi-3-mini-4k-instruct",
            description="Small efficient model for quick testing",
            vram_required_gb=4,
            recommended_lora_rank=8,
            recommended_quantization="4bit",
            model_type="cloud",
            ollama_equivalent="phi3:mini"
        )

        # Gemma Models (Google)
        self.models["gemma-7b"] = ModelInfo(
            name="Gemma 7B Instruct",
            model_id="google/gemma-7b-it",
            description="Google's Gemma 7B instruction-tuned model",
            vram_required_gb=6,
            recommended_lora_rank=16,
            recommended_quantization="4bit",
            model_type="cloud",
            ollama_equivalent="gemma:7b"
        )

    def get_model(self, model_id: str) -> Optional[ModelInfo]:
        """Get model info by ID"""
        return self.models.get(model_id)

    def get_all_models(self) -> Dict[str, ModelInfo]:
        """Get all registered models"""
        return self.models

    def register_custom_model(self, model_id: str, info: ModelInfo):
        """Register a custom model"""
        self.models[model_id] = info

    def get_models_by_vram(self, max_vram_gb: int) -> List[ModelInfo]:
        """Get models that fit within VRAM budget"""
        return [
            info for info in self.models.values()
            if info.vram_required_gb <= max_vram_gb
        ]

    def get_model_choices_for_gui(self) -> List[str]:
        """Get list of model choices formatted for GUI dropdown"""
        choices = []
        for model_id, info in self.models.items():
            label = f"{info.name} ({info.vram_required_gb}GB VRAM)"
            choices.append((label, model_id))
        return choices

    def get_model_names(self) -> List[str]:
        """Get list of model names"""
        return [info.name for info in self.models.values()]

    def get_model_ids(self) -> List[str]:
        """Get list of model IDs"""
        return list(self.models.keys())

    def list_models(self) -> List[str]:
        """Alias for get_model_ids() - returns list of model IDs"""
        return self.get_model_ids()


# Global registry instance
_registry = None


def get_registry() -> ModelRegistry:
    """Get the global model registry instance"""
    global _registry
    if _registry is None:
        _registry = ModelRegistry()
    return _registry


# Convenience function
def get_model_info(model_id: str) -> Optional[ModelInfo]:
    """Get model info by ID from global registry"""
    return get_registry().get_model(model_id)