CPICANN / src /val_single-phase.py
caobin's picture
Upload 24 files
38f7d61 verified
import argparse
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from model.CPICANN import CPICANN
from model.dataset import XrdDataset
def get_cs_anno():
vs = pd.read_csv(args.anno_struc).values
csAnno = {}
for v in vs:
csAnno[v[1]] = v[6]
return csAnno
def get_acc(cls, label):
correct_cnt = sum(cls.argmax(1) == label.int())
cls_acc = correct_cnt / cls.shape[0]
return cls_acc, correct_cnt
def run_one_epoch(model, dataloader):
model.eval()
csAnno = get_cs_anno()
csCorrect = [0 for _ in range(7)]
csTotal = [0 for _ in range(7)]
cMtrx = [[0 for _ in range(7)] for _ in range(7)]
epoch_loss, cls_acc = 0, 0
correct_cnt, total_cnt = 0, 0
pbar = tqdm(total=len(dataloader.dataset), desc='Evaluating... ', unit='data')
iters = len(dataloader)
for i, batch in enumerate(dataloader):
data = batch[0].to(args.device)
label_cls = batch[1].to(args.device)
with torch.no_grad():
logits = model(data)
logits.to(args.device)
pbar.update(len(data))
_cls_acc, correct = get_acc(logits, label_cls)
cls_acc += _cls_acc.item()
correct_cnt += correct.item()
total_cnt += len(data)
preds = logits.argmax(1)
for gt, pred in zip(label_cls, preds):
cs_gt = csAnno[gt.item()]
cMtrx[cs_gt][csAnno[pred.item()]] += 1
csTotal[cs_gt] += 1
if gt == pred:
csCorrect[cs_gt] += 1
return epoch_loss / iters, cls_acc * 100 / iters, correct_cnt, total_cnt, cMtrx, csCorrect, csTotal
def main():
model = CPICANN(embed_dim=128, num_classes=args.num_classes)
loaded = torch.load(args.load_path)
model.load_state_dict(loaded['model'])
model.to(args.device)
model.eval()
print('loaded model from {}'.format(args.load_path))
print(model)
valset = XrdDataset(args.data_dir, args.anno_val)
val_loader = DataLoader(valset, batch_size=128, num_workers=16, pin_memory=True, shuffle=True)
loss_val, acc_val, correct_cnt, total_cnt, cMtrx, csCorrect, csTotal = run_one_epoch(model, val_loader)
print("loss_val: ", loss_val)
print("acc_val: ", acc_val)
print("{}% ({}/{})".format(round(correct_cnt / total_cnt, 5) * 100, correct_cnt, total_cnt))
sums = np.array(cMtrx).sum(axis=1)
for i, row in enumerate(cMtrx):
buf = ""
for j, v in enumerate(row):
buf += "{}({}%) ".format(v, round(v / sums[i] * 100, 2))
print(buf)
print("csCorrect: ", csCorrect)
print("csTotal: ", csTotal)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--device', default='cuda:0', type=str)
parser.add_argument('--data_dir', default='data/val/', type=str)
parser.add_argument('--load_path', default='pretrained/single-phase_checkpoint_0200.pth', type=str,
help='path to load pretrained single-phase identification model')
parser.add_argument('--anno_struc', default='annotation/anno_struc.csv', type=str,
help='path to annotation file for training data')
parser.add_argument('--anno_val', default='annotation/anno_val.csv', type=str,
help='path to annotation file for validation data')
parser.add_argument('--num_classes', default=23073, type=int, metavar='N')
args = parser.parse_args()
main()
print('THE END')