| import torch | |
| import torch.nn.functional as F | |
| class Accuracy: | |
| def __init__(self): | |
| pass | |
| def __repr__(self): | |
| return "Accuracy()" | |
| def test(self, label_pd, label_gt, ignore_label=-1): | |
| correct_cnt = 0 | |
| total_cnt = 0 | |
| with torch.no_grad(): | |
| label_pd = F.softmax(label_pd, dim=1) | |
| label_pd = torch.max(label_pd, 1)[1] | |
| label_gt = label_gt.long() | |
| c = (label_pd == label_gt) | |
| correct_cnt = torch.sum(c).item() | |
| total_cnt = c.size(0) - torch.sum(label_gt==ignore_label).item() | |
| return correct_cnt, total_cnt | |