aamirhameed commited on
Commit
cdd24a3
·
verified ·
1 Parent(s): 5304fdb

Update knowledge_engine.py

Browse files
Files changed (1) hide show
  1. knowledge_engine.py +31 -21
knowledge_engine.py CHANGED
@@ -1,10 +1,11 @@
1
  import os
2
  import pickle
3
- from typing import Dict, Any
4
  from datetime import datetime
5
  from concurrent.futures import ThreadPoolExecutor
6
 
7
  from config import Config
 
8
  from langchain_community.document_loaders import TextLoader, DirectoryLoader
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from langchain_community.vectorstores import FAISS
@@ -12,7 +13,7 @@ from langchain.chains import RetrievalQA
12
  from langchain.prompts import PromptTemplate
13
  from langchain.retrievers import BM25Retriever
14
  from langchain_community.embeddings import HuggingFaceEmbeddings
15
- from langchain_community.llms import HuggingFaceHub
16
 
17
  class KnowledgeManager:
18
  def __init__(self):
@@ -23,18 +24,18 @@ class KnowledgeManager:
23
 
24
  def _init_embeddings(self):
25
  print("[i] Using Hugging Face embeddings")
26
- return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
 
 
27
 
28
  def _init_llm(self):
29
- print("[i] Using HuggingFaceHub with Mistral-7B")
30
- return HuggingFaceHub(
31
  repo_id="mistralai/Mistral-7B-Instruct-v0.1",
32
- huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
33
- model_kwargs={
34
- "temperature": 0.1,
35
- "max_new_tokens": 512,
36
- "do_sample": True
37
- }
38
  )
39
 
40
  def _init_retrievers(self):
@@ -76,7 +77,11 @@ class KnowledgeManager:
76
  )
77
  chunks = splitter.split_documents(docs)
78
 
79
- vector_db = FAISS.from_documents(chunks, self.embeddings)
 
 
 
 
80
  vector_db.save_local(str(Config.VECTOR_STORE_PATH))
81
 
82
  bm25_retriever = BM25Retriever.from_documents(chunks)
@@ -92,16 +97,19 @@ class KnowledgeManager:
92
  with open(Config.KNOWLEDGE_DIR / "sirraya_xbrain.txt", "w", encoding="utf-8") as f:
93
  f.write(default_text)
94
 
95
- def _parallel_retrieve(self, question: str):
96
  def retrieve_with_bm25():
97
- return self.bm25_retriever.get_relevant_documents(question)
98
 
99
  def retrieve_with_vector():
100
  retriever = self.vector_db.as_retriever(
101
  search_type="similarity_score_threshold",
102
- search_kwargs={"k": Config.MAX_CONTEXT_CHUNKS, "score_threshold": 0.83}
 
 
 
103
  )
104
- return retriever.get_relevant_documents(question)
105
 
106
  with ThreadPoolExecutor(max_workers=2) as executor:
107
  bm25_future = executor.submit(retrieve_with_bm25)
@@ -132,7 +140,7 @@ Answer:"""
132
  return RetrievalQA.from_chain_type(
133
  llm=self._init_llm(),
134
  chain_type="stuff",
135
- retriever=self.vector_db.as_retriever(search_kwargs={"k": 1}),
136
  chain_type_kwargs={
137
  "prompt": PromptTemplate(
138
  template=prompt_template,
@@ -155,10 +163,12 @@ Answer:"""
155
  docs = self._parallel_retrieve(question)
156
 
157
  if not docs:
158
- retriever = self.vector_db.as_retriever(search_kwargs={"k": Config.MAX_CONTEXT_CHUNKS})
159
- docs = retriever.get_relevant_documents(question)
 
 
160
 
161
- result = self.qa_chain.invoke({"input_documents": docs, "query": question})
162
  processing_time = (datetime.now() - start_time).total_seconds() * 1000
163
 
