Spaces:
Sleeping
Sleeping
File size: 7,947 Bytes
cd5b6a8 cdd24a3 cd5b6a8 b38b083 cdd24a3 cd5b6a8 719919b cdd24a3 cd5b6a8 719919b cd5b6a8 719919b cd5b6a8 5304fdb cd5b6a8 719919b cdd24a3 755bec7 cdd24a3 719919b cdd24a3 5304fdb cdd24a3 53a003d 755bec7 f03c070 b571adb 719919b cd5b6a8 719919b cd5b6a8 cdd24a3 755bec7 cdd24a3 cd5b6a8 719919b cd5b6a8 cdd24a3 cd5b6a8 755bec7 cd5b6a8 cdd24a3 755bec7 cdd24a3 cd5b6a8 755bec7 cd5b6a8 5304fdb cd5b6a8 b571adb cd5b6a8 719919b cd5b6a8 cdd24a3 cd5b6a8 755bec7 cd5b6a8 755bec7 cdd24a3 755bec7 cdd24a3 755bec7 cd5b6a8 cdd24a3 755bec7 cd5b6a8 755bec7 cd5b6a8 755bec7 cd5b6a8 755bec7 cd5b6a8 755bec7 cd5b6a8 b571adb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import os
import pickle
from typing import Dict, Any, List
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from config import Config
from langchain_core.documents import Document
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.retrievers import BM25Retriever
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEndpoint
class KnowledgeManager:
def __init__(self):
Config.setup_dirs()
self.embeddings = self._init_embeddings()
self.vector_db, self.bm25_retriever = self._init_retrievers()
self.qa_chain = self._create_qa_chain()
def _init_embeddings(self):
print("[i] Using Hugging Face embeddings")
return HuggingFaceEmbeddings(
model_name="sentence-transformers/all-mpnet-base-v2",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
def _init_llm(self):
print("[i] Using HuggingFaceEndpoint with Mistral-7B")
return HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-Instruct-v0.1",
temperature=0.1,
max_new_tokens=512,
huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
)
def _init_retrievers(self):
faiss_index_path = Config.VECTOR_STORE_PATH / "index.faiss"
faiss_pkl_path = Config.VECTOR_STORE_PATH / "index.pkl"
if faiss_index_path.exists() and faiss_pkl_path.exists():
try:
vector_db = FAISS.load_local(
str(Config.VECTOR_STORE_PATH),
self.embeddings,
allow_dangerous_deserialization=True
)
if Config.BM25_STORE_PATH.exists():
with open(Config.BM25_STORE_PATH, "rb") as f:
bm25_retriever = pickle.load(f)
return vector_db, bm25_retriever
except Exception as e:
print(f"[!] Error loading vector store: {e}. Rebuilding...")
return self._build_retrievers_from_documents()
def _build_retrievers_from_documents(self):
if not any(Config.KNOWLEDGE_DIR.glob("**/*.txt")):
print("[i] No knowledge files found. Creating default base...")
self._create_default_knowledge()
loader = DirectoryLoader(
str(Config.KNOWLEDGE_DIR),
glob="**/*.txt",
loader_cls=TextLoader,
loader_kwargs={'encoding': 'utf-8'}
)
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(
chunk_size=Config.CHUNK_SIZE,
chunk_overlap=Config.CHUNK_OVERLAP,
separators=["\n\n", "\n", ". ", "! ", "? ", "; ", " ", ""]
)
chunks = splitter.split_documents(docs)
vector_db = FAISS.from_documents(
chunks,
self.embeddings,
distance_strategy="COSINE"
)
vector_db.save_local(str(Config.VECTOR_STORE_PATH))
bm25_retriever = BM25Retriever.from_documents(chunks)
bm25_retriever.k = Config.MAX_CONTEXT_CHUNKS
with open(Config.BM25_STORE_PATH, "wb") as f:
pickle.dump(bm25_retriever, f)
return vector_db, bm25_retriever
def _create_default_knowledge(self):
default_text = """Sirraya xBrain - Advanced AI Platform\n\nCreated by Amir Hameed.\n\nFeatures:\n- Hybrid Retrieval (Vector + BM25)\n- LISA Assistant\n- FAISS, BM25 Integration"""
with open(Config.KNOWLEDGE_DIR / "sirraya_xbrain.txt", "w", encoding="utf-8") as f:
f.write(default_text)
def _parallel_retrieve(self, question: str) -> List[Document]:
def retrieve_with_bm25():
return self.bm25_retriever.invoke(question)
def retrieve_with_vector():
retriever = self.vector_db.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"k": Config.MAX_CONTEXT_CHUNKS,
"score_threshold": 0.5 # Lowered threshold for testing
}
)
docs = retriever.invoke(question)
# Ensure scores are within 0-1 range
for doc in docs:
if hasattr(doc, 'metadata') and 'score' in doc.metadata:
doc.metadata['score'] = max(0, min(1, doc.metadata['score']))
return docs
with ThreadPoolExecutor(max_workers=2) as executor:
bm25_future = executor.submit(retrieve_with_bm25)
vector_future = executor.submit(retrieve_with_vector)
bm25_results = bm25_future.result()
vector_results = vector_future.result()
return vector_results + bm25_results
def _create_qa_chain(self):
if not self.vector_db or not self.bm25_retriever:
return None
prompt_template = """You are LISA, an AI assistant for Sirraya xBrain. Answer using the context below:
Context:
{context}
Question: {question}
Instructions:
- Use only the context.
- Be accurate and helpful.
- If unsure, say: "I don't have that information in my knowledge base."
Answer:"""
return RetrievalQA.from_chain_type(
llm=self._init_llm(),
chain_type="stuff",
retriever=self.vector_db.as_retriever(search_kwargs={"k": Config.MAX_CONTEXT_CHUNKS}),
chain_type_kwargs={
"prompt": PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
},
return_source_documents=True
)
def query(self, question: str) -> Dict[str, Any]:
if not self.qa_chain:
return {
"answer": "Knowledge system not initialized. Please reload.",
"processing_time": 0,
"source_chunks": []
}
try:
start_time = datetime.now()
docs = self._parallel_retrieve(question)
if not docs:
print("[i] No docs found with normal threshold, trying lower threshold...")
retriever = self.vector_db.as_retriever(
search_kwargs={
"k": Config.MAX_CONTEXT_CHUNKS,
"score_threshold": 0.3 # Very low threshold for fallback
}
)
docs = retriever.invoke(question)
result = self.qa_chain.invoke({"query": question, "input_documents": docs})
return {
"answer": result.get("result", "No answer could be generated"),
"processing_time": (datetime.now() - start_time).total_seconds() * 1000,
"source_chunks": result.get("source_documents", [])
}
except Exception as e:
print(f"[!] Query error: {str(e)}")
return {
"answer": "I encountered an error processing your query. Please try again.",
"processing_time": 0,
"source_chunks": []
}
def get_knowledge_files_count(self) -> int:
return len(list(Config.KNOWLEDGE_DIR.glob("**/*.txt"))) if Config.KNOWLEDGE_DIR.exists() else 0
def save_uploaded_file(self, uploaded_file, filename: str) -> bool:
try:
with open(Config.KNOWLEDGE_DIR / filename, "wb") as f:
f.write(uploaded_file.getbuffer())
return True
except Exception as e:
print(f"[!] File save error: {e}")
return False |