Kalpokoch commited on
Commit
46f56aa
·
verified ·
1 Parent(s): 5a1000f

Update app/policy_vector_db.py

Browse files
Files changed (1) hide show
  1. app/policy_vector_db.py +60 -121
app/policy_vector_db.py CHANGED
@@ -1,174 +1,113 @@
1
- import json
2
  import os
3
- import shutil # Keep for potential cleanup during local testing
 
 
4
  from typing import List, Dict
5
 
6
  import chromadb
7
  from sentence_transformers import SentenceTransformer
8
- import torch # Imported for device detection (e.g., cuda vs cpu)
 
 
9
 
10
  class PolicyVectorDB:
11
- """Manages the creation and searching of a persistent vector database."""
12
- def __init__(self, persist_directory: str):
13
- self.persist_directory = persist_directory # Store the path for later use
14
- self.client = chromadb.PersistentClient(path=persist_directory)
15
  self.collection_name = "neepco_dop_policies"
16
- # Use 'cuda' if available, otherwise fallback to 'cpu' for the embedding model
17
- self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cuda' if torch.cuda.is_available() else 'cpu')
18
-
19
- # Collection is not retrieved/created immediately here.
20
- # This is handled by _get_collection() which is called on demand.
21
- self.collection = None # Initialize as None
 
 
 
22
 
23
  def _get_collection(self):
24
- """Lazy loads or creates the collection to ensure it exists before operations."""
25
  if self.collection is None:
26
- print(f"Attempting to get or create collection '{self.collection_name}' at '{self.persist_directory}'...")
27
  self.collection = self.client.get_or_create_collection(
28
  name=self.collection_name,
29
  metadata={"description": "NEEPCO Delegation of Powers Policy"}
30
  )
31
- print(f"Collection '{self.collection_name}' is ready. Current count: {self.collection.count()} documents.")
32
  return self.collection
33
 
34
  def _flatten_metadata(self, metadata: Dict) -> Dict:
35
- """Ensures all metadata values are strings for ChromaDB compatibility."""
36
- return {key: str(value) for key, value in metadata.items()}
37
 
38
  def add_chunks(self, chunks: List[Dict]):
39
- """Encodes and adds a list of chunk dictionaries to the database."""
40
- collection = self._get_collection() # Ensure collection is active
41
  if not chunks:
42
- print("No chunks provided to add.")
43
  return
44
 
45
- # Fetch existing IDs to avoid re-adding the same chunks on subsequent runs
46
- existing_ids = set(collection.get()['ids']) # This gets all IDs by default
47
- new_chunks = [chunk for chunk in chunks if chunk.get('id') not in existing_ids]
48
 
49
  if not new_chunks:
50
- print("No new chunks to add. All provided chunks already exist in the database.")
51
  return
52
 
53
- print(f"Found {len(new_chunks)} new chunks to add to the DB.")
54
- batch_size = 128 # Process in batches to manage memory and network efficiently
55
-
56
  for i in range(0, len(new_chunks), batch_size):
57
  batch = new_chunks[i:i + batch_size]
58
- print(f" - Processing batch {i//batch_size + 1}/{ -(-len(new_chunks) // batch_size) }...")
59
-
60
- texts = [chunk['text'] for chunk in batch]
61
- ids = [chunk['id'] for chunk in batch]
62
- metadatas = [self._flatten_metadata(chunk['metadata']) for chunk in batch]
63
 
64
  embeddings = self.embedding_model.encode(texts, show_progress_bar=False).tolist()
65
  collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
66
 
67
- print(f"Successfully added {len(new_chunks)} new chunks to the database! Total documents: {collection.count()}")
 
 
 
 
68
 
69
- def search(self, query_text: str, top_k: int = 3) -> List[Dict]:
70
- """Searches the collection for a given query text."""
71
- collection = self._get_collection() # Ensure collection is active
72
  query_embedding = self.embedding_model.encode([query_text]).tolist()
73
  results = collection.query(
74
  query_embeddings=query_embedding,
75
  n_results=top_k,
76
- include=['documents', 'metadatas', 'distances'] # Request necessary info
77
  )
78
-
79
  search_results = []
80
- if not results.get('documents'):
81
- print("No search results found.")
82
  return []
83
 
84
- for i, doc in enumerate(results['documents'][0]):
85
- relevance_score = 1 - results['distances'][0][i] # Higher score = more relevant
86
  search_results.append({
87
- 'text': doc,
88
- 'metadata': results['metadatas'][0][i],
89
- 'relevance_score': relevance_score
90
  })
 
 
