Spaces:
Paused
Paused
| import os | |
| import glob | |
| from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader, JSONLoader | |
| from langchain_community.vectorstores import Qdrant | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http import models | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_core.documents import Document | |
| class RAGEngine: | |
| def __init__(self, knowledge_base_dir="./knowledge_base", persist_directory="./qdrant_db"): | |
| self.knowledge_base_dir = knowledge_base_dir | |
| self.persist_directory = persist_directory | |
| self.collection_name = "phishing_knowledge" | |
| # Initialize Embeddings (using same model as before) | |
| self.embedding_fn = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
| # Initialize Qdrant Client (Local mode) | |
| self.client = QdrantClient(path=self.persist_directory) | |
| # Initialize Vector Store wrapper | |
| self.vector_store = Qdrant( | |
| client=self.client, | |
| collection_name=self.collection_name, | |
| embeddings=self.embedding_fn | |
| ) | |
| # Check if collection exists and has data | |
| try: | |
| count = self.client.count(collection_name=self.collection_name).count | |
| if count == 0: | |
| self._build_index() | |
| except: | |
| # Collection might not exist yet | |
| self._build_index() | |
| def _build_index(self): | |
| """Load documents and build index""" | |
| print("🔄 Building Knowledge Base Index (Qdrant)...") | |
| documents = self._load_documents() | |
| if not documents: | |
| print("⚠️ No documents found to index.") | |
| return | |
| # Split documents | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, | |
| chunk_overlap=50, | |
| separators=["\n\n", "\n", " ", ""] | |
| ) | |
| chunks = text_splitter.split_documents(documents) | |
| if chunks: | |
| # Re-create collection to ensure clean slate or add to it | |
| # For simplicity in local build, we use Qdrant.from_documents which creates/replaces | |
| self.vector_store = Qdrant.from_documents( | |
| chunks, | |
| self.embedding_fn, | |
| path=self.persist_directory, | |
| collection_name=self.collection_name, | |
| force_recreate=True | |
| ) | |
| # Update the client reference after recreation | |
| self.client = self.vector_store.client | |
| print(f"✅ Indexed {len(chunks)} chunks from {len(documents)} documents.") | |
| else: | |
| print("⚠️ No chunks created.") | |
| def _load_documents(self): | |
| """Load documents from directory or fallback file""" | |
| documents = [] | |
| # Check for directory or fallback file | |
| target_path = self.knowledge_base_dir | |
| if not os.path.exists(target_path): | |
| if os.path.exists("knowledge_base.txt"): | |
| target_path = "knowledge_base.txt" | |
| print("⚠️ Using fallback 'knowledge_base.txt' in root.") | |
| else: | |
| print(f"❌ Knowledge base not found at {target_path}") | |
| return [] | |
| try: | |
| if os.path.isfile(target_path): | |
| # Load single file | |
| if target_path.endswith(".pdf"): | |
| loader = PyPDFLoader(target_path) | |
| else: | |
| loader = TextLoader(target_path, encoding="utf-8") | |
| documents.extend(loader.load()) | |
| else: | |
| # Load directory | |
| loaders = [ | |
| DirectoryLoader(target_path, glob="**/*.txt", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}), | |
| DirectoryLoader(target_path, glob="**/*.md", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}), | |
| DirectoryLoader(target_path, glob="**/*.pdf", loader_cls=PyPDFLoader), | |
| ] | |
| for loader in loaders: | |
| try: | |
| docs = loader.load() | |
| documents.extend(docs) | |
| except Exception as e: | |
| print(f"⚠️ Error loading with {loader}: {e}") | |
| except Exception as e: | |
| print(f"❌ Error loading documents: {e}") | |
| return documents | |
| def refresh_knowledge_base(self): | |
| """Force rebuild of the index""" | |
| print("♻️ Refreshing Knowledge Base...") | |
| # In Qdrant local, we can just rebuild with force_recreate=True which is handled in _build_index | |
| self._build_index() | |
| return "✅ Knowledge Base Refreshed!" | |
| def retrieve(self, query, n_results=3, use_mmr=True): | |
| """ | |
| Retrieve relevant context | |
| Args: | |
| query: Câu truy vấn | |
| n_results: Số lượng kết quả trả về | |
| use_mmr: Sử dụng MMR (True) hay Similarity Search thường (False) | |
| """ | |
| if use_mmr: | |
| results = self.vector_store.max_marginal_relevance_search( | |
| query, | |
| k=n_results, | |
| fetch_k=n_results*3, | |
| lambda_mult=0.6 | |
| ) | |
| else: | |
| # Standard Similarity Search | |
| results = self.vector_store.similarity_search(query, k=n_results) | |
| # Format results | |
| if results: | |
| return [doc.page_content for doc in results] | |
| return [] |