File size: 9,098 Bytes
8fdb143
46f56aa
a6eea30
ea6be63
20fbd2c
a6eea30
 
04147ae
 
e8fe4a8
 
 
20fbd2c
 
e8fe4a8
a179120
ce750f8
e8fe4a8
ea6be63
46f56aa
a6eea30
20fbd2c
ea6be63
 
a179120
c5aeabe
a179120
c5aeabe
ea6be63
0c1cd5d
46f56aa
a038694
46f56aa
ad44818
ce750f8
 
 
ad44818
 
fd37461
 
 
ad44818
20fbd2c
 
ea6be63
a179120
f55e2f6
0194a83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20fbd2c
ce750f8
 
 
46f56aa
8fdb143
04147ae
8fdb143
04147ae
9df2551
 
 
 
 
 
ea6be63
9df2551
23141e5
8fdb143
e8fe4a8
8fdb143
a179120
23141e5
a179120
ea6be63
0c1cd5d
8fdb143
 
a179120
7a6b1de
9df2551
ea6be63
a179120
ea6be63
 
a179120
ad44818
23141e5
a179120
04147ae
46f56aa
ea6be63
e8fe4a8
0194a83
ea6be63
e8fe4a8
46f56aa
9df2551
ea6be63
0194a83
 
 
784b064
0194a83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
784b064
a179120
0194a83
784b064
e8fe4a8
ce750f8
 
 
ad44818
9df2551
23141e5
a6eea30
e8fe4a8
9df2551
 
e8fe4a8
9df2551
a179120
9df2551
 
 
 
 
 
 
a179120
9df2551
e8fe4a8
9df2551
 
ce750f8
9df2551
 
46f56aa
e8fe4a8
a179120
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import os
import json
import torch
from typing import List, Dict
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings
import logging

