Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from langchain.callbacks.streamlit.streamlit_callback_handler import ( | |
| StreamlitCallbackHandler, | |
| ) | |
| from langchain.schema.output import LLMResult | |
| from sql_formatter.core import format_sql | |
| class VectorSQLSearchDBCallBackHandler(StreamlitCallbackHandler): | |
| def __init__(self) -> None: | |
| self.progress_bar = st.progress(value=0.0, text="Writing SQL...") | |
| self.status_bar = st.empty() | |
| self.prog_value = 0 | |
| self.prog_interval = 0.2 | |
| def on_llm_start(self, serialized, prompts, **kwargs) -> None: | |
| pass | |
| def on_llm_end( | |
| self, | |
| response: LLMResult, | |
| *args, | |
| **kwargs, | |
| ): | |
| text = response.generations[0][0].text | |
| if text.replace(" ", "").upper().startswith("SELECT"): | |
| st.markdown("### Generated Vector Search SQL Statement \n" | |
| "> This sql statement is generated by LLM \n\n") | |
| st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""") | |
| self.prog_value += self.prog_interval | |
| self.progress_bar.progress( | |
| value=self.prog_value, text="Searching in DB...") | |
| def on_chain_start(self, serialized, inputs, **kwargs) -> None: | |
| cid = ".".join(serialized["id"]) | |
| self.prog_value += self.prog_interval | |
| self.progress_bar.progress( | |
| value=self.prog_value, text=f"Running Chain `{cid}`..." | |
| ) | |
| def on_chain_end(self, outputs, **kwargs) -> None: | |
| pass | |
| class VectorSQLSearchLLMCallBackHandler(VectorSQLSearchDBCallBackHandler): | |
| def __init__(self, table: str) -> None: | |
| self.progress_bar = st.progress(value=0.0, text="Writing SQL...") | |
| self.status_bar = st.empty() | |
| self.prog_value = 0 | |
| self.prog_interval = 0.1 | |
| self.table = table | |