aamirhameed commited on
Commit
d4ae976
·
verified ·
1 Parent(s): 7921157

Update knowledge_engine.py

Browse files
Files changed (1) hide show
  1. knowledge_engine.py +41 -92
knowledge_engine.py CHANGED
@@ -1,112 +1,61 @@
1
  import os
2
  from pathlib import Path
3
- from typing import List, Optional
4
-
5
- import faiss
6
- import numpy as np
7
- from sentence_transformers import SentenceTransformer
8
-
9
- from langchain.llms import HuggingFacePipeline
10
- from langchain.chains import RetrievalQA
11
- from langchain.vectorstores.faiss import FAISS
12
- from langchain.embeddings import HuggingFaceEmbeddings
13
  from langchain.document_loaders import TextLoader
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
15
-
16
- import torch
17
- from transformers import pipeline
 
18
 
19
  class KnowledgeManager:
20
- def __init__(self, knowledge_dir="knowledge_base"):
21
  self.knowledge_dir = Path(knowledge_dir)
22
- self.knowledge_dir.mkdir(exist_ok=True, parents=True)
23
-
24
  self.documents = []
25
- self.texts = []
26
  self.vectorstore = None
27
  self.retriever = None
28
- self.qa_chain = None
29
  self.llm = None
 
30
 
31
- self.device = "cpu" # For HF Spaces, CPU only
32
-
33
- # Initialize embeddings
34
- self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
35
-
36
- # Load and prepare knowledge
37
- self.load_documents()
38
- self.create_vectorstore()
39
- self.init_llm()
40
- self.init_qa_chain()
41
-
42
- def load_documents(self):
43
- # Load text files and split into chunks
44
  files = list(self.knowledge_dir.glob("*.txt"))
45
- self.documents = []
 
 
46
  for file in files:
47
- loader = TextLoader(str(file), encoding="utf-8")
48
- docs = loader.load()
49
- self.documents.extend(docs)
50
-
51
- # Split into smaller chunks (to improve retrieval granularity)
52
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
53
- self.texts = text_splitter.split_documents(self.documents)
54
 
55
- def create_vectorstore(self):
56
- if not self.texts:
57
- self.vectorstore = None
58
- return
59
- self.vectorstore = FAISS.from_documents(self.texts, self.embeddings)
60
- self.retriever = self.vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
61
-
62
- def init_llm(self):
63
- # Initialize HuggingFace pipeline + LangChain wrapper LLM
64
 
65
- # Try flan-t5-small first
66
- try:
67
- pipe = pipeline(
68
- "text2text-generation",
69
- model="google/flan-t5-small",
70
- device=-1, # CPU only
71
- max_length=256,
72
- do_sample=False,
73
- )
74
- self.llm = HuggingFacePipeline(pipeline=pipe)
75
- except Exception as e:
76
- print(f"Failed to load flan-t5-small: {e}")
77
- self.llm = None
78
 
79
- # Fallback: if no LLM, set to None and warn
80
- if self.llm is None:
81
- print("No LLM available, will fallback to retrieval-only.")
82
 
83
- def init_qa_chain(self):
84
- if self.llm and self.retriever:
85
- self.qa_chain = RetrievalQA.from_chain_type(
86
- llm=self.llm,
87
- retriever=self.retriever,
88
- return_source_documents=True,
89
- chain_type="stuff", # Stuff all docs in prompt, or "map_reduce"
90
- )
91
- else:
92
- self.qa_chain = None
93
 
94
- def get_knowledge_summary(self) -> str:
95
- count = len(self.texts) if self.texts else 0
96
- return f"{count} document chunks loaded."
 
97
 
98
- def query(self, question: str):
99
- if self.qa_chain:
100
- # Use LLM + retrieval
101
- result = self.qa_chain({"query": question})
102
- answer = result.get("result", "No answer found.")
103
- sources = result.get("source_documents", [])
104
- source_texts = [doc.page_content for doc in sources]
105
- return answer, source_texts
106
- elif self.retriever:
107
- # Retrieval only fallback
108
- docs = self.retriever.get_relevant_documents(question)
109
- answers = [doc.page_content for doc in docs]
110
- return "\n\n".join(answers), []
111
- else:
112
- return "Knowledge base not initialized.", []
 
1
  import os
2
  from pathlib import Path
 
 
 
 
 
 
 
 
 
 
3
  from langchain.document_loaders import TextLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ from langchain.chains import RetrievalQA
8
+ from langchain.llms import HuggingFaceHub
9
 
10
  class KnowledgeManager:
11
+ def __init__(self, knowledge_dir="."): # root dir by default
12
  self.knowledge_dir = Path(knowledge_dir)
 
 
13
  self.documents = []
14
+ self.embeddings = None
15
  self.vectorstore = None
16
  self.retriever = None
 
17
  self.llm = None
18
+ self.qa_chain = None
19
 
20
+ self._load_documents()
21
+ if self.documents:
22
+ self._initialize_embeddings()
23
+ self._initialize_vectorstore()
24
+ self._initialize_llm()
25
+ self._initialize_qa_chain()
26
+
27
+ def _load_documents(self):
28
+ if not self.knowledge_dir.exists():
29
+ raise FileNotFoundError(f"Directory {self.knowledge_dir} does not exist.")
30
+
 
 
31
  files = list(self.knowledge_dir.glob("*.txt"))
32
+ if not files:
33
+ raise FileNotFoundError(f"No .txt files found in {self.knowledge_dir}. Please upload your knowledge base files in root.")
34
+
35
  for file in files:
36
+ loader = TextLoader(str(file))
37
+ self.documents.extend(loader.load())
38
+
39
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
40
+ self.documents = splitter.split_documents(self.documents)
 
 
41
 
42
+ def _initialize_embeddings(self):
43
+ self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
 
 
 
 
 
 
44
 
45
+ def _initialize_vectorstore(self):
46
+ self.vectorstore = FAISS.from_documents(self.documents, self.embeddings)
47
+ self.retriever = self.vectorstore.as_retriever()
 
 
 
 
 
 
 
 
 
 
48
 
49
+ def _initialize_llm(self):
50
+ self.llm = HuggingFaceHub(repo_id="google/flan-t5-small", model_kwargs={"temperature":0, "max_length":256})
 
51
 
52
+ def _initialize_qa_chain(self):
53
+ self.qa_chain = RetrievalQA.from_chain_type(llm=self.llm, chain_type="stuff", retriever=self.retriever)
 
 
 
 
 
 
 
 
54
 
55
+ def ask(self, query):
56
+ if not self.qa_chain:
57
+ return "Knowledge base not initialized properly."
58
+ return self.qa_chain.run(query)
59
 
60
+ def get_knowledge_summary(self):
61
+ return f"Loaded {len(self.documents)} document chunks from {self.knowledge_dir}"