Spaces:
Paused
Paused
File size: 6,486 Bytes
fc74095 f961bd3 e96edf2 fc74095 e96edf2 fc74095 f961bd3 fc74095 f961bd3 fc74095 f961bd3 fc74095 f961bd3 e96edf2 f961bd3 e96edf2 f961bd3 e96edf2 fc74095 f961bd3 fc74095 f961bd3 fc74095 f961bd3 fc74095 9e66bad f961bd3 9e66bad f961bd3 9e66bad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import os
import glob
from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader
from langchain_qdrant import Qdrant
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from qdrant_client import QdrantClient, models
class RAGEngine:
def __init__(self, knowledge_base_dir="./knowledge_base"):
self.knowledge_base_dir = knowledge_base_dir
# Initialize Embeddings
self.embedding_fn = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
# Qdrant Cloud Configuration
# Prioritize Env Vars, fallback to Hardcoded (User provided)
self.qdrant_url = os.environ.get("QDRANT_URL") or "https://abd29675-7fb9-4d95-8941-e6130b09bf7f.us-east4-0.gcp.cloud.qdrant.io"
self.qdrant_api_key = os.environ.get("QDRANT_API_KEY") or "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.L0aAAAbxRypLfBeGCtFr2xX06iveGb76NrA3BPJQiNM"
self.collection_name = "phishing_knowledge"
if not self.qdrant_url or not self.qdrant_api_key:
print("⚠️ QDRANT_URL or QDRANT_API_KEY not set. RAG will not function correctly.")
self.vector_store = None
return
print(f"☁️ Connecting to Qdrant Cloud: {self.qdrant_url}...")
# Initialize Qdrant Client
self.client = QdrantClient(
url=self.qdrant_url,
api_key=self.qdrant_api_key
)
# Initialize Vector Store Wrapper
self.vector_store = Qdrant(
client=self.client,
collection_name=self.collection_name,
embeddings=self.embedding_fn
)
# Check if collection exists/is empty and build if needed
try:
if not self.client.collection_exists(self.collection_name):
print(f"⚠️ Collection '{self.collection_name}' not found. Creating...")
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(size=384, distance=models.Distance.COSINE)
)
print(f"✅ Collection '{self.collection_name}' created!")
self._build_index()
else:
count = self.client.count(collection_name=self.collection_name).count
if count == 0:
self._build_index()
else:
print(f"✅ Qdrant Collection '{self.collection_name}' ready with {count} vectors.")
except Exception as e:
print(f"⚠️ Collection check/creation failed: {e}")
# Try to build anyway, maybe wrapper handles it
self._build_index()
def _build_index(self):
"""Load documents and build index"""
print("🔄 Building Knowledge Base Index on Qdrant Cloud...")
documents = self._load_documents()
if not documents:
print("⚠️ No documents found to index.")
return
# Split documents
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
separators=["\n\n", "\n", " ", ""]
)
chunks = text_splitter.split_documents(documents)
if chunks:
# Add to vector store (Qdrant handles persistence automatically)
try:
self.vector_store.add_documents(chunks)
print(f"✅ Indexed {len(chunks)} chunks to Qdrant Cloud.")
except Exception as e:
print(f"❌ Error indexing to Qdrant: {e}")
else:
print("⚠️ No chunks created.")
def _load_documents(self):
"""Load documents from directory or fallback file"""
documents = []
# Check for directory or fallback file
target_path = self.knowledge_base_dir
if not os.path.exists(target_path):
if os.path.exists("knowledge_base.txt"):
target_path = "knowledge_base.txt"
print("⚠️ Using fallback 'knowledge_base.txt' in root.")
else:
print(f"❌ Knowledge base not found at {target_path}")
return []
try:
if os.path.isfile(target_path):
# Load single file
if target_path.endswith(".pdf"):
loader = PyPDFLoader(target_path)
else:
loader = TextLoader(target_path, encoding="utf-8")
documents.extend(loader.load())
else:
# Load directory
loaders = [
DirectoryLoader(target_path, glob="**/*.txt", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}),
DirectoryLoader(target_path, glob="**/*.md", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}),
DirectoryLoader(target_path, glob="**/*.pdf", loader_cls=PyPDFLoader),
]
for loader in loaders:
try:
docs = loader.load()
documents.extend(docs)
except Exception as e:
print(f"⚠️ Error loading with {loader}: {e}")
except Exception as e:
print(f"❌ Error loading documents: {e}")
return documents
def refresh_knowledge_base(self):
"""Force rebuild of the index"""
print("♻️ Refreshing Knowledge Base...")
if self.client:
try:
self.client.delete_collection(self.collection_name)
self._build_index()
return "✅ Knowledge Base Refreshed on Cloud!"
except Exception as e:
return f"❌ Error refreshing: {e}"
return "❌ Qdrant Client not initialized."
def retrieve(self, query, n_results=3):
"""Retrieve relevant context"""
if not self.vector_store:
return []
# Search
try:
results = self.vector_store.similarity_search(query, k=n_results)
if results:
return [doc.page_content for doc in results]
except Exception as e:
print(f"⚠️ Retrieval Error: {e}")
return []
|