Spaces:
Running
Running
Update app/policy_vector_db.py
Browse files- app/policy_vector_db.py +16 -5
app/policy_vector_db.py
CHANGED
|
@@ -23,7 +23,9 @@ class PolicyVectorDB:
|
|
| 23 |
|
| 24 |
# Using a powerful open-source embedding model.
|
| 25 |
# Change 'cpu' to 'cuda' if a GPU is available for significantly faster embedding.
|
|
|
|
| 26 |
self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cpu')
|
|
|
|
| 27 |
|
| 28 |
self.collection = None # Initialize collection as None for lazy loading
|
| 29 |
self.top_k_default = top_k_default
|
|
@@ -69,7 +71,7 @@ class PolicyVectorDB:
|
|
| 69 |
logger.info(f"Adding {len(new_chunks)} new chunks to the vector database...")
|
| 70 |
|
| 71 |
# Process in batches for efficiency
|
| 72 |
-
batch_size =
|
| 73 |
for i in range(0, len(new_chunks), batch_size):
|
| 74 |
batch = new_chunks[i:i + batch_size]
|
| 75 |
|
|
@@ -77,7 +79,8 @@ class PolicyVectorDB:
|
|
| 77 |
texts = [chunk['text'] for chunk in batch]
|
| 78 |
metadatas = [self._flatten_metadata(chunk.get('metadata', {})) for chunk in batch]
|
| 79 |
|
| 80 |
-
|
|
|
|
| 81 |
|
| 82 |
collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
|
| 83 |
logger.info(f"Added batch {i//batch_size + 1}/{(len(new_chunks) + batch_size - 1) // batch_size}")
|
|
@@ -90,19 +93,26 @@ class PolicyVectorDB:
|
|
| 90 |
Returns a list of results filtered by a relevance threshold.
|
| 91 |
"""
|
| 92 |
collection = self._get_collection()
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
k = top_k if top_k is not None else self.top_k_default
|
| 95 |
|
| 96 |
# Retrieve more results initially to allow for filtering
|
| 97 |
results = collection.query(
|
| 98 |
query_embeddings=query_embedding,
|
| 99 |
-
n_results=k * 2,
|
| 100 |
include=["documents", "metadatas", "distances"]
|
| 101 |
)
|
| 102 |
|
| 103 |
search_results = []
|
| 104 |
if results and results.get('documents') and results['documents'][0]:
|
| 105 |
for i, doc in enumerate(results['documents'][0]):
|
|
|
|
| 106 |
relevance_score = 1 - results['distances'][0][i]
|
| 107 |
|
| 108 |
if relevance_score >= self.relevance_threshold:
|
|
@@ -112,6 +122,7 @@ class PolicyVectorDB:
|
|
| 112 |
'relevance_score': relevance_score
|
| 113 |
})
|
| 114 |
|
|
|
|
| 115 |
return sorted(search_results, key=lambda x: x['relevance_score'], reverse=True)[:k]
|
| 116 |
|
| 117 |
def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str) -> bool:
|
|
@@ -145,4 +156,4 @@ def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str) -> b
|
|
| 145 |
return True
|
| 146 |
except Exception as e:
|
| 147 |
logger.error(f"An error occurred during DB population check: {e}", exc_info=True)
|
| 148 |
-
return False
|
|
|
|
| 23 |
|
| 24 |
# Using a powerful open-source embedding model.
|
| 25 |
# Change 'cpu' to 'cuda' if a GPU is available for significantly faster embedding.
|
| 26 |
+
logger.info("Loading embedding model 'BAAI/bge-large-en-v1.5'. This may take a moment...")
|
| 27 |
self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cpu')
|
| 28 |
+
logger.info("Embedding model loaded successfully.")
|
| 29 |
|
| 30 |
self.collection = None # Initialize collection as None for lazy loading
|
| 31 |
self.top_k_default = top_k_default
|
|
|
|
| 71 |
logger.info(f"Adding {len(new_chunks)} new chunks to the vector database...")
|
| 72 |
|
| 73 |
# Process in batches for efficiency
|
| 74 |
+
batch_size = 32 # Reduced batch size for potentially large embeddings
|
| 75 |
for i in range(0, len(new_chunks), batch_size):
|
| 76 |
batch = new_chunks[i:i + batch_size]
|
| 77 |
|
|
|
|
| 79 |
texts = [chunk['text'] for chunk in batch]
|
| 80 |
metadatas = [self._flatten_metadata(chunk.get('metadata', {})) for chunk in batch]
|
| 81 |
|
| 82 |
+
# For BGE models, it's recommended not to add instructions to the document embeddings
|
| 83 |
+
embeddings = self.embedding_model.encode(texts, normalize_embeddings=True, show_progress_bar=False).tolist()
|
| 84 |
|
| 85 |
collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
|
| 86 |
logger.info(f"Added batch {i//batch_size + 1}/{(len(new_chunks) + batch_size - 1) // batch_size}")
|
|
|
|
| 93 |
Returns a list of results filtered by a relevance threshold.
|
| 94 |
"""
|
| 95 |
collection = self._get_collection()
|
| 96 |
+
|
| 97 |
+
# ✅ IMPROVEMENT: Add the recommended instruction prefix for BGE retrieval models.
|
| 98 |
+
instructed_query = f"Represent this sentence for searching relevant passages: {query_text}"
|
| 99 |
+
|
| 100 |
+
# ✅ IMPROVEMENT: Normalize embeddings for more accurate similarity search.
|
| 101 |
+
query_embedding = self.embedding_model.encode([instructed_query], normalize_embeddings=True).tolist()
|
| 102 |
+
|
| 103 |
k = top_k if top_k is not None else self.top_k_default
|
| 104 |
|
| 105 |
# Retrieve more results initially to allow for filtering
|
| 106 |
results = collection.query(
|
| 107 |
query_embeddings=query_embedding,
|
| 108 |
+
n_results=k * 2, # Retrieve more to filter by threshold
|
| 109 |
include=["documents", "metadatas", "distances"]
|
| 110 |
)
|
| 111 |
|
| 112 |
search_results = []
|
| 113 |
if results and results.get('documents') and results['documents'][0]:
|
| 114 |
for i, doc in enumerate(results['documents'][0]):
|
| 115 |
+
# The distance for normalized embeddings is often interpreted as 1 - cosine_similarity
|
| 116 |
relevance_score = 1 - results['distances'][0][i]
|
| 117 |
|
| 118 |
if relevance_score >= self.relevance_threshold:
|
|
|
|
| 122 |
'relevance_score': relevance_score
|
| 123 |
})
|
| 124 |
|
| 125 |
+
# Sort by relevance score and return the top_k results
|
| 126 |
return sorted(search_results, key=lambda x: x['relevance_score'], reverse=True)[:k]
|
| 127 |
|
| 128 |
def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str) -> bool:
|
|
|
|
| 156 |
return True
|
| 157 |
except Exception as e:
|
| 158 |
logger.error(f"An error occurred during DB population check: {e}", exc_info=True)
|
| 159 |
+
return False
|