| | import torch |
| | import numpy as np |
| | from collections import OrderedDict |
| | import pandas as pd |
| | import os |
| | from tqdm import tqdm |
| | import cv2 |
| | from utils.misc import split_np_imgrid, get_np_imgrid |
| |
|
| |
|
| | def cal_ber(tn, tp, fn, fp): |
| | return 0.5*(fp/(tn+fp) + fn/(fn+tp)) |
| |
|
| | def cal_acc(tn, tp, fn, fp): |
| | return (tp + tn) / (tp + tn + fp + fn) |
| |
|
| |
|
| | def get_binary_classification_metrics(pred, gt, threshold=None): |
| | if threshold is not None: |
| | gt = (gt > threshold) |
| | pred = (pred > threshold) |
| | TP = np.logical_and(gt, pred).sum() |
| | TN = np.logical_and(np.logical_not(gt), np.logical_not(pred)).sum() |
| | FN = np.logical_and(gt, np.logical_not(pred)).sum() |
| | FP = np.logical_and(np.logical_not(gt), pred).sum() |
| | BER = cal_ber(TN, TP, FN, FP) |
| | ACC = cal_acc(TN, TP, FN, FP) |
| | return OrderedDict( [('TP', TP), |
| | ('TN', TN), |
| | ('FP', FP), |
| | ('FN', FN), |
| | ('BER', BER), |
| | ('ACC', ACC)] |
| | ) |
| |
|
| |
|
| | def evaluate(res_root, pred_id, gt_id, nimg, nrow, threshold): |
| | img_names = os.listdir(res_root) |
| | score_dict = OrderedDict() |
| |
|
| | for img_name in img_names: |
| | im_grid_path = os.path.join(res_root, img_name) |
| | im_grid = cv2.imread(im_grid_path) |
| | ims = split_np_imgrid(im_grid, nimg, nrow) |
| | pred = ims[pred_id] |
| | gt = ims[gt_id] |
| | score_dict[img_name] = get_binary_classification_metrics(pred, |
| | gt, |
| | threshold) |
| | |
| | df = pd.DataFrame(score_dict) |
| | df['ave'] = df.mean(axis=1) |
| |
|
| | tn = df['ave']['TN'] |
| | tp = df['ave']['TP'] |
| | fn = df['ave']['FN'] |
| | fp = df['ave']['FP'] |
| |
|
| | pos_err = (1 - tp / (tp + fn)) * 100 |
| | neg_err = (1 - tn / (tn + fp)) * 100 |
| | ber = (pos_err + neg_err) / 2 |
| | acc = (tn + tp) / (tn + tp + fn + fp) |
| |
|
| | return pos_err, neg_err, ber, acc, df |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | |
| |
|
| | class AverageMeter(object): |
| | """Computes and stores the average and current value""" |
| | def __init__(self): |
| | self.sum = 0 |
| | self.count = 0 |
| |
|
| | def update(self, val, weight=1): |
| | self.sum += val * weight |
| | self.count += weight |
| |
|
| | def average(self): |
| | if self.count == 0: |
| | return 0 |
| | else: |
| | return self.sum / self.count |
| |
|
| | def clear(self): |
| | self.sum = 0 |
| | self.count = 0 |
| |
|
| | def compute_cm_torch(y_pred, y_label, n_class): |
| | mask = (y_label >= 0) & (y_label < n_class) |
| | hist = torch.bincount(n_class * y_label[mask] + y_pred[mask], |
| | minlength=n_class**2).reshape(n_class, n_class) |
| | return hist |
| |
|
| | class MyConfuseMatrixMeter(AverageMeter): |
| | """More Clear Confusion Matrix Meter""" |
| | def __init__(self, n_class): |
| | super(MyConfuseMatrixMeter, self).__init__() |
| | self.n_class = n_class |
| |
|
| | def update_cm(self, y_pred, y_label, weight=1): |
| | y_label = y_label.type(torch.int64) |
| | val = compute_cm_torch(y_pred=y_pred.flatten(), y_label=y_label.flatten(), |
| | n_class=self.n_class) |
| | self.update(val, weight) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def get_scores_binary(self): |
| | assert self.n_class == 2, "this function can only be called for binary calssification problem" |
| | tn, fp, fn, tp = self.sum.flatten() |
| | eps = torch.finfo(torch.float32).eps |
| | pos_err = (1 - tp / (tp + fn + eps)) * 100 |
| | neg_err = (1 - tn / (tn + fp + eps)) * 100 |
| | ber = (pos_err + neg_err) / 2 |
| | acc = (tn + tp) / (tn + tp + fn + fp + eps) |
| | score_dict = {} |
| | score_dict['pos_err'] = pos_err |
| | score_dict['neg_err'] = neg_err |
| | score_dict['ber'] = ber |
| | score_dict['acc'] = acc |
| | return score_dict |
| |
|