davidtran999 commited on
Commit
5aa8ea6
·
verified ·
1 Parent(s): c105380

Upload backend/hue_portal/core/reranker.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. backend/hue_portal/core/reranker.py +199 -0
backend/hue_portal/core/reranker.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reranker module using BGE Reranker v2 M3 for improved document ranking.
3
+ Reduces top-8 results to top-3 most relevant chunks, cutting prompt size by ~40%.
4
+ """
5
+ import logging
6
+ from typing import List, Any, Optional
7
+ import os
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Global reranker instance (lazy loaded)
12
+ _reranker = None
13
+ _reranker_model_name = None
14
+
15
+
16
+ def get_reranker(model_name: Optional[str] = None):
17
+ """
18
+ Get or initialize BGE Reranker model.
19
+
20
+ Args:
21
+ model_name: Model name (default: BAAI/bge-reranker-v2-m3)
22
+
23
+ Returns:
24
+ Reranker model instance or None if unavailable.
25
+ """
26
+ global _reranker, _reranker_model_name
27
+
28
+ model_name = model_name or os.environ.get(
29
+ "RERANKER_MODEL",
30
+ "BAAI/bge-reranker-v2-m3"
31
+ )
32
+
33
+ # Return cached model if already loaded
34
+ if _reranker is not None and _reranker_model_name == model_name:
35
+ return _reranker
36
+
37
+ # Try FlagEmbedding first (best performance)
38
+ try:
39
+ from FlagEmbedding import FlagReranker
40
+
41
+ print(f"[RERANKER] Loading FlagEmbedding model: {model_name}", flush=True)
42
+ logger.info("[RERANKER] Loading FlagEmbedding model: %s", model_name)
43
+
44
+ _reranker = FlagReranker(model_name, use_fp16=False) # Use FP32 for CPU compatibility
45
+ _reranker_model_name = model_name
46
+
47
+ print(f"[RERANKER] ✅ FlagEmbedding model loaded successfully", flush=True)
48
+ logger.info("[RERANKER] ✅ FlagEmbedding model loaded successfully")
49
+
50
+ return _reranker
51
+ except ImportError:
52
+ print("[RERANKER] ⚠️ FlagEmbedding not available, trying sentence-transformers CrossEncoder...", flush=True)
53
+ logger.warning("[RERANKER] FlagEmbedding not available, trying CrossEncoder")
54
+ except Exception as e:
55
+ print(f"[RERANKER] ⚠️ FlagEmbedding failed: {e}, trying CrossEncoder...", flush=True)
56
+ logger.warning("[RERANKER] FlagEmbedding failed: %s, trying CrossEncoder", e)
57
+
58
+ # Fallback: Use sentence-transformers CrossEncoder (compatible with modern transformers)
59
+ try:
60
+ from sentence_transformers import CrossEncoder
61
+
62
+ # Use a lightweight cross-encoder model
63
+ fallback_model = "cross-encoder/ms-marco-MiniLM-L-6-v2"
64
+ print(f"[RERANKER] Loading CrossEncoder fallback: {fallback_model}", flush=True)
65
+ logger.info("[RERANKER] Loading CrossEncoder fallback: %s", fallback_model)
66
+
67
+ # Set timeout for model download (30 seconds)
68
+ os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", "30")
69
+
70
+ _reranker = CrossEncoder(fallback_model, max_length=512)
71
+ _reranker_model_name = fallback_model
72
+
73
+ print(f"[RERANKER] ✅ CrossEncoder loaded successfully", flush=True)
74
+ logger.info("[RERANKER] ✅ CrossEncoder loaded successfully")
75
+
76
+ return _reranker
77
+ except ImportError:
78
+ print(f"[RERANKER] ❌ sentence-transformers not installed. Install with: pip install sentence-transformers", flush=True)
79
+ logger.error("[RERANKER] sentence-transformers not installed")
80
+ return None
81
+ except Exception as e:
82
+ print(f"[RERANKER] ❌ Failed to load CrossEncoder fallback: {e}", flush=True)
83
+ logger.error("[RERANKER] Failed to load CrossEncoder fallback: %s", e)
84
+ return None
85
+
86
+
87
+ def rerank_documents(
88
+ query: str,
89
+ documents: List[Any],
90
+ top_k: int = 3,
91
+ model_name: Optional[str] = None
92
+ ) -> List[Any]:
93
+ """
94
+ Rerank documents using BGE Reranker v2 M3.
95
+
96
+ Args:
97
+ query: Search query.
98
+ documents: List of document objects (must have 'data' attribute with content).
99
+ top_k: Number of top results to return (default: 3).
100
+ model_name: Optional model name override.
101
+
102
+ Returns:
103
+ Top-k reranked documents.
104
+ """
105
+ if not documents or not query:
106
+ return documents[:top_k]
107
+
108
+ if len(documents) <= top_k:
109
+ # No need to rerank if we already have <= top_k results
110
+ return documents
111
+
112
+ reranker = get_reranker(model_name)
113
+ if reranker is None:
114
+ # Fallback: return top-k by original score
115
+ return documents[:top_k]
116
+
117
+ try:
118
+ # Prepare pairs for reranking: (query, document_text)
119
+ pairs = []
120
+ doc_objects = []
121
+
122
+ for doc in documents:
123
+ # Extract text from document
124
+ doc_data = getattr(doc, "data", doc) if hasattr(doc, "data") else doc
125
+
126
+ # Build text representation
127
+ text_parts = []
128
+ if hasattr(doc_data, "content"):
129
+ text_parts.append(str(doc_data.content))
130
+ if hasattr(doc_data, "section_title"):
131
+ text_parts.append(str(doc_data.section_title))
132
+ if hasattr(doc_data, "section_code"):
133
+ text_parts.append(str(doc_data.section_code))
134
+
135
+ # Fallback: try to get text from dict
136
+ if not text_parts and isinstance(doc_data, dict):
137
+ text_parts.append(str(doc_data.get("content", "")))
138
+ text_parts.append(str(doc_data.get("section_title", "")))
139
+
140
+ doc_text = " ".join(text_parts).strip()
141
+ if doc_text:
142
+ pairs.append((query, doc_text))
143
+ doc_objects.append(doc)
144
+
145
+ if not pairs:
146
+ return documents[:top_k]
147
+
148
+ # Rerank using cross-encoder
149
+ print(f"[RERANKER] Reranking {len(pairs)} documents...", flush=True)
150
+ logger.debug("[RERANKER] Reranking %d documents", len(pairs))
151
+
152
+ # Handle different reranker types
153
+ from FlagEmbedding import FlagReranker
154
+ from sentence_transformers import CrossEncoder
155
+
156
+ if isinstance(reranker, FlagReranker):
157
+ # FlagReranker.compute_score returns list of scores for multiple pairs
158
+ scores = reranker.compute_score(pairs, normalize=True)
159
+
160
+ # Handle both single score (float) and list of scores
161
+ if isinstance(scores, (int, float)):
162
+ scored_docs = [(doc_objects[0], float(scores))]
163
+ elif isinstance(scores, list):
164
+ scored_docs = list(zip(doc_objects, scores))
165
+ else:
166
+ logger.warning("[RERANKER] Unexpected score type: %s", type(scores))
167
+ return documents[:top_k]
168
+ elif isinstance(reranker, CrossEncoder):
169
+ # CrossEncoder.predict returns numpy array
170
+ scores = reranker.predict(pairs)
171
+ if hasattr(scores, "tolist"):
172
+ scores = scores.tolist()
173
+ elif not isinstance(scores, list):
174
+ scores = [float(scores)] if len(pairs) == 1 else list(scores)
175
+ scored_docs = list(zip(doc_objects, scores))
176
+ else:
177
+ logger.warning("[RERANKER] Unknown reranker type: %s", type(reranker))
178
+ return documents[:top_k]
179
+
180
+ # Sort by score (descending)
181
+ scored_docs.sort(key=lambda x: x[1], reverse=True)
182
+
183
+ # Return top-k
184
+ reranked = [doc for doc, score in scored_docs[:top_k]]
185
+
186
+ print(f"[RERANKER] ✅ Reranked to top-{top_k} (scores: {[f'{s:.3f}' for _, s in scored_docs[:top_k]]})", flush=True)
187
+ logger.debug(
188
+ "[RERANKER] ✅ Reranked to top-%d (scores: %s)",
189
+ top_k,
190
+ [f"{s:.3f}" for _, s in scored_docs[:top_k]]
191
+ )
192
+
193
+ return reranked
194
+
195
+ except Exception as e:
196
+ print(f"[RERANKER] ❌ Reranking failed: {e}, falling back to original order", flush=True)
197
+ logger.error("[RERANKER] Reranking failed: %s", e, exc_info=True)
198
+ return documents[:top_k]
199
+