Kalpokoch commited on
Commit
c5aeabe
·
verified ·
1 Parent(s): 67708c4

Update app/policy_vector_db.py

Browse files
Files changed (1) hide show
  1. app/policy_vector_db.py +16 -5
app/policy_vector_db.py CHANGED
@@ -23,7 +23,9 @@ class PolicyVectorDB:
23
 
24
  # Using a powerful open-source embedding model.
25
  # Change 'cpu' to 'cuda' if a GPU is available for significantly faster embedding.
 
26
  self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cpu')
 
27
 
28
  self.collection = None # Initialize collection as None for lazy loading
29
  self.top_k_default = top_k_default
@@ -69,7 +71,7 @@ class PolicyVectorDB:
69
  logger.info(f"Adding {len(new_chunks)} new chunks to the vector database...")
70
 
71
  # Process in batches for efficiency
72
- batch_size = 64
73
  for i in range(0, len(new_chunks), batch_size):
74
  batch = new_chunks[i:i + batch_size]
75
 
@@ -77,7 +79,8 @@ class PolicyVectorDB:
77
  texts = [chunk['text'] for chunk in batch]
78
  metadatas = [self._flatten_metadata(chunk.get('metadata', {})) for chunk in batch]
79
 
80
- embeddings = self.embedding_model.encode(texts, show_progress_bar=False).tolist()
 
81
 
82
  collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
83
  logger.info(f"Added batch {i//batch_size + 1}/{(len(new_chunks) + batch_size - 1) // batch_size}")
@@ -90,19 +93,26 @@ class PolicyVectorDB:
90
  Returns a list of results filtered by a relevance threshold.
91
  """
92
  collection = self._get_collection()
93
- query_embedding = self.embedding_model.encode([query_text]).tolist()
 
 
 
 
 
 
94
  k = top_k if top_k is not None else self.top_k_default
95
 
96
  # Retrieve more results initially to allow for filtering
97
  results = collection.query(
98
  query_embeddings=query_embedding,
99
- n_results=k * 2,
100
  include=["documents", "metadatas", "distances"]
101
  )
102
 
103
  search_results = []
104
  if results and results.get('documents') and results['documents'][0]:
105
  for i, doc in enumerate(results['documents'][0]):
 
106
  relevance_score = 1 - results['distances'][0][i]
107
 
108
  if relevance_score >= self.relevance_threshold:
@@ -112,6 +122,7 @@ class PolicyVectorDB:
112
  'relevance_score': relevance_score
113
  })
114
 
 
115
  return sorted(search_results, key=lambda x: x['relevance_score'], reverse=True)[:k]
116
 
117
  def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str) -> bool:
@@ -145,4 +156,4 @@ def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str) -> b
145
  return True
146
  except Exception as e:
147
  logger.error(f"An error occurred during DB population check: {e}", exc_info=True)
148
- return False
 
23
 
24
  # Using a powerful open-source embedding model.
25
  # Change 'cpu' to 'cuda' if a GPU is available for significantly faster embedding.
26
+ logger.info("Loading embedding model 'BAAI/bge-large-en-v1.5'. This may take a moment...")
27
  self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cpu')
28
+ logger.info("Embedding model loaded successfully.")
29
 
30
  self.collection = None # Initialize collection as None for lazy loading
31
  self.top_k_default = top_k_default
 
71
  logger.info(f"Adding {len(new_chunks)} new chunks to the vector database...")
72
 
73
  # Process in batches for efficiency
74
+ batch_size = 32 # Reduced batch size for potentially large embeddings
75
  for i in range(0, len(new_chunks), batch_size):
76
  batch = new_chunks[i:i + batch_size]
77
 
 
79
  texts = [chunk['text'] for chunk in batch]
80
  metadatas = [self._flatten_metadata(chunk.get('metadata', {})) for chunk in batch]
81
 
82
+ # For BGE models, it's recommended not to add instructions to the document embeddings
83
+ embeddings = self.embedding_model.encode(texts, normalize_embeddings=True, show_progress_bar=False).tolist()
84
 
85
  collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
86
  logger.info(f"Added batch {i//batch_size + 1}/{(len(new_chunks) + batch_size - 1) // batch_size}")
 
93
  Returns a list of results filtered by a relevance threshold.
94
  """
95
  collection = self._get_collection()
96
+
97
+ # ✅ IMPROVEMENT: Add the recommended instruction prefix for BGE retrieval models.
98
+ instructed_query = f"Represent this sentence for searching relevant passages: {query_text}"
99
+
100
+ # ✅ IMPROVEMENT: Normalize embeddings for more accurate similarity search.
101
+ query_embedding = self.embedding_model.encode([instructed_query], normalize_embeddings=True).tolist()
102
+
103
  k = top_k if top_k is not None else self.top_k_default
104
 
105
  # Retrieve more results initially to allow for filtering
106
  results = collection.query(
107
  query_embeddings=query_embedding,
108
+ n_results=k * 2, # Retrieve more to filter by threshold
109
  include=["documents", "metadatas", "distances"]
110
  )
111
 
112
  search_results = []
113
  if results and results.get('documents') and results['documents'][0]:
114
  for i, doc in enumerate(results['documents'][0]):
115
+ # The distance for normalized embeddings is often interpreted as 1 - cosine_similarity
116
  relevance_score = 1 - results['distances'][0][i]
117
 
118
  if relevance_score >= self.relevance_threshold:
 
122
  'relevance_score': relevance_score
123
  })
124
 
125
+ # Sort by relevance score and return the top_k results
126
  return sorted(search_results, key=lambda x: x['relevance_score'], reverse=True)[:k]
127
 
128
  def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str) -> bool:
 
156
  return True
157
  except Exception as e:
158
  logger.error(f"An error occurred during DB population check: {e}", exc_info=True)
159
+ return False