164
  return {
@@ -169,7 +179,7 @@ Answer:"""
169
  except Exception as e:
170
  print(f"[!] Query error: {e}")
171
  return {
172
- "answer": f"Error: {e}",
173
  "processing_time": 0,
174
  "source_chunks": []
175
  }
 
1
  import os
2
  import pickle
3
+ from typing import Dict, Any, List
4
  from datetime import datetime
5
  from concurrent.futures import ThreadPoolExecutor
6
 
7
  from config import Config
8
+ from langchain_core.documents import Document
9
  from langchain_community.document_loaders import TextLoader, DirectoryLoader
10
  from langchain.text_splitter import RecursiveCharacterTextSplitter
11
  from langchain_community.vectorstores import FAISS
 
13
  from langchain.prompts import PromptTemplate
14
  from langchain.retrievers import BM25Retriever
15
  from langchain_community.embeddings import HuggingFaceEmbeddings
16
+ from langchain_huggingface import HuggingFaceEndpoint
17
 
18
  class KnowledgeManager:
19
  def __init__(self):
 
24
 
25
  def _init_embeddings(self):
26
  print("[i] Using Hugging Face embeddings")
27
+ return HuggingFaceEmbeddings(
28
+ model_name="sentence-transformers/all-mpnet-base-v2",
29
+ model_kwargs={'device': 'cpu'}
30
+ )
31
 
32
  def _init_llm(self):
33
+ print("[i] Using HuggingFaceEndpoint with Mistral-7B")
34
+ return HuggingFaceEndpoint(
35
  repo_id="mistralai/Mistral-7B-Instruct-v0.1",
36
+ temperature=0.1,
37
+ max_length=512,
38
+ token=os.getenv("HUGGINGFACEHUB_API_TOKEN")
 
 
 
39
  )
40
 
41
  def _init_retrievers(self):
 
77
  )
78
  chunks = splitter.split_documents(docs)
79
 
80
+ vector_db = FAISS.from_documents(
81
+ chunks,
82
+ self.embeddings,
83
+ distance_strategy="COSINE" # Ensures scores between 0-1
84
+ )
85
  vector_db.save_local(str(Config.VECTOR_STORE_PATH))
86
 
87
  bm25_retriever = BM25Retriever.from_documents(chunks)
 
97
  with open(Config.KNOWLEDGE_DIR / "sirraya_xbrain.txt", "w", encoding="utf-8") as f:
98
  f.write(default_text)
99
 
100
+ def _parallel_retrieve(self, question: str) -> List[Document]:
101
  def retrieve_with_bm25():
102
+ return self.bm25_retriever.invoke(question) # Updated to use invoke()
103
 
104
  def retrieve_with_vector():
105
  retriever = self.vector_db.as_retriever(
106
  search_type="similarity_score_threshold",
107
+ search_kwargs={
108
+ "k": Config.MAX_CONTEXT_CHUNKS,
109
+ "score_threshold": 0.83
110
+ }
111
  )
112
+ return retriever.invoke(question) # Updated to use invoke()
113
 
114
  with ThreadPoolExecutor(max_workers=2) as executor:
115
  bm25_future = executor.submit(retrieve_with_bm25)
 
140
  return RetrievalQA.from_chain_type(
141
  llm=self._init_llm(),
142
  chain_type="stuff",
143
+ retriever=self.vector_db.as_retriever(search_kwargs={"k": Config.MAX_CONTEXT_CHUNKS}),
144
  chain_type_kwargs={
145
  "prompt": PromptTemplate(
146
  template=prompt_template,
 
163
  docs = self._parallel_retrieve(question)
164
 
165
  if not docs:
166
+ retriever = self.vector_db.as_retriever(
167
+ search_kwargs={"k": Config.MAX_CONTEXT_CHUNKS}
168
+ )
169
+ docs = retriever.invoke(question) # Updated to use invoke()
170
 
171
+ result = self.qa_chain.invoke({"query": question, "input_documents": docs})
172
  processing_time = (datetime.now() - start_time).total_seconds() * 1000
173
 
174
  return {
 
179
  except Exception as e:
180
  print(f"[!] Query error: {e}")
181
  return {
182
+ "answer": f"Error processing your query: {str(e)}",
183
  "processing_time": 0,
184
  "source_chunks": []
185
  }