File size: 5,262 Bytes
d0af142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e188610
 
d0af142
 
 
2cfc531
d0af142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e188610
d0af142
 
 
 
 
 
 
 
e188610
d0af142
 
 
 
 
 
 
 
 
 
 
e188610
d0af142
 
e188610
 
d0af142
 
 
e188610
 
 
d0af142
 
e188610
 
d0af142
 
e188610
 
 
 
 
d0af142
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import os
import warnings
from dotenv import load_dotenv

import numpy as np
from sklearn.preprocessing import normalize

# 避免 tokenizers 并行警告
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore", category=UserWarning, module="tokenizers")

# 文档加载
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Embeddings & 向量库
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS

# Prompt & Chains
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory

# LLM
from langchain_community.chat_models import ChatOpenAI

# Gradio
import gradio as gr

# -----------------------------
# 配置
# -----------------------------
PDF_PATH = "data/pdfs/Stream-Processing-with-Apache-Flink.pdf"
CHUNK_SIZE = 512
CHUNK_OVERLAP = 50
TOP_K = 3

# -----------------------------
# 1️⃣ 加载环境变量
# -----------------------------
load_dotenv()
print("✅ Environment ready")

# -----------------------------
# 2️⃣ 加载 PDF 文档
# -----------------------------
loader = PyPDFLoader(PDF_PATH)
documents = loader.load()
print(f"✅ Loaded {len(documents)} pages")

# -----------------------------
# 3️⃣ 分割文本
# -----------------------------
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
texts = text_splitter.split_documents(documents)
print(f"✅ Split into {len(texts)} chunks")

# -----------------------------
# 4️⃣ 生成向量 & 向量库
# -----------------------------
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
# 先计算 embeddings
vectors = embedding_model.embed_documents([doc.page_content for doc in texts])
# 归一化
vectors = normalize(np.array(vectors))

# 创建 FAISS 向量库
vector_store = FAISS.from_texts(
    [doc.page_content for doc in texts],
    embedding_model,
    metadatas=[doc.metadata for doc in texts]
)

# 替换为归一化向量
vector_store.index.reset()
vector_store.index.add(vectors.astype(np.float32))
print("✅ Embeddings created, normalized and FAISS index ready")

# -----------------------------
# 5️⃣ 检索器
# -----------------------------
retriever = vector_store.as_retriever(search_kwargs={"k": TOP_K})
print("✅ Retriever ready")

# -----------------------------
# 6️⃣ LLM
# -----------------------------
llm = ChatOpenAI(
    model_name="deepseek-chat",  # 或 "gpt-3.5-turbo"
    temperature=0.7,
    max_tokens=512
)
print("✅ LLM ready")

# -----------------------------
# 7️⃣ Prompt 模板
# -----------------------------
template = """
Use the following context to answer the question. If unsure, say "I don't know."
Context:
{context}
Question: {question}
Answer:
"""
prompt = PromptTemplate(template=template, input_variables=["context", "question"])

# -----------------------------
# 8️⃣ 构建 RetrievalQA Chain
# -----------------------------
rag_chain = RetrievalQA.from_chain_type(
    llm=llm,
    retriever=retriever,
    chain_type_kwargs={"prompt": prompt},
    return_source_documents=True
)

# -----------------------------
# 9️⃣ 构建对话记忆
# -----------------------------
memory = ConversationBufferMemory(
    memory_key="chat_history",
    return_messages=True,
    output_key="answer"
)

# -----------------------------
# 10️⃣ 持续对话 RAG 链
# -----------------------------
qa_chain = ConversationalRetrievalChain.from_llm(
    llm=llm,
    retriever=retriever,
    memory=memory,
    verbose=False
)

# -----------------------------
# 11️⃣ Gradio 问答函数
# -----------------------------
def answer_question(query, threshold=0.4):
    # FAISS 里没有直接阈值过滤,所以我们可以先检索 TOP_K 后手动过滤
    result = rag_chain({"query": query})
    answer = result["result"]
    sources = result.get("source_documents", [])

    # 计算 cosine 相似度,并应用阈值
    filtered_sources = []
    for doc in sources:
        emb = embedding_model.embed_documents([doc.page_content])[0]
        emb = emb / np.linalg.norm(emb)
        # query embedding
        query_emb = embedding_model.embed_documents([query])[0]
        query_emb = query_emb / np.linalg.norm(query_emb)
        score = float(np.dot(emb, query_emb))
        if score >= threshold:
            filtered_sources.append((doc.page_content, score))

    # 展示来源文档
    context = "\n\n".join([f"Score: {score:.4f}\n{doc[:400]}..." for doc, score in filtered_sources])
    return answer, context

# -----------------------------
# 12️⃣ Gradio 界面
# -----------------------------
demo = gr.Interface(
    fn=answer_question,
    inputs=[
        gr.Textbox(label="🔎 输入你的问题"),
        gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="相似度阈值")
    ],
    outputs=[
        gr.Textbox(label="💬 模型回答"),
        gr.Textbox(label="📄 检索到的文档")
    ],
    title="📘 Multi-PDF RAG System"
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)