|
|
from __future__ import annotations |
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union |
|
|
|
|
|
from langchain_core.language_models import BaseLanguageModel |
|
|
from langchain_core.output_parsers import StrOutputParser |
|
|
from langchain_core.prompts import BasePromptTemplate |
|
|
from langchain_core.runnables import Runnable, RunnablePassthrough |
|
|
|
|
|
from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from langchain_community.utilities.sql_database import SQLDatabase |
|
|
|
|
|
|
|
|
def _strip(text: str) -> str: |
|
|
return text.strip() |
|
|
|
|
|
|
|
|
class SQLInput(TypedDict): |
|
|
"""Input for a SQL Chain.""" |
|
|
|
|
|
question: str |
|
|
|
|
|
|
|
|
class SQLInputWithTables(TypedDict): |
|
|
"""Input for a SQL Chain.""" |
|
|
|
|
|
question: str |
|
|
table_names_to_use: List[str] |
|
|
|
|
|
|
|
|
def create_sql_query_chain( |
|
|
llm: BaseLanguageModel, |
|
|
db: SQLDatabase, |
|
|
prompt: Optional[BasePromptTemplate] = None, |
|
|
k: int = 5, |
|
|
) -> Runnable[Union[SQLInput, SQLInputWithTables, Dict[str, Any]], str]: |
|
|
"""Create a chain that generates SQL queries. |
|
|
|
|
|
*Security Note*: This chain generates SQL queries for the given database. |
|
|
|
|
|
The SQLDatabase class provides a get_table_info method that can be used |
|
|
to get column information as well as sample data from the table. |
|
|
|
|
|
To mitigate risk of leaking sensitive data, limit permissions |
|
|
to read and scope to the tables that are needed. |
|
|
|
|
|
Optionally, use the SQLInputWithTables input type to specify which tables |
|
|
are allowed to be accessed. |
|
|
|
|
|
Control access to who can submit requests to this chain. |
|
|
|
|
|
See https://python.langchain.com/docs/security for more information. |
|
|
|
|
|
Args: |
|
|
llm: The language model to use. |
|
|
db: The SQLDatabase to generate the query for. |
|
|
prompt: The prompt to use. If none is provided, will choose one |
|
|
based on dialect. Defaults to None. See Prompt section below for more. |
|
|
k: The number of results per select statement to return. Defaults to 5. |
|
|
|
|
|
Returns: |
|
|
A chain that takes in a question and generates a SQL query that answers |
|
|
that question. |
|
|
|
|
|
Example: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
# pip install -U langchain langchain-community langchain-openai |
|
|
from langchain_openai import ChatOpenAI |
|
|
from langchain.chains import create_sql_query_chain |
|
|
from langchain_community.utilities import SQLDatabase |
|
|
|
|
|
db = SQLDatabase.from_uri("sqlite:///Chinook.db") |
|
|
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) |
|
|
chain = create_sql_query_chain(llm, db) |
|
|
response = chain.invoke({"question": "How many employees are there"}) |
|
|
|
|
|
Prompt: |
|
|
If no prompt is provided, a default prompt is selected based on the SQLDatabase dialect. If one is provided, it must support input variables: |
|
|
* input: The user question plus suffix "\nSQLQuery: " is passed here. |
|
|
* top_k: The number of results per select statement (the `k` argument to |
|
|
this function) is passed in here. |
|
|
* table_info: Table definitions and sample rows are passed in here. If the |
|
|
user specifies "table_names_to_use" when invoking chain, only those |
|
|
will be included. Otherwise, all tables are included. |
|
|
* dialect (optional): If dialect input variable is in prompt, the db |
|
|
dialect will be passed in here. |
|
|
|
|
|
Here's an example prompt: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
from langchain_core.prompts import PromptTemplate |
|
|
|
|
|
template = '''Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. |
|
|
Use the following format: |
|
|
|
|
|
Question: "Question here" |
|
|
SQLQuery: "SQL Query to run" |
|
|
SQLResult: "Result of the SQLQuery" |
|
|
Answer: "Final answer here" |
|
|
|
|
|
Only use the following tables: |
|
|
|
|
|
{table_info}. |
|
|
|
|
|
Question: {input}''' |
|
|
prompt = PromptTemplate.from_template(template) |
|
|
""" |
|
|
if prompt is not None: |
|
|
prompt_to_use = prompt |
|
|
elif db.dialect in SQL_PROMPTS: |
|
|
prompt_to_use = SQL_PROMPTS[db.dialect] |
|
|
else: |
|
|
prompt_to_use = PROMPT |
|
|
if {"input", "top_k", "table_info"}.difference( |
|
|
prompt_to_use.input_variables + list(prompt_to_use.partial_variables) |
|
|
): |
|
|
raise ValueError( |
|
|
f"Prompt must have input variables: 'input', 'top_k', " |
|
|
f"'table_info'. Received prompt with input variables: " |
|
|
f"{prompt_to_use.input_variables}. Full prompt:\n\n{prompt_to_use}" |
|
|
) |
|
|
if "dialect" in prompt_to_use.input_variables: |
|
|
prompt_to_use = prompt_to_use.partial(dialect=db.dialect) |
|
|
|
|
|
inputs = { |
|
|
"input": lambda x: x["question"] + "\nSQLQuery: ", |
|
|
"table_info": lambda x: db.get_table_info( |
|
|
table_names=x.get("table_names_to_use") |
|
|
), |
|
|
} |
|
|
return ( |
|
|
RunnablePassthrough.assign(**inputs) |
|
|
| ( |
|
|
lambda x: { |
|
|
k: v |
|
|
for k, v in x.items() |
|
|
if k not in ("question", "table_names_to_use") |
|
|
} |
|
|
) |
|
|
| prompt_to_use.partial(top_k=str(k)) |
|
|
| llm.bind(stop=["\nSQLResult:"]) |
|
|
| StrOutputParser() |
|
|
| _strip |
|
|
) |
|
|
|