xTwin / knowledge_engine.py
aamirhameed's picture
Update knowledge_engine.py
f03c070 verified
raw
history blame
7.95 kB
import os
import pickle
from typing import Dict, Any, List
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from config import Config
from langchain_core.documents import Document
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.retrievers import BM25Retriever
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEndpoint
class KnowledgeManager:
def __init__(self):
Config.setup_dirs()
self.embeddings = self._init_embeddings()
self.vector_db, self.bm25_retriever = self._init_retrievers()
self.qa_chain = self._create_qa_chain()
def _init_embeddings(self):
print("[i] Using Hugging Face embeddings")
return HuggingFaceEmbeddings(
model_name="sentence-transformers/all-mpnet-base-v2",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
def _init_llm(self):
print("[i] Using HuggingFaceEndpoint with Mistral-7B")
return HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-Instruct-v0.1",
temperature=0.1,
max_new_tokens=512,
huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
)
def _init_retrievers(self):
faiss_index_path = Config.VECTOR_STORE_PATH / "index.faiss"
faiss_pkl_path = Config.VECTOR_STORE_PATH / "index.pkl"
if faiss_index_path.exists() and faiss_pkl_path.exists():
try:
vector_db = FAISS.load_local(
str(Config.VECTOR_STORE_PATH),
self.embeddings,
allow_dangerous_deserialization=True
)
if Config.BM25_STORE_PATH.exists():
with open(Config.BM25_STORE_PATH, "rb") as f:
bm25_retriever = pickle.load(f)
return vector_db, bm25_retriever
except Exception as e:
print(f"[!] Error loading vector store: {e}. Rebuilding...")
return self._build_retrievers_from_documents()
def _build_retrievers_from_documents(self):
if not any(Config.KNOWLEDGE_DIR.glob("**/*.txt")):
print("[i] No knowledge files found. Creating default base...")
self._create_default_knowledge()
loader = DirectoryLoader(
str(Config.KNOWLEDGE_DIR),
glob="**/*.txt",
loader_cls=TextLoader,
loader_kwargs={'encoding': 'utf-8'}
)
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(
chunk_size=Config.CHUNK_SIZE,
chunk_overlap=Config.CHUNK_OVERLAP,
separators=["\n\n", "\n", ". ", "! ", "? ", "; ", " ", ""]
)
chunks = splitter.split_documents(docs)
vector_db = FAISS.from_documents(
chunks,
self.embeddings,
distance_strategy="COSINE"
)
vector_db.save_local(str(Config.VECTOR_STORE_PATH))
bm25_retriever = BM25Retriever.from_documents(chunks)
bm25_retriever.k = Config.MAX_CONTEXT_CHUNKS
with open(Config.BM25_STORE_PATH, "wb") as f:
pickle.dump(bm25_retriever, f)
return vector_db, bm25_retriever
def _create_default_knowledge(self):
default_text = """Sirraya xBrain - Advanced AI Platform\n\nCreated by Amir Hameed.\n\nFeatures:\n- Hybrid Retrieval (Vector + BM25)\n- LISA Assistant\n- FAISS, BM25 Integration"""
with open(Config.KNOWLEDGE_DIR / "sirraya_xbrain.txt", "w", encoding="utf-8") as f:
f.write(default_text)
def _parallel_retrieve(self, question: str) -> List[Document]:
def retrieve_with_bm25():
return self.bm25_retriever.invoke(question)
def retrieve_with_vector():
retriever = self.vector_db.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"k": Config.MAX_CONTEXT_CHUNKS,
"score_threshold": 0.5 # Lowered threshold for testing
}
)
docs = retriever.invoke(question)
# Ensure scores are within 0-1 range
for doc in docs:
if hasattr(doc, 'metadata') and 'score' in doc.metadata:
doc.metadata['score'] = max(0, min(1, doc.metadata['score']))
return docs
with ThreadPoolExecutor(max_workers=2) as executor:
bm25_future = executor.submit(retrieve_with_bm25)
vector_future = executor.submit(retrieve_with_vector)
bm25_results = bm25_future.result()
vector_results = vector_future.result()
return vector_results + bm25_results
def _create_qa_chain(self):
if not self.vector_db or not self.bm25_retriever:
return None
prompt_template = """You are LISA, an AI assistant for Sirraya xBrain. Answer using the context below:
Context:
{context}
Question: {question}
Instructions:
- Use only the context.
- Be accurate and helpful.
- If unsure, say: "I don't have that information in my knowledge base."
Answer:"""
return RetrievalQA.from_chain_type(
llm=self._init_llm(),
chain_type="stuff",
retriever=self.vector_db.as_retriever(search_kwargs={"k": Config.MAX_CONTEXT_CHUNKS}),
chain_type_kwargs={
"prompt": PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
},
return_source_documents=True
)
def query(self, question: str) -> Dict[str, Any]:
if not self.qa_chain:
return {
"answer": "Knowledge system not initialized. Please reload.",
"processing_time": 0,
"source_chunks": []
}
try:
start_time = datetime.now()
docs = self._parallel_retrieve(question)
if not docs:
print("[i] No docs found with normal threshold, trying lower threshold...")
retriever = self.vector_db.as_retriever(
search_kwargs={
"k": Config.MAX_CONTEXT_CHUNKS,
"score_threshold": 0.3 # Very low threshold for fallback
}
)
docs = retriever.invoke(question)
result = self.qa_chain.invoke({"query": question, "input_documents": docs})
return {
"answer": result.get("result", "No answer could be generated"),
"processing_time": (datetime.now() - start_time).total_seconds() * 1000,
"source_chunks": result.get("source_documents", [])
}
except Exception as e:
print(f"[!] Query error: {str(e)}")
return {
"answer": "I encountered an error processing your query. Please try again.",
"processing_time": 0,
"source_chunks": []
}
def get_knowledge_files_count(self) -> int:
return len(list(Config.KNOWLEDGE_DIR.glob("**/*.txt"))) if Config.KNOWLEDGE_DIR.exists() else 0
def save_uploaded_file(self, uploaded_file, filename: str) -> bool:
try:
with open(Config.KNOWLEDGE_DIR / filename, "wb") as f:
f.write(uploaded_file.getbuffer())
return True
except Exception as e:
print(f"[!] File save error: {e}")
return False