File size: 546 Bytes
7d4ed22
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
from sentence_transformers import CrossEncoder


class Reranker:
    def __init__(self, model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"):
        self.model = CrossEncoder(model_name)

    def rerank_results(self, query: str, results: list[dict], top_n: int = 5) -> list[dict]:
        pairs = [(query, r["text"]) for r in results if r.get("text")]
        scores = self.model.predict(pairs)
        scored_results = sorted(zip(scores, results), key=lambda x: x[0], reverse=True)
        return [r for _, r in scored_results[:top_n]]