Spaces:
Running
Running
| import os | |
| import requests | |
| import re | |
| import json | |
| from dotenv import load_dotenv | |
| from huggingface_hub import InferenceClient | |
| class SQLGenerator: | |
| def __init__(self): | |
| load_dotenv() | |
| # 1. CLEAN THE KEY | |
| raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| self.api_key = raw_key.strip() if raw_key else None | |
| # 2. SETUP CLIENT | |
| if self.api_key: | |
| self.client = InferenceClient(api_key=self.api_key) | |
| else: | |
| self.client = None | |
| print(" ❌ FATAL: API Key missing.") | |
| # 3. USE QWEN 2.5 (Best Free Model) | |
| self.model_id = "Qwen/Qwen2.5-Coder-32B-Instruct" | |
| def generate_followup_questions(self, question, sql_query): | |
| return ["Visualize this result", "Export as CSV", "Compare with last year"] | |
| def generate_sql(self, question, context, history=None): | |
| if not self.client: | |
| return "SELECT 'Error: HF_API_KEY Missing' as status", "Configuration Error", "Please add HF_API_KEY to your Space Secrets." | |
| # 🛡️ Safety Layer | |
| forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"] | |
| if any(word in question.upper() for word in forbidden): | |
| return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data." | |
| # 🧠 SMART PROMPT (Fixes the "No Such Table" error) | |
| messages = [ | |
| {"role": "system", "content": f"""You are a precise SQL Expert. | |
| CRITICAL RULES: | |
| 1. You MUST use the EXACT table names and column names from the SCHEMA below. | |
| 2. Do NOT hallucinate table names (e.g., if schema says 'Employee', do NOT use 'employees'). | |
| 3. Output valid JSON only. | |
| SCHEMA: | |
| {context} | |
| """}, | |
| {"role": "user", "content": f"Question: {question}\nReturn JSON format: {{ 'sql': 'SELECT ...', 'message': '...', 'explanation': '...' }}"} | |
| ] | |
| try: | |
| print(f" ⚡ Generating SQL using {self.model_id}...") | |
| response = self.client.chat.completions.create( | |
| model=self.model_id, | |
| messages=messages, | |
| max_tokens=500, | |
| temperature=0.1, # Low temp = More strict | |
| stream=False | |
| ) | |
| raw_text = response.choices[0].message.content | |
| return self._process_response(raw_text) | |
| except Exception as e: | |
| print(f" ❌ AI ERROR: {e}") | |
| # Failover to 7B if 32B is busy | |
| if "404" in str(e) or "429" in str(e) or "503" in str(e): | |
| return self._fallback_generate(messages) | |
| return f"SELECT 'Error: {str(e)[:50]}' as status", "System Error", "AI Model unavailable." | |
| def _fallback_generate(self, messages): | |
| try: | |
| # Fallback to the smaller, faster model | |
| backup_model = "Qwen/Qwen2.5-Coder-7B-Instruct" | |
| print(f" ⚠️ Switching to backup: {backup_model}...") | |
| response = self.client.chat.completions.create( | |
| model=backup_model, | |
| messages=messages, | |
| max_tokens=500 | |
| ) | |
| return self._process_response(response.choices[0].message.content) | |
| except Exception as e: | |
| return "SELECT 'Error: All models failed' as status", "System Error", "Please check your API Key permissions." | |
| def _process_response(self, raw_text): | |
| sql_query = "" | |
| message = "Here is the data." | |
| explanation = "Query generated successfully." | |
| try: | |
| clean_json = re.sub(r"```json|```", "", raw_text).strip() | |
| json_match = re.search(r"\{.*\}", clean_json, re.DOTALL) | |
| if json_match: | |
| data = json.loads(json_match.group(0)) | |
| sql_query = data.get("sql", "") | |
| message = data.get("message", message) | |
| explanation = data.get("explanation", explanation) | |
| else: | |
| match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE) | |
| if match: sql_query = match.group(1) | |
| except: | |
| match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE) | |
| if match: sql_query = match.group(1) | |
| sql_query = sql_query.strip().replace("\n", " ") | |
| if sql_query and not sql_query.endswith(";"): sql_query += ";" | |
| if not sql_query: | |
| sql_query = "SELECT 'Error: Empty Query' as status" | |
| return sql_query, explanation, message |