PhishingTest / rag_engine.py
dungeon29's picture
Update rag_engine.py
ad173f1 verified
raw
history blame
9.82 kB
import os
import glob
from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader
from langchain_qdrant import Qdrant
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from qdrant_client import QdrantClient, models
from datasets import load_dataset
class RAGEngine:
def __init__(self, knowledge_base_dir="./knowledge_base"):
self.knowledge_base_dir = knowledge_base_dir
# Initialize Embeddings
self.embedding_fn = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
# Qdrant Cloud Configuration
# Prioritize Env Vars, fallback to Hardcoded (User provided)
self.qdrant_url = os.environ.get("QDRANT_URL") or "https://abd29675-7fb9-4d95-8941-e6130b09bf7f.us-east4-0.gcp.cloud.qdrant.io"
self.qdrant_api_key = os.environ.get("QDRANT_API_KEY") or "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.L0aAAAbxRypLfBeGCtFr2xX06iveGb76NrA3BPJQiNM"
self.collection_name = "phishing_knowledge"
if not self.qdrant_url or not self.qdrant_api_key:
print("⚠️ QDRANT_URL or QDRANT_API_KEY not set. RAG will not function correctly.")
self.vector_store = None
return
print(f"☁️ Connecting to Qdrant Cloud: {self.qdrant_url}...")
# Initialize Qdrant Client
self.client = QdrantClient(
url=self.qdrant_url,
api_key=self.qdrant_api_key
)
# Initialize Vector Store Wrapper
self.vector_store = Qdrant(
client=self.client,
collection_name=self.collection_name,
embeddings=self.embedding_fn
)
# Check if collection exists/is empty and build if needed
try:
if not self.client.collection_exists(self.collection_name):
print(f"⚠️ Collection '{self.collection_name}' not found. Creating...")
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(size=384, distance=models.Distance.COSINE)
)
print(f"✅ Collection '{self.collection_name}' created!")
self._build_index()
else:
# Check if dataset is already indexed
dataset_filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.source",
match=models.MatchValue(value="hf_dataset")
)
]
)
dataset_count = self.client.count(
collection_name=self.collection_name,
count_filter=dataset_filter
).count
print(f"✅ Qdrant Collection '{self.collection_name}' ready with {count} vectors.")
if dataset_count == 0:
print("⚠️ Phishing dataset not found in collection. Loading...")
self.load_from_huggingface()
except Exception as e:
print(f"⚠️ Collection check/creation failed: {e}")
# Try to build anyway, maybe wrapper handles it
self._build_index()
self.load_from_huggingface()
def _build_index(self):
"""Load documents and build index"""
print("🔄 Building Knowledge Base Index on Qdrant Cloud...")
documents = self._load_documents()
if not documents:
print("⚠️ No documents found to index.")
return
# Split documents
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
separators=["\n\n", "\n", " ", ""]
)
chunks = text_splitter.split_documents(documents)
if chunks:
# Add to vector store (Qdrant handles persistence automatically)
try:
self.vector_store.add_documents(chunks)
print(f"✅ Indexed {len(chunks)} chunks to Qdrant Cloud.")
except Exception as e:
print(f"❌ Error indexing to Qdrant: {e}")
else:
print("⚠️ No chunks created.")
def _load_documents(self):
"""Load documents from directory or fallback file"""
documents = []
# Check for directory or fallback file
target_path = self.knowledge_base_dir
if not os.path.exists(target_path):
if os.path.exists("knowledge_base.txt"):
target_path = "knowledge_base.txt"
print("⚠️ Using fallback 'knowledge_base.txt' in root.")
else:
print(f"❌ Knowledge base not found at {target_path}")
return []
try:
if os.path.isfile(target_path):
# Load single file
if target_path.endswith(".pdf"):
loader = PyPDFLoader(target_path)
else:
loader = TextLoader(target_path, encoding="utf-8")
documents.extend(loader.load())
else:
# Load directory
loaders = [
DirectoryLoader(target_path, glob="**/*.txt", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}),
DirectoryLoader(target_path, glob="**/*.md", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}),
DirectoryLoader(target_path, glob="**/*.pdf", loader_cls=PyPDFLoader),
]
for loader in loaders:
try:
docs = loader.load()
documents.extend(docs)
except Exception as e:
print(f"⚠️ Error loading with {loader}: {e}")
except Exception as e:
print(f"❌ Error loading documents: {e}")
return documents
def load_from_huggingface(self):
"""Load and index dataset manually from Hugging Face JSON"""
dataset_url = "https://huggingface.co/datasets/ealvaradob/phishing-dataset/resolve/main/combined_reduced.json"
print(f"📥 Downloading dataset from {dataset_url}...")
try:
import requests
import json
response = requests.get(dataset_url)
if response.status_code != 200:
print(f"❌ Failed to download dataset: {response.status_code}")
return
data = response.json()
print(f"✅ Dataset downloaded. Processing {len(data)} rows...")
documents = []
for row in data:
# Structure: text, label
content = row.get('text', '')
label = row.get('label', -1)
if content:
doc = Document(
page_content=content,
metadata={"source": "hf_dataset", "label": label}
)
documents.append(doc)
if documents:
# Batch add to vector store
print(f"🔄 Indexing {len(documents)} documents to Qdrant...")
# Use a larger chunk size for efficiency since these are likely short texts
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100
)
chunks = text_splitter.split_documents(documents)
# Add in batches to avoid hitting API limits or timeouts
batch_size = 100
total_chunks = len(chunks)
for i in range(0, total_chunks, batch_size):
batch = chunks[i:i+batch_size]
try:
self.vector_store.add_documents(batch)
print(f" - Indexed batch {i//batch_size + 1}/{(total_chunks + batch_size - 1)//batch_size}")
except Exception as e:
print(f" ⚠️ Error indexing batch {i}: {e}")
print(f"✅ Successfully indexed {total_chunks} chunks from dataset!")
else:
print("⚠️ No valid documents found in dataset.")
except Exception as e:
print(f"❌ Error loading HF dataset: {e}")
def refresh_knowledge_base(self):
"""Force rebuild of the index"""
print("♻️ Refreshing Knowledge Base...")
if self.client:
try:
self.client.delete_collection(self.collection_name)
self._build_index()
self.load_from_huggingface()
return "✅ Knowledge Base Refreshed on Cloud!"
except Exception as e:
return f"❌ Error refreshing: {e}"
return "❌ Qdrant Client not initialized."
def retrieve(self, query, n_results=3):
"""Retrieve relevant context"""
if not self.vector_store:
return []
# Search
try:
results = self.vector_store.similarity_search(query, k=n_results)
if results:
return [doc.page_content for doc in results]
except Exception as e:
print(f"⚠️ Retrieval Error: {e}")
return []