NHZ commited on
Commit
05b86d4
·
verified ·
1 Parent(s): c69c25e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -63
app.py CHANGED
@@ -1,58 +1,45 @@
 
1
  import requests
2
  import numpy as np
3
  import faiss
4
  from PyPDF2 import PdfReader
5
- from transformers import AutoTokenizer, AutoModel
6
- from groq import Groq
 
 
 
 
7
  import streamlit as st
8
- import torch
9
- import os
10
 
11
- # Initialize Groq client using secret API key
12
- client = Groq(api_key=os.getenv("GROQ_API_KEY"))
13
 
14
- # Function to download and extract content from a public Google Drive PDF link
15
  def extract_pdf_content(drive_url):
16
- # Extract file ID from the Google Drive URL
17
  file_id = drive_url.split("/d/")[1].split("/view")[0]
18
  download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
19
-
20
- # Download the PDF content
21
  response = requests.get(download_url)
22
  if response.status_code != 200:
23
  return None
24
 
25
- # Save and extract text from the PDF
26
  with open("document.pdf", "wb") as f:
27
  f.write(response.content)
 
28
  reader = PdfReader("document.pdf")
29
  text = ""
30
  for page in reader.pages:
31
  text += page.extract_text()
32
  return text
33
 
34
- # Function to chunk and tokenize text
35
- def chunk_and_tokenize(text, tokenizer, chunk_size=512):
36
- tokens = tokenizer.encode(text, add_special_tokens=False)
37
- chunks = [tokens[i:i + chunk_size] for i in range(0, len(tokens), chunk_size)]
38
- return chunks
39
-
40
- # Function to compute embeddings and build FAISS index
41
- def build_faiss_index(chunks, model):
42
- embeddings = []
43
- for chunk in chunks:
44
- input_ids = torch.tensor([chunk])
45
- with torch.no_grad():
46
- embedding = model(input_ids).last_hidden_state.mean(dim=1).detach().numpy()
47
- embeddings.append(embedding)
48
- embeddings = np.vstack(embeddings)
49
-
50
- index = faiss.IndexFlatL2(embeddings.shape[1])
51
- index.add(embeddings)
52
- return index
53
 
54
  # Streamlit app
55
- st.title("RAG-based Application with Groq API")
56
 
57
  # Predefined Google Drive link
58
  drive_url = "https://drive.google.com/file/d/1XvqA1OIssRs2gbmOtKFKj-02yQ5X2yg0/view?usp=sharing"
@@ -62,45 +49,40 @@ st.write("Extracting content from the document...")
62
  text = extract_pdf_content(drive_url)
63
  if text:
64
  st.write("Document extracted successfully!")
65
-
66
- # Initialize tokenizer and model
67
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
68
- model = AutoModel.from_pretrained("bert-base-uncased")
69
 
70
- st.write("Chunking and tokenizing content...")
71
- chunks = chunk_and_tokenize(text, tokenizer)
72
 
73
- st.write("Building FAISS index...")
74
- index = build_faiss_index(chunks, model)
75
 
76
- # Query input
77
  query = st.text_input("Enter your query:")
78
  if query:
79
- st.write("Searching for the most relevant chunk...")
80
- query_tokens = tokenizer.encode(query, add_special_tokens=False)
81
- query_embedding = (
82
- model(torch.tensor([query_tokens]))
83
- .last_hidden_state.mean(dim=1)
84
- .detach().numpy()
 
 
 
 
 
 
 
 
85
  )
86
- _, indices = index.search(query_embedding, k=1)
87
-
88
- # Retrieve the most relevant chunk
89
- relevant_chunk = chunks[indices[0][0]]
90
- relevant_text = tokenizer.decode(relevant_chunk)
91
- st.write("Relevant chunk found:", relevant_text)
92
 
93
- # Interact with Groq API
94
- st.write("Querying the Groq API...")
95
- chat_completion = client.chat.completions.create(
96
- messages=[
97
- {
98
- "role": "user",
99
- "content": relevant_text,
100
- }
101
- ],
102
- model="llama-3.3-70b-versatile",
103
  )
104
- st.write("Model Response:", chat_completion.choices[0].message.content)
 
 
 
105
  else:
106
  st.error("Failed to extract content from the document.")
 
 
1
+ import os
2
  import requests
3
  import numpy as np
4
  import faiss
5
  from PyPDF2 import PdfReader
6
+ from sentence_transformers import SentenceTransformer
7
+ from langchain.vectorstores import FAISS
8
+ from langchain.embeddings import HuggingFaceEmbeddings
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain.llms import GroqLLM
12
  import streamlit as st
 
 
13
 
14
+ # Initialize Groq API LLM
15
+ llm = GroqLLM(api_key=os.getenv("GROQ_API_KEY"))
16
 
17
+ # Function to extract content from a public Google Drive PDF link
18
  def extract_pdf_content(drive_url):
 
19
  file_id = drive_url.split("/d/")[1].split("/view")[0]
20
  download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
 
 
21
  response = requests.get(download_url)
22
  if response.status_code != 200:
23
  return None
24
 
 
25
  with open("document.pdf", "wb") as f:
26
  f.write(response.content)
27
+
28
  reader = PdfReader("document.pdf")
29
  text = ""
30
  for page in reader.pages:
31
  text += page.extract_text()
32
  return text
33
 
34
+ # Function to create a FAISS vector store from the document content
35
+ def create_vector_store(text):
36
+ sentences = text.split(". ")
37
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
38
+ vector_store = FAISS.from_texts(sentences, embedding=embeddings)
39
+ return vector_store, sentences
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # Streamlit app
42
+ st.title("RAG-based Application with Focused Context")
43
 
44
  # Predefined Google Drive link
45
  drive_url = "https://drive.google.com/file/d/1XvqA1OIssRs2gbmOtKFKj-02yQ5X2yg0/view?usp=sharing"
 
49
  text = extract_pdf_content(drive_url)
50
  if text:
51
  st.write("Document extracted successfully!")
 
 
 
 
52
 
53
+ st.write("Creating vector store...")
54
+ vector_store, sentences = create_vector_store(text)
55
 
56
+ st.write("Vector store created successfully!")
 
57
 
 
58
  query = st.text_input("Enter your query:")
59
  if query:
60
+ st.write("Retrieving relevant context from the document...")
61
+ retriever = vector_store.as_retriever()
62
+ retriever.search_kwargs["k"] = 3 # Retrieve top 3 matches
63
+
64
+ # Define a prompt template to guide LLM response generation
65
+ prompt_template = PromptTemplate(
66
+ template="""
67
+ Use the following context to answer the question:
68
+
69
+ {context}
70
+
71
+ Question: {question}
72
+ Answer:""",
73
+ input_variables=["context", "question"]
74
  )
 
 
 
 
 
 
75
 
76
+ # Create a RetrievalQA chain
77
+ qa_chain = RetrievalQA(
78
+ retriever=retriever,
79
+ llm=llm,
80
+ prompt=prompt_template
 
 
 
 
 
81
  )
82
+
83
+ # Run the query through the QA chain
84
+ result = qa_chain.run(query)
85
+ st.write("Answer:", result)
86
  else:
87
  st.error("Failed to extract content from the document.")
88
+