CG_AskPDF / app.py
CatoG's picture
Update app.py
687ddae verified
raw
history blame
12.7 kB
import os
from huggingface_hub import InferenceClient
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
import gradio as gr
import warnings
import uuid
MODEL_OPTIONS = [
"allenai/Olmo-3-32B-Think",
"allenai/Olmo-3-7B-Instruct",
"allenai/Olmo-3-7B-Think",
"ArliAI/QwQ-32B-ArliAI-RpR-v4",
"baichuan-inc/Baichuan-M2-32B",
"darkc0de/XortronCriminalComputingConfig",
"deepseek-ai/DeepSeek-R1",
"deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
"deepseek-ai/DeepSeek-V3.1-Terminus",
"deepseek-ai/DeepSeek-V3.2-Exp",
"DeepHat/DeepHat-V1-7B",
"dphn/Dolphin-Mistral-24B-Venice-Edition",
"Goekdeniz-Guelmez/Josiefied-Qwen3-8B-abliterated-v1",
"google/gemma-2-2b-it",
"Gryphe/MythoMax-L2-13b",
"HuggingFaceH4/zephyr-7b-beta",
"HuggingFaceTB/SmolLM3-3B",
"inclusionAI/Ling-1T",
"Intelligent-Internet/II-Medical-8B",
"meta-llama/Llama-3-8B-Instruct",
"meta-llama/Llama-3.1-8B",
"meta-llama/Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.3-70B-Instruct",
"meta-llama/Llama-Guard-3-8B",
"meta-llama/Meta-Llama-3-8B",
"meta-llama/Meta-Llama-3-8B-Instruct",
"MiniMaxAI/MiniMax-M2",
"mistralai/Mistral-7B-Instruct-v0.2",
"moonshotai/Kimi-K2-Instruct",
"moonshotai/Kimi-K2-Instruct-0905",
"moonshotai/Kimi-K2-Thinking",
"moonshotai/Kimi-K2-Tinking",
"nvidia/NVIDIA-Nemotron-Nano-12B-v2",
"openai/gpt-oss-120b",
"openai/gpt-oss-20b",
"PrimeIntellect/INTELLECT-3-FP8",
"Qwen/Qwen2.5-1.5B-Instruct",
"Qwen/Qwen2.5-7B",
"Qwen/Qwen2.5-7B-Instruct",
"Qwen/Qwen2.5-Coder-1.5B-Instruct",
"Qwen/Qwen2.5-Coder-7B-Instruct",
"Qwen/Qwen3-1.7B",
"Qwen/Qwen3-14B",
"Qwen/Qwen3-30B-A3B",
"Qwen/Qwen3-30B-A3B-Instruct-2507",
"Qwen/Qwen3-32B",
"Qwen/Qwen3-4B-Instruct-2507",
"Qwen/Qwen3-4B-Thinking-2507",
"Qwen/Qwen3-8B",
"Qwen/Qwen3-235B-A22B-Instruct-2507",
"Qwen/Qwen3-Coder-30B-A3B-Instruct",
"Qwen/Qwen3-Next-80B-A3B-Instruct",
"Qwen/Qwen3-Next-80B-A3B-Thinking",
"zai-org/GLM-4.5",
"zai-org/GLM-4.5-Air",
"zai-org/GLM-4.6",
]
# Suppress warnings
def warn(*args, **kwargs):
pass
warnings.warn = warn
warnings.filterwarnings("ignore")
# ---------------------------
# Get credentials from environment variables
# ---------------------------
def get_huggingface_token():
"""
Get HuggingFace API token from environment.
Set this in your Space settings under Settings > Repository secrets:
- HF_TOKEN or HUGGINGFACE_TOKEN
"""
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
if not token:
raise ValueError(
"HF_TOKEN not found. Please set it in your HuggingFace Space secrets."
)
return token
# ---------------------------
# LLM
# ---------------------------
def get_llm(model_id: str = MODEL_OPTIONS[0], max_tokens: int = 256, temperature: float = 0.8):
"""
Returns InferenceClient for HuggingFace models.
"""
token = get_huggingface_token()
client = InferenceClient(token=token)
return client, model_id, max_tokens, temperature
# ---------------------------
# Document loader
# ---------------------------
def document_loader(file):
# Handle file path string from Gradio
file_path = file if isinstance(file, str) else file.name
loader = PyPDFLoader(file_path)
loaded_document = loader.load()
return loaded_document
# ---------------------------
# Text splitter
# ---------------------------
def text_splitter(data, chunk_size: int = 500, chunk_overlap: int = 50):
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
)
chunks = splitter.split_documents(data)
return chunks
# ---------------------------
# Embedding model
# ---------------------------
def get_embedding_model(model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
"""
Create HuggingFace embedding model.
Using sentence-transformers for efficient embeddings.
"""
embedding = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
return embedding
# ---------------------------
# Vector DB
# ---------------------------
def vector_database(chunks, embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
embedding_model = get_embedding_model(embedding_model_name)
# Create unique collection name to avoid reusing cached data
collection_name = f"rag_collection_{uuid.uuid4().hex[:8]}"
vectordb = Chroma.from_documents(
chunks,
embedding_model,
collection_name=collection_name
)
return vectordb
# ---------------------------
# Retriever
# ---------------------------
def retriever(file, chunk_size: int = 500, chunk_overlap: int = 50, embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"):
splits = document_loader(file)
chunks = text_splitter(splits, chunk_size, chunk_overlap)
vectordb = vector_database(chunks, embedding_model)
retriever_obj = vectordb.as_retriever()
return retriever_obj
# ---------------------------
# QA Chain
# ---------------------------
def retriever_qa(file, query, model_choice, max_tokens, temperature, embedding_model, chunk_size, chunk_overlap):
if not file:
return "Please upload a PDF file first."
if not query.strip():
return "Please enter a query."
try:
selected_model = model_choice or MODEL_OPTIONS[0]
client, model_id, max_tok, temp = get_llm(selected_model, int(max_tokens), float(temperature))
retriever_obj = retriever(file, int(chunk_size), int(chunk_overlap), embedding_model)
# Get relevant documents
docs = retriever_obj.invoke(query)
context = "\n\n".join(doc.page_content for doc in docs)
# Create messages for chat completion
messages = [
{
"role": "system",
"content": "You are a helpful assistant that answers questions based only on the provided context."
},
{
"role": "user",
"content": f"""Context:
{context}
Question: {query}
Please answer the question based only on the context provided above."""
}
]
# Call chat completion API
response = client.chat_completion(
messages=messages,
model=model_id,
max_tokens=max_tok,
temperature=temp
)
return response.choices[0].message.content
except Exception as e:
import traceback
error_details = traceback.format_exc()
return f"Error: {str(e)}\n\nDetails:\n{error_details}"
# ---------------------------
# Gradio Interface
# ---------------------------
with gr.Blocks(title="QA Bot - PDF Question Answering") as demo:
gr.Markdown("# 📄 QA Bot - PDF Question Answering")
gr.Markdown(
"Upload a PDF document and ask questions about its content. "
"Powered by HuggingFace models and LangChain."
)
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(
label="Upload PDF File",
file_count="single",
file_types=[".pdf"],
type="filepath"
)
query_input = gr.Textbox(
label="Your Question",
lines=3,
placeholder="Ask a question about the uploaded document..."
)
model_dropdown = gr.Dropdown(
label="LLM Model",
choices=MODEL_OPTIONS,
value=MODEL_OPTIONS[0],
)
with gr.Accordion("⚙️ Advanced Settings", open=False):
max_tokens_slider = gr.Slider(
label="Max New Tokens",
minimum=50,
maximum=2048,
value=256,
step=1,
info="Maximum number of tokens in the generated output"
)
temperature_slider = gr.Slider(
label="Temperature",
minimum=0.0,
maximum=2.0,
value=0.8,
step=0.1,
info="Controls randomness/creativity of responses"
)
truncate_slider = gr.Dropdown(
label="Embedding Model",
choices=[
"ai-forever/ru-en-RoSBERTa",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-base-zh-v1.5",
"BAAI/bge-large-en-v1.5",
"BAAI/bge-m3",
"BAAI/bge-small-en-v1.5",
"cointegrated/rubert-tiny2",
"google/embeddinggemma-300m",
"intfloat/multilingual-e5-base",
"intfloat/multilingual-e5-large",
"intfloat/multilingual-e5-small",
"jhgan/ko-sroberta-multitask",
"lokeshch19/ModernPubMedBERT",
"mixedbread-ai/mxbai-embed-large-v1",
"mixedbread-ai/mxbai-embed-xsmall-v1",
"MongoDB/mdbr-leaf-mt",
"pritamdeka/S-Biomed-Roberta-snli-multinli-stsb",
"pritamdeka/S-PubMedBert-MS-MARCO",
"Qwen/Qwen3-Embedding-8B",
"sentence-transformers/all-MiniLM-L6-v2",
"sentence-transformers/all-MiniLM-L12-v2",
"sentence-transformers/all-mpnet-base-v2",
"sentence-transformers/clip-ViT-B-32-multilingual-v1",
"sentence-transformers/LaBSE",
"sentence-transformers/msmarco-MiniLM-L6-v3",
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
"shibing624/text2vec-base-chinese",
"Snowflake/snowflake-arctic-embed-l-v2.0",
"Snowflake/snowflake-arctic-embed-m-v1.5",
],
value="sentence-transformers/all-MiniLM-L6-v2",
info="Model used for generating embeddings"
)
chunk_size_slider = gr.Slider(
label="Chunk Size",
minimum=100,
maximum=2000,
value=500,
step=50,
info="Size of text chunks for processing"
)
chunk_overlap_slider = gr.Slider(
label="Chunk Overlap",
minimum=0,
maximum=500,
value=50,
step=10,
info="Overlap between consecutive chunks"
)
submit_btn = gr.Button("Ask Question", variant="primary")
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Answer",
lines=15
)
submit_btn.click(
fn=retriever_qa,
inputs=[
file_input,
query_input,
model_dropdown,
max_tokens_slider,
temperature_slider,
truncate_slider,
chunk_size_slider,
chunk_overlap_slider
],
outputs=output_text
)
gr.Markdown(
"""
"""
)
# ---------------------------
# Launch the app
# ---------------------------
if __name__ == "__main__":
demo.launch()