91
  return search_results
92
 
93
- # --- NEW FUNCTION: To be called by app.py to ensure DB is populated ---
94
- def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str):
95
- """
96
- Checks if the database is populated. If not, loads chunks from JSON and adds them.
97
- This function is intended to run at application startup.
98
- """
99
- print(f"Checking if database at '{db_instance.persist_directory}' needs population...")
100
  try:
101
- # Check count of the collection to see if it's already populated
102
  if db_instance._get_collection().count() == 0:
103
- print("Database is empty or collection not found. Populating from chunks...")
104
  if not os.path.exists(chunks_file_path):
105
- print(f"ERROR: Chunks file not found at '{chunks_file_path}'. Cannot populate DB.")
106
  return False
107
-
108
- with open(chunks_file_path, 'r', encoding='utf-8') as f:
109
- chunks_to_add = json.load(f)
110
-
111
- print(f"Loaded {len(chunks_to_add)} chunks from '{chunks_file_path}'.")
112
- db_instance.add_chunks(chunks_to_add)
113
- print(f"Database population complete. Total documents: {db_instance._get_collection().count()}")
114
- return True
115
- else:
116
- print(f"Database already populated with {db_instance._get_collection().count()} documents.")
117
- return True
118
- except Exception as e:
119
- print(f"An error occurred during database population check: {e}")
120
- # Log more details for debugging if needed
121
- return False
122
-
123
 
124
- # The 'main' function is kept for local testing/manual initial setup,
125
- # but it WILL NOT be called by the Dockerized application on Hugging Face Spaces.
126
- if __name__ == "__main__":
127
- print("\n--- Running PolicyVectorDB main for LOCAL TESTING/BUILD ONLY ---")
128
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
129
- INPUT_CHUNKS_PATH = os.path.join(BASE_DIR, "../processed_chunks.json")
130
- # Use a temporary local path for building so it doesn't interfere with your repo structure
131
- PERSIST_DIRECTORY = "./.temp_local_vector_db_build"
132
 
133
- # Clean up old local build directory if it exists for a fresh build
134
- if os.path.exists(PERSIST_DIRECTORY):
135
- print(f"Removing existing local build database at '{PERSIST_DIRECTORY}' to ensure a clean build.")
136
- shutil.rmtree(PERSIST_DIRECTORY)
137
-
138
- print(f"Creating database directory: '{PERSIST_DIRECTORY}'")
139
- os.makedirs(PERSIST_DIRECTORY, exist_ok=True)
140
- os.chmod(PERSIST_DIRECTORY, 0o777) # Ensure write permissions for local build
141
-
142
- print("\nStep 1: Loading processed chunks...")
143
- with open(INPUT_CHUNKS_PATH, 'r', encoding='utf-8') as f:
144
- chunks_to_add = json.load(f)
145
- print(f"Loaded {len(chunks_to_add)} chunks.")
146
-
147
- print("\nStep 2: Setting up persistent vector database (local build)...")
148
- db = PolicyVectorDB(persist_directory=PERSIST_DIRECTORY)
149
-
150
- print("\nStep 3: Adding chunks to the database...")
151
- db.add_chunks(chunks_to_add)
152
-
153
- print(f"\n✅ Local vector database setup complete. Total chunks in DB: {db._get_collection().count()}")
154
- print(f"Database is saved in: {os.path.abspath(PERSIST_DIRECTORY)}")
155
- print("\n--- Remember: This local build is for testing. The deployed app will build its own DB. ---")
156
-
157
- print("\n--- Running Local Verification Tests ---")
158
- test_questions = [
159
- "Who can approve changes to the pay structure?",
160
- "What is the financial limit for a DGM for works on a limited tender basis?",
161
- "What's the delegation power of an ED for single tender O&M contracts from an OEM?"
162
- ]
163
-
164
- for question in test_questions:
165
- print(f"\n--- Testing Query ---")
166
- print(f"Query: {question}")
167
- search_results = db.search(question, top_k=2)
168
- if search_results:
169
- for j, result in enumerate(search_results, 1):
170
- print(f" Result {j} (Relevance: {result['relevance_score']:.4f}):")
171
- print(f" Text: {result['text'][:300]}...")
172
- print(f" Metadata: {result['metadata']}")
173
  else:
174
- print(" No results found.")
 
 
 
 
 
 
1
  import os
2
+ import json
3
+ import shutil
4
+ import logging
5
  from typing import List, Dict
6
 
7
  import chromadb
