PhishingTest / rag_engine.py
dungeon29's picture
Update rag_engine.py
6c05eaf verified
raw
history blame
5.65 kB
import os
import glob
from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader, JSONLoader
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient
from qdrant_client.http import models
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
class RAGEngine:
def __init__(self, knowledge_base_dir="./knowledge_base", persist_directory="./qdrant_db"):
self.knowledge_base_dir = knowledge_base_dir
self.persist_directory = persist_directory
self.collection_name = "phishing_knowledge"
# Initialize Embeddings (using same model as before)
self.embedding_fn = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
# Initialize Qdrant Client (Local mode)
self.client = QdrantClient(path=self.persist_directory)
# Initialize Vector Store wrapper
self.vector_store = Qdrant(
client=self.client,
collection_name=self.collection_name,
embeddings=self.embedding_fn
)
# Check if collection exists and has data
try:
count = self.client.count(collection_name=self.collection_name).count
if count == 0:
self._build_index()
except:
# Collection might not exist yet
self._build_index()
def _build_index(self):
"""Load documents and build index"""
print("🔄 Building Knowledge Base Index (Qdrant)...")
documents = self._load_documents()
if not documents:
print("⚠️ No documents found to index.")
return
# Split documents
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
separators=["\n\n", "\n", " ", ""]
)
chunks = text_splitter.split_documents(documents)
if chunks:
# Re-create collection to ensure clean slate or add to it
# For simplicity in local build, we use Qdrant.from_documents which creates/replaces
self.vector_store = Qdrant.from_documents(
chunks,
self.embedding_fn,
path=self.persist_directory,
collection_name=self.collection_name,
force_recreate=True
)
# Update the client reference after recreation
self.client = self.vector_store.client
print(f"✅ Indexed {len(chunks)} chunks from {len(documents)} documents.")
else:
print("⚠️ No chunks created.")
def _load_documents(self):
"""Load documents from directory or fallback file"""
documents = []
# Check for directory or fallback file
target_path = self.knowledge_base_dir
if not os.path.exists(target_path):
if os.path.exists("knowledge_base.txt"):
target_path = "knowledge_base.txt"
print("⚠️ Using fallback 'knowledge_base.txt' in root.")
else:
print(f"❌ Knowledge base not found at {target_path}")
return []
try:
if os.path.isfile(target_path):
# Load single file
if target_path.endswith(".pdf"):
loader = PyPDFLoader(target_path)
else:
loader = TextLoader(target_path, encoding="utf-8")
documents.extend(loader.load())
else:
# Load directory
loaders = [
DirectoryLoader(target_path, glob="**/*.txt", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}),
DirectoryLoader(target_path, glob="**/*.md", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}),
DirectoryLoader(target_path, glob="**/*.pdf", loader_cls=PyPDFLoader),
]
for loader in loaders:
try:
docs = loader.load()
documents.extend(docs)
except Exception as e:
print(f"⚠️ Error loading with {loader}: {e}")
except Exception as e:
print(f"❌ Error loading documents: {e}")
return documents
def refresh_knowledge_base(self):
"""Force rebuild of the index"""
print("♻️ Refreshing Knowledge Base...")
# In Qdrant local, we can just rebuild with force_recreate=True which is handled in _build_index
self._build_index()
return "✅ Knowledge Base Refreshed!"
def retrieve(self, query, n_results=3, use_mmr=True):
"""
Retrieve relevant context
Args:
query: Câu truy vấn
n_results: Số lượng kết quả trả về
use_mmr: Sử dụng MMR (True) hay Similarity Search thường (False)
"""
if use_mmr:
results = self.vector_store.max_marginal_relevance_search(
query,
k=n_results,
fetch_k=n_results*3,
lambda_mult=0.6
)
else:
# Standard Similarity Search
results = self.vector_store.similarity_search(query, k=n_results)
# Format results
if results:
return [doc.page_content for doc in results]
return []