Avinashstat commited on
Commit
72a7de1
·
verified ·
1 Parent(s): e3f2ca1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import numpy as np
3
+ import streamlit as st
4
+ from pypdf import PdfReader
5
+
6
+ from sentence_transformers import SentenceTransformer
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
+
9
+
10
+ # -------------------- Config -------------------- #
11
+
12
+ EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
13
+ LLM_MODEL_NAME = "google/gemma-2b-it" # you can change this later
14
+
15
+
16
+ # -------------------- Model loaders (cached) -------------------- #
17
+
18
+ @st.cache_resource(show_spinner=True)
19
+ def load_embedder():
20
+ return SentenceTransformer(EMBEDDING_MODEL_NAME)
21
+
22
+
23
+ @st.cache_resource(show_spinner=True)
24
+ def load_llm_pipeline():
25
+ """
26
+ Load a text-generation pipeline for the LLM.
27
+ Using device_map="auto" will use GPU if available.
28
+ """
29
+ tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ LLM_MODEL_NAME,
32
+ device_map="auto",
33
+ )
34
+ gen_pipe = pipeline(
35
+ "text-generation",
36
+ model=model,
37
+ tokenizer=tokenizer,
38
+ max_new_tokens=512,
39
+ do_sample=False,
40
+ temperature=0.1,
41
+ top_p=0.9,
42
+ )
43
+ return gen_pipe
44
+
45
+
46
+ # -------------------- Helpers -------------------- #
47
+
48
+ def extract_text_from_pdf(file) -> str:
49
+ """Extract all text from an uploaded PDF file."""
50
+ pdf_reader = PdfReader(file)
51
+ all_text = []
52
+ for page in pdf_reader.pages:
53
+ text = page.extract_text()
54
+ if text:
55
+ all_text.append(text)
56
+ return "\n".join(all_text)
57
+
58
+
59
+ def chunk_text(text, chunk_size=800, overlap=200):
60
+ """Split long text into overlapping chunks (by words)."""
61
+ words = text.split()
62
+ chunks = []
63
+ start = 0
64
+ while start < len(words):
65
+ end = start + chunk_size
66
+ chunk = " ".join(words[start:end])
67
+ chunks.append(chunk)
68
+ start += chunk_size - overlap
69
+ return chunks
70
+
71
+
72
+ def embed_texts(texts, embedder: SentenceTransformer):
73
+ """Get embeddings for a list of texts."""
74
+ if not texts:
75
+ return np.array([])
76
+ embeddings = embedder.encode(texts, convert_to_numpy=True, show_progress_bar=False)
77
+ return embeddings.astype("float32")
78
+
79
+
80
+ def cosine_sim_matrix(matrix, vector):
81
+ """Cosine similarity between each row in matrix and a single vector."""
82
+ if matrix.size == 0:
83
+ return np.array([])
84
+ matrix_norm = matrix / (np.linalg.norm(matrix, axis=1, keepdims=True) + 1e-10)
85
+ vector_norm = vector / (np.linalg.norm(vector) + 1e-10)
86
+ return np.dot(matrix_norm, vector_norm)
87
+
88
+
89
+ def retrieve_relevant_chunks(question, chunks, chunk_embeddings, embedder, top_k=4):
90
+ """Find top_k most relevant chunks for the question."""
91
+ if len(chunks) == 0:
92
+ return []
93
+
94
+ q_emb = embed_texts([question], embedder)[0]
95
+ sims = cosine_sim_matrix(chunk_embeddings, q_emb)
96
+ top_idx = np.argsort(sims)[::-1][:top_k]
97
+ return [chunks[i] for i in top_idx]
98
+
99
+
100
+ def build_prompt(question, context_chunks):
101
+ context = "\n\n---\n\n".join(context_chunks)
102
+ system_instruction = (
103
+ "You are a helpful assistant that answers questions "
104
+ "using ONLY the information provided in the document context.\n"
105
+ "If the answer is not in the context, say that you cannot find it in the document."
106
+ )
107
+
108
+ prompt = (
109
+ f"{system_instruction}\n\n"
110
+ f"Document context:\n{context}\n\n"
111
+ f"Question: {question}\n\n"
112
+ f"Answer:"
113
+ )
114
+ return prompt
115
+
116
+
117
+ def answer_question(question, chunks, llm_pipe):
118
+ """Call the LLM with the question + retrieved context."""
119
+ prompt = build_prompt(question, chunks)
120
+
121
+ # For most HF instruction models, plain prompt works ok.
122
+ outputs = llm_pipe(
123
+ prompt,
124
+ num_return_sequences=1,
125
+ truncation=True,
126
+ )
127
+ text = outputs[0]["generated_text"]
128
+
129
+ # Try to remove the prompt part if the model echoes it
130
+ if prompt in text:
131
+ text = text.split(prompt, 1)[-1].strip()
132
+
133
+ return text.strip()
134
+
135
+
136
+ # -------------------- Streamlit UI -------------------- #
137
+
138
+ st.set_page_config(page_title="Chat with your PDF (HuggingFace)", layout="wide")
139
+
140
+ st.title("📄 Chat with your PDF (HuggingFace RAG)")
141
+
142
+ st.markdown(
143
+ """
144
+ Upload a PDF, let the app index it, and then ask questions.
145
+ The model will answer based only on the document content (RAG).
146
+ """
147
+ )
148
+
149
+ with st.sidebar:
150
+ st.header("1. Upload and process PDF")
151
+ uploaded_pdf = st.file_uploader("Choose a PDF file", type=["pdf"])
152
+ process_button = st.button("Process Document")
153
+
154
+ # Session state to keep doc data
155
+ if "chunks" not in st.session_state:
156
+ st.session_state.chunks = []
157
+ st.session_state.embeddings = None
158
+
159
+ # Load models (lazy)
160
+ with st.spinner("Loading models (first time only)..."):
161
+ embedder = load_embedder()
162
+ llm_pipe = load_llm_pipeline()
163
+
164
+ # Step 1: Process PDF
165
+ if process_button:
166
+ if uploaded_pdf is None:
167
+ st.sidebar.error("Please upload a PDF first.")
168
+ else:
169
+ with st.spinner("Reading and indexing your PDF..."):
170
+ pdf_bytes = io.BytesIO(uploaded_pdf.read())
171
+ text = extract_text_from_pdf(pdf_bytes)
172
+
173
+ if not text.strip():
174
+ st.error("Could not extract any text from this PDF.")
175
+ else:
176
+ chunks = chunk_text(text)
177
+ embeddings = embed_texts(chunks, embedder)
178
+
179
+ st.session_state.chunks = chunks
180
+ st.session_state.embeddings = embeddings
181
+
182
+ st.success(f"Done! Indexed {len(chunks)} chunks from the PDF.")
183
+
184
+ # Step 2: Ask questions
185
+ st.header("2. Ask questions about your document")
186
+
187
+ question = st.text_input("Type your question here")
188
+
189
+ if st.button("Get answer"):
190
+ if not st.session_state.chunks:
191
+ st.error("Please upload and process a PDF first.")
192
+ elif not question.strip():
193
+ st.error("Please type a question.")
194
+ else:
195
+ with st.spinner("Thinking with your document..."):
196
+ relevant_chunks = retrieve_relevant_chunks(
197
+ question,
198
+ st.session_state.chunks,
199
+ st.session_state.embeddings,
200
+ embedder,
201
+ top_k=4,
202
+ )
203
+ answer = answer_question(question, relevant_chunks, llm_pipe)
204
+
205
+ st.subheader("Answer")
206
+ st.write(answer)
207
+
208
+ with st.expander("Show relevant excerpts from the PDF"):
209
+ for i, ch in enumerate(relevant_chunks, start=1):
210
+ st.markdown(f"**Chunk {i}:**")
211
+ st.write(ch)
212
+ st.markdown("---")