Kalpokoch commited on
Commit
a038694
·
verified ·
1 Parent(s): 344f629

Update app/policy_vector_db.py

Browse files
Files changed (1) hide show
  1. app/policy_vector_db.py +20 -46
app/policy_vector_db.py CHANGED
@@ -7,22 +7,17 @@ import chromadb
7
  from chromadb.config import Settings
8
  import logging
9
 
10
- logger = logging.getLogger("app") # Use the same logger as app.py
11
 
12
  class PolicyVectorDB:
13
  def __init__(self, persist_directory: str, top_k_default: int = 5, relevance_threshold: float = 0.5):
14
  self.persist_directory = persist_directory
15
- # Allow reset is useful for development, can be removed in production if not needed
16
  self.client = chromadb.PersistentClient(path=persist_directory, settings=Settings(allow_reset=True))
17
  self.collection_name = "neepco_dop_policies"
18
-
19
- # IMPORTANT: Keeping BAAI/bge-large-en-v1.5 as per your requirement for accuracy.
20
- # Be aware this will be slow on CPU.
21
  self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cuda' if torch.cuda.is_available() else 'cpu')
22
-
23
  self.collection = None
24
  self.top_k_default = top_k_default
25
- self.relevance_threshold = relevance_threshold # Note: set to 0 in app.py
26
 
27
  def _get_collection(self):
28
  if self.collection is None:
@@ -33,8 +28,6 @@ class PolicyVectorDB:
33
  return self.collection
34
 
35
  def _flatten_metadata(self, metadata: Dict) -> Dict:
36
- # ChromaDB requires metadata values to be JSON-serializable strings, ints, floats, or bools.
37
- # Ensuring all values are string for consistency and compatibility.
38
  return {key: str(value) for key, value in metadata.items()}
39
 
40
  def add_chunks(self, chunks: List[Dict]):
@@ -43,58 +36,43 @@ class PolicyVectorDB:
43
  logger.info("No chunks provided to add.")
44
  return
45
 
46
- # Fetch existing IDs to avoid adding duplicates
47
  existing_ids = set()
48
  try:
49
- # This can be slow for very large collections, consider optimizing if needed
50
  existing_ids = set(collection.get(include=[])['ids'])
51
  except Exception as e:
52
- logger.warning(f"Could not retrieve existing IDs from ChromaDB: {e}. Assuming no existing IDs for now.")
53
 
54
  new_chunks = [chunk for chunk in chunks if chunk.get('id') and chunk['id'] not in existing_ids]
55
-
56
  if not new_chunks:
57
  logger.info("No new chunks to add.")
58
  return
59
 
60
- logger.info(f"Adding {len(new_chunks)} new chunks to the vector database...")
61
-
62
- batch_size = 128 # Good batch size for embedding
63
  for i in range(0, len(new_chunks), batch_size):
64
  batch = new_chunks[i:i + batch_size]
65
  texts = [chunk['text'] for chunk in batch]
66
  ids = [chunk['id'] for chunk in batch]
67
- metadatas = [self._flatten_metadata(chunk['metadata']) if chunk.get('metadata') else {} for chunk in batch]
68
-
69
- # Embed texts. This is the CPU-heavy part for BAAI/bge-large-en-v1.5
70
  embeddings = self.embedding_model.encode(texts, show_progress_bar=False).tolist()
71
-
72
  collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
73
- logger.info(f"Added batch {i//batch_size + 1}/{(len(new_chunks) + batch_size - 1) // batch_size}")
74
  logger.info(f"Finished adding {len(new_chunks)} chunks.")
75
 
76
  def search(self, query_text: str, top_k: int = None) -> List[Dict]:
77
  collection = self._get_collection()
78
-
79
- # Embed query text. This is also CPU-heavy for BAAI/bge-large-en-v1.5
80
  query_embedding = self.embedding_model.encode([query_text]).tolist()
81
-
82
- top_k = top_k if top_k else self.top_k_default
83
  results = collection.query(
84
  query_embeddings=query_embedding,
85
  n_results=top_k,
86
  include=["documents", "metadatas", "distances"]
87
  )
88
-
89
  search_results = []
90
- # Ensure results are not empty before accessing
91
- if results and results['documents'] and results['documents'][0]:
92
  for i, doc in enumerate(results['documents'][0]):
