| | import math |
| |
|
| | def get_topk_results(predictions, scores, targets, k, all_items=None): |
| | |
| | results = [] |
| | B = len(targets) |
| | predictions = [_.split("Response:")[-1] for _ in predictions] |
| | predictions = [_.strip().replace(" ","") for _ in predictions] |
| | |
| |
|
| | if all_items is not None: |
| | for i, seq in enumerate(predictions): |
| | if seq not in all_items: |
| | scores[i] = -1000 |
| |
|
| | for b in range(B): |
| | batch_seqs = predictions[b * k: (b + 1) * k] |
| | batch_scores = scores[b * k: (b + 1) * k] |
| |
|
| | pairs = [(a, b) for a, b in zip(batch_seqs, batch_scores)] |
| | sorted_pairs = sorted(pairs, key=lambda x: x[1], reverse=True) |
| | target_item = targets[b] |
| | one_results = [] |
| | for sorted_pred in sorted_pairs: |
| | if sorted_pred[0] == target_item: |
| | one_results.append(1) |
| | else: |
| | one_results.append(0) |
| |
|
| | results.append(one_results) |
| |
|
| | |
| | return results |
| |
|
| | def get_metrics_results(topk_results, metrics): |
| | res = {} |
| | for m in metrics: |
| | if m.lower().startswith("hit"): |
| | k = int(m.split("@")[1]) |
| | res[m] = hit_k(topk_results, k) |
| | elif m.lower().startswith("ndcg"): |
| | k = int(m.split("@")[1]) |
| | res[m] = ndcg_k(topk_results, k) |
| | else: |
| | raise NotImplementedError |
| |
|
| | return res |
| |
|
| |
|
| | def ndcg_k(topk_results, k): |
| |
|
| | ndcg = 0.0 |
| | for row in topk_results: |
| | res = row[:k] |
| | one_ndcg = 0.0 |
| | for i in range(len(res)): |
| | one_ndcg += res[i] / math.log(i + 2, 2) |
| | ndcg += one_ndcg |
| | return ndcg |
| |
|
| |
|
| | def hit_k(topk_results, k): |
| | hit = 0.0 |
| | for row in topk_results: |
| | res = row[:k] |
| | if sum(res) > 0: |
| | hit += 1 |
| | return hit |
| |
|
| |
|