8
  from sentence_transformers import SentenceTransformer
9
+ import torch
10
+
11
+ logger = logging.getLogger("vector-db")
12
 
13
  class PolicyVectorDB:
14
+ def __init__(self, persist_directory: str, top_k_default: int = 5, relevance_threshold: float = 0.65):
15
+ self.persist_directory = persist_directory
 
 
16
  self.collection_name = "neepco_dop_policies"
17
+ self.top_k_default = top_k_default
18
+ self.relevance_threshold = relevance_threshold
19
+
20
+ self.client = chromadb.PersistentClient(path=self.persist_directory)
21
+ self.collection = None
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ self.embedding_model = SentenceTransformer("BAAI/bge-large-en-v1.5", device=device)
25
+ logger.info(f"[INIT] Embedding model loaded on {device.upper()}.")
26
 
27
  def _get_collection(self):
 
28
  if self.collection is None:
 
29
  self.collection = self.client.get_or_create_collection(
30
  name=self.collection_name,
31
  metadata={"description": "NEEPCO Delegation of Powers Policy"}
32
  )
33
+ logger.info(f"[COLLECTION] Loaded collection '{self.collection_name}'. Count: {self.collection.count()}")
34
  return self.collection
35
 
36
  def _flatten_metadata(self, metadata: Dict) -> Dict:
37
+ return {k: str(v) for k, v in metadata.items()}
 
38
 
39
  def add_chunks(self, chunks: List[Dict]):
40
+ collection = self._get_collection()
 
41
  if not chunks:
42
+ logger.warning("[ADD] No chunks to add.")
43
  return
44
 
45
+ existing_ids = set(collection.get()['ids'])
46
+ new_chunks = [c for c in chunks if c['id'] not in existing_ids]
 
47
 
48
  if not new_chunks:
49
+ logger.info("[ADD] All chunks already exist in DB.")
50
  return
51
 
52
+ logger.info(f"[ADD] Adding {len(new_chunks)} new chunks.")
53
+ batch_size = 128
 
54
  for i in range(0, len(new_chunks), batch_size):
55
  batch = new_chunks[i:i + batch_size]
56
+ texts = [c['text'] for c in batch]
57
+ ids = [c['id'] for c in batch]
58
+ metadatas = [self._flatten_metadata(c['metadata']) for c in batch]
 
 
59
 
60
  embeddings = self.embedding_model.encode(texts, show_progress_bar=False).tolist()
61
  collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
62
 
63
+ logger.info(f"[ADD] Total docs after insert: {collection.count()}")
64
+
65
+ def search(self, query_text: str, top_k: int = None) -> List[Dict]:
66
+ collection = self._get_collection()
67
+ top_k = top_k or self.top_k_default
68
 
 
 
 
69
  query_embedding = self.embedding_model.encode([query_text]).tolist()
70
  results = collection.query(
71
  query_embeddings=query_embedding,
72
  n_results=top_k,
73
+ include=["documents", "metadatas", "distances"]
74
  )
75
+
76
  search_results = []
77
+ if not results.get("documents"):
78
+ logger.warning("[SEARCH] No documents found.")
79
  return []
80
 
81
+ for i, doc in enumerate(results["documents"][0]):
82
+ score = 1 - results["distances"][0][i]
83
  search_results.append({
84
+ "text": doc,
85
+ "metadata": results["metadatas"][0][i],
86
+ "relevance_score": round(score, 4)
87
  })
88
+
89
+ logger.info(f"[SEARCH] Retrieved {len(search_results)} results for query: {query_text}")
90
  return search_results
91
 
92
+
93
+ def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str) -> bool:
94
+ logger.info("[POPULATE] Checking vector DB...")
95
+
 
 
 
96
  try:
 
97
  if db_instance._get_collection().count() == 0:
 
98
  if not os.path.exists(chunks_file_path):
99
+ logger.error(f"[ERROR] Chunks file not found at {chunks_file_path}")
100
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ with open(chunks_file_path, "r", encoding="utf-8") as f:
103
+ chunks = json.load(f)
 
 
 
 
 
 
104
 
105
+ logger.info(f"[POPULATE] Loaded {len(chunks)} chunks. Populating DB...")
106
+ db_instance.add_chunks(chunks)
107
+ logger.info("[POPULATE] DB population complete.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  else:
109
+ logger.info("[POPULATE] DB already populated.")
110
+ return True
111
+ except Exception as e:
112
+ logger.exception(f"[EXCEPTION] During DB population: {str(e)}")
113
+ return False