# --- Basic Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class PolicyVectorDB:
    """
    Manages the connection, population, and querying of a ChromaDB vector database 
    for policy documents.
    """
    def __init__(self, persist_directory: str, top_k_default: int = 5, relevance_threshold: float = 0.5):
        self.persist_directory = persist_directory
        self.client = chromadb.PersistentClient(path=persist_directory, settings=Settings(allow_reset=True))
        self.collection_name = "neepco_dop_policies"
        
        # Using a powerful open-source embedding model.
        # Change 'cpu' to 'cuda' if a GPU is available for significantly faster embedding.
        logger.info("Loading embedding model 'BAAI/bge-large-en-v1.5'. This may take a moment...")
        self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cpu') 
        logger.info("Embedding model loaded successfully.")
        
        self.collection = None # Initialize collection as None for lazy loading
        self.top_k_default = top_k_default
        self.relevance_threshold = relevance_threshold

    def _get_collection(self):
        """
        Retrieves or creates the ChromaDB collection. Implements lazy loading.
        """
        if self.collection is None:
            self.collection = self.client.get_or_create_collection(
                name=self.collection_name,
                metadata={"description": "NEEPCO Delegation of Powers Policy"}
            )
        return self.collection

    def _flatten_metadata(self, metadata: Dict) -> Dict:
        """Ensures all metadata values are strings, as required by some ChromaDB versions."""
        return {key: str(value) for key, value in metadata.items()}

    def expand_query(self, query_text: str) -> List[str]:
        """
        Generates query variations to improve retrieval.
        Uses simple heuristics - zero LLM cost.
        """
        queries = [query_text]
        
        # Expand with synonyms for policy-related terms
        synonyms = {
            "approval": ["approval", "consent", "authorization", "permission"],
            "limit": ["limit", "threshold", "ceiling", "maximum"],
            "authority": ["authority", "official", "person", "representative"],
            "delegate": ["delegate", "authorize", "empower", "assign"],
            "power": ["power", "authority", "delegation", "responsibility"],
            "financial": ["financial", "monetary", "funds", "budget"],
        }
        
        for term, variants in synonyms.items():
            if term in query_text.lower():
                for variant in variants:
                    if variant.lower() not in query_text.lower():
                        expanded = query_text.replace(term, variant)
                        if expanded not in queries:
                            queries.append(expanded)
                        if len(queries) >= 4:
                            break
            if len(queries) >= 4:
                break
        
        return queries[:4]  # Limit to 4 variations

    def add_chunks(self, chunks: List[Dict]):
        """
        Adds a list of chunks to the vector database, skipping any that already exist.
        """
        collection = self._get_collection()
        if not chunks:
            logger.info("No chunks provided to add.")
            return

        chunks_with_ids = [c for c in chunks if c.get('id')]
        if len(chunks) != len(chunks_with_ids):
            logger.warning(f"Skipped {len(chunks) - len(chunks_with_ids)} chunks that were missing an 'id'.")
        if not chunks_with_ids:
            return

        existing_ids = set(collection.get(ids=[str(c['id']) for c in chunks_with_ids])['ids'])
        new_chunks = [chunk for chunk in chunks_with_ids if str(chunk.get('id')) not in existing_ids]

        if not new_chunks:
            logger.info("All provided chunks already exist in the database. No new data to add.")
            return
        
        logger.info(f"Adding {len(new_chunks)} new chunks to the vector database...")
        
        # Process in batches for efficiency
        batch_size = 32 # Reduced batch size for potentially large embeddings
        for i in range(0, len(new_chunks), batch_size):
            batch = new_chunks[i:i + batch_size]
            
            ids = [str(chunk['id']) for chunk in batch]
            texts = [chunk['text'] for chunk in batch]
            metadatas = [self._flatten_metadata(chunk.get('metadata', {})) for chunk in batch]
            
            # For BGE models, it's recommended not to add instructions to the document embeddings
            embeddings = self.embedding_model.encode(texts, normalize_embeddings=True, show_progress_bar=False).tolist()
            
            collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
            logger.info(f"Added batch {i//batch_size + 1}/{(len(new_chunks) + batch_size - 1) // batch_size}")
            
        logger.info(f"Finished adding {len(new_chunks)} chunks.")

    def search(self, query_text: str, top_k: int = None) -> List[Dict]:
        """
        Searches the vector database for a given query text with expansion.
        Returns a list of results filtered by a relevance threshold.
        """
        collection = self._get_collection()
        k = top_k if top_k is not None else self.top_k_default
        
        # Expand query for better recall
        queries = self.expand_query(query_text)
        all_results = {}
        
        for query in queries:
            # Add the recommended instruction prefix for BGE retrieval models.
            instructed_query = f"Represent this sentence for searching relevant passages: {query}"
            
            # Normalize embeddings for more accurate similarity search.
            query_embedding = self.embedding_model.encode([instructed_query], normalize_embeddings=True).tolist()
            
            # Retrieve more results initially to allow for filtering
            results = collection.query(
                query_embeddings=query_embedding,
                n_results=k * 2, # Retrieve more to filter by threshold
                include=["documents", "metadatas", "distances"]
            )
            
            if results and results.get('documents') and results['documents'][0]:
                for i, doc in enumerate(results['documents'][0]):
                    # The distance for normalized embeddings is often interpreted as 1 - cosine_similarity
                    relevance_score = 1 - results['distances'][0][i]
                    
                    if relevance_score >= self.relevance_threshold:
                        key = doc  # Use document text as key
                        # Keep highest relevance score for duplicate documents
                        if key not in all_results or relevance_score > all_results[key]['relevance_score']:
                            all_results[key] = {
                                'text': doc,
                                'metadata': results['metadatas'][0][i],
                                'relevance_score': relevance_score
                            }
        
        # Sort by relevance score and return the top_k results
        return sorted(all_results.values(), key=lambda x: x['relevance_score'], reverse=True)[:k]

def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str) -> bool:
    """
    Checks if the DB is empty and populates it from a JSONL file if needed.
    """
    try:
        if db_instance._get_collection().count() > 0:
            logger.info("Vector database already contains data. Skipping population.")
            return True

        logger.info("Vector database is empty. Attempting to populate from chunks file.")
        if not os.path.exists(chunks_file_path):
            logger.error(f"Chunks file not found at '{chunks_file_path}'. Cannot populate DB.")
            return False
        
        chunks_to_add = []
        with open(chunks_file_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    chunks_to_add.append(json.loads(line))
                except json.JSONDecodeError:
                    logger.warning(f"Skipping malformed line in chunks file: {line.strip()}")
        
        if not chunks_to_add:
            logger.warning(f"Chunks file at '{chunks_file_path}' is empty or invalid. No data to add.")
            return False

        db_instance.add_chunks(chunks_to_add)
        logger.info("Vector database population attempt complete.")
        return True
    except Exception as e:
        logger.error(f"An error occurred during DB population check: {e}", exc_info=True)
        return False