ChatbotDemo / app /policy_vector_db.py
Kalpokoch's picture
Update app/policy_vector_db.py
a7c611f verified
raw
history blame
8.56 kB
import json
import os
import shutil # Keep for potential cleanup during local testing
from typing import List, Dict
import chromadb
from sentence_transformers import SentenceTransformer
import torch # Imported for device detection (e.g., cuda vs cpu)
class PolicyVectorDB:
"""Manages the creation and searching of a persistent vector database."""
def __init__(self, persist_directory: str):
self.persist_directory = persist_directory # Store the path for later use
self.client = chromadb.PersistentClient(path=persist_directory)
self.collection_name = "neepco_dop_policies"
# Use 'cuda' if available, otherwise fallback to 'cpu' for the embedding model
self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cuda' if torch.cuda.is_available() else 'cpu')
# Collection is not retrieved/created immediately here.
# This is handled by _get_collection() which is called on demand.
self.collection = None # Initialize as None
def _get_collection(self):
"""Lazy loads or creates the collection to ensure it exists before operations."""
if self.collection is None:
print(f"Attempting to get or create collection '{self.collection_name}' at '{self.persist_directory}'...")
self.collection = self.client.get_or_create_collection(
name=self.collection_name,
metadata={"description": "NEEPCO Delegation of Powers Policy"}
)
print(f"Collection '{self.collection_name}' is ready. Current count: {self.collection.count()} documents.")
return self.collection
def _flatten_metadata(self, metadata: Dict) -> Dict:
"""Ensures all metadata values are strings for ChromaDB compatibility."""
return {key: str(value) for key, value in metadata.items()}
def add_chunks(self, chunks: List[Dict]):
"""Encodes and adds a list of chunk dictionaries to the database."""
collection = self._get_collection() # Ensure collection is active
if not chunks:
print("No chunks provided to add.")
return
# Fetch existing IDs to avoid re-adding the same chunks on subsequent runs
existing_ids = set(collection.get()['ids']) # This gets all IDs by default
new_chunks = [chunk for chunk in chunks if chunk.get('id') not in existing_ids]
if not new_chunks:
print("No new chunks to add. All provided chunks already exist in the database.")
return
print(f"Found {len(new_chunks)} new chunks to add to the DB.")
batch_size = 128 # Process in batches to manage memory and network efficiently
for i in range(0, len(new_chunks), batch_size):
batch = new_chunks[i:i + batch_size]
print(f" - Processing batch {i//batch_size + 1}/{ -(-len(new_chunks) // batch_size) }...")
texts = [chunk['text'] for chunk in batch]
ids = [chunk['id'] for chunk in batch]
metadatas = [self._flatten_metadata(chunk['metadata']) for chunk in batch]
embeddings = self.embedding_model.encode(texts, show_progress_bar=False).tolist()
collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
print(f"Successfully added {len(new_chunks)} new chunks to the database! Total documents: {collection.count()}")
def search(self, query_text: str, top_k: int = 3) -> List[Dict]:
"""Searches the collection for a given query text."""
collection = self._get_collection() # Ensure collection is active
query_embedding = self.embedding_model.encode([query_text]).tolist()
results = collection.query(
query_embeddings=query_embedding,
n_results=top_k,
include=['documents', 'metadatas', 'distances'] # Request necessary info
)
search_results = []
if not results.get('documents'):
print("No search results found.")
return []
for i, doc in enumerate(results['documents'][0]):
relevance_score = 1 - results['distances'][0][i] # Higher score = more relevant
search_results.append({
'text': doc,
'metadata': results['metadatas'][0][i],
'relevance_score': relevance_score
})
return search_results
# --- NEW FUNCTION: To be called by app.py to ensure DB is populated ---
def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str):
"""
Checks if the database is populated. If not, loads chunks from JSON and adds them.
This function is intended to run at application startup.
"""
print(f"Checking if database at '{db_instance.persist_directory}' needs population...")
try:
# Check count of the collection to see if it's already populated
if db_instance._get_collection().count() == 0:
print("Database is empty or collection not found. Populating from chunks...")
if not os.path.exists(chunks_file_path):
print(f"ERROR: Chunks file not found at '{chunks_file_path}'. Cannot populate DB.")
return False
with open(chunks_file_path, 'r', encoding='utf-8') as f:
chunks_to_add = json.load(f)
print(f"Loaded {len(chunks_to_add)} chunks from '{chunks_file_path}'.")
db_instance.add_chunks(chunks_to_add)
print(f"Database population complete. Total documents: {db_instance._get_collection().count()}")
return True
else:
print(f"Database already populated with {db_instance._get_collection().count()} documents.")
return True
except Exception as e:
print(f"An error occurred during database population check: {e}")
# Log more details for debugging if needed
return False
# The 'main' function is kept for local testing/manual initial setup,
# but it WILL NOT be called by the Dockerized application on Hugging Face Spaces.
if __name__ == "__main__":
print("\n--- Running PolicyVectorDB main for LOCAL TESTING/BUILD ONLY ---")
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
INPUT_CHUNKS_PATH = os.path.join(BASE_DIR, "../processed_chunks.json")
# Use a temporary local path for building so it doesn't interfere with your repo structure
PERSIST_DIRECTORY = "./.temp_local_vector_db_build"
# Clean up old local build directory if it exists for a fresh build
if os.path.exists(PERSIST_DIRECTORY):
print(f"Removing existing local build database at '{PERSIST_DIRECTORY}' to ensure a clean build.")
shutil.rmtree(PERSIST_DIRECTORY)
print(f"Creating database directory: '{PERSIST_DIRECTORY}'")
os.makedirs(PERSIST_DIRECTORY, exist_ok=True)
os.chmod(PERSIST_DIRECTORY, 0o777) # Ensure write permissions for local build
print("\nStep 1: Loading processed chunks...")
with open(INPUT_CHUNKS_PATH, 'r', encoding='utf-8') as f:
chunks_to_add = json.load(f)
print(f"Loaded {len(chunks_to_add)} chunks.")
print("\nStep 2: Setting up persistent vector database (local build)...")
db = PolicyVectorDB(persist_directory=PERSIST_DIRECTORY)
print("\nStep 3: Adding chunks to the database...")
db.add_chunks(chunks_to_add)
print(f"\n✅ Local vector database setup complete. Total chunks in DB: {db._get_collection().count()}")
print(f"Database is saved in: {os.path.abspath(PERSIST_DIRECTORY)}")
print("\n--- Remember: This local build is for testing. The deployed app will build its own DB. ---")
print("\n--- Running Local Verification Tests ---")
test_questions = [
"Who can approve changes to the pay structure?",
"What is the financial limit for a DGM for works on a limited tender basis?",
"What's the delegation power of an ED for single tender O&M contracts from an OEM?"
]
for question in test_questions:
print(f"\n--- Testing Query ---")
print(f"Query: {question}")
search_results = db.search(question, top_k=2)
if search_results:
for j, result in enumerate(search_results, 1):
print(f" Result {j} (Relevance: {result['relevance_score']:.4f}):")
print(f" Text: {result['text'][:300]}...")
print(f" Metadata: {result['metadata']}")
else:
print(" No results found.")