File size: 8,838 Bytes
acd8e16
 
 
 
 
 
 
 
 
 
c1c187b
acd8e16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1c187b
acd8e16
 
 
e65f7af
acd8e16
e65f7af
 
acd8e16
e65f7af
 
 
 
 
 
682dc03
05dfa56
682dc03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acd8e16
c1c187b
a598199
 
 
 
 
e65f7af
a598199
e65f7af
a598199
e65f7af
a598199
e65f7af
ca333e4
 
682dc03
 
c1c187b
e65f7af
acd8e16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b16182c
e65f7af
acd8e16
 
 
e65f7af
 
acd8e16
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""
Models Registry for Hugging Face Spaces
Optimized for remote inference without local model loading.
"""

import yaml
import os
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import sys
from huggingface_hub import InferenceClient

# Add src to path for imports
sys.path.append('src')
from utils.config_loader import config_loader


@dataclass
class ModelConfig:
    """Configuration for a model."""
    name: str
    provider: str
    model_id: str
    params: Dict[str, Any]
    description: str


class ModelsRegistry:
    """Registry for managing models from YAML configuration."""
    
    def __init__(self, config_path: str = "config/models.yaml"):
        self.config_path = config_path
        self.models = self._load_models()
    
    def _load_models(self) -> List[ModelConfig]:
        """Load models from YAML configuration file."""
        if not os.path.exists(self.config_path):
            raise FileNotFoundError(f"Models config file not found: {self.config_path}")
        
        with open(self.config_path, 'r') as f:
            config = yaml.safe_load(f)
        
        models = []
        for model_data in config.get('models', []):
            model = ModelConfig(
                name=model_data['name'],
                provider=model_data['provider'],
                model_id=model_data['model_id'],
                params=model_data.get('params', {}),
                description=model_data.get('description', '')
            )
            models.append(model)
        
        return models
    
    def get_models(self) -> List[ModelConfig]:
        """Get all available models."""
        return self.models
    
    def get_model_by_name(self, name: str) -> Optional[ModelConfig]:
        """Get a specific model by name."""
        for model in self.models:
            if model.name == name:
                return model
        return None
    
    def get_models_by_provider(self, provider: str) -> List[ModelConfig]:
        """Get all models from a specific provider."""
        return [model for model in self.models if model.provider == provider]


class HuggingFaceInference:
    """Interface for Hugging Face Inference API using InferenceClient."""

    def __init__(self, api_token: Optional[str] = None):
        self.api_token = api_token or os.getenv("HF_TOKEN")
        # We'll create clients dynamically based on provider

    def generate(self, model_id: str, prompt: str, params: Dict[str, Any], provider: str = "hf-inference") -> str:
        """Generate text using Hugging Face Inference API with specified provider."""
        try:
            # Create InferenceClient with the specified provider
            client = InferenceClient(
                provider=provider,
                api_key=os.environ.get("HF_TOKEN")
            )
            
            # Use different methods based on provider capabilities
            if provider == "nebius" or provider == "together" or provider == "groq":
                # Nebius provider only supports conversational tasks, use chat completion
                completion = client.chat.completions.create(
                    model=model_id,
                    messages=[
                        {
                            "role": "user",
                            "content": prompt
                        }
                    ],
                    max_tokens=params.get('max_new_tokens', 128),
                    temperature=params.get('temperature', 0.1),
                    top_p=params.get('top_p', 0.9)
                )
                # Extract the content from the response
                return completion.choices[0].message.content
            else:
                # Other providers use text_generation
                result = client.text_generation(
                    prompt=prompt,
                    model=model_id,
                    max_new_tokens=params.get('max_new_tokens', 128),
                    temperature=params.get('temperature', 0.1),
                    top_p=params.get('top_p', 0.9),
                    return_full_text=False  # Only return the generated part
                )
                return result
                
        except Exception as e:
            # Improved error handling with detailed error messages
            error_msg = str(e)
            print(f"🔍 Debug - Full error: {error_msg}")
            
            if "404" in error_msg or "Not Found" in error_msg:
                raise Exception(f"Model not found: {model_id} - Model may not be available via {provider} provider")
            elif "401" in error_msg or "Unauthorized" in error_msg:
                raise Exception(f"Authentication failed - check HF_TOKEN for {provider} provider")
            elif "503" in error_msg or "Service Unavailable" in error_msg:
                raise Exception(f"Model {model_id} is loading on {provider}, please try again in a moment")
            elif "timeout" in error_msg.lower():
                raise Exception(f"Request timeout - model may be loading on {provider}")
            elif "not supported for task" in error_msg:
                raise Exception(f"Model {model_id} task not supported by {provider} provider: {error_msg}")
            elif "not supported by provider" in error_msg:
                raise Exception(f"Model {model_id} not supported by {provider} provider: {error_msg}")
            else:
                raise Exception(f"{provider} API error: {error_msg}")


class ModelInterface:
    """Unified interface for all model providers."""

    def __init__(self):
        self.hf_interface = HuggingFaceInference()
        self.mock_mode = os.getenv("MOCK_MODE", "false").lower() == "true"
        self.has_hf_token = bool(os.getenv("HF_TOKEN"))
    
    def _generate_mock_sql(self, model_config: ModelConfig, prompt: str) -> str:
        """Generate mock SQL for demo purposes when API keys aren't available."""
        # Get mock SQL configuration
        mock_config = config_loader.get_mock_sql_config()
        patterns = mock_config["patterns"]
        templates = mock_config["templates"]
        
        # Extract the question from the prompt
        if "Question:" in prompt:
            question = prompt.split("Question:")[1].split("Requirements:")[0].strip()
        else:
            question = "unknown question"
        
        # Simple mock SQL generation based on configured patterns
        question_lower = question.lower()
        
        # Check patterns in order of specificity
        if any(pattern in question_lower for pattern in patterns["count_queries"]):
            if "trips" in question_lower:
                return templates["count_trips"]
            else:
                return templates["count_generic"]
        elif any(pattern in question_lower for pattern in patterns["average_queries"]):
            if "fare" in question_lower:
                return templates["avg_fare"]
            else:
                return templates["avg_generic"]
        elif any(pattern in question_lower for pattern in patterns["total_queries"]):
            return templates["total_amount"]
        elif any(pattern in question_lower for pattern in patterns["passenger_queries"]):
            return templates["passenger_count"]
        else:
            # Default fallback
            return templates["default"]
    
    def generate_sql(self, model_config: ModelConfig, prompt: str) -> str:
        """Generate SQL using the specified model."""
        # Use mock mode if no HF token is available
        if not self.has_hf_token:
            print(f"🎭 No HF_TOKEN available, using mock mode for {model_config.name}")
            return self._generate_mock_sql(model_config, prompt)
        
        # Use mock mode only if explicitly set
        if self.mock_mode:
            print(f"🎭 Mock mode enabled for {model_config.name}")
            return self._generate_mock_sql(model_config, prompt)
        
        try:
            if model_config.provider in ["huggingface", "hf-inference", "together", "nebius"]:
                print(f"🤗 Using {model_config.provider} Inference API for {model_config.name}")
                return self.hf_interface.generate(
                    model_config.model_id,
                    prompt,
                    model_config.params,
                    model_config.provider
                )
            else:
                raise ValueError(f"Unsupported provider: {model_config.provider}")
        except Exception as e:
            print(f"⚠️ Error with {model_config.name}: {str(e)}")
            print(f"🎭 Falling back to mock mode for {model_config.name}")
            return self._generate_mock_sql(model_config, prompt)


# Global instances
models_registry = ModelsRegistry()
model_interface = ModelInterface()