Kalpokoch commited on
Commit
f55e2f6
·
verified ·
1 Parent(s): 1267728

Update app/policy_vector_db.py

Browse files
Files changed (1) hide show
  1. app/policy_vector_db.py +86 -9
app/policy_vector_db.py CHANGED
@@ -8,6 +8,9 @@ from sentence_transformers import SentenceTransformer
8
  import chromadb
9
  from chromadb.config import Settings
10
  import logging
 
 
 
11
 
12
  # --- Basic Logging Setup ---
13
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -16,20 +19,40 @@ logger = logging.getLogger(__name__)
16
  class PolicyVectorDB:
17
  """
18
  Enhanced vector database for policy documents with metadata-aware search capabilities.
 
19
  """
20
  def __init__(self, persist_directory: str, top_k_default: int = 5, relevance_threshold: float = 0.5):
21
  self.persist_directory = persist_directory
22
  self.client = chromadb.PersistentClient(path=persist_directory, settings=Settings(allow_reset=True))
23
  self.collection_name = "neepco_dop_policies"
24
 
 
 
 
 
 
25
  logger.info("Loading embedding model 'BAAI/bge-large-en-v1.5'. This may take a moment...")
26
- self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cpu')
 
 
 
 
 
 
 
 
 
 
 
27
  logger.info("Embedding model loaded successfully.")
28
 
29
  self.collection = None
30
  self.top_k_default = top_k_default
31
  self.relevance_threshold = relevance_threshold
32
 
 
 
 
33
  # Add monetary normalization for queries
34
  self.money_patterns = {
35
  r'(\d+(?:,\d+)*(?:\.\d+)?)\s*crore': lambda x: float(x.replace(',', '')) * 1e7,
@@ -97,8 +120,40 @@ class PolicyVectorDB:
97
 
98
  return entities
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def add_chunks(self, chunks: List[Dict]):
101
- """Enhanced chunk addition with better metadata handling."""
102
  collection = self._get_collection()
103
  if not chunks:
104
  logger.info("No chunks provided to add.")
@@ -119,7 +174,9 @@ class PolicyVectorDB:
119
 
120
  logger.info(f"Adding {len(new_chunks)} new chunks to the vector database...")
121
 
122
- batch_size = 32
 
 
123
  for i in range(0, len(new_chunks), batch_size):
124
  batch = new_chunks[i:i + batch_size]
125
 
@@ -127,7 +184,8 @@ class PolicyVectorDB:
127
  texts = [chunk['text'] for chunk in batch]
128
  metadatas = [self._flatten_metadata(chunk.get('metadata', {})) for chunk in batch]
129
 
130
- embeddings = self.embedding_model.encode(texts, normalize_embeddings=True, show_progress_bar=False).tolist()
 
131
 
132
  collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
133
  logger.info(f"Added batch {i//batch_size + 1}/{(len(new_chunks) + batch_size - 1) // batch_size}")
@@ -137,6 +195,7 @@ class PolicyVectorDB:
137
  def search(self, query_text: str, top_k: int = None, filters: Dict = None) -> List[Dict]:
138
  """
139
  Enhanced search with metadata filtering and entity extraction.
 
140
  """
141
  collection = self._get_collection()
142
 
@@ -158,7 +217,15 @@ class PolicyVectorDB:
158
  where_conditions["section"] = {"$in": [s.split()[-1] for s in entities['sections']]}
159
 
160
  instructed_query = f"Represent this sentence for searching relevant passages: {query_text}"
161
- query_embedding = self.embedding_model.encode([instructed_query], normalize_embeddings=True).tolist()
 
 
 
 
 
 
 
 
162
 
163
  k = top_k if top_k is not None else self.top_k_default
164
 
@@ -176,14 +243,13 @@ class PolicyVectorDB:
176
 
177
  search_results = []
178
  if results and results.get('documents') and results['documents'][0]:
179
- for i, doc in enumerate(results['documents'][0]): # Fixed: iterate over results['documents'][0]
180
- # Fixed: Access distances correctly as results['distances'][0][i]
181
  relevance_score = 1 - results['distances'][0][i]
182
 
183
  if relevance_score >= self.relevance_threshold:
184
  result = {
185
  'text': doc,
186
- 'metadata': results['metadatas'][0][i], # Fixed: Access metadata correctly
187
  'relevance_score': relevance_score
188
  }
189
 
@@ -273,6 +339,11 @@ class PolicyVectorDB:
273
  logger.warning(f"Error in search_by_amount: {e}")
274
  return []
275
 
 
 
 
 
 
276
  def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str) -> bool:
277
  """Checks if the DB is empty and populates it from a JSONL file if needed."""
278
  try:
@@ -297,7 +368,13 @@ def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str) -> b
297
  logger.warning(f"Chunks file at '{chunks_file_path}' is empty or invalid. No data to add.")
298
  return False
299
 
300
- db_instance.add_chunks(chunks_to_add)
 
 
 
 
 
 
301
  logger.info("Vector database population attempt complete.")
302
  return True
303
  except Exception as e:
 
8
  import chromadb
9
  from chromadb.config import Settings
10
  import logging
11
+ import multiprocessing as mp
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ import numpy as np
14
 
15
  # --- Basic Logging Setup ---
16
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
19
  class PolicyVectorDB:
20
  """
21
  Enhanced vector database for policy documents with metadata-aware search capabilities.
22
+ Optimized for CPU utilization.
23
  """
24
  def __init__(self, persist_directory: str, top_k_default: int = 5, relevance_threshold: float = 0.5):
25
  self.persist_directory = persist_directory
26
  self.client = chromadb.PersistentClient(path=persist_directory, settings=Settings(allow_reset=True))
27
  self.collection_name = "neepco_dop_policies"
28
 
29
+ # Optimize CPU usage
30
+ self.cpu_count = mp.cpu_count()
31
+ torch.set_num_threads(self.cpu_count)
32
+
33
+ logger.info(f"Detected {self.cpu_count} CPU cores, optimizing threading...")
34
  logger.info("Loading embedding model 'BAAI/bge-large-en-v1.5'. This may take a moment...")
35
+
36
+ # Optimize model loading for CPU
37
+ self.embedding_model = SentenceTransformer(
38
+ 'BAAI/bge-large-en-v1.5',
39
+ device='cpu',
40
+ # Use all available CPU cores for inference
41
+ model_kwargs={'torch_dtype': torch.float32}
42
+ )
43
+
44
+ # Set model to use optimized CPU inference
45
+ self.embedding_model.max_seq_length = 512 # Reduce context length for speed
46
+
47
  logger.info("Embedding model loaded successfully.")
48
 
49
  self.collection = None
50
  self.top_k_default = top_k_default
51
  self.relevance_threshold = relevance_threshold
52
 
53
+ # Thread pool for parallel processing
54
+ self.thread_pool = ThreadPoolExecutor(max_workers=self.cpu_count)
55
+
56
  # Add monetary normalization for queries