93
- # ChromaDB distances are L2, which means smaller is better.
94
- # Converting to a similarity score where 1 is perfect match, 0 is no match.
95
- # A common conversion for L2 is 1 / (1 + distance) or max_dist - dist.
96
- # Here, 1 - distance is used, assuming normalized embeddings leading to dist between 0 and 2.
97
- relevance_score = 1 - results['distances'][0][i]
98
  search_results.append({
99
  'text': doc,
100
  'metadata': results['metadatas'][0][i],
@@ -103,30 +81,26 @@ class PolicyVectorDB:
103
  return search_results
104
 
105
  def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str):
106
- """
107
- Ensures the ChromaDB is populated with data from the chunks file if it's currently empty.
108
- """
109
  try:
110
- # Check if the collection already has data
111
  if db_instance._get_collection().count() == 0:
112
- logger.info("Vector database is empty. Attempting to populate from chunks file.")
113
  if not os.path.exists(chunks_file_path):
114
- logger.error(f"Chunks file not found at {chunks_file_path}. Cannot populate DB.")
115
  return False
116
-
117
  with open(chunks_file_path, 'r', encoding='utf-8') as f:
118
  chunks_to_add = json.load(f)
119
-
120
  if not chunks_to_add:
121
- logger.warning(f"Chunks file at {chunks_file_path} is empty. No data to add to DB.")
122
  return False
123
 
124
  db_instance.add_chunks(chunks_to_add)
125
- logger.info("Vector database population attempt complete.")
126
  return True
127
  else:
128
- logger.info("Vector database already contains data. Skipping population.")
129
  return True
130
  except Exception as e:
131
- logger.error(f"DB Population Error: {e}", exc_info=True) # exc_info for full traceback
132
- return False
 
7
  from chromadb.config import Settings
8
  import logging
9
 
10
+ logger = logging.getLogger("app")
11
 
12
  class PolicyVectorDB:
13
  def __init__(self, persist_directory: str, top_k_default: int = 5, relevance_threshold: float = 0.5):
14
  self.persist_directory = persist_directory
 
15
  self.client = chromadb.PersistentClient(path=persist_directory, settings=Settings(allow_reset=True))
16
  self.collection_name = "neepco_dop_policies"
 
 
 
17
  self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cuda' if torch.cuda.is_available() else 'cpu')
 
18
  self.collection = None
19
  self.top_k_default = top_k_default
20
+ self.relevance_threshold = relevance_threshold
21
 
22
  def _get_collection(self):
23
  if self.collection is None:
 
28
  return self.collection
29
 
30
  def _flatten_metadata(self, metadata: Dict) -> Dict:
 
 
31
  return {key: str(value) for key, value in metadata.items()}
32
 
33
  def add_chunks(self, chunks: List[Dict]):
 
36
  logger.info("No chunks provided to add.")
37
  return
38
 
 
39
  existing_ids = set()
40
  try:
 
41
  existing_ids = set(collection.get(include=[])['ids'])
42
  except Exception as e:
43
+ logger.warning(f"Could not retrieve existing IDs from ChromaDB: {e}")
44
 
45
  new_chunks = [chunk for chunk in chunks if chunk.get('id') and chunk['id'] not in existing_ids]
 
46
  if not new_chunks:
47
  logger.info("No new chunks to add.")
48
  return
49
 
50
+ batch_size = 128
 
 
51
  for i in range(0, len(new_chunks), batch_size):
52
  batch = new_chunks[i:i + batch_size]
53
  texts = [chunk['text'] for chunk in batch]
54
  ids = [chunk['id'] for chunk in batch]
55
+ metadatas = [self._flatten_metadata(chunk.get('metadata', {})) for chunk in batch]
 
 
56
  embeddings = self.embedding_model.encode(texts, show_progress_bar=False).tolist()
 
57
  collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
58
+ logger.info(f"Added batch {i // batch_size + 1}/{(len(new_chunks) + batch_size - 1) // batch_size}")
59
  logger.info(f"Finished adding {len(new_chunks)} chunks.")
60
 
61
  def search(self, query_text: str, top_k: int = None) -> List[Dict]:
62
  collection = self._get_collection()
 
 
63
  query_embedding = self.embedding_model.encode([query_text]).tolist()
64
+ top_k = top_k or self.top_k_default
65
+
66
  results = collection.query(
67
  query_embeddings=query_embedding,
68
  n_results=top_k,
69
  include=["documents", "metadatas", "distances"]
70
  )
71
+
72
  search_results = []
73
+ if results and results['documents'][0]:
 
74
  for i, doc in enumerate(results['documents'][0]):
75
+ relevance_score = 1 - results['distances'][0][i]
 
 
 
 
76
  search_results.append({
77
  'text': doc,
78
  'metadata': results['metadatas'][0][i],
 
81
  return search_results
82
 
83
  def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str):
 
 
 
84
  try:
 
85
  if db_instance._get_collection().count() == 0:
86
+ logger.info("Vector database is empty. Attempting to populate...")
87
  if not os.path.exists(chunks_file_path):
88
+ logger.error(f"Chunks file not found at {chunks_file_path}")
89
  return False
90
+
91
  with open(chunks_file_path, 'r', encoding='utf-8') as f:
92
  chunks_to_add = json.load(f)
93
+
94
  if not chunks_to_add:
95
+ logger.warning("Chunks file is empty.")
96
  return False
97
 
98
  db_instance.add_chunks(chunks_to_add)
99
+ logger.info("Database population complete.")
100
  return True
101
  else:
102
+ logger.info("Database already populated.")
103
  return True
104
  except Exception as e:
105
+ logger.error(f"DB Population Error: {e}", exc_info=True)
106
+ return False