Spaces:
Running
Running
| import json | |
| import os | |
| import shutil # Keep for potential cleanup during local testing | |
| from typing import List, Dict | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| import torch # Imported for device detection (e.g., cuda vs cpu) | |
| class PolicyVectorDB: | |
| """Manages the creation and searching of a persistent vector database.""" | |
| def __init__(self, persist_directory: str): | |
| self.persist_directory = persist_directory # Store the path for later use | |
| self.client = chromadb.PersistentClient(path=persist_directory) | |
| self.collection_name = "neepco_dop_policies" | |
| # Use 'cuda' if available, otherwise fallback to 'cpu' for the embedding model | |
| self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cuda' if torch.cuda.is_available() else 'cpu') | |
| # Collection is not retrieved/created immediately here. | |
| # This is handled by _get_collection() which is called on demand. | |
| self.collection = None # Initialize as None | |
| def _get_collection(self): | |
| """Lazy loads or creates the collection to ensure it exists before operations.""" | |
| if self.collection is None: | |
| print(f"Attempting to get or create collection '{self.collection_name}' at '{self.persist_directory}'...") | |
| self.collection = self.client.get_or_create_collection( | |
| name=self.collection_name, | |
| metadata={"description": "NEEPCO Delegation of Powers Policy"} | |
| ) | |
| print(f"Collection '{self.collection_name}' is ready. Current count: {self.collection.count()} documents.") | |
| return self.collection | |
| def _flatten_metadata(self, metadata: Dict) -> Dict: | |
| """Ensures all metadata values are strings for ChromaDB compatibility.""" | |
| return {key: str(value) for key, value in metadata.items()} | |
| def add_chunks(self, chunks: List[Dict]): | |
| """Encodes and adds a list of chunk dictionaries to the database.""" | |
| collection = self._get_collection() # Ensure collection is active | |
| if not chunks: | |
| print("No chunks provided to add.") | |
| return | |
| # Fetch existing IDs to avoid re-adding the same chunks on subsequent runs | |
| existing_ids = set(collection.get()['ids']) # This gets all IDs by default | |
| new_chunks = [chunk for chunk in chunks if chunk.get('id') not in existing_ids] | |
| if not new_chunks: | |
| print("No new chunks to add. All provided chunks already exist in the database.") | |
| return | |
| print(f"Found {len(new_chunks)} new chunks to add to the DB.") | |
| batch_size = 128 # Process in batches to manage memory and network efficiently | |
| for i in range(0, len(new_chunks), batch_size): | |
| batch = new_chunks[i:i + batch_size] | |
| print(f" - Processing batch {i//batch_size + 1}/{ -(-len(new_chunks) // batch_size) }...") | |
| texts = [chunk['text'] for chunk in batch] | |
| ids = [chunk['id'] for chunk in batch] | |
| metadatas = [self._flatten_metadata(chunk['metadata']) for chunk in batch] | |
| embeddings = self.embedding_model.encode(texts, show_progress_bar=False).tolist() | |
| collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas) | |
| print(f"Successfully added {len(new_chunks)} new chunks to the database! Total documents: {collection.count()}") | |
| def search(self, query_text: str, top_k: int = 3) -> List[Dict]: | |
| """Searches the collection for a given query text.""" | |
| collection = self._get_collection() # Ensure collection is active | |
| query_embedding = self.embedding_model.encode([query_text]).tolist() | |
| results = collection.query( | |
| query_embeddings=query_embedding, | |
| n_results=top_k, | |
| include=['documents', 'metadatas', 'distances'] # Request necessary info | |
| ) | |
| search_results = [] | |
| if not results.get('documents'): | |
| print("No search results found.") | |
| return [] | |
| for i, doc in enumerate(results['documents'][0]): | |
| relevance_score = 1 - results['distances'][0][i] # Higher score = more relevant | |
| search_results.append({ | |
| 'text': doc, | |
| 'metadata': results['metadatas'][0][i], | |
| 'relevance_score': relevance_score | |
| }) | |
| return search_results | |
| # --- NEW FUNCTION: To be called by app.py to ensure DB is populated --- | |
| def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str): | |
| """ | |
| Checks if the database is populated. If not, loads chunks from JSON and adds them. | |
| This function is intended to run at application startup. | |
| """ | |
| print(f"Checking if database at '{db_instance.persist_directory}' needs population...") | |
| try: | |
| # Check count of the collection to see if it's already populated | |
| if db_instance._get_collection().count() == 0: | |
| print("Database is empty or collection not found. Populating from chunks...") | |
| if not os.path.exists(chunks_file_path): | |
| print(f"ERROR: Chunks file not found at '{chunks_file_path}'. Cannot populate DB.") | |
| return False | |
| with open(chunks_file_path, 'r', encoding='utf-8') as f: | |
| chunks_to_add = json.load(f) | |
| print(f"Loaded {len(chunks_to_add)} chunks from '{chunks_file_path}'.") | |
| db_instance.add_chunks(chunks_to_add) | |
| print(f"Database population complete. Total documents: {db_instance._get_collection().count()}") | |
| return True | |
| else: | |
| print(f"Database already populated with {db_instance._get_collection().count()} documents.") | |
| return True | |
| except Exception as e: | |
| print(f"An error occurred during database population check: {e}") | |
| # Log more details for debugging if needed | |
| return False | |
| # The 'main' function is kept for local testing/manual initial setup, | |
| # but it WILL NOT be called by the Dockerized application on Hugging Face Spaces. | |
| if __name__ == "__main__": | |
| print("\n--- Running PolicyVectorDB main for LOCAL TESTING/BUILD ONLY ---") | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| INPUT_CHUNKS_PATH = os.path.join(BASE_DIR, "../processed_chunks.json") | |
| # Use a temporary local path for building so it doesn't interfere with your repo structure | |
| PERSIST_DIRECTORY = "./.temp_local_vector_db_build" | |
| # Clean up old local build directory if it exists for a fresh build | |
| if os.path.exists(PERSIST_DIRECTORY): | |
| print(f"Removing existing local build database at '{PERSIST_DIRECTORY}' to ensure a clean build.") | |
| shutil.rmtree(PERSIST_DIRECTORY) | |
| print(f"Creating database directory: '{PERSIST_DIRECTORY}'") | |
| os.makedirs(PERSIST_DIRECTORY, exist_ok=True) | |
| os.chmod(PERSIST_DIRECTORY, 0o777) # Ensure write permissions for local build | |
| print("\nStep 1: Loading processed chunks...") | |
| with open(INPUT_CHUNKS_PATH, 'r', encoding='utf-8') as f: | |
| chunks_to_add = json.load(f) | |
| print(f"Loaded {len(chunks_to_add)} chunks.") | |
| print("\nStep 2: Setting up persistent vector database (local build)...") | |
| db = PolicyVectorDB(persist_directory=PERSIST_DIRECTORY) | |
| print("\nStep 3: Adding chunks to the database...") | |
| db.add_chunks(chunks_to_add) | |
| print(f"\n✅ Local vector database setup complete. Total chunks in DB: {db._get_collection().count()}") | |
| print(f"Database is saved in: {os.path.abspath(PERSIST_DIRECTORY)}") | |
| print("\n--- Remember: This local build is for testing. The deployed app will build its own DB. ---") | |
| print("\n--- Running Local Verification Tests ---") | |
| test_questions = [ | |
| "Who can approve changes to the pay structure?", | |
| "What is the financial limit for a DGM for works on a limited tender basis?", | |
| "What's the delegation power of an ED for single tender O&M contracts from an OEM?" | |
| ] | |
| for question in test_questions: | |
| print(f"\n--- Testing Query ---") | |
| print(f"Query: {question}") | |
| search_results = db.search(question, top_k=2) | |
| if search_results: | |
| for j, result in enumerate(search_results, 1): | |
| print(f" Result {j} (Relevance: {result['relevance_score']:.4f}):") | |
| print(f" Text: {result['text'][:300]}...") | |
| print(f" Metadata: {result['metadata']}") | |
| else: | |
| print(" No results found.") |