57
  self.money_patterns = {
58
  r'(\d+(?:,\d+)*(?:\.\d+)?)\s*crore': lambda x: float(x.replace(',', '')) * 1e7,
 
120
 
121
  return entities
122
 
123
+ def _encode_batch_parallel(self, texts: List[str]) -> np.ndarray:
124
+ """Parallel encoding of text batches for better CPU utilization."""
125
+ # Split texts into smaller batches for parallel processing
126
+ batch_size = max(1, len(texts) // self.cpu_count)
127
+ if len(texts) <= batch_size:
128
+ return self.embedding_model.encode(
129
+ texts,
130
+ normalize_embeddings=True,
131
+ show_progress_bar=False,
132
+ batch_size=32, # Optimize batch size for CPU
133
+ convert_to_numpy=True
134
+ )
135
+
136
+ # Process in parallel batches
137
+ batches = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
138
+
139
+ def encode_batch(batch):
140
+ return self.embedding_model.encode(
141
+ batch,
142
+ normalize_embeddings=True,
143
+ show_progress_bar=False,
144
+ batch_size=16,
145
+ convert_to_numpy=True
146
+ )
147
+
148
+ # Use thread pool for parallel encoding
149
+ futures = [self.thread_pool.submit(encode_batch, batch) for batch in batches]
150
+ results = [future.result() for future in futures]
151
+
152
+ # Concatenate results
153
+ return np.vstack(results) if results else np.array([])
154
+
155
  def add_chunks(self, chunks: List[Dict]):
156
+ """Enhanced chunk addition with better metadata handling and parallel processing."""
157
  collection = self._get_collection()
158
  if not chunks:
159
  logger.info("No chunks provided to add.")
 
174
 
175
  logger.info(f"Adding {len(new_chunks)} new chunks to the vector database...")
176
 
177
+ # Optimized batch size for CPU processing
178
+ batch_size = min(64, max(16, len(new_chunks) // 4))
179
+
180
  for i in range(0, len(new_chunks), batch_size):
181
  batch = new_chunks[i:i + batch_size]
182
 
 
184
  texts = [chunk['text'] for chunk in batch]
185
  metadatas = [self._flatten_metadata(chunk.get('metadata', {})) for chunk in batch]
186
 
187
+ # Use parallel encoding
188
+ embeddings = self._encode_batch_parallel(texts).tolist()
189
 
190
  collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
191
  logger.info(f"Added batch {i//batch_size + 1}/{(len(new_chunks) + batch_size - 1) // batch_size}")
 
195
  def search(self, query_text: str, top_k: int = None, filters: Dict = None) -> List[Dict]:
196
  """
197
  Enhanced search with metadata filtering and entity extraction.
198
+ Optimized for CPU performance.
199
  """
200
  collection = self._get_collection()
201
 
 
217
  where_conditions["section"] = {"$in": [s.split()[-1] for s in entities['sections']]}
218
 
219
  instructed_query = f"Represent this sentence for searching relevant passages: {query_text}"
220
+
221
+ # Optimized single query encoding
222
+ query_embedding = self.embedding_model.encode(
223
+ [instructed_query],
224
+ normalize_embeddings=True,
225
+ show_progress_bar=False,
226
+ batch_size=1,
227
+ convert_to_numpy=True
228
+ ).tolist()
229
 
230
  k = top_k if top_k is not None else self.top_k_default
231
 
 
243
 
244
  search_results = []
245
  if results and results.get('documents') and results['documents'][0]:
246
+ for i, doc in enumerate(results['documents'][0]):
 
247
  relevance_score = 1 - results['distances'][0][i]
248
 
249
  if relevance_score >= self.relevance_threshold:
250
  result = {
251
  'text': doc,
252
+ 'metadata': results['metadatas'][0][i],
253
  'relevance_score': relevance_score
254
  }
255
 
 
339
  logger.warning(f"Error in search_by_amount: {e}")
340
  return []
341
 
342
+ def __del__(self):
343
+ """Cleanup thread pool on deletion."""
344
+ if hasattr(self, 'thread_pool'):
345
+ self.thread_pool.shutdown(wait=False)
346
+
347
  def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str) -> bool:
348
  """Checks if the DB is empty and populates it from a JSONL file if needed."""
349
  try:
 
368
  logger.warning(f"Chunks file at '{chunks_file_path}' is empty or invalid. No data to add.")
369
  return False
370
 
371
+ # Process in batches to avoid memory issues
372
+ batch_size = 500
373
+ for i in range(0, len(chunks_to_add), batch_size):
374
+ batch = chunks_to_add[i:i + batch_size]
375
+ db_instance.add_chunks(batch)
376
+ logger.info(f"Processed batch {i//batch_size + 1}/{(len(chunks_to_add) + batch_size - 1) // batch_size}")
377
+
378
  logger.info("Vector database population attempt complete.")
379
  return True
380
  except Exception as e: