Spaces:
Sleeping
Sleeping
| import os | |
| import warnings | |
| from dotenv import load_dotenv | |
| import numpy as np | |
| from sklearn.preprocessing import normalize | |
| # 避免 tokenizers 并行警告 | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| warnings.filterwarnings("ignore", category=UserWarning, module="tokenizers") | |
| # 文档加载 | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| # Embeddings & 向量库 | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| # Prompt & Chains | |
| from langchain.prompts import PromptTemplate | |
| from langchain.chains import RetrievalQA, ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| # LLM | |
| from langchain_community.chat_models import ChatOpenAI | |
| # Gradio | |
| import gradio as gr | |
| # ----------------------------- | |
| # 配置 | |
| # ----------------------------- | |
| PDF_PATH = "data/pdfs/Stream-Processing-with-Apache-Flink.pdf" | |
| CHUNK_SIZE = 512 | |
| CHUNK_OVERLAP = 50 | |
| TOP_K = 3 | |
| # ----------------------------- | |
| # 1️⃣ 加载环境变量 | |
| # ----------------------------- | |
| load_dotenv() | |
| print("✅ Environment ready") | |
| # ----------------------------- | |
| # 2️⃣ 加载 PDF 文档 | |
| # ----------------------------- | |
| loader = PyPDFLoader(PDF_PATH) | |
| documents = loader.load() | |
| print(f"✅ Loaded {len(documents)} pages") | |
| # ----------------------------- | |
| # 3️⃣ 分割文本 | |
| # ----------------------------- | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) | |
| texts = text_splitter.split_documents(documents) | |
| print(f"✅ Split into {len(texts)} chunks") | |
| # ----------------------------- | |
| # 4️⃣ 生成向量 & 向量库 | |
| # ----------------------------- | |
| embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
| # 先计算 embeddings | |
| vectors = embedding_model.embed_documents([doc.page_content for doc in texts]) | |
| # 归一化 | |
| vectors = normalize(np.array(vectors)) | |
| # 创建 FAISS 向量库 | |
| vector_store = FAISS.from_texts( | |
| [doc.page_content for doc in texts], | |
| embedding_model, | |
| metadatas=[doc.metadata for doc in texts] | |
| ) | |
| # 替换为归一化向量 | |
| vector_store.index.reset() | |
| vector_store.index.add(vectors.astype(np.float32)) | |
| print("✅ Embeddings created, normalized and FAISS index ready") | |
| # ----------------------------- | |
| # 5️⃣ 检索器 | |
| # ----------------------------- | |
| retriever = vector_store.as_retriever(search_kwargs={"k": TOP_K}) | |
| print("✅ Retriever ready") | |
| # ----------------------------- | |
| # 6️⃣ LLM | |
| # ----------------------------- | |
| llm = ChatOpenAI( | |
| model_name="deepseek-chat", # 或 "gpt-3.5-turbo" | |
| temperature=0.7, | |
| max_tokens=512 | |
| ) | |
| print("✅ LLM ready") | |
| # ----------------------------- | |
| # 7️⃣ Prompt 模板 | |
| # ----------------------------- | |
| template = """ | |
| Use the following context to answer the question. If unsure, say "I don't know." | |
| Context: | |
| {context} | |
| Question: {question} | |
| Answer: | |
| """ | |
| prompt = PromptTemplate(template=template, input_variables=["context", "question"]) | |
| # ----------------------------- | |
| # 8️⃣ 构建 RetrievalQA Chain | |
| # ----------------------------- | |
| rag_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| retriever=retriever, | |
| chain_type_kwargs={"prompt": prompt}, | |
| return_source_documents=True | |
| ) | |
| # ----------------------------- | |
| # 9️⃣ 构建对话记忆 | |
| # ----------------------------- | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| return_messages=True, | |
| output_key="answer" | |
| ) | |
| # ----------------------------- | |
| # 10️⃣ 持续对话 RAG 链 | |
| # ----------------------------- | |
| qa_chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=retriever, | |
| memory=memory, | |
| verbose=False | |
| ) | |
| # ----------------------------- | |
| # 11️⃣ Gradio 问答函数 | |
| # ----------------------------- | |
| def answer_question(query, threshold=0.4): | |
| # FAISS 里没有直接阈值过滤,所以我们可以先检索 TOP_K 后手动过滤 | |
| result = rag_chain({"query": query}) | |
| answer = result["result"] | |
| sources = result.get("source_documents", []) | |
| # 计算 cosine 相似度,并应用阈值 | |
| filtered_sources = [] | |
| for doc in sources: | |
| emb = embedding_model.embed_documents([doc.page_content])[0] | |
| emb = emb / np.linalg.norm(emb) | |
| # query embedding | |
| query_emb = embedding_model.embed_documents([query])[0] | |
| query_emb = query_emb / np.linalg.norm(query_emb) | |
| score = float(np.dot(emb, query_emb)) | |
| if score >= threshold: | |
| filtered_sources.append((doc.page_content, score)) | |
| # 展示来源文档 | |
| context = "\n\n".join([f"Score: {score:.4f}\n{doc[:400]}..." for doc, score in filtered_sources]) | |
| return answer, context | |
| # ----------------------------- | |
| # 12️⃣ Gradio 界面 | |
| # ----------------------------- | |
| demo = gr.Interface( | |
| fn=answer_question, | |
| inputs=[ | |
| gr.Textbox(label="🔎 输入你的问题"), | |
| gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="相似度阈值") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="💬 模型回答"), | |
| gr.Textbox(label="📄 检索到的文档") | |
| ], | |
| title="📘 Multi-PDF RAG System" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |