xTwin / knowledge_engine.py
aamirhameed's picture
Update knowledge_engine.py
5304fdb verified
raw
history blame
7.08 kB
import os
import pickle
from typing import Dict, Any
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from config import Config
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.retrievers import BM25Retriever
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceHub
class KnowledgeManager:
def __init__(self):
Config.setup_dirs()
self.embeddings = self._init_embeddings()
self.vector_db, self.bm25_retriever = self._init_retrievers()
self.qa_chain = self._create_qa_chain()
def _init_embeddings(self):
print("[i] Using Hugging Face embeddings")
return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
def _init_llm(self):
print("[i] Using HuggingFaceHub with Mistral-7B")
return HuggingFaceHub(
repo_id="mistralai/Mistral-7B-Instruct-v0.1",
huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
model_kwargs={
"temperature": 0.1,
"max_new_tokens": 512,
"do_sample": True
}
)
def _init_retrievers(self):
faiss_index_path = Config.VECTOR_STORE_PATH / "index.faiss"
faiss_pkl_path = Config.VECTOR_STORE_PATH / "index.pkl"
if faiss_index_path.exists() and faiss_pkl_path.exists():
try:
vector_db = FAISS.load_local(
str(Config.VECTOR_STORE_PATH),
self.embeddings,
allow_dangerous_deserialization=True
)
if Config.BM25_STORE_PATH.exists():
with open(Config.BM25_STORE_PATH, "rb") as f:
bm25_retriever = pickle.load(f)
return vector_db, bm25_retriever
except Exception as e:
print(f"[!] Error loading vector store: {e}. Rebuilding...")
return self._build_retrievers_from_documents()
def _build_retrievers_from_documents(self):
if not any(Config.KNOWLEDGE_DIR.glob("**/*.txt")):
print("[i] No knowledge files found. Creating default base...")
self._create_default_knowledge()
loader = DirectoryLoader(
str(Config.KNOWLEDGE_DIR),
glob="**/*.txt",
loader_cls=TextLoader,
loader_kwargs={'encoding': 'utf-8'}
)
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(
chunk_size=Config.CHUNK_SIZE,
chunk_overlap=Config.CHUNK_OVERLAP,
separators=["\n\n", "\n", ". ", "! ", "? ", "; ", " ", ""]
)
chunks = splitter.split_documents(docs)
vector_db = FAISS.from_documents(chunks, self.embeddings)
vector_db.save_local(str(Config.VECTOR_STORE_PATH))
bm25_retriever = BM25Retriever.from_documents(chunks)
bm25_retriever.k = Config.MAX_CONTEXT_CHUNKS
with open(Config.BM25_STORE_PATH, "wb") as f:
pickle.dump(bm25_retriever, f)
return vector_db, bm25_retriever
def _create_default_knowledge(self):
default_text = """Sirraya xBrain - Advanced AI Platform\n\nCreated by Amir Hameed.\n\nFeatures:\n- Hybrid Retrieval (Vector + BM25)\n- LISA Assistant\n- FAISS, BM25 Integration"""
with open(Config.KNOWLEDGE_DIR / "sirraya_xbrain.txt", "w", encoding="utf-8") as f:
f.write(default_text)
def _parallel_retrieve(self, question: str):
def retrieve_with_bm25():
return self.bm25_retriever.get_relevant_documents(question)
def retrieve_with_vector():
retriever = self.vector_db.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": Config.MAX_CONTEXT_CHUNKS, "score_threshold": 0.83}
)
return retriever.get_relevant_documents(question)
with ThreadPoolExecutor(max_workers=2) as executor:
bm25_future = executor.submit(retrieve_with_bm25)
vector_future = executor.submit(retrieve_with_vector)
bm25_results = bm25_future.result()
vector_results = vector_future.result()
return vector_results + bm25_results
def _create_qa_chain(self):
if not self.vector_db or not self.bm25_retriever:
return None
prompt_template = """You are LISA, an AI assistant for Sirraya xBrain. Answer using the context below:
Context:
{context}
Question: {question}
Instructions:
- Use only the context.
- Be accurate and helpful.
- If unsure, say: "I don't have that information in my knowledge base."
Answer:"""
return RetrievalQA.from_chain_type(
llm=self._init_llm(),
chain_type="stuff",
retriever=self.vector_db.as_retriever(search_kwargs={"k": 1}),
chain_type_kwargs={
"prompt": PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
},
return_source_documents=True
)
def query(self, question: str) -> Dict[str, Any]:
if not self.qa_chain:
return {
"answer": "Knowledge system not initialized. Please reload.",
"processing_time": 0,
"source_chunks": []
}
try:
start_time = datetime.now()
docs = self._parallel_retrieve(question)
if not docs:
retriever = self.vector_db.as_retriever(search_kwargs={"k": Config.MAX_CONTEXT_CHUNKS})
docs = retriever.get_relevant_documents(question)
result = self.qa_chain.invoke({"input_documents": docs, "query": question})
processing_time = (datetime.now() - start_time).total_seconds() * 1000
return {
"answer": result.get("result", ""),
"processing_time": processing_time,
"source_chunks": result.get("source_documents", [])
}
except Exception as e:
print(f"[!] Query error: {e}")
return {
"answer": f"Error: {e}",
"processing_time": 0,
"source_chunks": []
}
def get_knowledge_files_count(self) -> int:
return len(list(Config.KNOWLEDGE_DIR.glob("**/*.txt"))) if Config.KNOWLEDGE_DIR.exists() else 0
def save_uploaded_file(self, uploaded_file, filename: str) -> bool:
try:
with open(Config.KNOWLEDGE_DIR / filename, "wb") as f:
f.write(uploaded_file.getbuffer())
return True
except Exception as e:
print(f"[!] File save error: {e}")
return False