siyu618 commited on
Commit
d0af142
·
verified ·
1 Parent(s): e188610

Upload 2 files

Browse files
Files changed (2) hide show
  1. .env +2 -0
  2. app.py +162 -18
.env ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_API_KEY=sk-71c364e42b4545ceba5e0b5b9f71df08
2
+ OPENAI_API_BASE=https://api.deepseek.com/v1
app.py CHANGED
@@ -1,35 +1,179 @@
1
- import pickle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
- from config.rag_config import RAGConfig
4
- from src.rag_pipeline import RAGPipeline
5
 
6
- config = RAGConfig()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # 加载向量库
9
- with open(config.vector_db_path, "rb") as f:
10
- data = pickle.load(f)
 
 
 
 
 
11
 
12
- docs, doc_embeddings = data["texts"], data["embeddings"]
13
- pipeline = RAGPipeline(config, docs, doc_embeddings)
 
 
 
 
 
 
 
 
 
14
 
15
- def answer_question(query, threshold):
16
- pipeline.config.similarity_threshold = threshold
17
- answer, retrieved = pipeline.ask(query)
18
- context = "\n\n".join([f"Score: {s:.4f}\n{t}" for t, s in retrieved])
19
  return answer, context
20
 
 
 
 
21
  demo = gr.Interface(
22
  fn=answer_question,
23
  inputs=[
24
- gr.Textbox(label="Enter your question"),
25
- gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="Similarity Threshold")
26
  ],
27
  outputs=[
28
- gr.Textbox(label="Answer"),
29
- gr.Textbox(label="Retrieved Contexts")
30
  ],
31
  title="📘 Multi-PDF RAG System"
32
  )
33
 
34
  if __name__ == "__main__":
35
- demo.launch()
 
1
+ import os
2
+ import warnings
3
+ from dotenv import load_dotenv
4
+
5
+ import numpy as np
6
+ from sklearn.preprocessing import normalize
7
+
8
+ # 避免 tokenizers 并行警告
9
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
10
+ warnings.filterwarnings("ignore", category=UserWarning, module="tokenizers")
11
+
12
+ # 文档加载
13
+ from langchain_community.document_loaders import PyPDFLoader
14
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+
16
+ # Embeddings & 向量库
17
+ from langchain_huggingface import HuggingFaceEmbeddings
18
+ from langchain_community.vectorstores import FAISS
19
+
20
+ # Prompt & Chains
21
+ from langchain.prompts import PromptTemplate
22
+ from langchain.chains import RetrievalQA, ConversationalRetrievalChain
23
+ from langchain.memory import ConversationBufferMemory
24
+
25
+ # LLM
26
+ from langchain_community.chat_models import ChatOpenAI
27
+
28
+ # Gradio
29
  import gradio as gr
 
 
30
 
