Kalpokoch commited on
Commit
04147ae
·
verified ·
1 Parent(s): 72e44d7

Update app/policy_vector_db.py

Browse files
Files changed (1) hide show
  1. app/policy_vector_db.py +67 -17
app/policy_vector_db.py CHANGED
@@ -5,16 +5,24 @@ from typing import List, Dict
5
  from sentence_transformers import SentenceTransformer
6
  import chromadb
7
  from chromadb.config import Settings
 
 
 
8
 
9
  class PolicyVectorDB:
10
  def __init__(self, persist_directory: str, top_k_default: int = 5, relevance_threshold: float = 0.5):
11
  self.persist_directory = persist_directory
 
12
  self.client = chromadb.PersistentClient(path=persist_directory, settings=Settings(allow_reset=True))
13
  self.collection_name = "neepco_dop_policies"
 
 
 
14
  self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cuda' if torch.cuda.is_available() else 'cpu')
 
15
  self.collection = None
16
  self.top_k_default = top_k_default
17
- self.relevance_threshold = relevance_threshold
18
 
19
  def _get_collection(self):
20
  if self.collection is None:
@@ -25,58 +33,100 @@ class PolicyVectorDB:
25
  return self.collection
26
 
27
  def _flatten_metadata(self, metadata: Dict) -> Dict:
 
 
28
  return {key: str(value) for key, value in metadata.items()}
29
 
30
  def add_chunks(self, chunks: List[Dict]):
31
  collection = self._get_collection()
32
  if not chunks:
33
- print("No chunks provided to add.")
34
  return
35
- existing_ids = set(collection.get()['ids'])
36
- new_chunks = [chunk for chunk in chunks if chunk.get('id') not in existing_ids]
 
 
 
 
 
 
 
 
 
37
  if not new_chunks:
38
- print("No new chunks to add.")
39
  return
40
- batch_size = 128
 
 
 
41
  for i in range(0, len(new_chunks), batch_size):
42
  batch = new_chunks[i:i + batch_size]
43
  texts = [chunk['text'] for chunk in batch]
44
  ids = [chunk['id'] for chunk in batch]
45
- metadatas = [self._flatten_metadata(chunk['metadata']) for chunk in batch]
 
 
46
  embeddings = self.embedding_model.encode(texts, show_progress_bar=False).tolist()
 
47
  collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
 
 
48
 
49
  def search(self, query_text: str, top_k: int = None) -> List[Dict]:
50
  collection = self._get_collection()
 
 
51
  query_embedding = self.embedding_model.encode([query_text]).tolist()
 
52
  top_k = top_k if top_k else self.top_k_default
53
  results = collection.query(
54
  query_embeddings=query_embedding,
55
  n_results=top_k,
56
  include=["documents", "metadatas", "distances"]
57
  )
 
58
  search_results = []
59
- for i, doc in enumerate(results['documents'][0]):
60
- relevance_score = 1 - results['distances'][0][i]
61
- search_results.append({
62
- 'text': doc,
63
- 'metadata': results['metadatas'][0][i],
64
- 'relevance_score': relevance_score
65
- })
 
 
 
 
 
 
66
  return search_results
67
 
68
  def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str):
 
 
 
69
  try:
 
70
  if db_instance._get_collection().count() == 0:
 
71
  if not os.path.exists(chunks_file_path):
72
- print(f"Chunks file not found at {chunks_file_path}")
73
  return False
 
74
  with open(chunks_file_path, 'r', encoding='utf-8') as f:
75
  chunks_to_add = json.load(f)
 
 
 
 
 
76
  db_instance.add_chunks(chunks_to_add)
 
77
  return True
78
  else:
 
79
  return True
80
  except Exception as e:
81
- print(f"DB Population Error: {e}")
82
- return False
 
5
  from sentence_transformers import SentenceTransformer
6
  import chromadb
7
  from chromadb.config import Settings
8
+ import logging
9
+
10
+ logger = logging.getLogger("app") # Use the same logger as app.py
11
 
12
  class PolicyVectorDB:
13
  def __init__(self, persist_directory: str, top_k_default: int = 5, relevance_threshold: float = 0.5):
14
  self.persist_directory = persist_directory
15
+ # Allow reset is useful for development, can be removed in production if not needed
16
  self.client = chromadb.PersistentClient(path=persist_directory, settings=Settings(allow_reset=True))
17
  self.collection_name = "neepco_dop_policies"
18
+
19
+ # IMPORTANT: Keeping BAAI/bge-large-en-v1.5 as per your requirement for accuracy.
20
+ # Be aware this will be slow on CPU.
21
  self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cuda' if torch.cuda.is_available() else 'cpu')
22
+
23
  self.collection = None
24
  self.top_k_default = top_k_default
25
+ self.relevance_threshold = relevance_threshold # Note: set to 0 in app.py
26
 
