| import math | |
| class RecMetric: | |
| def __init__(self, k_list=(1, 10, 50)): | |
| self.k_list = k_list | |
| self.metric = {} | |
| self.reset_metric() | |
| def evaluate(self, preds, labels): | |
| for label in labels: | |
| pred_list = preds | |
| if label == -100: | |
| continue | |
| for k in self.k_list: | |
| self.metric[f"recall@{k}"] += self.compute_recall(pred_list, label, k) | |
| self.metric[f"ndcg@{k}"] += self.compute_ndcg(pred_list, label, k) | |
| self.metric[f"mrr@{k}"] += self.compute_mrr(pred_list, label, k) | |
| self.metric["count"] += 1 | |
| def compute_recall(self, pred_list, label, k): | |
| return int(label in pred_list[:k]) | |
| def compute_mrr(self, pred_list, label, k): | |
| if label in pred_list[:k]: | |
| label_rank = pred_list.index(label) | |
| return 1 / (label_rank + 1) | |
| return 0 | |
| def compute_ndcg(self, pred_list, label, k): | |
| if label in pred_list[:k]: | |
| label_rank = pred_list.index(label) | |
| return 1 / math.log2(label_rank + 2) | |
| return 0 | |
| def reset_metric(self): | |
| for metric in ["recall", "ndcg", "mrr"]: | |
| for k in self.k_list: | |
| self.metric[f"{metric}@{k}"] = 0 | |
| self.metric["count"] = 0 | |
| def report(self): | |
| report = {} | |
| for k, v in self.metric.items(): | |
| if k != "count": | |
| report[k] = v / self.metric["count"] | |
| else: | |
| report[k] = v | |
| return report | |