import os import warnings from dotenv import load_dotenv import numpy as np from sklearn.preprocessing import normalize # 避免 tokenizers 并行警告 os.environ["TOKENIZERS_PARALLELISM"] = "false" warnings.filterwarnings("ignore", category=UserWarning, module="tokenizers") # 文档加载 from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter # Embeddings & 向量库 from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS # Prompt & Chains from langchain.prompts import PromptTemplate from langchain.chains import RetrievalQA, ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory # LLM from langchain_community.chat_models import ChatOpenAI # Gradio import gradio as gr # ----------------------------- # 配置 # ----------------------------- PDF_PATH = "data/pdfs/Stream-Processing-with-Apache-Flink.pdf" CHUNK_SIZE = 512 CHUNK_OVERLAP = 50 TOP_K = 3 # ----------------------------- # 1️⃣ 加载环境变量 # ----------------------------- load_dotenv() print("✅ Environment ready") # ----------------------------- # 2️⃣ 加载 PDF 文档 # ----------------------------- loader = PyPDFLoader(PDF_PATH) documents = loader.load() print(f"✅ Loaded {len(documents)} pages") # ----------------------------- # 3️⃣ 分割文本 # ----------------------------- text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) texts = text_splitter.split_documents(documents) print(f"✅ Split into {len(texts)} chunks") # ----------------------------- # 4️⃣ 生成向量 & 向量库 # ----------------------------- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # 先计算 embeddings vectors = embedding_model.embed_documents([doc.page_content for doc in texts]) # 归一化 vectors = normalize(np.array(vectors)) # 创建 FAISS 向量库 vector_store = FAISS.from_texts( [doc.page_content for doc in texts], embedding_model, metadatas=[doc.metadata for doc in texts] ) # 替换为归一化向量 vector_store.index.reset() vector_store.index.add(vectors.astype(np.float32)) print("✅ Embeddings created, normalized and FAISS index ready") # ----------------------------- # 5️⃣ 检索器 # ----------------------------- retriever = vector_store.as_retriever(search_kwargs={"k": TOP_K}) print("✅ Retriever ready") # ----------------------------- # 6️⃣ LLM # ----------------------------- llm = ChatOpenAI( model_name="deepseek-chat", # 或 "gpt-3.5-turbo" temperature=0.7, max_tokens=512 ) print("✅ LLM ready") # ----------------------------- # 7️⃣ Prompt 模板 # ----------------------------- template = """ Use the following context to answer the question. If unsure, say "I don't know." Context: {context} Question: {question} Answer: """ prompt = PromptTemplate(template=template, input_variables=["context", "question"]) # ----------------------------- # 8️⃣ 构建 RetrievalQA Chain # ----------------------------- rag_chain = RetrievalQA.from_chain_type( llm=llm, retriever=retriever, chain_type_kwargs={"prompt": prompt}, return_source_documents=True ) # ----------------------------- # 9️⃣ 构建对话记忆 # ----------------------------- memory = ConversationBufferMemory( memory_key="chat_history", return_messages=True, output_key="answer" ) # ----------------------------- # 10️⃣ 持续对话 RAG 链 # ----------------------------- qa_chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=retriever, memory=memory, verbose=False ) # ----------------------------- # 11️⃣ Gradio 问答函数 # ----------------------------- def answer_question(query, threshold=0.4): # FAISS 里没有直接阈值过滤,所以我们可以先检索 TOP_K 后手动过滤 result = rag_chain({"query": query}) answer = result["result"] sources = result.get("source_documents", []) # 计算 cosine 相似度,并应用阈值 filtered_sources = [] for doc in sources: emb = embedding_model.embed_documents([doc.page_content])[0] emb = emb / np.linalg.norm(emb) # query embedding query_emb = embedding_model.embed_documents([query])[0] query_emb = query_emb / np.linalg.norm(query_emb) score = float(np.dot(emb, query_emb)) if score >= threshold: filtered_sources.append((doc.page_content, score)) # 展示来源文档 context = "\n\n".join([f"Score: {score:.4f}\n{doc[:400]}..." for doc, score in filtered_sources]) return answer, context # ----------------------------- # 12️⃣ Gradio 界面 # ----------------------------- demo = gr.Interface( fn=answer_question, inputs=[ gr.Textbox(label="🔎 输入你的问题"), gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="相似度阈值") ], outputs=[ gr.Textbox(label="💬 模型回答"), gr.Textbox(label="📄 检索到的文档") ], title="📘 Multi-PDF RAG System" ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)