File size: 849 Bytes
94f5c4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from config.rag_config import RAGConfig

class Retriever:
    def __init__(self, embeddings, texts, config: RAGConfig):
        self.embeddings = embeddings
        self.texts = texts
        self.top_k = config.top_k
        self.threshold = config.similarity_threshold

    def retrieve(self, query_embedding):
        scores = cosine_similarity([query_embedding], self.embeddings)[0]
        # 阈值过滤
        filtered = [(self.texts[i], float(scores[i]))
                    for i in np.argsort(scores)[::-1]
                    if scores[i] >= self.threshold]
        results = filtered[:self.top_k]
        if not results:
            best_idx = int(np.argmax(scores))
            results = [(self.texts[best_idx], float(scores[best_idx]))]
        return results