27
  def _get_collection(self):
28
  if self.collection is None:
 
33
  return self.collection
34
 
35
  def _flatten_metadata(self, metadata: Dict) -> Dict:
36
+ # ChromaDB requires metadata values to be JSON-serializable strings, ints, floats, or bools.
37
+ # Ensuring all values are string for consistency and compatibility.
38
  return {key: str(value) for key, value in metadata.items()}
39
 
40
  def add_chunks(self, chunks: List[Dict]):
41
  collection = self._get_collection()
42
  if not chunks:
43
+ logger.info("No chunks provided to add.")
44
  return
45
+
46
+ # Fetch existing IDs to avoid adding duplicates
47
+ existing_ids = set()
48
+ try:
49
+ # This can be slow for very large collections, consider optimizing if needed
50
+ existing_ids = set(collection.get(include=[])['ids'])
51
+ except Exception as e:
52
+ logger.warning(f"Could not retrieve existing IDs from ChromaDB: {e}. Assuming no existing IDs for now.")
53
+
54
+ new_chunks = [chunk for chunk in chunks if chunk.get('id') and chunk['id'] not in existing_ids]
55
+
56
  if not new_chunks:
57
+ logger.info("No new chunks to add.")
58
  return
59
+
60
+ logger.info(f"Adding {len(new_chunks)} new chunks to the vector database...")
61
+
62
+ batch_size = 128 # Good batch size for embedding
63
  for i in range(0, len(new_chunks), batch_size):
64
  batch = new_chunks[i:i + batch_size]
65
  texts = [chunk['text'] for chunk in batch]
66
  ids = [chunk['id'] for chunk in batch]
67
+ metadatas = [self._flatten_metadata(chunk['metadata']) if chunk.get('metadata') else {} for chunk in batch]
68
+
69
+ # Embed texts. This is the CPU-heavy part for BAAI/bge-large-en-v1.5
70
  embeddings = self.embedding_model.encode(texts, show_progress_bar=False).tolist()
71
+
72
  collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
73
+ logger.info(f"Added batch {i//batch_size + 1}/{(len(new_chunks) + batch_size - 1) // batch_size}")
74
+ logger.info(f"Finished adding {len(new_chunks)} chunks.")
75
 
76
  def search(self, query_text: str, top_k: int = None) -> List[Dict]:
77
  collection = self._get_collection()
78
+
79
+ # Embed query text. This is also CPU-heavy for BAAI/bge-large-en-v1.5
80
  query_embedding = self.embedding_model.encode([query_text]).tolist()
81
+
82
  top_k = top_k if top_k else self.top_k_default
83
  results = collection.query(
84
  query_embeddings=query_embedding,
85
  n_results=top_k,
86
  include=["documents", "metadatas", "distances"]
87
  )
88
+
89
  search_results = []
90
+ # Ensure results are not empty before accessing
91
+ if results and results['documents'] and results['documents'][0]:
92
+ for i, doc in enumerate(results['documents'][0]):
93
+ # ChromaDB distances are L2, which means smaller is better.
94
+ # Converting to a similarity score where 1 is perfect match, 0 is no match.
95
+ # A common conversion for L2 is 1 / (1 + distance) or max_dist - dist.
96
+ # Here, 1 - distance is used, assuming normalized embeddings leading to dist between 0 and 2.
97
+ relevance_score = 1 - results['distances'][0][i]
98
+ search_results.append({
99
+ 'text': doc,
100
+ 'metadata': results['metadatas'][0][i],
101
+ 'relevance_score': relevance_score
102
+ })
103
  return search_results
104
 
105
  def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str):
106
+ """
107
+ Ensures the ChromaDB is populated with data from the chunks file if it's currently empty.
108
+ """
109
  try:
110
+ # Check if the collection already has data
111
  if db_instance._get_collection().count() == 0:
112
+ logger.info("Vector database is empty. Attempting to populate from chunks file.")
113
  if not os.path.exists(chunks_file_path):
114
+ logger.error(f"Chunks file not found at {chunks_file_path}. Cannot populate DB.")
115
  return False
116
+
117
  with open(chunks_file_path, 'r', encoding='utf-8') as f:
118
  chunks_to_add = json.load(f)
119
+
120
+ if not chunks_to_add:
121
+ logger.warning(f"Chunks file at {chunks_file_path} is empty. No data to add to DB.")
122
+ return False
123
+
124
  db_instance.add_chunks(chunks_to_add)
125
+ logger.info("Vector database population attempt complete.")
126
  return True
127
  else:
128
+ logger.info("Vector database already contains data. Skipping population.")
129
  return True
130
  except Exception as e:
131
+ logger.error(f"DB Population Error: {e}", exc_info=True) # exc_info for full traceback
132
+ return False