|
|
import streamlit as st |
|
|
from openai import OpenAI |
|
|
import sqlite3 |
|
|
from init_db import initialize_database |
|
|
|
|
|
|
|
|
initialize_database() |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Zero SQL", layout="wide") |
|
|
st.title("Zero SQL - Natural Language to SQL Query") |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("API Configuration") |
|
|
api_key = st.text_input("OpenAI API Key", type="password") |
|
|
|
|
|
|
|
|
with st.form("query_form"): |
|
|
user_input = st.text_area( |
|
|
"Enter your data request in natural language:", |
|
|
placeholder="e.g. Show all orders from last week", |
|
|
height=150 |
|
|
) |
|
|
submitted = st.form_submit_button("Generate Query") |
|
|
|
|
|
if submitted: |
|
|
if not api_key: |
|
|
st.error("🔑 API key is required!") |
|
|
elif not user_input: |
|
|
st.error("📝 Please enter your data request!") |
|
|
else: |
|
|
try: |
|
|
|
|
|
client = OpenAI(api_key=api_key) |
|
|
|
|
|
|
|
|
system_context = """Given the following SQL tables, your job is to write queries given a user's request. |
|
|
CREATE TABLE Produkte ( |
|
|
ProduktID INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
Produktname TEXT NOT NULL, |
|
|
Preis REAL NOT NULL |
|
|
); |
|
|
|
|
|
CREATE TABLE Bestellungen ( |
|
|
BestellungID INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
ProduktID INTEGER NOT NULL, |
|
|
Menge INTEGER NOT NULL, |
|
|
Bestelldatum TEXT NOT NULL, |
|
|
Person TEXT NOT NULL, |
|
|
FOREIGN KEY (ProduktID) REFERENCES Produkte(ProduktID) |
|
|
);""" |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model="gpt-4o", |
|
|
messages=[ |
|
|
{"role": "system", "content": system_context}, |
|
|
{"role": "user", "content": f"Generate the SQL query for: {user_input}. Only output the raw SQL query without any code block delimiters or markdown."} |
|
|
], |
|
|
response_format={"type": "text"} |
|
|
) |
|
|
|
|
|
sql_query = response.choices[0].message.content.strip() |
|
|
|
|
|
|
|
|
conn = sqlite3.connect('database.db') |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(sql_query) |
|
|
|
|
|
results = cursor.fetchall() |
|
|
column_names = [description[0] for description in cursor.description] |
|
|
conn.close() |
|
|
|
|
|
|
|
|
st.subheader("Generated SQL Query") |
|
|
st.code(sql_query, language="sql") |
|
|
|
|
|
st.subheader("Query Results") |
|
|
if results: |
|
|
st.dataframe( |
|
|
data=results, |
|
|
columns=column_names, |
|
|
use_container_width=True, |
|
|
hide_index=True |
|
|
) |
|
|
else: |
|
|
st.info("No results found", icon="ℹ️") |
|
|
|
|
|
except sqlite3.Error as e: |
|
|
st.error(f"SQL Error: {str(e)}") |
|
|
except Exception as e: |
|
|
st.error(f"An error occurred: {str(e)}") |