Spaces:
Paused
Paused
| import os | |
| import glob | |
| from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader | |
| from langchain_qdrant import Qdrant | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_core.documents import Document | |
| from qdrant_client import QdrantClient, models | |
| from datasets import load_dataset | |
| class RAGEngine: | |
| def __init__(self, knowledge_base_dir="./knowledge_base"): | |
| self.knowledge_base_dir = knowledge_base_dir | |
| # Initialize Embeddings | |
| self.embedding_fn = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
| # Qdrant Cloud Configuration | |
| # Prioritize Env Vars, fallback to Hardcoded (User provided) | |
| self.qdrant_url = os.environ.get("QDRANT_URL") or "https://abd29675-7fb9-4d95-8941-e6130b09bf7f.us-east4-0.gcp.cloud.qdrant.io" | |
| self.qdrant_api_key = os.environ.get("QDRANT_API_KEY") or "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.L0aAAAbxRypLfBeGCtFr2xX06iveGb76NrA3BPJQiNM" | |
| self.collection_name = "phishing_knowledge" | |
| if not self.qdrant_url or not self.qdrant_api_key: | |
| print("⚠️ QDRANT_URL or QDRANT_API_KEY not set. RAG will not function correctly.") | |
| self.vector_store = None | |
| return | |
| print(f"☁️ Connecting to Qdrant Cloud: {self.qdrant_url}...") | |
| # Initialize Qdrant Client | |
| self.client = QdrantClient( | |
| url=self.qdrant_url, | |
| api_key=self.qdrant_api_key | |
| ) | |
| # Initialize Vector Store Wrapper | |
| self.vector_store = Qdrant( | |
| client=self.client, | |
| collection_name=self.collection_name, | |
| embeddings=self.embedding_fn | |
| ) | |
| # Check if collection exists/is empty and build if needed | |
| try: | |
| if not self.client.collection_exists(self.collection_name): | |
| print(f"⚠️ Collection '{self.collection_name}' not found. Creating...") | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=models.VectorParams(size=384, distance=models.Distance.COSINE) | |
| ) | |
| print(f"✅ Collection '{self.collection_name}' created!") | |
| self._build_index() | |
| else: | |
| # Check if dataset is already indexed | |
| dataset_filter = models.Filter( | |
| must=[ | |
| models.FieldCondition( | |
| key="metadata.source", | |
| match=models.MatchValue(value="hf_dataset") | |
| ) | |
| ] | |
| ) | |
| dataset_count = self.client.count( | |
| collection_name=self.collection_name, | |
| count_filter=dataset_filter | |
| ).count | |
| print(f"✅ Qdrant Collection '{self.collection_name}' ready with {count} vectors.") | |
| if dataset_count == 0: | |
| print("⚠️ Phishing dataset not found in collection. Loading...") | |
| self.load_from_huggingface() | |
| except Exception as e: | |
| print(f"⚠️ Collection check/creation failed: {e}") | |
| # Try to build anyway, maybe wrapper handles it | |
| self._build_index() | |
| self.load_from_huggingface() | |
| def _build_index(self): | |
| """Load documents and build index""" | |
| print("🔄 Building Knowledge Base Index on Qdrant Cloud...") | |
| documents = self._load_documents() | |
| if not documents: | |
| print("⚠️ No documents found to index.") | |
| return | |
| # Split documents | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, | |
| chunk_overlap=50, | |
| separators=["\n\n", "\n", " ", ""] | |
| ) | |
| chunks = text_splitter.split_documents(documents) | |
| if chunks: | |
| # Add to vector store (Qdrant handles persistence automatically) | |
| try: | |
| self.vector_store.add_documents(chunks) | |
| print(f"✅ Indexed {len(chunks)} chunks to Qdrant Cloud.") | |
| except Exception as e: | |
| print(f"❌ Error indexing to Qdrant: {e}") | |
| else: | |
| print("⚠️ No chunks created.") | |
| def _load_documents(self): | |
| """Load documents from directory or fallback file""" | |
| documents = [] | |
| # Check for directory or fallback file | |
| target_path = self.knowledge_base_dir | |
| if not os.path.exists(target_path): | |
| if os.path.exists("knowledge_base.txt"): | |
| target_path = "knowledge_base.txt" | |
| print("⚠️ Using fallback 'knowledge_base.txt' in root.") | |
| else: | |
| print(f"❌ Knowledge base not found at {target_path}") | |
| return [] | |
| try: | |
| if os.path.isfile(target_path): | |
| # Load single file | |
| if target_path.endswith(".pdf"): | |
| loader = PyPDFLoader(target_path) | |
| else: | |
| loader = TextLoader(target_path, encoding="utf-8") | |
| documents.extend(loader.load()) | |
| else: | |
| # Load directory | |
| loaders = [ | |
| DirectoryLoader(target_path, glob="**/*.txt", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}), | |
| DirectoryLoader(target_path, glob="**/*.md", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}), | |
| DirectoryLoader(target_path, glob="**/*.pdf", loader_cls=PyPDFLoader), | |
| ] | |
| for loader in loaders: | |
| try: | |
| docs = loader.load() | |
| documents.extend(docs) | |
| except Exception as e: | |
| print(f"⚠️ Error loading with {loader}: {e}") | |
| except Exception as e: | |
| print(f"❌ Error loading documents: {e}") | |
| return documents | |
| def load_from_huggingface(self): | |
| """Load and index dataset manually from Hugging Face JSON""" | |
| dataset_url = "https://huggingface.co/datasets/ealvaradob/phishing-dataset/resolve/main/combined_reduced.json" | |
| print(f"📥 Downloading dataset from {dataset_url}...") | |
| try: | |
| import requests | |
| import json | |
| response = requests.get(dataset_url) | |
| if response.status_code != 200: | |
| print(f"❌ Failed to download dataset: {response.status_code}") | |
| return | |
| data = response.json() | |
| print(f"✅ Dataset downloaded. Processing {len(data)} rows...") | |
| documents = [] | |
| for row in data: | |
| # Structure: text, label | |
| content = row.get('text', '') | |
| label = row.get('label', -1) | |
| if content: | |
| doc = Document( | |
| page_content=content, | |
| metadata={"source": "hf_dataset", "label": label} | |
| ) | |
| documents.append(doc) | |
| if documents: | |
| # Batch add to vector store | |
| print(f"🔄 Indexing {len(documents)} documents to Qdrant...") | |
| # Use a larger chunk size for efficiency since these are likely short texts | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=100 | |
| ) | |
| chunks = text_splitter.split_documents(documents) | |
| # Add in batches to avoid hitting API limits or timeouts | |
| batch_size = 100 | |
| total_chunks = len(chunks) | |
| for i in range(0, total_chunks, batch_size): | |
| batch = chunks[i:i+batch_size] | |
| try: | |
| self.vector_store.add_documents(batch) | |
| print(f" - Indexed batch {i//batch_size + 1}/{(total_chunks + batch_size - 1)//batch_size}") | |
| except Exception as e: | |
| print(f" ⚠️ Error indexing batch {i}: {e}") | |
| print(f"✅ Successfully indexed {total_chunks} chunks from dataset!") | |
| else: | |
| print("⚠️ No valid documents found in dataset.") | |
| except Exception as e: | |
| print(f"❌ Error loading HF dataset: {e}") | |
| def refresh_knowledge_base(self): | |
| """Force rebuild of the index""" | |
| print("♻️ Refreshing Knowledge Base...") | |
| if self.client: | |
| try: | |
| self.client.delete_collection(self.collection_name) | |
| self._build_index() | |
| self.load_from_huggingface() | |
| return "✅ Knowledge Base Refreshed on Cloud!" | |
| except Exception as e: | |
| return f"❌ Error refreshing: {e}" | |
| return "❌ Qdrant Client not initialized." | |
| def retrieve(self, query, n_results=3): | |
| """Retrieve relevant context""" | |
| if not self.vector_store: | |
| return [] | |
| # Search | |
| try: | |
| results = self.vector_store.similarity_search(query, k=n_results) | |
| if results: | |
| return [doc.page_content for doc in results] | |
| except Exception as e: | |
| print(f"⚠️ Retrieval Error: {e}") | |
| return [] | |