|
|
"""
|
|
|
FILE: 06_reranker.py
|
|
|
|
|
|
PURPOSE:
|
|
|
- Improve ranking accuracy by comparing query + result pairs using a CrossEncoder
|
|
|
- Works on top FAISS candidates and reorders them based on semantic relevance
|
|
|
|
|
|
REQUIREMENTS:
|
|
|
pip install sentence-transformers
|
|
|
"""
|
|
|
|
|
|
from sentence_transformers import CrossEncoder
|
|
|
|
|
|
|
|
|
|
|
|
RERANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-12-v2"
|
|
|
|
|
|
class Reranker:
|
|
|
|
|
|
def __init__(self):
|
|
|
print(f"π€ Loading reranking model: {RERANK_MODEL}")
|
|
|
self.model = CrossEncoder(RERANK_MODEL)
|
|
|
|
|
|
def rerank(self, query, candidates):
|
|
|
"""
|
|
|
candidates = list of dict objects:
|
|
|
[
|
|
|
{"name": "", "domain": "", "category": "", "region": "", "text": "...", "score": number}
|
|
|
]
|
|
|
"""
|
|
|
|
|
|
|
|
|
pairs = []
|
|
|
for c in candidates:
|
|
|
clean_text = c["text"].replace("β’", ", ").replace(" ", " ").strip()
|
|
|
pairs.append((query, clean_text))
|
|
|
|
|
|
scores = self.model.predict(pairs)
|
|
|
|
|
|
|
|
|
for i, s in enumerate(scores):
|
|
|
candidates[i]["rerank_score"] = float(s)
|
|
|
|
|
|
return sorted(candidates, key=lambda x: x["rerank_score"], reverse=True)
|
|
|
|