Spaces:
Sleeping
Sleeping
File size: 5,262 Bytes
d0af142 e188610 d0af142 2cfc531 d0af142 e188610 d0af142 e188610 d0af142 e188610 d0af142 e188610 d0af142 e188610 d0af142 e188610 d0af142 e188610 d0af142 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import os
import warnings
from dotenv import load_dotenv
import numpy as np
from sklearn.preprocessing import normalize
# 避免 tokenizers 并行警告
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore", category=UserWarning, module="tokenizers")
# 文档加载
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
# Embeddings & 向量库
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
# Prompt & Chains
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
# LLM
from langchain_community.chat_models import ChatOpenAI
# Gradio
import gradio as gr
# -----------------------------
# 配置
# -----------------------------
PDF_PATH = "data/pdfs/Stream-Processing-with-Apache-Flink.pdf"
CHUNK_SIZE = 512
CHUNK_OVERLAP = 50
TOP_K = 3
# -----------------------------
# 1️⃣ 加载环境变量
# -----------------------------
load_dotenv()
print("✅ Environment ready")
# -----------------------------
# 2️⃣ 加载 PDF 文档
# -----------------------------
loader = PyPDFLoader(PDF_PATH)
documents = loader.load()
print(f"✅ Loaded {len(documents)} pages")
# -----------------------------
# 3️⃣ 分割文本
# -----------------------------
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
texts = text_splitter.split_documents(documents)
print(f"✅ Split into {len(texts)} chunks")
# -----------------------------
# 4️⃣ 生成向量 & 向量库
# -----------------------------
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
# 先计算 embeddings
vectors = embedding_model.embed_documents([doc.page_content for doc in texts])
# 归一化
vectors = normalize(np.array(vectors))
# 创建 FAISS 向量库
vector_store = FAISS.from_texts(
[doc.page_content for doc in texts],
embedding_model,
metadatas=[doc.metadata for doc in texts]
)
# 替换为归一化向量
vector_store.index.reset()
vector_store.index.add(vectors.astype(np.float32))
print("✅ Embeddings created, normalized and FAISS index ready")
# -----------------------------
# 5️⃣ 检索器
# -----------------------------
retriever = vector_store.as_retriever(search_kwargs={"k": TOP_K})
print("✅ Retriever ready")
# -----------------------------
# 6️⃣ LLM
# -----------------------------
llm = ChatOpenAI(
model_name="deepseek-chat", # 或 "gpt-3.5-turbo"
temperature=0.7,
max_tokens=512
)
print("✅ LLM ready")
# -----------------------------
# 7️⃣ Prompt 模板
# -----------------------------
template = """
Use the following context to answer the question. If unsure, say "I don't know."
Context:
{context}
Question: {question}
Answer:
"""
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
# -----------------------------
# 8️⃣ 构建 RetrievalQA Chain
# -----------------------------
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
chain_type_kwargs={"prompt": prompt},
return_source_documents=True
)
# -----------------------------
# 9️⃣ 构建对话记忆
# -----------------------------
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
output_key="answer"
)
# -----------------------------
# 10️⃣ 持续对话 RAG 链
# -----------------------------
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
memory=memory,
verbose=False
)
# -----------------------------
# 11️⃣ Gradio 问答函数
# -----------------------------
def answer_question(query, threshold=0.4):
# FAISS 里没有直接阈值过滤,所以我们可以先检索 TOP_K 后手动过滤
result = rag_chain({"query": query})
answer = result["result"]
sources = result.get("source_documents", [])
# 计算 cosine 相似度,并应用阈值
filtered_sources = []
for doc in sources:
emb = embedding_model.embed_documents([doc.page_content])[0]
emb = emb / np.linalg.norm(emb)
# query embedding
query_emb = embedding_model.embed_documents([query])[0]
query_emb = query_emb / np.linalg.norm(query_emb)
score = float(np.dot(emb, query_emb))
if score >= threshold:
filtered_sources.append((doc.page_content, score))
# 展示来源文档
context = "\n\n".join([f"Score: {score:.4f}\n{doc[:400]}..." for doc, score in filtered_sources])
return answer, context
# -----------------------------
# 12️⃣ Gradio 界面
# -----------------------------
demo = gr.Interface(
fn=answer_question,
inputs=[
gr.Textbox(label="🔎 输入你的问题"),
gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="相似度阈值")
],
outputs=[
gr.Textbox(label="💬 模型回答"),
gr.Textbox(label="📄 检索到的文档")
],
title="📘 Multi-PDF RAG System"
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|