Spaces:
Sleeping
Sleeping
| """ | |
| Models Registry for Hugging Face Spaces | |
| Optimized for remote inference without local model loading. | |
| """ | |
| import yaml | |
| import os | |
| from typing import List, Dict, Any, Optional | |
| from dataclasses import dataclass | |
| import sys | |
| from huggingface_hub import InferenceClient | |
| # Add src to path for imports | |
| sys.path.append('src') | |
| from utils.config_loader import config_loader | |
| class ModelConfig: | |
| """Configuration for a model.""" | |
| name: str | |
| provider: str | |
| model_id: str | |
| params: Dict[str, Any] | |
| description: str | |
| class ModelsRegistry: | |
| """Registry for managing models from YAML configuration.""" | |
| def __init__(self, config_path: str = "config/models.yaml"): | |
| self.config_path = config_path | |
| self.models = self._load_models() | |
| def _load_models(self) -> List[ModelConfig]: | |
| """Load models from YAML configuration file.""" | |
| if not os.path.exists(self.config_path): | |
| raise FileNotFoundError(f"Models config file not found: {self.config_path}") | |
| with open(self.config_path, 'r') as f: | |
| config = yaml.safe_load(f) | |
| models = [] | |
| for model_data in config.get('models', []): | |
| model = ModelConfig( | |
| name=model_data['name'], | |
| provider=model_data['provider'], | |
| model_id=model_data['model_id'], | |
| params=model_data.get('params', {}), | |
| description=model_data.get('description', '') | |
| ) | |
| models.append(model) | |
| return models | |
| def get_models(self) -> List[ModelConfig]: | |
| """Get all available models.""" | |
| return self.models | |
| def get_model_by_name(self, name: str) -> Optional[ModelConfig]: | |
| """Get a specific model by name.""" | |
| for model in self.models: | |
| if model.name == name: | |
| return model | |
| return None | |
| def get_models_by_provider(self, provider: str) -> List[ModelConfig]: | |
| """Get all models from a specific provider.""" | |
| return [model for model in self.models if model.provider == provider] | |
| class HuggingFaceInference: | |
| """Interface for Hugging Face Inference API using InferenceClient.""" | |
| def __init__(self, api_token: Optional[str] = None): | |
| self.api_token = api_token or os.getenv("HF_TOKEN") | |
| # We'll create clients dynamically based on provider | |
| def generate(self, model_id: str, prompt: str, params: Dict[str, Any], provider: str = "hf-inference") -> str: | |
| """Generate text using Hugging Face Inference API with specified provider.""" | |
| try: | |
| # Create InferenceClient with the specified provider | |
| client = InferenceClient( | |
| provider=provider, | |
| api_key=os.environ.get("HF_TOKEN") | |
| ) | |
| # Use different methods based on provider capabilities | |
| if provider == "nebius" or provider == "together" or provider == "groq": | |
| # Nebius provider only supports conversational tasks, use chat completion | |
| completion = client.chat.completions.create( | |
| model=model_id, | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": prompt | |
| } | |
| ], | |
| max_tokens=params.get('max_new_tokens', 128), | |
| temperature=params.get('temperature', 0.1), | |
| top_p=params.get('top_p', 0.9) | |
| ) | |
| # Extract the content from the response | |
| return completion.choices[0].message.content | |
| else: | |
| # Other providers use text_generation | |
| result = client.text_generation( | |
| prompt=prompt, | |
| model=model_id, | |
| max_new_tokens=params.get('max_new_tokens', 128), | |
| temperature=params.get('temperature', 0.1), | |
| top_p=params.get('top_p', 0.9), | |
| return_full_text=False # Only return the generated part | |
| ) | |
| return result | |
| except Exception as e: | |
| # Improved error handling with detailed error messages | |
| error_msg = str(e) | |
| print(f"🔍 Debug - Full error: {error_msg}") | |
| if "404" in error_msg or "Not Found" in error_msg: | |
| raise Exception(f"Model not found: {model_id} - Model may not be available via {provider} provider") | |
| elif "401" in error_msg or "Unauthorized" in error_msg: | |
| raise Exception(f"Authentication failed - check HF_TOKEN for {provider} provider") | |
| elif "503" in error_msg or "Service Unavailable" in error_msg: | |
| raise Exception(f"Model {model_id} is loading on {provider}, please try again in a moment") | |
| elif "timeout" in error_msg.lower(): | |
| raise Exception(f"Request timeout - model may be loading on {provider}") | |
| elif "not supported for task" in error_msg: | |
| raise Exception(f"Model {model_id} task not supported by {provider} provider: {error_msg}") | |
| elif "not supported by provider" in error_msg: | |
| raise Exception(f"Model {model_id} not supported by {provider} provider: {error_msg}") | |
| else: | |
| raise Exception(f"{provider} API error: {error_msg}") | |
| class ModelInterface: | |
| """Unified interface for all model providers.""" | |
| def __init__(self): | |
| self.hf_interface = HuggingFaceInference() | |
| self.mock_mode = os.getenv("MOCK_MODE", "false").lower() == "true" | |
| self.has_hf_token = bool(os.getenv("HF_TOKEN")) | |
| def _generate_mock_sql(self, model_config: ModelConfig, prompt: str) -> str: | |
| """Generate mock SQL for demo purposes when API keys aren't available.""" | |
| # Get mock SQL configuration | |
| mock_config = config_loader.get_mock_sql_config() | |
| patterns = mock_config["patterns"] | |
| templates = mock_config["templates"] | |
| # Extract the question from the prompt | |
| if "Question:" in prompt: | |
| question = prompt.split("Question:")[1].split("Requirements:")[0].strip() | |
| else: | |
| question = "unknown question" | |
| # Simple mock SQL generation based on configured patterns | |
| question_lower = question.lower() | |
| # Check patterns in order of specificity | |
| if any(pattern in question_lower for pattern in patterns["count_queries"]): | |
| if "trips" in question_lower: | |
| return templates["count_trips"] | |
| else: | |
| return templates["count_generic"] | |
| elif any(pattern in question_lower for pattern in patterns["average_queries"]): | |
| if "fare" in question_lower: | |
| return templates["avg_fare"] | |
| else: | |
| return templates["avg_generic"] | |
| elif any(pattern in question_lower for pattern in patterns["total_queries"]): | |
| return templates["total_amount"] | |
| elif any(pattern in question_lower for pattern in patterns["passenger_queries"]): | |
| return templates["passenger_count"] | |
| else: | |
| # Default fallback | |
| return templates["default"] | |
| def generate_sql(self, model_config: ModelConfig, prompt: str) -> str: | |
| """Generate SQL using the specified model.""" | |
| # Use mock mode if no HF token is available | |
| if not self.has_hf_token: | |
| print(f"🎭 No HF_TOKEN available, using mock mode for {model_config.name}") | |
| return self._generate_mock_sql(model_config, prompt) | |
| # Use mock mode only if explicitly set | |
| if self.mock_mode: | |
| print(f"🎭 Mock mode enabled for {model_config.name}") | |
| return self._generate_mock_sql(model_config, prompt) | |
| try: | |
| if model_config.provider in ["huggingface", "hf-inference", "together", "nebius"]: | |
| print(f"🤗 Using {model_config.provider} Inference API for {model_config.name}") | |
| return self.hf_interface.generate( | |
| model_config.model_id, | |
| prompt, | |
| model_config.params, | |
| model_config.provider | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported provider: {model_config.provider}") | |
| except Exception as e: | |
| print(f"⚠️ Error with {model_config.name}: {str(e)}") | |
| print(f"🎭 Falling back to mock mode for {model_config.name}") | |
| return self._generate_mock_sql(model_config, prompt) | |
| # Global instances | |
| models_registry = ModelsRegistry() | |
| model_interface = ModelInterface() |