File size: 4,680 Bytes
91ed273
77ad74c
91ed273
 
 
3ee6432
91ed273
 
146146a
91ed273
146146a
77ad74c
146146a
 
 
77ad74c
3ee6432
 
 
 
 
 
77ad74c
 
91ed273
 
 
 
 
3ee6432
146146a
 
627c842
91ed273
 
 
 
77ad74c
3ee6432
77ad74c
 
 
 
 
 
3ee6432
77ad74c
 
3ee6432
77ad74c
3ee6432
3656fbb
3ee6432
 
 
 
 
 
 
77ad74c
3ee6432
 
 
 
 
3656fbb
3ee6432
 
77ad74c
 
3ee6432
 
3656fbb
3ee6432
 
77ad74c
 
3ee6432
 
 
 
 
 
 
 
 
3656fbb
3ee6432
3656fbb
 
 
91ed273
304a74a
3656fbb
 
 
 
 
 
 
5d48e70
3656fbb
 
 
 
 
 
 
 
 
 
627c842
3656fbb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import requests
import re
import json
from dotenv import load_dotenv
from huggingface_hub import InferenceClient

class SQLGenerator:
    def __init__(self):
        load_dotenv()
        
        # 1. CLEAN THE KEY
        raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
        self.api_key = raw_key.strip() if raw_key else None
        
        # 2. SETUP CLIENT
        if self.api_key:
            self.client = InferenceClient(api_key=self.api_key)
        else:
            self.client = None
            print("   ❌ FATAL: API Key missing.")

        # 3. USE QWEN 2.5 (Best Free Model)
        self.model_id = "Qwen/Qwen2.5-Coder-32B-Instruct"

    def generate_followup_questions(self, question, sql_query):
        return ["Visualize this result", "Export as CSV", "Compare with last year"]

    def generate_sql(self, question, context, history=None):
        if not self.client:
            return "SELECT 'Error: HF_API_KEY Missing' as status", "Configuration Error", "Please add HF_API_KEY to your Space Secrets."

        # 🛡️ Safety Layer
        forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"]
        if any(word in question.upper() for word in forbidden):
             return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."

        # 🧠 SMART PROMPT (Fixes the "No Such Table" error)
        messages = [
            {"role": "system", "content": f"""You are a precise SQL Expert.
            
            CRITICAL RULES:
            1. You MUST use the EXACT table names and column names from the SCHEMA below.
            2. Do NOT hallucinate table names (e.g., if schema says 'Employee', do NOT use 'employees').
            3. Output valid JSON only.
            
            SCHEMA:
            {context}
            """},
            {"role": "user", "content": f"Question: {question}\nReturn JSON format: {{ 'sql': 'SELECT ...', 'message': '...', 'explanation': '...' }}"}
        ]

        try:
            print(f"   ⚡ Generating SQL using {self.model_id}...")
            
            response = self.client.chat.completions.create(
                model=self.model_id,
                messages=messages,
                max_tokens=500,
                temperature=0.1, # Low temp = More strict
                stream=False
            )
            
            raw_text = response.choices[0].message.content
            return self._process_response(raw_text)

        except Exception as e:
            print(f"   ❌ AI ERROR: {e}")
            # Failover to 7B if 32B is busy
            if "404" in str(e) or "429" in str(e) or "503" in str(e):
                return self._fallback_generate(messages)
            return f"SELECT 'Error: {str(e)[:50]}' as status", "System Error", "AI Model unavailable."

    def _fallback_generate(self, messages):
        try:
            # Fallback to the smaller, faster model
            backup_model = "Qwen/Qwen2.5-Coder-7B-Instruct" 
            print(f"   ⚠️ Switching to backup: {backup_model}...")
            response = self.client.chat.completions.create(
                model=backup_model,
                messages=messages,
                max_tokens=500
            )
            return self._process_response(response.choices[0].message.content)
        except Exception as e:
             return "SELECT 'Error: All models failed' as status", "System Error", "Please check your API Key permissions."

    def _process_response(self, raw_text):
        sql_query = ""
        message = "Here is the data."
        explanation = "Query generated successfully."

        try:
            clean_json = re.sub(r"```json|```", "", raw_text).strip()
            json_match = re.search(r"\{.*\}", clean_json, re.DOTALL)
            if json_match:
                data = json.loads(json_match.group(0))
                sql_query = data.get("sql", "")
                message = data.get("message", message)
                explanation = data.get("explanation", explanation)
            else:
                 match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
                 if match: sql_query = match.group(1)
        except:
            match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
            if match: sql_query = match.group(1)

        sql_query = sql_query.strip().replace("\n", " ")
        if sql_query and not sql_query.endswith(";"): sql_query += ";"
        
        if not sql_query:
            sql_query = "SELECT 'Error: Empty Query' as status"

        return sql_query, explanation, message