Spaces:
Running
Running
File size: 9,098 Bytes
8fdb143 46f56aa a6eea30 ea6be63 20fbd2c a6eea30 04147ae e8fe4a8 20fbd2c e8fe4a8 a179120 ce750f8 e8fe4a8 ea6be63 46f56aa a6eea30 20fbd2c ea6be63 a179120 c5aeabe a179120 c5aeabe ea6be63 0c1cd5d 46f56aa a038694 46f56aa ad44818 ce750f8 ad44818 fd37461 ad44818 20fbd2c ea6be63 a179120 f55e2f6 0194a83 20fbd2c ce750f8 46f56aa 8fdb143 04147ae 8fdb143 04147ae 9df2551 ea6be63 9df2551 23141e5 8fdb143 e8fe4a8 8fdb143 a179120 23141e5 a179120 ea6be63 0c1cd5d 8fdb143 a179120 7a6b1de 9df2551 ea6be63 a179120 ea6be63 a179120 ad44818 23141e5 a179120 04147ae 46f56aa ea6be63 e8fe4a8 0194a83 ea6be63 e8fe4a8 46f56aa 9df2551 ea6be63 0194a83 784b064 0194a83 784b064 a179120 0194a83 784b064 e8fe4a8 ce750f8 ad44818 9df2551 23141e5 a6eea30 e8fe4a8 9df2551 e8fe4a8 9df2551 a179120 9df2551 a179120 9df2551 e8fe4a8 9df2551 ce750f8 9df2551 46f56aa e8fe4a8 a179120 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import os
import json
import torch
from typing import List, Dict
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings
import logging
# --- Basic Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class PolicyVectorDB:
"""
Manages the connection, population, and querying of a ChromaDB vector database
for policy documents.
"""
def __init__(self, persist_directory: str, top_k_default: int = 5, relevance_threshold: float = 0.5):
self.persist_directory = persist_directory
self.client = chromadb.PersistentClient(path=persist_directory, settings=Settings(allow_reset=True))
self.collection_name = "neepco_dop_policies"
# Using a powerful open-source embedding model.
# Change 'cpu' to 'cuda' if a GPU is available for significantly faster embedding.
logger.info("Loading embedding model 'BAAI/bge-large-en-v1.5'. This may take a moment...")
self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cpu')
logger.info("Embedding model loaded successfully.")
self.collection = None # Initialize collection as None for lazy loading
self.top_k_default = top_k_default
self.relevance_threshold = relevance_threshold
def _get_collection(self):
"""
Retrieves or creates the ChromaDB collection. Implements lazy loading.
"""
if self.collection is None:
self.collection = self.client.get_or_create_collection(
name=self.collection_name,
metadata={"description": "NEEPCO Delegation of Powers Policy"}
)
return self.collection
def _flatten_metadata(self, metadata: Dict) -> Dict:
"""Ensures all metadata values are strings, as required by some ChromaDB versions."""
return {key: str(value) for key, value in metadata.items()}
def expand_query(self, query_text: str) -> List[str]:
"""
Generates query variations to improve retrieval.
Uses simple heuristics - zero LLM cost.
"""
queries = [query_text]
# Expand with synonyms for policy-related terms
synonyms = {
"approval": ["approval", "consent", "authorization", "permission"],
"limit": ["limit", "threshold", "ceiling", "maximum"],
"authority": ["authority", "official", "person", "representative"],
"delegate": ["delegate", "authorize", "empower", "assign"],
"power": ["power", "authority", "delegation", "responsibility"],
"financial": ["financial", "monetary", "funds", "budget"],
}
for term, variants in synonyms.items():
if term in query_text.lower():
for variant in variants:
if variant.lower() not in query_text.lower():
expanded = query_text.replace(term, variant)
if expanded not in queries:
queries.append(expanded)
if len(queries) >= 4:
break
if len(queries) >= 4:
break
return queries[:4] # Limit to 4 variations
def add_chunks(self, chunks: List[Dict]):
"""
Adds a list of chunks to the vector database, skipping any that already exist.
"""
collection = self._get_collection()
if not chunks:
logger.info("No chunks provided to add.")
return
chunks_with_ids = [c for c in chunks if c.get('id')]
if len(chunks) != len(chunks_with_ids):
logger.warning(f"Skipped {len(chunks) - len(chunks_with_ids)} chunks that were missing an 'id'.")
if not chunks_with_ids:
return
existing_ids = set(collection.get(ids=[str(c['id']) for c in chunks_with_ids])['ids'])
new_chunks = [chunk for chunk in chunks_with_ids if str(chunk.get('id')) not in existing_ids]
if not new_chunks:
logger.info("All provided chunks already exist in the database. No new data to add.")
return
logger.info(f"Adding {len(new_chunks)} new chunks to the vector database...")
# Process in batches for efficiency
batch_size = 32 # Reduced batch size for potentially large embeddings
for i in range(0, len(new_chunks), batch_size):
batch = new_chunks[i:i + batch_size]
ids = [str(chunk['id']) for chunk in batch]
texts = [chunk['text'] for chunk in batch]
metadatas = [self._flatten_metadata(chunk.get('metadata', {})) for chunk in batch]
# For BGE models, it's recommended not to add instructions to the document embeddings
embeddings = self.embedding_model.encode(texts, normalize_embeddings=True, show_progress_bar=False).tolist()
collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
logger.info(f"Added batch {i//batch_size + 1}/{(len(new_chunks) + batch_size - 1) // batch_size}")
logger.info(f"Finished adding {len(new_chunks)} chunks.")
def search(self, query_text: str, top_k: int = None) -> List[Dict]:
"""
Searches the vector database for a given query text with expansion.
Returns a list of results filtered by a relevance threshold.
"""
collection = self._get_collection()
k = top_k if top_k is not None else self.top_k_default
# Expand query for better recall
queries = self.expand_query(query_text)
all_results = {}
for query in queries:
# Add the recommended instruction prefix for BGE retrieval models.
instructed_query = f"Represent this sentence for searching relevant passages: {query}"
# Normalize embeddings for more accurate similarity search.
query_embedding = self.embedding_model.encode([instructed_query], normalize_embeddings=True).tolist()
# Retrieve more results initially to allow for filtering
results = collection.query(
query_embeddings=query_embedding,
n_results=k * 2, # Retrieve more to filter by threshold
include=["documents", "metadatas", "distances"]
)
if results and results.get('documents') and results['documents'][0]:
for i, doc in enumerate(results['documents'][0]):
# The distance for normalized embeddings is often interpreted as 1 - cosine_similarity
relevance_score = 1 - results['distances'][0][i]
if relevance_score >= self.relevance_threshold:
key = doc # Use document text as key
# Keep highest relevance score for duplicate documents
if key not in all_results or relevance_score > all_results[key]['relevance_score']:
all_results[key] = {
'text': doc,
'metadata': results['metadatas'][0][i],
'relevance_score': relevance_score
}
# Sort by relevance score and return the top_k results
return sorted(all_results.values(), key=lambda x: x['relevance_score'], reverse=True)[:k]
def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str) -> bool:
"""
Checks if the DB is empty and populates it from a JSONL file if needed.
"""
try:
if db_instance._get_collection().count() > 0:
logger.info("Vector database already contains data. Skipping population.")
return True
logger.info("Vector database is empty. Attempting to populate from chunks file.")
if not os.path.exists(chunks_file_path):
logger.error(f"Chunks file not found at '{chunks_file_path}'. Cannot populate DB.")
return False
chunks_to_add = []
with open(chunks_file_path, 'r', encoding='utf-8') as f:
for line in f:
try:
chunks_to_add.append(json.loads(line))
except json.JSONDecodeError:
logger.warning(f"Skipping malformed line in chunks file: {line.strip()}")
if not chunks_to_add:
logger.warning(f"Chunks file at '{chunks_file_path}' is empty or invalid. No data to add.")
return False
db_instance.add_chunks(chunks_to_add)
logger.info("Vector database population attempt complete.")
return True
except Exception as e:
logger.error(f"An error occurred during DB population check: {e}", exc_info=True)
return False
|