Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| from typing import Dict, Any, List | |
| from datetime import datetime | |
| from concurrent.futures import ThreadPoolExecutor | |
| from config import Config | |
| from langchain_core.documents import Document | |
| from langchain_community.document_loaders import TextLoader, DirectoryLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.chains import RetrievalQA | |
| from langchain.prompts import PromptTemplate | |
| from langchain.retrievers import BM25Retriever | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| class KnowledgeManager: | |
| def __init__(self): | |
| Config.setup_dirs() | |
| self.embeddings = self._init_embeddings() | |
| self.vector_db, self.bm25_retriever = self._init_retrievers() | |
| self.qa_chain = self._create_qa_chain() | |
| def _init_embeddings(self): | |
| print("[i] Using Hugging Face embeddings") | |
| return HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-mpnet-base-v2", | |
| model_kwargs={'device': 'cpu'}, | |
| encode_kwargs={'normalize_embeddings': True} | |
| ) | |
| def _init_llm(self): | |
| print("[i] Using HuggingFaceEndpoint with Mistral-7B") | |
| return HuggingFaceEndpoint( | |
| repo_id="mistralai/Mistral-7B-Instruct-v0.1", | |
| temperature=0.1, | |
| max_new_tokens=512, | |
| huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"), | |
| ) | |
| def _init_retrievers(self): | |
| faiss_index_path = Config.VECTOR_STORE_PATH / "index.faiss" | |
| faiss_pkl_path = Config.VECTOR_STORE_PATH / "index.pkl" | |
| if faiss_index_path.exists() and faiss_pkl_path.exists(): | |
| try: | |
| vector_db = FAISS.load_local( | |
| str(Config.VECTOR_STORE_PATH), | |
| self.embeddings, | |
| allow_dangerous_deserialization=True | |
| ) | |
| if Config.BM25_STORE_PATH.exists(): | |
| with open(Config.BM25_STORE_PATH, "rb") as f: | |
| bm25_retriever = pickle.load(f) | |
| return vector_db, bm25_retriever | |
| except Exception as e: | |
| print(f"[!] Error loading vector store: {e}. Rebuilding...") | |
| return self._build_retrievers_from_documents() | |
| def _build_retrievers_from_documents(self): | |
| if not any(Config.KNOWLEDGE_DIR.glob("**/*.txt")): | |
| print("[i] No knowledge files found. Creating default base...") | |
| self._create_default_knowledge() | |
| loader = DirectoryLoader( | |
| str(Config.KNOWLEDGE_DIR), | |
| glob="**/*.txt", | |
| loader_cls=TextLoader, | |
| loader_kwargs={'encoding': 'utf-8'} | |
| ) | |
| docs = loader.load() | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=Config.CHUNK_SIZE, | |
| chunk_overlap=Config.CHUNK_OVERLAP, | |
| separators=["\n\n", "\n", ". ", "! ", "? ", "; ", " ", ""] | |
| ) | |
| chunks = splitter.split_documents(docs) | |
| vector_db = FAISS.from_documents( | |
| chunks, | |
| self.embeddings, | |
| distance_strategy="COSINE" | |
| ) | |
| vector_db.save_local(str(Config.VECTOR_STORE_PATH)) | |
| bm25_retriever = BM25Retriever.from_documents(chunks) | |
| bm25_retriever.k = Config.MAX_CONTEXT_CHUNKS | |
| with open(Config.BM25_STORE_PATH, "wb") as f: | |
| pickle.dump(bm25_retriever, f) | |
| return vector_db, bm25_retriever | |
| def _create_default_knowledge(self): | |
| default_text = """Sirraya xBrain - Advanced AI Platform\n\nCreated by Amir Hameed.\n\nFeatures:\n- Hybrid Retrieval (Vector + BM25)\n- LISA Assistant\n- FAISS, BM25 Integration""" | |
| with open(Config.KNOWLEDGE_DIR / "sirraya_xbrain.txt", "w", encoding="utf-8") as f: | |
| f.write(default_text) | |
| def _parallel_retrieve(self, question: str) -> List[Document]: | |
| def retrieve_with_bm25(): | |
| return self.bm25_retriever.invoke(question) | |
| def retrieve_with_vector(): | |
| retriever = self.vector_db.as_retriever( | |
| search_type="similarity_score_threshold", | |
| search_kwargs={ | |
| "k": Config.MAX_CONTEXT_CHUNKS, | |
| "score_threshold": 0.5 # Lowered threshold for testing | |
| } | |
| ) | |
| docs = retriever.invoke(question) | |
| # Ensure scores are within 0-1 range | |
| for doc in docs: | |
| if hasattr(doc, 'metadata') and 'score' in doc.metadata: | |
| doc.metadata['score'] = max(0, min(1, doc.metadata['score'])) | |
| return docs | |
| with ThreadPoolExecutor(max_workers=2) as executor: | |
| bm25_future = executor.submit(retrieve_with_bm25) | |
| vector_future = executor.submit(retrieve_with_vector) | |
| bm25_results = bm25_future.result() | |
| vector_results = vector_future.result() | |
| return vector_results + bm25_results | |
| def _create_qa_chain(self): | |
| if not self.vector_db or not self.bm25_retriever: | |
| return None | |
| prompt_template = """You are LISA, an AI assistant for Sirraya xBrain. Answer using the context below: | |
| Context: | |
| {context} | |
| Question: {question} | |
| Instructions: | |
| - Use only the context. | |
| - Be accurate and helpful. | |
| - If unsure, say: "I don't have that information in my knowledge base." | |
| Answer:""" | |
| return RetrievalQA.from_chain_type( | |
| llm=self._init_llm(), | |
| chain_type="stuff", | |
| retriever=self.vector_db.as_retriever(search_kwargs={"k": Config.MAX_CONTEXT_CHUNKS}), | |
| chain_type_kwargs={ | |
| "prompt": PromptTemplate( | |
| template=prompt_template, | |
| input_variables=["context", "question"] | |
| ) | |
| }, | |
| return_source_documents=True | |
| ) | |
| def query(self, question: str) -> Dict[str, Any]: | |
| if not self.qa_chain: | |
| return { | |
| "answer": "Knowledge system not initialized. Please reload.", | |
| "processing_time": 0, | |
| "source_chunks": [] | |
| } | |
| try: | |
| start_time = datetime.now() | |
| docs = self._parallel_retrieve(question) | |
| if not docs: | |
| print("[i] No docs found with normal threshold, trying lower threshold...") | |
| retriever = self.vector_db.as_retriever( | |
| search_kwargs={ | |
| "k": Config.MAX_CONTEXT_CHUNKS, | |
| "score_threshold": 0.3 # Very low threshold for fallback | |
| } | |
| ) | |
| docs = retriever.invoke(question) | |
| result = self.qa_chain.invoke({"query": question, "input_documents": docs}) | |
| return { | |
| "answer": result.get("result", "No answer could be generated"), | |
| "processing_time": (datetime.now() - start_time).total_seconds() * 1000, | |
| "source_chunks": result.get("source_documents", []) | |
| } | |
| except Exception as e: | |
| print(f"[!] Query error: {str(e)}") | |
| return { | |
| "answer": "I encountered an error processing your query. Please try again.", | |
| "processing_time": 0, | |
| "source_chunks": [] | |
| } | |
| def get_knowledge_files_count(self) -> int: | |
| return len(list(Config.KNOWLEDGE_DIR.glob("**/*.txt"))) if Config.KNOWLEDGE_DIR.exists() else 0 | |
| def save_uploaded_file(self, uploaded_file, filename: str) -> bool: | |
| try: | |
| with open(Config.KNOWLEDGE_DIR / filename, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| return True | |
| except Exception as e: | |
| print(f"[!] File save error: {e}") | |
| return False |