chatbot / app.py
Avinashstat's picture
Create app.py
72a7de1 verified
raw
history blame
6.47 kB
import io
import numpy as np
import streamlit as st
from pypdf import PdfReader
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# -------------------- Config -------------------- #
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL_NAME = "google/gemma-2b-it" # you can change this later
# -------------------- Model loaders (cached) -------------------- #
@st.cache_resource(show_spinner=True)
def load_embedder():
return SentenceTransformer(EMBEDDING_MODEL_NAME)
@st.cache_resource(show_spinner=True)
def load_llm_pipeline():
"""
Load a text-generation pipeline for the LLM.
Using device_map="auto" will use GPU if available.
"""
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
LLM_MODEL_NAME,
device_map="auto",
)
gen_pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
do_sample=False,
temperature=0.1,
top_p=0.9,
)
return gen_pipe
# -------------------- Helpers -------------------- #
def extract_text_from_pdf(file) -> str:
"""Extract all text from an uploaded PDF file."""
pdf_reader = PdfReader(file)
all_text = []
for page in pdf_reader.pages:
text = page.extract_text()
if text:
all_text.append(text)
return "\n".join(all_text)
def chunk_text(text, chunk_size=800, overlap=200):
"""Split long text into overlapping chunks (by words)."""
words = text.split()
chunks = []
start = 0
while start < len(words):
end = start + chunk_size
chunk = " ".join(words[start:end])
chunks.append(chunk)
start += chunk_size - overlap
return chunks
def embed_texts(texts, embedder: SentenceTransformer):
"""Get embeddings for a list of texts."""
if not texts:
return np.array([])
embeddings = embedder.encode(texts, convert_to_numpy=True, show_progress_bar=False)
return embeddings.astype("float32")
def cosine_sim_matrix(matrix, vector):
"""Cosine similarity between each row in matrix and a single vector."""
if matrix.size == 0:
return np.array([])
matrix_norm = matrix / (np.linalg.norm(matrix, axis=1, keepdims=True) + 1e-10)
vector_norm = vector / (np.linalg.norm(vector) + 1e-10)
return np.dot(matrix_norm, vector_norm)
def retrieve_relevant_chunks(question, chunks, chunk_embeddings, embedder, top_k=4):
"""Find top_k most relevant chunks for the question."""
if len(chunks) == 0:
return []
q_emb = embed_texts([question], embedder)[0]
sims = cosine_sim_matrix(chunk_embeddings, q_emb)
top_idx = np.argsort(sims)[::-1][:top_k]
return [chunks[i] for i in top_idx]
def build_prompt(question, context_chunks):
context = "\n\n---\n\n".join(context_chunks)
system_instruction = (
"You are a helpful assistant that answers questions "
"using ONLY the information provided in the document context.\n"
"If the answer is not in the context, say that you cannot find it in the document."
)
prompt = (
f"{system_instruction}\n\n"
f"Document context:\n{context}\n\n"
f"Question: {question}\n\n"
f"Answer:"
)
return prompt
def answer_question(question, chunks, llm_pipe):
"""Call the LLM with the question + retrieved context."""
prompt = build_prompt(question, chunks)
# For most HF instruction models, plain prompt works ok.
outputs = llm_pipe(
prompt,
num_return_sequences=1,
truncation=True,
)
text = outputs[0]["generated_text"]
# Try to remove the prompt part if the model echoes it
if prompt in text:
text = text.split(prompt, 1)[-1].strip()
return text.strip()
# -------------------- Streamlit UI -------------------- #
st.set_page_config(page_title="Chat with your PDF (HuggingFace)", layout="wide")
st.title("πŸ“„ Chat with your PDF (HuggingFace RAG)")
st.markdown(
"""
Upload a PDF, let the app index it, and then ask questions.
The model will answer based only on the document content (RAG).
"""
)
with st.sidebar:
st.header("1. Upload and process PDF")
uploaded_pdf = st.file_uploader("Choose a PDF file", type=["pdf"])
process_button = st.button("Process Document")
# Session state to keep doc data
if "chunks" not in st.session_state:
st.session_state.chunks = []
st.session_state.embeddings = None
# Load models (lazy)
with st.spinner("Loading models (first time only)..."):
embedder = load_embedder()
llm_pipe = load_llm_pipeline()
# Step 1: Process PDF
if process_button:
if uploaded_pdf is None:
st.sidebar.error("Please upload a PDF first.")
else:
with st.spinner("Reading and indexing your PDF..."):
pdf_bytes = io.BytesIO(uploaded_pdf.read())
text = extract_text_from_pdf(pdf_bytes)
if not text.strip():
st.error("Could not extract any text from this PDF.")
else:
chunks = chunk_text(text)
embeddings = embed_texts(chunks, embedder)
st.session_state.chunks = chunks
st.session_state.embeddings = embeddings
st.success(f"Done! Indexed {len(chunks)} chunks from the PDF.")
# Step 2: Ask questions
st.header("2. Ask questions about your document")
question = st.text_input("Type your question here")
if st.button("Get answer"):
if not st.session_state.chunks:
st.error("Please upload and process a PDF first.")
elif not question.strip():
st.error("Please type a question.")
else:
with st.spinner("Thinking with your document..."):
relevant_chunks = retrieve_relevant_chunks(
question,
st.session_state.chunks,
st.session_state.embeddings,
embedder,
top_k=4,
)
answer = answer_question(question, relevant_chunks, llm_pipe)
st.subheader("Answer")
st.write(answer)
with st.expander("Show relevant excerpts from the PDF"):
for i, ch in enumerate(relevant_chunks, start=1):
st.markdown(f"**Chunk {i}:**")
st.write(ch)
st.markdown("---")