Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| from typing import Dict, Any | |
| from datetime import datetime | |
| from concurrent.futures import ThreadPoolExecutor | |
| from config import Config | |
| from langchain_community.document_loaders import TextLoader, DirectoryLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.chains import RetrievalQA | |
| from langchain.prompts import PromptTemplate | |
| from langchain.retrievers import BM25Retriever | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.llms import HuggingFaceHub | |
| class KnowledgeManager: | |
| def __init__(self): | |
| Config.setup_dirs() | |
| self.embeddings = self._init_embeddings() | |
| self.vector_db, self.bm25_retriever = self._init_retrievers() | |
| self.qa_chain = self._create_qa_chain() | |
| def _init_embeddings(self): | |
| print("[i] Using Hugging Face embeddings") | |
| return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| def _init_llm(self): | |
| print("[i] Using HuggingFaceHub with Mistral-7B") | |
| return HuggingFaceHub( | |
| repo_id="mistralai/Mistral-7B-Instruct-v0.1", | |
| huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"), | |
| model_kwargs={ | |
| "temperature": 0.1, | |
| "max_new_tokens": 512, | |
| "do_sample": True | |
| } | |
| ) | |
| def _init_retrievers(self): | |
| faiss_index_path = Config.VECTOR_STORE_PATH / "index.faiss" | |
| faiss_pkl_path = Config.VECTOR_STORE_PATH / "index.pkl" | |
| if faiss_index_path.exists() and faiss_pkl_path.exists(): | |
| try: | |
| vector_db = FAISS.load_local( | |
| str(Config.VECTOR_STORE_PATH), | |
| self.embeddings, | |
| allow_dangerous_deserialization=True | |
| ) | |
| if Config.BM25_STORE_PATH.exists(): | |
| with open(Config.BM25_STORE_PATH, "rb") as f: | |
| bm25_retriever = pickle.load(f) | |
| return vector_db, bm25_retriever | |
| except Exception as e: | |
| print(f"[!] Error loading vector store: {e}. Rebuilding...") | |
| return self._build_retrievers_from_documents() | |
| def _build_retrievers_from_documents(self): | |
| if not any(Config.KNOWLEDGE_DIR.glob("**/*.txt")): | |
| print("[i] No knowledge files found. Creating default base...") | |
| self._create_default_knowledge() | |
| loader = DirectoryLoader( | |
| str(Config.KNOWLEDGE_DIR), | |
| glob="**/*.txt", | |
| loader_cls=TextLoader, | |
| loader_kwargs={'encoding': 'utf-8'} | |
| ) | |
| docs = loader.load() | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=Config.CHUNK_SIZE, | |
| chunk_overlap=Config.CHUNK_OVERLAP, | |
| separators=["\n\n", "\n", ". ", "! ", "? ", "; ", " ", ""] | |
| ) | |
| chunks = splitter.split_documents(docs) | |
| vector_db = FAISS.from_documents(chunks, self.embeddings) | |
| vector_db.save_local(str(Config.VECTOR_STORE_PATH)) | |
| bm25_retriever = BM25Retriever.from_documents(chunks) | |
| bm25_retriever.k = Config.MAX_CONTEXT_CHUNKS | |
| with open(Config.BM25_STORE_PATH, "wb") as f: | |
| pickle.dump(bm25_retriever, f) | |
| return vector_db, bm25_retriever | |
| def _create_default_knowledge(self): | |
| default_text = """Sirraya xBrain - Advanced AI Platform\n\nCreated by Amir Hameed.\n\nFeatures:\n- Hybrid Retrieval (Vector + BM25)\n- LISA Assistant\n- FAISS, BM25 Integration""" | |
| with open(Config.KNOWLEDGE_DIR / "sirraya_xbrain.txt", "w", encoding="utf-8") as f: | |
| f.write(default_text) | |
| def _parallel_retrieve(self, question: str): | |
| def retrieve_with_bm25(): | |
| return self.bm25_retriever.get_relevant_documents(question) | |
| def retrieve_with_vector(): | |
| retriever = self.vector_db.as_retriever( | |
| search_type="similarity_score_threshold", | |
| search_kwargs={"k": Config.MAX_CONTEXT_CHUNKS, "score_threshold": 0.83} | |
| ) | |
| return retriever.get_relevant_documents(question) | |
| with ThreadPoolExecutor(max_workers=2) as executor: | |
| bm25_future = executor.submit(retrieve_with_bm25) | |
| vector_future = executor.submit(retrieve_with_vector) | |
| bm25_results = bm25_future.result() | |
| vector_results = vector_future.result() | |
| return vector_results + bm25_results | |
| def _create_qa_chain(self): | |
| if not self.vector_db or not self.bm25_retriever: | |
| return None | |
| prompt_template = """You are LISA, an AI assistant for Sirraya xBrain. Answer using the context below: | |
| Context: | |
| {context} | |
| Question: {question} | |
| Instructions: | |
| - Use only the context. | |
| - Be accurate and helpful. | |
| - If unsure, say: "I don't have that information in my knowledge base." | |
| Answer:""" | |
| return RetrievalQA.from_chain_type( | |
| llm=self._init_llm(), | |
| chain_type="stuff", | |
| retriever=self.vector_db.as_retriever(search_kwargs={"k": 1}), | |
| chain_type_kwargs={ | |
| "prompt": PromptTemplate( | |
| template=prompt_template, | |
| input_variables=["context", "question"] | |
| ) | |
| }, | |
| return_source_documents=True | |
| ) | |
| def query(self, question: str) -> Dict[str, Any]: | |
| if not self.qa_chain: | |
| return { | |
| "answer": "Knowledge system not initialized. Please reload.", | |
| "processing_time": 0, | |
| "source_chunks": [] | |
| } | |
| try: | |
| start_time = datetime.now() | |
| docs = self._parallel_retrieve(question) | |
| if not docs: | |
| retriever = self.vector_db.as_retriever(search_kwargs={"k": Config.MAX_CONTEXT_CHUNKS}) | |
| docs = retriever.get_relevant_documents(question) | |
| result = self.qa_chain.invoke({"input_documents": docs, "query": question}) | |
| processing_time = (datetime.now() - start_time).total_seconds() * 1000 | |
| return { | |
| "answer": result.get("result", ""), | |
| "processing_time": processing_time, | |
| "source_chunks": result.get("source_documents", []) | |
| } | |
| except Exception as e: | |
| print(f"[!] Query error: {e}") | |
| return { | |
| "answer": f"Error: {e}", | |
| "processing_time": 0, | |
| "source_chunks": [] | |
| } | |
| def get_knowledge_files_count(self) -> int: | |
| return len(list(Config.KNOWLEDGE_DIR.glob("**/*.txt"))) if Config.KNOWLEDGE_DIR.exists() else 0 | |
| def save_uploaded_file(self, uploaded_file, filename: str) -> bool: | |
| try: | |
| with open(Config.KNOWLEDGE_DIR / filename, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| return True | |
| except Exception as e: | |
| print(f"[!] File save error: {e}") | |
| return False |