dungeon29 commited on
Commit
f961bd3
·
verified ·
1 Parent(s): cfcd4e2

Update rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +63 -33
rag_engine.py CHANGED
@@ -1,33 +1,58 @@
1
  import os
2
  import glob
3
- from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader, JSONLoader
4
- from langchain_community.vectorstores import Chroma
5
  from langchain_huggingface import HuggingFaceEmbeddings
6
  from langchain_text_splitters import RecursiveCharacterTextSplitter
7
- from langchain_core.documents import Document
8
 
9
  class RAGEngine:
10
- def __init__(self, knowledge_base_dir="./knowledge_base", persist_directory="./chroma_db"):
11
  self.knowledge_base_dir = knowledge_base_dir
12
- self.persist_directory = persist_directory
13
 
14
- # Initialize Embeddings (using same model as before)
15
  self.embedding_fn = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
16
 
17
- # Initialize Vector Store
18
- self.vector_store = Chroma(
19
- persist_directory=self.persist_directory,
20
- embedding_function=self.embedding_fn,
21
- collection_name="phishing_knowledge"
 
 
 
 
 
 
 
 
 
 
 
 
22
  )
23
 
24
- # Build index if empty or on init
25
- if not self.vector_store.get()['ids']:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  self._build_index()
27
 
28
  def _build_index(self):
29
  """Load documents and build index"""
30
- print("🔄 Building Knowledge Base Index...")
31
 
32
  documents = self._load_documents()
33
  if not documents:
@@ -43,10 +68,12 @@ class RAGEngine:
43
  chunks = text_splitter.split_documents(documents)
44
 
45
  if chunks:
46
- # Add to vector store
47
- self.vector_store.add_documents(chunks)
48
- self.vector_store.persist()
49
- print(f"✅ Indexed {len(chunks)} chunks from {len(documents)} documents.")
 
 
50
  else:
51
  print("⚠️ No chunks created.")
52
 
@@ -95,23 +122,26 @@ class RAGEngine:
95
  def refresh_knowledge_base(self):
96
  """Force rebuild of the index"""
97
  print("♻️ Refreshing Knowledge Base...")
98
- # Clear existing collection
99
- self.vector_store.delete_collection()
100
- self.vector_store = Chroma(
101
- persist_directory=self.persist_directory,
102
- embedding_function=self.embedding_fn,
103
- collection_name="phishing_knowledge"
104
- )
105
- # Rebuild
106
- self._build_index()
107
- return "✅ Knowledge Base Refreshed!"
108
 
109
  def retrieve(self, query, n_results=3):
110
  """Retrieve relevant context"""
 
 
 
111
  # Search
112
- results = self.vector_store.similarity_search(query, k=n_results)
113
-
114
- # Format results
115
- if results:
116
- return [doc.page_content for doc in results]
 
 
117
  return []
 
1
  import os
2
  import glob
3
+ from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader
4
+ from langchain_community.vectorstores import Qdrant
5
  from langchain_huggingface import HuggingFaceEmbeddings
6
  from langchain_text_splitters import RecursiveCharacterTextSplitter
7
+ from qdrant_client import QdrantClient
8
 
9
  class RAGEngine:
10
+ def __init__(self, knowledge_base_dir="./knowledge_base"):
11
  self.knowledge_base_dir = knowledge_base_dir
 
12
 
13
+ # Initialize Embeddings
14
  self.embedding_fn = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
15
 
16
+ # Qdrant Cloud Configuration
17
+ # Prioritize Env Vars, fallback to Hardcoded (User provided)
18
+ self.qdrant_url = os.environ.get("QDRANT_URL") or "https://abd29675-7fb9-4d95-8941-e6130b09bf7f.us-east4-0.gcp.cloud.qdrant.io"
19
+ self.qdrant_api_key = os.environ.get("QDRANT_API_KEY") or "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.L0aAAAbxRypLfBeGCtFr2xX06iveGb76NrA3BPJQiNM"
20
+ self.collection_name = "phishing_knowledge"
21
+
22
+ if not self.qdrant_url or not self.qdrant_api_key:
23
+ print("⚠️ QDRANT_URL or QDRANT_API_KEY not set. RAG will not function correctly.")
24
+ self.vector_store = None
25
+ return
26
+
27
+ print(f"☁️ Connecting to Qdrant Cloud: {self.qdrant_url}...")
28
+
29
+ # Initialize Qdrant Client
30
+ self.client = QdrantClient(
31
+ url=self.qdrant_url,
32
+ api_key=self.qdrant_api_key
33
  )
34
 
35
+ # Initialize Vector Store Wrapper
36
+ self.vector_store = Qdrant(
37
+ client=self.client,
38
+ collection_name=self.collection_name,
39
+ embeddings=self.embedding_fn
40
+ )
41
+
42
+ # Check if collection exists/is empty and build if needed
43
+ try:
44
+ count = self.client.count(collection_name=self.collection_name).count
45
+ if count == 0:
46
+ self._build_index()
47
+ else:
48
+ print(f"✅ Qdrant Collection '{self.collection_name}' ready with {count} vectors.")
49
+ except Exception as e:
50
+ print(f"⚠️ Collection check failed (might not exist): {e}")
51
  self._build_index()
52
 
53
  def _build_index(self):
54
  """Load documents and build index"""
55
+ print("🔄 Building Knowledge Base Index on Qdrant Cloud...")
56
 
57
  documents = self._load_documents()
58
  if not documents:
 
68
  chunks = text_splitter.split_documents(documents)
69
 
70
  if chunks:
71
+ # Add to vector store (Qdrant handles persistence automatically)
72
+ try:
73
+ self.vector_store.add_documents(chunks)
74
+ print(f"✅ Indexed {len(chunks)} chunks to Qdrant Cloud.")
75
+ except Exception as e:
76
+ print(f"❌ Error indexing to Qdrant: {e}")
77
  else:
78
  print("⚠️ No chunks created.")
79
 
 
122
  def refresh_knowledge_base(self):
123
  """Force rebuild of the index"""
124
  print("♻️ Refreshing Knowledge Base...")
125
+ if self.client:
126
+ try:
127
+ self.client.delete_collection(self.collection_name)
128
+ self._build_index()
129
+ return "✅ Knowledge Base Refreshed on Cloud!"
130
+ except Exception as e:
131
+ return f"❌ Error refreshing: {e}"
132
+ return "❌ Qdrant Client not initialized."
 
 
133
 
134
  def retrieve(self, query, n_results=3):
135
  """Retrieve relevant context"""
136
+ if not self.vector_store:
137
+ return []
138
+
139
  # Search
140
+ try:
141
+ results = self.vector_store.similarity_search(query, k=n_results)
142
+ if results:
143
+ return [doc.page_content for doc in results]
144
+ except Exception as e:
145
+ print(f"⚠️ Retrieval Error: {e}")
146
+
147
  return []