Spaces:
Runtime error
Runtime error
| from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
| from huggingface_hub import HfApi | |
| import matplotlib.pyplot as plt | |
| from typing import Tuple, Optional | |
| import pandas as pd | |
| import gradio as gr | |
| import duckdb | |
| import requests | |
| import llama_cpp | |
| import instructor | |
| import enum | |
| from pydantic import BaseModel, Field | |
| BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co" | |
| view_name = "dataset_view" | |
| hf_api = HfApi() | |
| conn = duckdb.connect() | |
| llama = llama_cpp.Llama( | |
| model_path="Hermes-2-Pro-Llama-3-8B-Q8_0.gguf", | |
| n_gpu_layers=-1, | |
| chat_format="chatml", | |
| n_ctx=2048, | |
| verbose=False, | |
| temperature=0.1, | |
| ) | |
| create = instructor.patch( | |
| create=llama.create_chat_completion_openai_v1, | |
| mode=instructor.Mode.JSON_SCHEMA, | |
| ) | |
| class OutputTypes(str, enum.Enum): | |
| TABLE = "table" | |
| BARCHART = "barchart" | |
| LINECHART = "linechart" | |
| class SQLResponse(BaseModel): | |
| sql: str | |
| visualization_type: Optional[OutputTypes] = Field( | |
| None, description="The type of visualization to display" | |
| ) | |
| data_key: Optional[str] = Field( | |
| None, description="The column name that contains the data for chart responses" | |
| ) | |
| label_key: Optional[str] = Field( | |
| None, description="The column name that contains the labels for chart responses" | |
| ) | |
| def get_dataset_ddl(dataset_id: str) -> str: | |
| response = requests.get(f"{BASE_DATASETS_SERVER_URL}/parquet?dataset={dataset_id}") | |
| response.raise_for_status() # Check if the request was successful | |
| first_parquet = response.json().get("parquet_files", [])[0] | |
| first_parquet_url = first_parquet.get("url") | |
| if not first_parquet_url: | |
| raise ValueError("No valid URL found for the first parquet file.") | |
| conn.execute( | |
| f"CREATE OR REPLACE VIEW {view_name} as SELECT * FROM read_parquet('{first_parquet_url}');" | |
| ) | |
| dataset_ddl = conn.execute(f"PRAGMA table_info('{view_name}');").fetchall() | |
| column_data_types = ",\n\t".join( | |
| [f"{column[1]} {column[2]}" for column in dataset_ddl] | |
| ) | |
| sql_ddl = """ | |
| CREATE TABLE {} ( | |
| {} | |
| ); | |
| """.format( | |
| view_name, column_data_types | |
| ) | |
| return sql_ddl | |
| def generate_query(dataset_id: str, query: str) -> str: | |
| ddl = get_dataset_ddl(dataset_id) | |
| system_prompt = f""" | |
| You are an expert SQL assistant with access to the following DuckDB Table: | |
| ```sql | |
| {ddl} | |
| ``` | |
| Please assist the user by writing a SQL query that answers the user's question. | |
| """ | |
| print("Calling LLM with system prompt: ", system_prompt) | |
| resp: SQLResponse = create( | |
| model="Hermes-2-Pro-Llama-3-8B", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| { | |
| "role": "user", | |
| "content": query, | |
| }, | |
| ], | |
| response_model=SQLResponse, | |
| ) | |
| print("Received Response: ", resp) | |
| return resp | |
| def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.Figure]: | |
| response: SQLResponse = generate_query(dataset_id, query) | |
| df = conn.execute(response.sql).fetchdf() | |
| plot = None | |
| if response.visualization_type == OutputTypes.LINECHART: | |
| plot = df.plot( | |
| kind="line", x=response.data_key, y=response.label_key | |
| ).get_figure() | |
| elif response.visualization_type == OutputTypes.BARCHART: | |
| plot = df.plot( | |
| kind="bar", x=response.data_key, y=response.label_key | |
| ).get_figure() | |
| markdown_output = f"""```sql\n{response.sql}\n```""" | |
| return df, markdown_output, plot | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Query your HF Datasets with Natural Language ππ") | |
| dataset_id = HuggingfaceHubSearch( | |
| label="Hub Dataset ID", | |
| placeholder="Find your favorite dataset...", | |
| search_type="dataset", | |
| value="teknium/OpenHermes-2.5", | |
| ) | |
| user_query = gr.Textbox("", label="Ask anything...") | |
| btn = gr.Button("Ask πͺ") | |
| sql_query = gr.Markdown(label="Output SQL Query") | |
| df = gr.DataFrame() | |
| plot = gr.Plot() | |
| btn.click( | |
| query_dataset, | |
| inputs=[dataset_id, user_query], | |
| outputs=[df, sql_query, plot], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |