Spaces:
Sleeping
Sleeping
File size: 8,838 Bytes
acd8e16 c1c187b acd8e16 c1c187b acd8e16 e65f7af acd8e16 e65f7af acd8e16 e65f7af 682dc03 05dfa56 682dc03 acd8e16 c1c187b a598199 e65f7af a598199 e65f7af a598199 e65f7af a598199 e65f7af ca333e4 682dc03 c1c187b e65f7af acd8e16 b16182c e65f7af acd8e16 e65f7af acd8e16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
"""
Models Registry for Hugging Face Spaces
Optimized for remote inference without local model loading.
"""
import yaml
import os
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import sys
from huggingface_hub import InferenceClient
# Add src to path for imports
sys.path.append('src')
from utils.config_loader import config_loader
@dataclass
class ModelConfig:
"""Configuration for a model."""
name: str
provider: str
model_id: str
params: Dict[str, Any]
description: str
class ModelsRegistry:
"""Registry for managing models from YAML configuration."""
def __init__(self, config_path: str = "config/models.yaml"):
self.config_path = config_path
self.models = self._load_models()
def _load_models(self) -> List[ModelConfig]:
"""Load models from YAML configuration file."""
if not os.path.exists(self.config_path):
raise FileNotFoundError(f"Models config file not found: {self.config_path}")
with open(self.config_path, 'r') as f:
config = yaml.safe_load(f)
models = []
for model_data in config.get('models', []):
model = ModelConfig(
name=model_data['name'],
provider=model_data['provider'],
model_id=model_data['model_id'],
params=model_data.get('params', {}),
description=model_data.get('description', '')
)
models.append(model)
return models
def get_models(self) -> List[ModelConfig]:
"""Get all available models."""
return self.models
def get_model_by_name(self, name: str) -> Optional[ModelConfig]:
"""Get a specific model by name."""
for model in self.models:
if model.name == name:
return model
return None
def get_models_by_provider(self, provider: str) -> List[ModelConfig]:
"""Get all models from a specific provider."""
return [model for model in self.models if model.provider == provider]
class HuggingFaceInference:
"""Interface for Hugging Face Inference API using InferenceClient."""
def __init__(self, api_token: Optional[str] = None):
self.api_token = api_token or os.getenv("HF_TOKEN")
# We'll create clients dynamically based on provider
def generate(self, model_id: str, prompt: str, params: Dict[str, Any], provider: str = "hf-inference") -> str:
"""Generate text using Hugging Face Inference API with specified provider."""
try:
# Create InferenceClient with the specified provider
client = InferenceClient(
provider=provider,
api_key=os.environ.get("HF_TOKEN")
)
# Use different methods based on provider capabilities
if provider == "nebius" or provider == "together" or provider == "groq":
# Nebius provider only supports conversational tasks, use chat completion
completion = client.chat.completions.create(
model=model_id,
messages=[
{
"role": "user",
"content": prompt
}
],
max_tokens=params.get('max_new_tokens', 128),
temperature=params.get('temperature', 0.1),
top_p=params.get('top_p', 0.9)
)
# Extract the content from the response
return completion.choices[0].message.content
else:
# Other providers use text_generation
result = client.text_generation(
prompt=prompt,
model=model_id,
max_new_tokens=params.get('max_new_tokens', 128),
temperature=params.get('temperature', 0.1),
top_p=params.get('top_p', 0.9),
return_full_text=False # Only return the generated part
)
return result
except Exception as e:
# Improved error handling with detailed error messages
error_msg = str(e)
print(f"🔍 Debug - Full error: {error_msg}")
if "404" in error_msg or "Not Found" in error_msg:
raise Exception(f"Model not found: {model_id} - Model may not be available via {provider} provider")
elif "401" in error_msg or "Unauthorized" in error_msg:
raise Exception(f"Authentication failed - check HF_TOKEN for {provider} provider")
elif "503" in error_msg or "Service Unavailable" in error_msg:
raise Exception(f"Model {model_id} is loading on {provider}, please try again in a moment")
elif "timeout" in error_msg.lower():
raise Exception(f"Request timeout - model may be loading on {provider}")
elif "not supported for task" in error_msg:
raise Exception(f"Model {model_id} task not supported by {provider} provider: {error_msg}")
elif "not supported by provider" in error_msg:
raise Exception(f"Model {model_id} not supported by {provider} provider: {error_msg}")
else:
raise Exception(f"{provider} API error: {error_msg}")
class ModelInterface:
"""Unified interface for all model providers."""
def __init__(self):
self.hf_interface = HuggingFaceInference()
self.mock_mode = os.getenv("MOCK_MODE", "false").lower() == "true"
self.has_hf_token = bool(os.getenv("HF_TOKEN"))
def _generate_mock_sql(self, model_config: ModelConfig, prompt: str) -> str:
"""Generate mock SQL for demo purposes when API keys aren't available."""
# Get mock SQL configuration
mock_config = config_loader.get_mock_sql_config()
patterns = mock_config["patterns"]
templates = mock_config["templates"]
# Extract the question from the prompt
if "Question:" in prompt:
question = prompt.split("Question:")[1].split("Requirements:")[0].strip()
else:
question = "unknown question"
# Simple mock SQL generation based on configured patterns
question_lower = question.lower()
# Check patterns in order of specificity
if any(pattern in question_lower for pattern in patterns["count_queries"]):
if "trips" in question_lower:
return templates["count_trips"]
else:
return templates["count_generic"]
elif any(pattern in question_lower for pattern in patterns["average_queries"]):
if "fare" in question_lower:
return templates["avg_fare"]
else:
return templates["avg_generic"]
elif any(pattern in question_lower for pattern in patterns["total_queries"]):
return templates["total_amount"]
elif any(pattern in question_lower for pattern in patterns["passenger_queries"]):
return templates["passenger_count"]
else:
# Default fallback
return templates["default"]
def generate_sql(self, model_config: ModelConfig, prompt: str) -> str:
"""Generate SQL using the specified model."""
# Use mock mode if no HF token is available
if not self.has_hf_token:
print(f"🎭 No HF_TOKEN available, using mock mode for {model_config.name}")
return self._generate_mock_sql(model_config, prompt)
# Use mock mode only if explicitly set
if self.mock_mode:
print(f"🎭 Mock mode enabled for {model_config.name}")
return self._generate_mock_sql(model_config, prompt)
try:
if model_config.provider in ["huggingface", "hf-inference", "together", "nebius"]:
print(f"🤗 Using {model_config.provider} Inference API for {model_config.name}")
return self.hf_interface.generate(
model_config.model_id,
prompt,
model_config.params,
model_config.provider
)
else:
raise ValueError(f"Unsupported provider: {model_config.provider}")
except Exception as e:
print(f"⚠️ Error with {model_config.name}: {str(e)}")
print(f"🎭 Falling back to mock mode for {model_config.name}")
return self._generate_mock_sql(model_config, prompt)
# Global instances
models_registry = ModelsRegistry()
model_interface = ModelInterface() |