Spaces:
Configuration error
Configuration error
| # train.py | |
| #!/usr/bin/env python3 | |
| """ valuate network using pytorch | |
| Junde Wu | |
| """ | |
| import os | |
| import sys | |
| import argparse | |
| from datetime import datetime | |
| from collections import OrderedDict | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| from skimage import io | |
| from torch.utils.data import DataLoader | |
| #from dataset import * | |
| from torch.autograd import Variable | |
| from PIL import Image | |
| from tensorboardX import SummaryWriter | |
| #from models.discriminatorlayer import discriminator | |
| from dataset import * | |
| from conf import settings | |
| import time | |
| import cfg | |
| from tqdm import tqdm | |
| from torch.utils.data import DataLoader, random_split | |
| from utils import * | |
| import function | |
| def main(): | |
| args = cfg.parse_args() | |
| if args.dataset == 'refuge' or args.dataset == 'refuge2': | |
| args.data_path = '../dataset' | |
| GPUdevice = torch.device('cuda', args.gpu_device) | |
| net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed) | |
| '''load pretrained model''' | |
| assert args.weights != 0 | |
| print(f'=> resuming from {args.weights}') | |
| assert os.path.exists(args.weights) | |
| checkpoint_file = os.path.join(args.weights) | |
| assert os.path.exists(checkpoint_file) | |
| loc = 'cuda:{}'.format(args.gpu_device) | |
| checkpoint = torch.load(checkpoint_file, map_location=loc) | |
| start_epoch = checkpoint['epoch'] | |
| best_tol = checkpoint['best_tol'] | |
| state_dict = checkpoint['state_dict'] | |
| if args.distributed != 'none': | |
| from collections import OrderedDict | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| # name = k[7:] # remove `module.` | |
| name = 'module.' + k | |
| new_state_dict[name] = v | |
| # load params | |
| else: | |
| new_state_dict = state_dict | |
| net.load_state_dict(new_state_dict) | |
| # args.path_helper = checkpoint['path_helper'] | |
| # logger = create_logger(args.path_helper['log_path']) | |
| # print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') | |
| # args.path_helper = set_log_dir('logs', args.exp_name) | |
| # logger = create_logger(args.path_helper['log_path']) | |
| # logger.info(args) | |
| args.path_helper = set_log_dir('logs', args.exp_name) | |
| logger = create_logger(args.path_helper['log_path']) | |
| logger.info(args) | |
| '''segmentation data''' | |
| nice_train_loader, nice_test_loader = get_dataloader(args) | |
| '''begain valuation''' | |
| best_acc = 0.0 | |
| best_tol = 1e4 | |
| if args.mod == 'sam_adpt': | |
| net.eval() | |
| if args.dataset != 'REFUGE': | |
| tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, start_epoch, net) | |
| logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {start_epoch}.') | |
| else: | |
| tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, start_epoch, net) | |
| logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {start_epoch}.') | |
| if __name__ == '__main__': | |
| main() | |