File size: 6,467 Bytes
72a7de1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import io
import numpy as np
import streamlit as st
from pypdf import PdfReader

from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline


# -------------------- Config -------------------- #

EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL_NAME = "google/gemma-2b-it"  # you can change this later


# -------------------- Model loaders (cached) -------------------- #

@st.cache_resource(show_spinner=True)
def load_embedder():
    return SentenceTransformer(EMBEDDING_MODEL_NAME)


@st.cache_resource(show_spinner=True)
def load_llm_pipeline():
    """
    Load a text-generation pipeline for the LLM.
    Using device_map="auto" will use GPU if available.
    """
    tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        LLM_MODEL_NAME,
        device_map="auto",
    )
    gen_pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=512,
        do_sample=False,
        temperature=0.1,
        top_p=0.9,
    )
    return gen_pipe


# -------------------- Helpers -------------------- #

def extract_text_from_pdf(file) -> str:
    """Extract all text from an uploaded PDF file."""
    pdf_reader = PdfReader(file)
    all_text = []
    for page in pdf_reader.pages:
        text = page.extract_text()
        if text:
            all_text.append(text)
    return "\n".join(all_text)


def chunk_text(text, chunk_size=800, overlap=200):
    """Split long text into overlapping chunks (by words)."""
    words = text.split()
    chunks = []
    start = 0
    while start < len(words):
        end = start + chunk_size
        chunk = " ".join(words[start:end])
        chunks.append(chunk)
        start += chunk_size - overlap
    return chunks


def embed_texts(texts, embedder: SentenceTransformer):
    """Get embeddings for a list of texts."""
    if not texts:
        return np.array([])
    embeddings = embedder.encode(texts, convert_to_numpy=True, show_progress_bar=False)
    return embeddings.astype("float32")


def cosine_sim_matrix(matrix, vector):
    """Cosine similarity between each row in matrix and a single vector."""
    if matrix.size == 0:
        return np.array([])
    matrix_norm = matrix / (np.linalg.norm(matrix, axis=1, keepdims=True) + 1e-10)
    vector_norm = vector / (np.linalg.norm(vector) + 1e-10)
    return np.dot(matrix_norm, vector_norm)


def retrieve_relevant_chunks(question, chunks, chunk_embeddings, embedder, top_k=4):
    """Find top_k most relevant chunks for the question."""
    if len(chunks) == 0:
        return []

    q_emb = embed_texts([question], embedder)[0]
    sims = cosine_sim_matrix(chunk_embeddings, q_emb)
    top_idx = np.argsort(sims)[::-1][:top_k]
    return [chunks[i] for i in top_idx]


def build_prompt(question, context_chunks):
    context = "\n\n---\n\n".join(context_chunks)
    system_instruction = (
        "You are a helpful assistant that answers questions "
        "using ONLY the information provided in the document context.\n"
        "If the answer is not in the context, say that you cannot find it in the document."
    )

    prompt = (
        f"{system_instruction}\n\n"
        f"Document context:\n{context}\n\n"
        f"Question: {question}\n\n"
        f"Answer:"
    )
    return prompt


def answer_question(question, chunks, llm_pipe):
    """Call the LLM with the question + retrieved context."""
    prompt = build_prompt(question, chunks)

    # For most HF instruction models, plain prompt works ok.
    outputs = llm_pipe(
        prompt,
        num_return_sequences=1,
        truncation=True,
    )
    text = outputs[0]["generated_text"]

    # Try to remove the prompt part if the model echoes it
    if prompt in text:
        text = text.split(prompt, 1)[-1].strip()

    return text.strip()


# -------------------- Streamlit UI -------------------- #

st.set_page_config(page_title="Chat with your PDF (HuggingFace)", layout="wide")

st.title("📄 Chat with your PDF (HuggingFace RAG)")

st.markdown(
    """
Upload a PDF, let the app index it, and then ask questions.
The model will answer based only on the document content (RAG).
"""
)

with st.sidebar:
    st.header("1. Upload and process PDF")
    uploaded_pdf = st.file_uploader("Choose a PDF file", type=["pdf"])
    process_button = st.button("Process Document")

# Session state to keep doc data
if "chunks" not in st.session_state:
    st.session_state.chunks = []
    st.session_state.embeddings = None

# Load models (lazy)
with st.spinner("Loading models (first time only)..."):
    embedder = load_embedder()
    llm_pipe = load_llm_pipeline()

# Step 1: Process PDF
if process_button:
    if uploaded_pdf is None:
        st.sidebar.error("Please upload a PDF first.")
    else:
        with st.spinner("Reading and indexing your PDF..."):
            pdf_bytes = io.BytesIO(uploaded_pdf.read())
            text = extract_text_from_pdf(pdf_bytes)

            if not text.strip():
                st.error("Could not extract any text from this PDF.")
            else:
                chunks = chunk_text(text)
                embeddings = embed_texts(chunks, embedder)

                st.session_state.chunks = chunks
                st.session_state.embeddings = embeddings

                st.success(f"Done! Indexed {len(chunks)} chunks from the PDF.")

# Step 2: Ask questions
st.header("2. Ask questions about your document")

question = st.text_input("Type your question here")

if st.button("Get answer"):
    if not st.session_state.chunks:
        st.error("Please upload and process a PDF first.")
    elif not question.strip():
        st.error("Please type a question.")
    else:
        with st.spinner("Thinking with your document..."):
            relevant_chunks = retrieve_relevant_chunks(
                question,
                st.session_state.chunks,
                st.session_state.embeddings,
                embedder,
                top_k=4,
            )
            answer = answer_question(question, relevant_chunks, llm_pipe)

        st.subheader("Answer")
        st.write(answer)

        with st.expander("Show relevant excerpts from the PDF"):
            for i, ch in enumerate(relevant_chunks, start=1):
                st.markdown(f"**Chunk {i}:**")
                st.write(ch)
                st.markdown("---")