Spaces:
Runtime error
Runtime error
Upload 4 files
Browse files
llm.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_openai import OpenAI
|
| 2 |
+
from langchain_community.llms import HuggingFaceHub
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
load_dotenv()
|
| 7 |
+
|
| 8 |
+
openai_key = os.getenv("OPENAI_API_KEY")
|
| 9 |
+
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 10 |
+
|
| 11 |
+
llm = OpenAI(temperature=0.6, openai_api_key=openai_key)
|
| 12 |
+
|
| 13 |
+
#! Alternatively, can use HuggingFace hub's LLM
|
| 14 |
+
# llm = HuggingFaceHub(
|
| 15 |
+
# repo_id='google/flan-t5-large', model_kwargs={"temperature": 0.7, "max_length": 256}
|
| 16 |
+
# )
|
main.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from langserve import add_routes
|
| 3 |
+
from langchain.chains import ConversationChain
|
| 4 |
+
from memory import vectorstore_as_memory
|
| 5 |
+
from prompt import PROMPT
|
| 6 |
+
from llm import llm
|
| 7 |
+
|
| 8 |
+
app = FastAPI(title="Retrieval App")
|
| 9 |
+
|
| 10 |
+
# Initialize the conversation chain with a default memory
|
| 11 |
+
memory = vectorstore_as_memory("USER1")
|
| 12 |
+
final_chain = ConversationChain(
|
| 13 |
+
llm=llm,
|
| 14 |
+
prompt=PROMPT,
|
| 15 |
+
memory=memory,
|
| 16 |
+
verbose=False
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# Define a function to update the memory associated with the final_chain
|
| 20 |
+
def update_memory(username):
|
| 21 |
+
memory = vectorstore_as_memory(username)
|
| 22 |
+
final_chain.memory = memory
|
| 23 |
+
|
| 24 |
+
# Define a route to handle API calls
|
| 25 |
+
@app.post("/api/{username}")
|
| 26 |
+
async def api_endpoint(username: str):
|
| 27 |
+
update_memory(username)
|
| 28 |
+
return {"message": f"Memory updated successfully with username: {username}"}
|
| 29 |
+
|
| 30 |
+
# Add routes to the FastAPI app
|
| 31 |
+
add_routes(app, final_chain)
|
| 32 |
+
|
| 33 |
+
if __name__ == "__main__":
|
| 34 |
+
import uvicorn
|
| 35 |
+
|
| 36 |
+
uvicorn.run(app, host="localhost", port=8000)
|
memory.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
import os
|
| 3 |
+
from langchain.memory import VectorStoreRetrieverMemory
|
| 4 |
+
from langchain_community.vectorstores.redis import Redis
|
| 5 |
+
from langchain.embeddings import OpenAIEmbeddings
|
| 6 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 7 |
+
from langchain_core.runnables import ConfigurableField
|
| 8 |
+
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
redis_url = os.getenv("REDIS_URL")
|
| 12 |
+
openai_key = os.getenv("OPENAI_API_KEY")
|
| 13 |
+
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
embedding_fn = OpenAIEmbeddings(openai_api_key=openai_key)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
#! Alternatively, can use Hugging Face embeddings if you don't have one
|
| 20 |
+
# modelPath = "HuggingFaceH4/zephyr-7b-beta"
|
| 21 |
+
# model_kwargs = {'device':'cpu'}
|
| 22 |
+
# encode_kwargs = {'normalize_embeddings':False}
|
| 23 |
+
# embedding_fn = HuggingFaceEmbeddings(
|
| 24 |
+
# model_name = modelPath,
|
| 25 |
+
# model_kwargs = model_kwargs,
|
| 26 |
+
# encode_kwargs=encode_kwargs
|
| 27 |
+
# )
|
| 28 |
+
|
| 29 |
+
schema = {'text': [{'name': 'content',
|
| 30 |
+
'weight': 1,
|
| 31 |
+
'no_stem': False,
|
| 32 |
+
'withsuffixtrie': False,
|
| 33 |
+
'no_index': False,
|
| 34 |
+
'sortable': False}],
|
| 35 |
+
'vector': [{'name': 'content_vector',
|
| 36 |
+
'dims': 1536,
|
| 37 |
+
'algorithm': 'FLAT',
|
| 38 |
+
'datatype': 'FLOAT32',
|
| 39 |
+
'distance_metric': 'COSINE'}]}
|
| 40 |
+
|
| 41 |
+
def vectorstore_as_memory(username):
|
| 42 |
+
try:
|
| 43 |
+
new_rds = Redis.from_existing_index(
|
| 44 |
+
embedding=embedding_fn,
|
| 45 |
+
index_name=username,
|
| 46 |
+
redis_url=redis_url,
|
| 47 |
+
# schema=rds.schema,
|
| 48 |
+
schema=schema,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
retriever = new_rds.as_retriever(search_type="similarity", search_kwargs={"k": 3})
|
| 52 |
+
memory = VectorStoreRetrieverMemory(retriever=retriever)
|
| 53 |
+
return memory
|
| 54 |
+
|
| 55 |
+
except ValueError:
|
| 56 |
+
rds = Redis.from_texts(
|
| 57 |
+
texts=["Hi there"],
|
| 58 |
+
embedding=embedding_fn,
|
| 59 |
+
redis_url=redis_url,
|
| 60 |
+
index_name=username
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
retriever = rds.as_retriever(search_type="similarity", search_kwargs={"k": 3})
|
| 64 |
+
memory = VectorStoreRetrieverMemory(retriever=retriever)
|
| 65 |
+
return memory
|
| 66 |
+
|
prompt.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.prompts import PromptTemplate
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
_DEFAULT_TEMPLATE = """The following is a friendly conversation between a human and an AI.
|
| 5 |
+
The AI is talkative and provides lots of specific details from its context.
|
| 6 |
+
If the AI does not know the answer to a question, it truthfully says it does not know.
|
| 7 |
+
|
| 8 |
+
Relevant pieces of previous conversation:
|
| 9 |
+
{history}
|
| 10 |
+
|
| 11 |
+
(Note that you do not need to use these pieces of information if not relevant)
|
| 12 |
+
|
| 13 |
+
Current conversation:
|
| 14 |
+
Human: {input}
|
| 15 |
+
AI:"""
|
| 16 |
+
PROMPT = PromptTemplate(
|
| 17 |
+
input_variables=["history", "input"], template=_DEFAULT_TEMPLATE
|
| 18 |
+
)
|