Spaces:
Sleeping
Sleeping
| from langchain.chat_models import ChatOpenAI | |
| from langchain.chains import ConversationChain | |
| from langchain.chains.conversation.memory import ConversationBufferWindowMemory | |
| from langchain.prompts import ( | |
| SystemMessagePromptTemplate, | |
| HumanMessagePromptTemplate, | |
| ChatPromptTemplate, | |
| MessagesPlaceholder | |
| ) | |
| import streamlit as st | |
| from streamlit_chat import message | |
| from utils import * | |
| import os | |
| from dotenv import load_dotenv | |
| # Load environment variables from the .env file | |
| load_dotenv() | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| if OPENAI_API_KEY is None: | |
| raise ValueError("OpenAI API key is not found in the .env file") | |
| st.subheader("Article Chatbot") | |
| if 'responses' not in st.session_state: | |
| st.session_state['responses'] = ["How can I assist you?"] | |
| if 'requests' not in st.session_state: | |
| st.session_state['requests'] = [] | |
| llm = ChatOpenAI(model_name="gpt-3.5-turbo", openai_api_key=OPENAI_API_KEY) | |
| if 'buffer_memory' not in st.session_state: | |
| st.session_state.buffer_memory=ConversationBufferWindowMemory(k=3,return_messages=True) | |
| system_msg_template = SystemMessagePromptTemplate.from_template(template="""Answer the question as truthfully as possible using the provided context, | |
| and if the answer is not contained within the text below, say 'I don't know'""") | |
| human_msg_template = HumanMessagePromptTemplate.from_template(template="{input}") | |
| prompt_template = ChatPromptTemplate.from_messages([system_msg_template, MessagesPlaceholder(variable_name="history"), human_msg_template]) | |
| conversation = ConversationChain(memory=st.session_state.buffer_memory, prompt=prompt_template, llm=llm, verbose=True) | |
| # container for chat history | |
| response_container = st.container() | |
| # container for text box | |
| textcontainer = st.container() | |
| with textcontainer: | |
| query = st.text_input("Query: ", key="input") | |
| if query: | |
| with st.spinner("typing..."): | |
| conversation_string = get_conversation_string() | |
| # st.code(conversation_string) | |
| refined_query = query_refiner(conversation_string, query) | |
| st.subheader("Refined Query:") | |
| st.write(refined_query) | |
| context = find_match(refined_query) | |
| # print(context) | |
| response = conversation.predict(input=f"Context:\n {context} \n\n Query:\n{query}") | |
| st.session_state.requests.append(query) | |
| st.session_state.responses.append(response) | |
| with response_container: | |
| if st.session_state['responses']: | |
| for i in range(len(st.session_state['responses'])): | |
| message(st.session_state['responses'][i],key=str(i)) | |
| if i < len(st.session_state['requests']): | |
| message(st.session_state["requests"][i], is_user=True,key=str(i)+ '_user') |