|
|
""" |
|
|
FILE: 06_reranker.py |
|
|
|
|
|
PURPOSE: |
|
|
- Improve ranking accuracy by comparing query + result pairs using a CrossEncoder |
|
|
- Works on top FAISS candidates and reorders them based on semantic relevance |
|
|
|
|
|
REQUIREMENTS: |
|
|
pip install sentence-transformers |
|
|
""" |
|
|
|
|
|
from sentence_transformers import CrossEncoder |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RERANK_MODEL = "BAAI/bge-reranker-base" |
|
|
|
|
|
class Reranker: |
|
|
|
|
|
def __init__(self): |
|
|
print(f"π€ Loading reranking model: {RERANK_MODEL} (Max Accuracy Mode)") |
|
|
|
|
|
self.model = CrossEncoder(RERANK_MODEL, max_length=512) |
|
|
|
|
|
def rerank(self, query, candidates): |
|
|
""" |
|
|
candidates = list of dict objects: |
|
|
[ |
|
|
{"name": "", "domain": "", "category": "", "region": "", "text": "...", "score": number} |
|
|
] |
|
|
""" |
|
|
if not candidates: |
|
|
return [] |
|
|
|
|
|
|
|
|
pairs = [] |
|
|
for c in candidates: |
|
|
|
|
|
|
|
|
|
|
|
clean_text = f"{c['name']} ({c['category']} in {c['region']}): {c['text'].replace('β’', ', ')}" |
|
|
pairs.append((query, clean_text)) |
|
|
|
|
|
|
|
|
scores = self.model.predict(pairs) |
|
|
|
|
|
|
|
|
for i, s in enumerate(scores): |
|
|
candidates[i]["rerank_score"] = float(s) |
|
|
|
|
|
return sorted(candidates, key=lambda x: x["rerank_score"], reverse=True) |