Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| from langchain_groq import ChatGroq | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain_community.embeddings import OllamaEmbeddings | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain.chains import create_retrieval_chain | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_community.document_loaders import PyPDFDirectoryLoader | |
| from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
| # from langchain.vectorstores.cassandra import Cassandra | |
| from langchain_community.vectorstores import Cassandra | |
| from langchain_community.llms import Ollama | |
| from cassandra.auth import PlainTextAuthProvider | |
| import tempfile | |
| import cassio | |
| from PyPDF2 import PdfReader | |
| from cassandra.cluster import Cluster | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| from dotenv import load_dotenv | |
| import time | |
| load_dotenv() | |
| ASTRA_DB_SECURE_BUNDLE_PATH ='secure-connect-pdf-query-db.zip' | |
| os.environ["LANGCHAIN_TRACING_V2"]="true" | |
| LANGCHAIN_API_KEY=os.getenv("LANGCHAIN_API_KEY") | |
| LANGCHAIN_PROJECT=os.getenv("LANGCHAIN_PROJECT") | |
| LANGCHAIN_ENDPOINT=os.getenv("LANGCHAIN_ENDPOINT") | |
| ASTRA_DB_APPLICATION_TOKEN=os.getenv("ASTRA_DB_APPLICATION_TOKEN") | |
| ASTRA_DB_ID=os.getenv("ASTRA_DB_ID") | |
| ASTRA_DB_KEYSPACE=os.getenv("ASTRA_DB_KEYSPACE") | |
| ASTRA_DB_API_ENDPOINT=os.getenv("ASTRA_DB_API_ENDPOINT") | |
| ASTRA_DB_CLIENT_ID=os.getenv("ASTRA_DB_CLIENT_ID") | |
| ASTRA_DB_CLIENT_SECRET=os.getenv("ASTRA_DB_CLIENT_SECRET") | |
| ASTRA_DB_TABLE=os.getenv("ASTRA_DB_TABLE") | |
| groq_api_key=os.getenv('groq_api_key') | |
| cassio.init(token=ASTRA_DB_APPLICATION_TOKEN,database_id=ASTRA_DB_ID,secure_connect_bundle=ASTRA_DB_SECURE_BUNDLE_PATH) | |
| cloud_config = { | |
| 'secure_connect_bundle': ASTRA_DB_SECURE_BUNDLE_PATH | |
| } | |
| def doc_loader(pdf_reader): | |
| encode_kwargs = {'normalize_embeddings': True} | |
| huggigface_embeddings=HuggingFaceBgeEmbeddings( | |
| model_name='BAAI/bge-small-en-v1.5', | |
| # model_name='sentence-transformers/all-MiniLM-16-v2', | |
| model_kwargs={'device':'cpu'}, | |
| encode_kwargs=encode_kwargs) | |
| loader=PyPDFLoader(pdf_reader) | |
| documents=loader.load_and_split() | |
| text_splitter=RecursiveCharacterTextSplitter(chunk_size=1000,chunk_overlap=200) | |
| final_documents=text_splitter.split_documents(documents) | |
| astrasession = Cluster( | |
| cloud={"secure_connect_bundle": ASTRA_DB_SECURE_BUNDLE_PATH}, | |
| auth_provider=PlainTextAuthProvider("token", ASTRA_DB_APPLICATION_TOKEN), | |
| ).connect() | |
| # Truncate the existing table | |
| astrasession.execute(f'TRUNCATE {ASTRA_DB_KEYSPACE}.{ASTRA_DB_TABLE}') | |
| astra_vector_store=Cassandra( | |
| embedding=huggigface_embeddings, | |
| table_name="qa_mini_demo", | |
| session=astrasession, | |
| keyspace=ASTRA_DB_KEYSPACE | |
| ) | |
| astra_vector_store.add_documents(final_documents) | |
| return astra_vector_store | |
| def prompt_temp(): | |
| prompt=ChatPromptTemplate.from_template( | |
| """ | |
| Answer the question based on the provided context only. | |
| Please provide the most accurate response based on the question. | |
| {context}, | |
| Questions:{input} | |
| """ | |
| ) | |
| return prompt | |
| def generate_response(llm,prompt,user_input,vectorstore): | |
| document_chain=create_stuff_documents_chain(llm,prompt) | |
| retriever=vectorstore.as_retriever(search_type="similarity",search_kwargs={"k":5}) | |
| retrieval_chain=create_retrieval_chain(retriever,document_chain) | |
| response=retrieval_chain.invoke({"input":user_input}) | |
| return response | |
| # ['answer'] | |
| def main(): | |
| st.set_page_config(page_title='Chat Groq Demo') | |
| st.header('Chat Groq Demo') | |
| user_input=st.text_input('Enter the Prompt here') | |
| file=st.file_uploader('Choose Invoice File',type='pdf') | |
| submit = st.button("Submit") | |
| st.session_state.submit_clicked = False | |
| if submit : | |
| st.session_state.submit_clicked = True | |
| if user_input and file: | |
| with tempfile.NamedTemporaryFile(delete=False) as temp_file: | |
| temp_file.write(file.getbuffer()) | |
| file_path = temp_file.name | |
| # with open(file.name, mode='wb') as w: | |
| # # w.write(file.getvalue()) | |
| # w.write(file.getbuffer()) | |
| llm=ChatGroq(groq_api_key=groq_api_key,model_name="gemma-7b-it") | |
| prompt=prompt_temp() | |
| vectorstore=doc_loader(file_path) | |
| response=generate_response(llm,prompt,user_input,vectorstore) | |
| st.write(response['answer']) | |
| # with st.expander("Document Similarity Search"): | |
| # for i,doc in enumerate(response['context']): | |
| # st.write(doc.page_content) | |
| # st.write('---------------------------------') | |
| if __name__=="__main__": | |
| main() | |