hf-rag-multi / src /retriever.py
siyu618's picture
Upload 18 files
94f5c4b verified
raw
history blame contribute delete
849 Bytes
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from config.rag_config import RAGConfig
class Retriever:
def __init__(self, embeddings, texts, config: RAGConfig):
self.embeddings = embeddings
self.texts = texts
self.top_k = config.top_k
self.threshold = config.similarity_threshold
def retrieve(self, query_embedding):
scores = cosine_similarity([query_embedding], self.embeddings)[0]
# 阈值过滤
filtered = [(self.texts[i], float(scores[i]))
for i in np.argsort(scores)[::-1]
if scores[i] >= self.threshold]
results = filtered[:self.top_k]
if not results:
best_idx = int(np.argmax(scores))
results = [(self.texts[best_idx], float(scores[best_idx]))]
return results