Spaces:
Runtime error
Runtime error
Caleb Fahlgren
commited on
Commit
·
e915c68
1
Parent(s):
853c083
add plotting capabilities
Browse files
app.py
CHANGED
|
@@ -1,13 +1,16 @@
|
|
| 1 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
| 2 |
from huggingface_hub import HfApi
|
|
|
|
|
|
|
| 3 |
import pandas as pd
|
| 4 |
import gradio as gr
|
| 5 |
import duckdb
|
| 6 |
import requests
|
| 7 |
import llama_cpp
|
| 8 |
import instructor
|
|
|
|
| 9 |
|
| 10 |
-
from pydantic import BaseModel
|
| 11 |
|
| 12 |
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
|
| 13 |
view_name = "dataset_view"
|
|
@@ -21,6 +24,7 @@ llama = llama_cpp.Llama(
|
|
| 21 |
chat_format="chatml",
|
| 22 |
n_ctx=2048,
|
| 23 |
verbose=False,
|
|
|
|
| 24 |
)
|
| 25 |
|
| 26 |
create = instructor.patch(
|
|
@@ -29,8 +33,23 @@ create = instructor.patch(
|
|
| 29 |
)
|
| 30 |
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
class SQLResponse(BaseModel):
|
| 33 |
sql: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
def get_dataset_ddl(dataset_id: str) -> str:
|
|
@@ -63,7 +82,7 @@ CREATE TABLE {} (
|
|
| 63 |
return sql_ddl
|
| 64 |
|
| 65 |
|
| 66 |
-
def
|
| 67 |
ddl = get_dataset_ddl(dataset_id)
|
| 68 |
|
| 69 |
system_prompt = f"""
|
|
@@ -76,6 +95,8 @@ def generate_sql(dataset_id: str, query: str) -> str:
|
|
| 76 |
Please assist the user by writing a SQL query that answers the user's question.
|
| 77 |
"""
|
| 78 |
|
|
|
|
|
|
|
| 79 |
resp: SQLResponse = create(
|
| 80 |
model="Hermes-2-Pro-Llama-3-8B",
|
| 81 |
messages=[
|
|
@@ -88,15 +109,28 @@ def generate_sql(dataset_id: str, query: str) -> str:
|
|
| 88 |
response_model=SQLResponse,
|
| 89 |
)
|
| 90 |
|
| 91 |
-
|
| 92 |
|
|
|
|
| 93 |
|
| 94 |
-
def query_dataset(dataset_id: str, query: str) -> tuple[pd.DataFrame, str]:
|
| 95 |
-
sql_query = generate_sql(dataset_id, query)
|
| 96 |
-
df = conn.execute(sql_query).fetchdf()
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
|
| 102 |
with gr.Blocks() as demo:
|
|
@@ -105,19 +139,20 @@ with gr.Blocks() as demo:
|
|
| 105 |
label="Hub Dataset ID",
|
| 106 |
placeholder="Find your favorite dataset...",
|
| 107 |
search_type="dataset",
|
| 108 |
-
value="
|
| 109 |
)
|
| 110 |
user_query = gr.Textbox("", label="Ask anything...")
|
| 111 |
|
| 112 |
btn = gr.Button("Ask 🪄")
|
| 113 |
|
| 114 |
-
df = gr.DataFrame()
|
| 115 |
sql_query = gr.Markdown(label="Output SQL Query")
|
|
|
|
|
|
|
| 116 |
|
| 117 |
btn.click(
|
| 118 |
query_dataset,
|
| 119 |
inputs=[dataset_id, user_query],
|
| 120 |
-
outputs=[df, sql_query],
|
| 121 |
)
|
| 122 |
|
| 123 |
|
|
|
|
| 1 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
| 2 |
from huggingface_hub import HfApi
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from typing import Tuple, Optional
|
| 5 |
import pandas as pd
|
| 6 |
import gradio as gr
|
| 7 |
import duckdb
|
| 8 |
import requests
|
| 9 |
import llama_cpp
|
| 10 |
import instructor
|
| 11 |
+
import enum
|
| 12 |
|
| 13 |
+
from pydantic import BaseModel, Field
|
| 14 |
|
| 15 |
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
|
| 16 |
view_name = "dataset_view"
|
|
|
|
| 24 |
chat_format="chatml",
|
| 25 |
n_ctx=2048,
|
| 26 |
verbose=False,
|
| 27 |
+
temperature=0.1,
|
| 28 |
)
|
| 29 |
|
| 30 |
create = instructor.patch(
|
|
|
|
| 33 |
)
|
| 34 |
|
| 35 |
|
| 36 |
+
class OutputTypes(str, enum.Enum):
|
| 37 |
+
TABLE = "table"
|
| 38 |
+
BARCHART = "barchart"
|
| 39 |
+
LINECHART = "linechart"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
class SQLResponse(BaseModel):
|
| 43 |
sql: str
|
| 44 |
+
visualization_type: Optional[OutputTypes] = Field(
|
| 45 |
+
None, description="The type of visualization to display"
|
| 46 |
+
)
|
| 47 |
+
data_key: Optional[str] = Field(
|
| 48 |
+
None, description="The column name that contains the data for chart responses"
|
| 49 |
+
)
|
| 50 |
+
label_key: Optional[str] = Field(
|
| 51 |
+
None, description="The column name that contains the labels for chart responses"
|
| 52 |
+
)
|
| 53 |
|
| 54 |
|
| 55 |
def get_dataset_ddl(dataset_id: str) -> str:
|
|
|
|
| 82 |
return sql_ddl
|
| 83 |
|
| 84 |
|
| 85 |
+
def generate_query(dataset_id: str, query: str) -> str:
|
| 86 |
ddl = get_dataset_ddl(dataset_id)
|
| 87 |
|
| 88 |
system_prompt = f"""
|
|
|
|
| 95 |
Please assist the user by writing a SQL query that answers the user's question.
|
| 96 |
"""
|
| 97 |
|
| 98 |
+
print("Calling LLM with system prompt: ", system_prompt)
|
| 99 |
+
|
| 100 |
resp: SQLResponse = create(
|
| 101 |
model="Hermes-2-Pro-Llama-3-8B",
|
| 102 |
messages=[
|
|
|
|
| 109 |
response_model=SQLResponse,
|
| 110 |
)
|
| 111 |
|
| 112 |
+
print("Received Response: ", resp)
|
| 113 |
|
| 114 |
+
return resp
|
| 115 |
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
+
def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.Figure]:
|
| 118 |
+
response: SQLResponse = generate_query(dataset_id, query)
|
| 119 |
+
df = conn.execute(response.sql).fetchdf()
|
| 120 |
+
|
| 121 |
+
plot = None
|
| 122 |
+
|
| 123 |
+
if response.visualization_type == OutputTypes.LINECHART:
|
| 124 |
+
plot = df.plot(
|
| 125 |
+
kind="line", x=response.data_key, y=response.label_key
|
| 126 |
+
).get_figure()
|
| 127 |
+
elif response.visualization_type == OutputTypes.BARCHART:
|
| 128 |
+
plot = df.plot(
|
| 129 |
+
kind="bar", x=response.data_key, y=response.label_key
|
| 130 |
+
).get_figure()
|
| 131 |
+
|
| 132 |
+
markdown_output = f"""```sql\n{response.sql}\n```"""
|
| 133 |
+
return df, markdown_output, plot
|
| 134 |
|
| 135 |
|
| 136 |
with gr.Blocks() as demo:
|
|
|
|
| 139 |
label="Hub Dataset ID",
|
| 140 |
placeholder="Find your favorite dataset...",
|
| 141 |
search_type="dataset",
|
| 142 |
+
value="teknium/OpenHermes-2.5",
|
| 143 |
)
|
| 144 |
user_query = gr.Textbox("", label="Ask anything...")
|
| 145 |
|
| 146 |
btn = gr.Button("Ask 🪄")
|
| 147 |
|
|
|
|
| 148 |
sql_query = gr.Markdown(label="Output SQL Query")
|
| 149 |
+
df = gr.DataFrame()
|
| 150 |
+
plot = gr.Plot()
|
| 151 |
|
| 152 |
btn.click(
|
| 153 |
query_dataset,
|
| 154 |
inputs=[dataset_id, user_query],
|
| 155 |
+
outputs=[df, sql_query, plot],
|
| 156 |
)
|
| 157 |
|
| 158 |
|