DataEngEval / src /models_registry.py
uparekh01151's picture
feat: add Groq provider models and show provider info in UI
05dfa56
"""
Models Registry for Hugging Face Spaces
Optimized for remote inference without local model loading.
"""
import yaml
import os
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import sys
from huggingface_hub import InferenceClient
# Add src to path for imports
sys.path.append('src')
from utils.config_loader import config_loader
@dataclass
class ModelConfig:
"""Configuration for a model."""
name: str
provider: str
model_id: str
params: Dict[str, Any]
description: str
class ModelsRegistry:
"""Registry for managing models from YAML configuration."""
def __init__(self, config_path: str = "config/models.yaml"):
self.config_path = config_path
self.models = self._load_models()
def _load_models(self) -> List[ModelConfig]:
"""Load models from YAML configuration file."""
if not os.path.exists(self.config_path):
raise FileNotFoundError(f"Models config file not found: {self.config_path}")
with open(self.config_path, 'r') as f:
config = yaml.safe_load(f)
models = []
for model_data in config.get('models', []):
model = ModelConfig(
name=model_data['name'],
provider=model_data['provider'],
model_id=model_data['model_id'],
params=model_data.get('params', {}),
description=model_data.get('description', '')
)
models.append(model)
return models
def get_models(self) -> List[ModelConfig]:
"""Get all available models."""
return self.models
def get_model_by_name(self, name: str) -> Optional[ModelConfig]:
"""Get a specific model by name."""
for model in self.models:
if model.name == name:
return model
return None
def get_models_by_provider(self, provider: str) -> List[ModelConfig]:
"""Get all models from a specific provider."""
return [model for model in self.models if model.provider == provider]
class HuggingFaceInference:
"""Interface for Hugging Face Inference API using InferenceClient."""
def __init__(self, api_token: Optional[str] = None):
self.api_token = api_token or os.getenv("HF_TOKEN")
# We'll create clients dynamically based on provider
def generate(self, model_id: str, prompt: str, params: Dict[str, Any], provider: str = "hf-inference") -> str:
"""Generate text using Hugging Face Inference API with specified provider."""
try:
# Create InferenceClient with the specified provider
client = InferenceClient(
provider=provider,
api_key=os.environ.get("HF_TOKEN")
)
# Use different methods based on provider capabilities
if provider == "nebius" or provider == "together" or provider == "groq":
# Nebius provider only supports conversational tasks, use chat completion
completion = client.chat.completions.create(
model=model_id,
messages=[
{
"role": "user",
"content": prompt
}
],
max_tokens=params.get('max_new_tokens', 128),
temperature=params.get('temperature', 0.1),
top_p=params.get('top_p', 0.9)
)
# Extract the content from the response
return completion.choices[0].message.content
else:
# Other providers use text_generation
result = client.text_generation(
prompt=prompt,
model=model_id,
max_new_tokens=params.get('max_new_tokens', 128),
temperature=params.get('temperature', 0.1),
top_p=params.get('top_p', 0.9),
return_full_text=False # Only return the generated part
)
return result
except Exception as e:
# Improved error handling with detailed error messages
error_msg = str(e)
print(f"🔍 Debug - Full error: {error_msg}")
if "404" in error_msg or "Not Found" in error_msg:
raise Exception(f"Model not found: {model_id} - Model may not be available via {provider} provider")
elif "401" in error_msg or "Unauthorized" in error_msg:
raise Exception(f"Authentication failed - check HF_TOKEN for {provider} provider")
elif "503" in error_msg or "Service Unavailable" in error_msg:
raise Exception(f"Model {model_id} is loading on {provider}, please try again in a moment")
elif "timeout" in error_msg.lower():
raise Exception(f"Request timeout - model may be loading on {provider}")
elif "not supported for task" in error_msg:
raise Exception(f"Model {model_id} task not supported by {provider} provider: {error_msg}")
elif "not supported by provider" in error_msg:
raise Exception(f"Model {model_id} not supported by {provider} provider: {error_msg}")
else:
raise Exception(f"{provider} API error: {error_msg}")
class ModelInterface:
"""Unified interface for all model providers."""
def __init__(self):
self.hf_interface = HuggingFaceInference()
self.mock_mode = os.getenv("MOCK_MODE", "false").lower() == "true"
self.has_hf_token = bool(os.getenv("HF_TOKEN"))
def _generate_mock_sql(self, model_config: ModelConfig, prompt: str) -> str:
"""Generate mock SQL for demo purposes when API keys aren't available."""
# Get mock SQL configuration
mock_config = config_loader.get_mock_sql_config()
patterns = mock_config["patterns"]
templates = mock_config["templates"]
# Extract the question from the prompt
if "Question:" in prompt:
question = prompt.split("Question:")[1].split("Requirements:")[0].strip()
else:
question = "unknown question"
# Simple mock SQL generation based on configured patterns
question_lower = question.lower()
# Check patterns in order of specificity
if any(pattern in question_lower for pattern in patterns["count_queries"]):
if "trips" in question_lower:
return templates["count_trips"]
else:
return templates["count_generic"]
elif any(pattern in question_lower for pattern in patterns["average_queries"]):
if "fare" in question_lower:
return templates["avg_fare"]
else:
return templates["avg_generic"]
elif any(pattern in question_lower for pattern in patterns["total_queries"]):
return templates["total_amount"]
elif any(pattern in question_lower for pattern in patterns["passenger_queries"]):
return templates["passenger_count"]
else:
# Default fallback
return templates["default"]
def generate_sql(self, model_config: ModelConfig, prompt: str) -> str:
"""Generate SQL using the specified model."""
# Use mock mode if no HF token is available
if not self.has_hf_token:
print(f"🎭 No HF_TOKEN available, using mock mode for {model_config.name}")
return self._generate_mock_sql(model_config, prompt)
# Use mock mode only if explicitly set
if self.mock_mode:
print(f"🎭 Mock mode enabled for {model_config.name}")
return self._generate_mock_sql(model_config, prompt)
try:
if model_config.provider in ["huggingface", "hf-inference", "together", "nebius"]:
print(f"🤗 Using {model_config.provider} Inference API for {model_config.name}")
return self.hf_interface.generate(
model_config.model_id,
prompt,
model_config.params,
model_config.provider
)
else:
raise ValueError(f"Unsupported provider: {model_config.provider}")
except Exception as e:
print(f"⚠️ Error with {model_config.name}: {str(e)}")
print(f"🎭 Falling back to mock mode for {model_config.name}")
return self._generate_mock_sql(model_config, prompt)
# Global instances
models_registry = ModelsRegistry()
model_interface = ModelInterface()