| import argparse | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from model.CPICANN import CPICANN | |
| from model.dataset import XrdDataset | |
| def get_cs_anno(): | |
| vs = pd.read_csv(args.anno_struc).values | |
| csAnno = {} | |
| for v in vs: | |
| csAnno[v[1]] = v[6] | |
| return csAnno | |
| def get_acc(cls, label): | |
| correct_cnt = sum(cls.argmax(1) == label.int()) | |
| cls_acc = correct_cnt / cls.shape[0] | |
| return cls_acc, correct_cnt | |
| def run_one_epoch(model, dataloader): | |
| model.eval() | |
| csAnno = get_cs_anno() | |
| csCorrect = [0 for _ in range(7)] | |
| csTotal = [0 for _ in range(7)] | |
| cMtrx = [[0 for _ in range(7)] for _ in range(7)] | |
| epoch_loss, cls_acc = 0, 0 | |
| correct_cnt, total_cnt = 0, 0 | |
| pbar = tqdm(total=len(dataloader.dataset), desc='Evaluating... ', unit='data') | |
| iters = len(dataloader) | |
| for i, batch in enumerate(dataloader): | |
| data = batch[0].to(args.device) | |
| label_cls = batch[1].to(args.device) | |
| with torch.no_grad(): | |
| logits = model(data) | |
| logits.to(args.device) | |
| pbar.update(len(data)) | |
| _cls_acc, correct = get_acc(logits, label_cls) | |
| cls_acc += _cls_acc.item() | |
| correct_cnt += correct.item() | |
| total_cnt += len(data) | |
| preds = logits.argmax(1) | |
| for gt, pred in zip(label_cls, preds): | |
| cs_gt = csAnno[gt.item()] | |
| cMtrx[cs_gt][csAnno[pred.item()]] += 1 | |
| csTotal[cs_gt] += 1 | |
| if gt == pred: | |
| csCorrect[cs_gt] += 1 | |
| return epoch_loss / iters, cls_acc * 100 / iters, correct_cnt, total_cnt, cMtrx, csCorrect, csTotal | |
| def main(): | |
| model = CPICANN(embed_dim=128, num_classes=args.num_classes) | |
| loaded = torch.load(args.load_path) | |
| model.load_state_dict(loaded['model']) | |
| model.to(args.device) | |
| model.eval() | |
| print('loaded model from {}'.format(args.load_path)) | |
| print(model) | |
| valset = XrdDataset(args.data_dir, args.anno_val) | |
| val_loader = DataLoader(valset, batch_size=128, num_workers=16, pin_memory=True, shuffle=True) | |
| loss_val, acc_val, correct_cnt, total_cnt, cMtrx, csCorrect, csTotal = run_one_epoch(model, val_loader) | |
| print("loss_val: ", loss_val) | |
| print("acc_val: ", acc_val) | |
| print("{}% ({}/{})".format(round(correct_cnt / total_cnt, 5) * 100, correct_cnt, total_cnt)) | |
| sums = np.array(cMtrx).sum(axis=1) | |
| for i, row in enumerate(cMtrx): | |
| buf = "" | |
| for j, v in enumerate(row): | |
| buf += "{}({}%) ".format(v, round(v / sums[i] * 100, 2)) | |
| print(buf) | |
| print("csCorrect: ", csCorrect) | |
| print("csTotal: ", csTotal) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--device', default='cuda:0', type=str) | |
| parser.add_argument('--data_dir', default='data/val/', type=str) | |
| parser.add_argument('--load_path', default='pretrained/single-phase_checkpoint_0200.pth', type=str, | |
| help='path to load pretrained single-phase identification model') | |
| parser.add_argument('--anno_struc', default='annotation/anno_struc.csv', type=str, | |
| help='path to annotation file for training data') | |
| parser.add_argument('--anno_val', default='annotation/anno_val.csv', type=str, | |
| help='path to annotation file for validation data') | |
| parser.add_argument('--num_classes', default=23073, type=int, metavar='N') | |
| args = parser.parse_args() | |
| main() | |
| print('THE END') | |