File size: 11,350 Bytes
9679fcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 |
"""
Hybrid Retrieval System
Combines semantic search (HNSW) with keyword search (BM25) for optimal retrieval
"""
import json
import numpy as np
import hnswlib
from typing import List, Dict, Tuple
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
import pickle
from dataclasses import dataclass
@dataclass
class RetrievalResult:
"""Represents a retrieval result with metadata"""
chunk_id: str
text: str
source_title: str
source_url: str
semantic_score: float
keyword_score: float
combined_score: float
community_id: int
rank: int
class HybridRetriever:
"""Hybrid retrieval combining semantic and keyword search"""
def __init__(
self,
chunks_file: str,
graphrag_index_file: str,
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
embedding_dim: int = 384
):
self.chunks_file = chunks_file
self.graphrag_index_file = graphrag_index_file
self.embedding_dim = embedding_dim
# Load components
print("[INFO] Loading hybrid retriever components...")
self.embedding_model = SentenceTransformer(embedding_model)
self.chunks = self._load_chunks()
self.graphrag_index = self._load_graphrag_index()
# Build indexes
self.hnsw_index = None
self.bm25 = None
self.chunk_embeddings = None
print("[SUCCESS] Hybrid retriever initialized")
def _load_chunks(self) -> List[Dict]:
"""Load chunks from file"""
with open(self.chunks_file, 'r', encoding='utf-8') as f:
chunks = json.load(f)
print(f"[INFO] Loaded {len(chunks)} chunks")
return chunks
def _load_graphrag_index(self) -> Dict:
"""Load GraphRAG index"""
with open(self.graphrag_index_file, 'r', encoding='utf-8') as f:
index = json.load(f)
print(f"[INFO] Loaded GraphRAG index with {index['metadata']['total_communities']} communities")
return index
def build_semantic_index(self):
"""Build HNSW semantic search index"""
print("[INFO] Building semantic index with HNSW...")
# Generate embeddings for all chunks
chunk_texts = [chunk['text'] for chunk in self.chunks]
print(f"[INFO] Generating embeddings for {len(chunk_texts)} chunks...")
self.chunk_embeddings = self.embedding_model.encode(
chunk_texts,
show_progress_bar=True,
convert_to_numpy=True,
normalize_embeddings=True # L2 normalization for cosine similarity
)
# Build HNSW index with optimized parameters
import time
n_chunks = len(self.chunks)
print(f"[INFO] Building HNSW index for {n_chunks} chunks...")
start_build = time.time()
# Initialize HNSW index
# ef_construction: controls index build time/accuracy tradeoff (higher = more accurate but slower)
# M: number of bi-directional links per element (higher = better recall but more memory)
self.hnsw_index = hnswlib.Index(space='cosine', dim=self.embedding_dim)
# For 86K vectors, optimal parameters for speed + accuracy:
# M=64 gives excellent recall with reasonable memory
# ef_construction=200 balances build time and quality
self.hnsw_index.init_index(
max_elements=n_chunks,
ef_construction=200, # Higher = better quality, slower build
M=64, # Higher = better recall, more memory
random_seed=42
)
# Set number of threads for parallel insertion
self.hnsw_index.set_num_threads(8)
# Add all vectors to index
print(f"[INFO] Adding {n_chunks} vectors to index (using 8 threads)...")
self.hnsw_index.add_items(self.chunk_embeddings, np.arange(n_chunks))
build_time = time.time() - start_build
print(f"[SUCCESS] HNSW index built in {build_time:.1f} seconds ({build_time/60:.2f} minutes)")
print(f"[SUCCESS] Index contains {self.hnsw_index.get_current_count()} vectors")
def build_keyword_index(self):
"""Build BM25 keyword search index"""
print("[INFO] Building BM25 keyword index...")
# Tokenize chunks for BM25
tokenized_chunks = [chunk['text'].lower().split() for chunk in self.chunks]
# Build BM25 index
self.bm25 = BM25Okapi(tokenized_chunks)
print(f"[SUCCESS] BM25 index built for {len(tokenized_chunks)} chunks")
def semantic_search(self, query: str, top_k: int = 10) -> List[Tuple[int, float]]:
"""Semantic search using HNSW"""
# Encode query
query_embedding = self.embedding_model.encode(
[query],
convert_to_numpy=True,
normalize_embeddings=True
)
# Set ef (exploration factor) for search - higher = more accurate but slower
# For maximum accuracy, set ef = top_k * 2
self.hnsw_index.set_ef(max(top_k * 2, 100))
# Search in HNSW index
indices, distances = self.hnsw_index.knn_query(query_embedding, k=top_k)
# Convert cosine distances to similarity scores (1 - distance)
# HNSW returns distances, we want similarities
scores = 1 - distances[0]
# Return (index, score) tuples
results = [(int(idx), float(score)) for idx, score in zip(indices[0], scores)]
return results
def keyword_search(self, query: str, top_k: int = 10) -> List[Tuple[int, float]]:
"""Keyword search using BM25"""
# Tokenize query
query_tokens = query.lower().split()
# Get BM25 scores
scores = self.bm25.get_scores(query_tokens)
# Get top-k indices
top_indices = np.argsort(scores)[::-1][:top_k]
# Return (index, score) tuples
results = [(int(idx), float(scores[idx])) for idx in top_indices]
return results
def hybrid_search(
self,
query: str,
top_k: int = 10,
semantic_weight: float = 0.7,
keyword_weight: float = 0.3,
rerank: bool = True
) -> List[RetrievalResult]:
"""
Hybrid search combining semantic and keyword search
Args:
query: Search query
top_k: Number of results to return
semantic_weight: Weight for semantic scores (0-1)
keyword_weight: Weight for keyword scores (0-1)
rerank: Whether to rerank by community relevance
"""
# Get results from both search methods
semantic_results = self.semantic_search(query, top_k * 2) # Get more for fusion
keyword_results = self.keyword_search(query, top_k * 2)
# Normalize scores to [0, 1] range
def normalize_scores(results):
if not results:
return []
scores = [score for _, score in results]
min_score, max_score = min(scores), max(scores)
if max_score == min_score:
return [(idx, 1.0) for idx, _ in results]
return [(idx, (score - min_score) / (max_score - min_score))
for idx, score in results]
semantic_results = normalize_scores(semantic_results)
keyword_results = normalize_scores(keyword_results)
# Combine scores using reciprocal rank fusion
combined_scores = {}
for idx, score in semantic_results:
combined_scores[idx] = {
'semantic': score * semantic_weight,
'keyword': 0.0,
'combined': score * semantic_weight
}
for idx, score in keyword_results:
if idx in combined_scores:
combined_scores[idx]['keyword'] = score * keyword_weight
combined_scores[idx]['combined'] += score * keyword_weight
else:
combined_scores[idx] = {
'semantic': 0.0,
'keyword': score * keyword_weight,
'combined': score * keyword_weight
}
# Sort by combined score
sorted_indices = sorted(
combined_scores.items(),
key=lambda x: x[1]['combined'],
reverse=True
)[:top_k]
# Build retrieval results
results = []
for rank, (idx, scores) in enumerate(sorted_indices):
chunk = self.chunks[idx]
community_id = self.graphrag_index['node_to_community'].get(chunk['chunk_id'], -1)
result = RetrievalResult(
chunk_id=chunk['chunk_id'],
text=chunk['text'],
source_title=chunk['source_title'],
source_url=chunk['source_url'],
semantic_score=scores['semantic'],
keyword_score=scores['keyword'],
combined_score=scores['combined'],
community_id=community_id,
rank=rank + 1
)
results.append(result)
return results
def get_community_context(self, community_id: int) -> Dict:
"""Get context from a community"""
if str(community_id) in self.graphrag_index['communities']:
return self.graphrag_index['communities'][str(community_id)]
return {}
def save_indexes(self, output_dir: str = "dataset/wikipedia_ireland"):
"""Save indexes for fast loading"""
print("[INFO] Saving indexes...")
# Save HNSW index
self.hnsw_index.save_index(f"{output_dir}/hybrid_hnsw_index.bin")
# Save BM25 and embeddings
with open(f"{output_dir}/hybrid_indexes.pkl", 'wb') as f:
pickle.dump({
'bm25': self.bm25,
'embeddings': self.chunk_embeddings
}, f)
print(f"[SUCCESS] Indexes saved to {output_dir}")
def load_indexes(self, output_dir: str = "dataset/wikipedia_ireland"):
"""Load pre-built indexes"""
print("[INFO] Loading pre-built indexes...")
# Load HNSW index
self.hnsw_index = hnswlib.Index(space='cosine', dim=self.embedding_dim)
self.hnsw_index.load_index(f"{output_dir}/hybrid_hnsw_index.bin")
self.hnsw_index.set_num_threads(8) # Enable multi-threading for search
# Load BM25 and embeddings
with open(f"{output_dir}/hybrid_indexes.pkl", 'rb') as f:
data = pickle.load(f)
self.bm25 = data['bm25']
self.chunk_embeddings = data['embeddings']
print("[SUCCESS] Indexes loaded successfully")
if __name__ == "__main__":
# Build and save indexes
retriever = HybridRetriever(
chunks_file="dataset/wikipedia_ireland/chunks.json",
graphrag_index_file="dataset/wikipedia_ireland/graphrag_index.json"
)
retriever.build_semantic_index()
retriever.build_keyword_index()
retriever.save_indexes()
# Test hybrid search
query = "What is the capital of Ireland?"
results = retriever.hybrid_search(query, top_k=5)
print("\nHybrid Search Results:")
for result in results:
print(f"\nRank {result.rank}: {result.source_title}")
print(f"Score: {result.combined_score:.3f} (semantic: {result.semantic_score:.3f}, keyword: {result.keyword_score:.3f})")
print(f"Text: {result.text[:200]}...")
|