Spaces:
Paused
Paused
| import os | |
| import sys | |
| import time | |
| import argparse | |
| import traceback | |
| import torch | |
| import torch.nn as nn | |
| from lib import utility | |
| from lib.utils import AverageMeter, convert_secs2time | |
| os.environ["MKL_THREADING_LAYER"] = "GNU" | |
| # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" | |
| def train(args): | |
| device_ids = args.device_ids | |
| nprocs = len(device_ids) | |
| if nprocs > 1: | |
| torch.multiprocessing.spawn( | |
| train_worker, args=(nprocs, 1, args), nprocs=nprocs, | |
| join=True) | |
| elif nprocs == 1: | |
| train_worker(device_ids[0], nprocs, 1, args) | |
| else: | |
| assert False | |
| def train_worker(world_rank, world_size, nodes_size, args): | |
| # initialize config. | |
| config = utility.get_config(args) | |
| config.device_id = world_rank if nodes_size == 1 else world_rank % torch.cuda.device_count() | |
| # set environment | |
| utility.set_environment(config) | |
| # initialize instances, such as writer, logger and wandb. | |
| if world_rank == 0: | |
| config.init_instance() | |
| if config.logger is not None: | |
| config.logger.info("\n" + "\n".join(["%s: %s" % item for item in config.__dict__.items()])) | |
| config.logger.info("Loaded configure file %s: %s" % (config.type, config.id)) | |
| # worker communication | |
| if world_size > 1: | |
| torch.distributed.init_process_group( | |
| backend="nccl", init_method="tcp://localhost:23456" if nodes_size == 1 else "env://", | |
| rank=world_rank, world_size=world_size) | |
| torch.cuda.set_device(config.device) | |
| # model | |
| net = utility.get_net(config) | |
| if world_size > 1: | |
| net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) | |
| net = net.float().to(config.device) | |
| net.train(True) | |
| if config.ema and world_rank == 0: | |
| net_ema = utility.get_net(config) | |
| if world_size > 1: | |
| net_ema = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net_ema) | |
| net_ema = net_ema.float().to(config.device) | |
| net_ema.eval() | |
| utility.accumulate_net(net_ema, net, 0) | |
| else: | |
| net_ema = None | |
| # multi-GPU training | |
| if world_size > 1: | |
| net_module = nn.parallel.DistributedDataParallel(net, device_ids=[config.device_id], | |
| output_device=config.device_id, find_unused_parameters=True) | |
| else: | |
| net_module = net | |
| criterions = utility.get_criterions(config) | |
| optimizer = utility.get_optimizer(config, net_module) | |
| scheduler = utility.get_scheduler(config, optimizer) | |
| # load pretrain model | |
| if args.pretrained_weight is not None: | |
| if not os.path.exists(args.pretrained_weight): | |
| pretrained_weight = os.path.join(config.work_dir, args.pretrained_weight) | |
| else: | |
| pretrained_weight = args.pretrained_weight | |
| try: | |
| checkpoint = torch.load(pretrained_weight) | |
| net.load_state_dict(checkpoint["net"], strict=False) | |
| if net_ema is not None: | |
| net_ema.load_state_dict(checkpoint["net_ema"], strict=False) | |
| if config.logger is not None: | |
| config.logger.warn("Successed to load pretrain model %s." % pretrained_weight) | |
| start_epoch = checkpoint["epoch"] | |
| optimizer.load_state_dict(checkpoint["optimizer"]) | |
| scheduler.load_state_dict(checkpoint["scheduler"]) | |
| except: | |
| start_epoch = 0 | |
| if config.logger is not None: | |
| config.logger.warn("Failed to load pretrain model %s." % pretrained_weight) | |
| else: | |
| start_epoch = 0 | |
| if config.logger is not None: | |
| config.logger.info("Loaded network") | |
| # data - train, val | |
| train_loader = utility.get_dataloader(config, "train", world_rank, world_size) | |
| if world_rank == 0: | |
| val_loader = utility.get_dataloader(config, "val") | |
| if config.logger is not None: | |
| config.logger.info("Loaded data") | |
| # forward & backward | |
| if config.logger is not None: | |
| config.logger.info("Optimizer type %s. Start training..." % (config.optimizer)) | |
| if not os.path.exists(config.model_dir) and world_rank == 0: | |
| os.makedirs(config.model_dir) | |
| # training | |
| best_metric, best_net = None, None | |
| epoch_time, eval_time = AverageMeter(), AverageMeter() | |
| for i_epoch, epoch in enumerate(range(config.max_epoch + 1)): | |
| try: | |
| epoch_start_time = time.time() | |
| if epoch >= start_epoch: | |
| # forward and backward | |
| if epoch != start_epoch: | |
| utility.forward_backward(config, train_loader, net_module, net, net_ema, criterions, optimizer, | |
| epoch) | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| # validating | |
| if epoch % config.val_epoch == 0 and epoch != 0 and world_rank == 0: | |
| eval_start_time = time.time() | |
| epoch_nets = {"net": net, "net_ema": net_ema} | |
| for net_name, epoch_net in epoch_nets.items(): | |
| if epoch_net is None: | |
| continue | |
| result, metrics = utility.forward(config, val_loader, epoch_net) | |
| for k, metric in enumerate(metrics): | |
| if config.logger is not None and len(metric) != 0: | |
| config.logger.info( | |
| "Val_{}/Metric{:3d} in this epoch: [NME {:.6f}, FR {:.6f}, AUC {:.6f}]".format( | |
| net_name, k, metric[0], metric[1], metric[2])) | |
| # update best model. | |
| cur_metric = metrics[config.key_metric_index][0] | |
| if best_metric is None or best_metric > cur_metric: | |
| best_metric = cur_metric | |
| best_net = epoch_net | |
| current_pytorch_model_path = os.path.join(config.model_dir, "best_model.pkl") | |
| # current_onnx_model_path = os.path.join(config.model_dir, "train.onnx") | |
| utility.save_model( | |
| config, | |
| epoch, | |
| best_net, | |
| net_ema, | |
| optimizer, | |
| scheduler, | |
| current_pytorch_model_path) | |
| if best_metric is not None: | |
| config.logger.info( | |
| "Val/Best_Metric%03d in this epoch: %.6f" % (config.key_metric_index, best_metric)) | |
| eval_time.update(time.time() - eval_start_time) | |
| # saving model | |
| if epoch == config.max_epoch and world_rank == 0: | |
| current_pytorch_model_path = os.path.join(config.model_dir, "last_model.pkl") | |
| # current_onnx_model_path = os.path.join(config.model_dir, "model_epoch_%s.onnx" % epoch) | |
| utility.save_model( | |
| config, | |
| epoch, | |
| net, | |
| net_ema, | |
| optimizer, | |
| scheduler, | |
| current_pytorch_model_path) | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| # adjusting learning rate | |
| if epoch > 0: | |
| scheduler.step() | |
| epoch_time.update(time.time() - epoch_start_time) | |
| last_time = convert_secs2time(epoch_time.avg * (config.max_epoch - i_epoch), True) | |
| if config.logger is not None: | |
| config.logger.info( | |
| "Train/Epoch: %d/%d, Learning rate decays to %s, " % ( | |
| epoch, config.max_epoch, str(scheduler.get_last_lr())) \ | |
| + last_time + 'eval_time: {:4.2f}, '.format(eval_time.avg) + '\n\n') | |
| except: | |
| traceback.print_exc() | |
| config.logger.error("Exception happened in training steps") | |
| if config.logger is not None: | |
| config.logger.info("Training finished") | |
| try: | |
| if config.logger is not None and best_metric is not None: | |
| new_folder_name = config.folder + '-fin-{:.4f}'.format(best_metric) | |
| new_work_dir = os.path.join(config.ckpt_dir, config.data_definition, new_folder_name) | |
| os.system('mv {} {}'.format(config.work_dir, new_work_dir)) | |
| except: | |
| traceback.print_exc() | |
| if world_size > 1: | |
| torch.distributed.destroy_process_group() | |