31
+ # -----------------------------
32
+ # 配置
33
+ # -----------------------------
34
+ PDF_PATH = "pdfs/Stream-Processing-with-Apache-Flink.pdf"
35
+ CHUNK_SIZE = 512
36
+ CHUNK_OVERLAP = 50
37
+ TOP_K = 3
38
+
39
+ # -----------------------------
40
+ # 1️⃣ 加载环境变量
41
+ # -----------------------------
42
+ load_dotenv()
43
+ print("✅ Environment ready")
44
+
45
+ # -----------------------------
46
+ # 2️⃣ 加载 PDF 文档
47
+ # -----------------------------
48
+ loader = PyPDFLoader(PDF_PATH)
49
+ documents = loader.load()
50
+ print(f"✅ Loaded {len(documents)} pages")
51
+
52
+ # -----------------------------
53
+ # 3️⃣ 分割文本
54
+ # -----------------------------
55
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
56
+ texts = text_splitter.split_documents(documents)
57
+ print(f"✅ Split into {len(texts)} chunks")
58
+
59
+ # -----------------------------
60
+ # 4️⃣ 生成向量 & 向量库
61
+ # -----------------------------
62
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
63
+ # 先计算 embeddings
64
+ vectors = embedding_model.embed_documents([doc.page_content for doc in texts])
65
+ # 归一化
66
+ vectors = normalize(np.array(vectors))
67
+
68
+ # 创建 FAISS 向量库
69
+ vector_store = FAISS.from_texts(
70
+ [doc.page_content for doc in texts],
71
+ embedding_model,
72
+ metadatas=[doc.metadata for doc in texts]
73
+ )
74
+
75
+ # 替换为归一化向量
76
+ vector_store.index.reset()
77
+ vector_store.index.add(vectors.astype(np.float32))
78
+ print("✅ Embeddings created, normalized and FAISS index ready")
79
+
80
+ # -----------------------------
81
+ # 5️⃣ 检索器
82
+ # -----------------------------
83
+ retriever = vector_store.as_retriever(search_kwargs={"k": TOP_K})
84
+ print("✅ Retriever ready")
85
+
86
+ # -----------------------------
87
+ # 6️⃣ LLM
88
+ # -----------------------------
89
+ llm = ChatOpenAI(
90
+ model_name="deepseek-chat", # 或 "gpt-3.5-turbo"
91
+ temperature=0.7,
92
+ max_tokens=512
93
+ )
94
+ print("✅ LLM ready")
95
+
96
+ # -----------------------------
97
+ # 7️⃣ Prompt 模板
98
+ # -----------------------------
99
+ template = """
100
+ Use the following context to answer the question. If unsure, say "I don't know."
101
+ Context:
102
+ {context}
103
+ Question: {question}
104
+ Answer:
105
+ """
106
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
107
+
108
+ # -----------------------------
109
+ # 8️⃣ 构建 RetrievalQA Chain
110
+ # -----------------------------
111
+ rag_chain = RetrievalQA.from_chain_type(
112
+ llm=llm,
113
+ retriever=retriever,
114
+ chain_type_kwargs={"prompt": prompt},
115
+ return_source_documents=True
116
+ )
117
+
118
+ # -----------------------------
119
+ # 9️⃣ 构建对话记忆
120
+ # -----------------------------
121
+ memory = ConversationBufferMemory(
122
+ memory_key="chat_history",
123
+ return_messages=True,
124
+ output_key="answer"
125
+ )
126
+
127
+ # -----------------------------
128
+ # 10️⃣ 持续对话 RAG 链
129
+ # -----------------------------
130
+ qa_chain = ConversationalRetrievalChain.from_llm(
131
+ llm=llm,
132
+ retriever=retriever,
133
+ memory=memory,
134
+ verbose=False
135
+ )
136
 
137
+ # -----------------------------
138
+ # 11️⃣ Gradio 问答函数
139
+ # -----------------------------
140
+ def answer_question(query, threshold=0.4):
141
+ # FAISS 里没有直接阈值过滤,所以我们可以先检索 TOP_K 后手动过滤
142
+ result = rag_chain({"query": query})
143
+ answer = result["result"]
144
+ sources = result.get("source_documents", [])
145
 
146
+ # 计算 cosine 相似度,并应用阈值
147
+ filtered_sources = []
148
+ for doc in sources:
149
+ emb = embedding_model.embed_documents([doc.page_content])[0]
150
+ emb = emb / np.linalg.norm(emb)
151
+ # query embedding
152
+ query_emb = embedding_model.embed_documents([query])[0]
153
+ query_emb = query_emb / np.linalg.norm(query_emb)
154
+ score = float(np.dot(emb, query_emb))
155
+ if score >= threshold:
156
+ filtered_sources.append((doc.page_content, score))
157
 
158
+ # 展示来源文档
159
+ context = "\n\n".join([f"Score: {score:.4f}\n{doc[:400]}..." for doc, score in filtered_sources])
 
 
160
  return answer, context
161
 
162
+ # -----------------------------
163
+ # 12️⃣ Gradio 界面
164
+ # -----------------------------
165
  demo = gr.Interface(
166
  fn=answer_question,
167
  inputs=[
168
+ gr.Textbox(label="🔎 输入你的问题"),
169
+ gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="相似度阈值")
170
  ],
171
  outputs=[
172
+ gr.Textbox(label="💬 模型回答"),
173
+ gr.Textbox(label="📄 检索到的文档")
174
  ],
175
  title="📘 Multi-PDF RAG System"
176
  )
177
 
178
  if __name__ == "__main__":
179
+ demo.launch(server_name="0.0.0.0", server_port=7860)