#!/usr/bin/env python # -*- encoding: utf-8 -*- ''' @File : main.py @Time : 2022/10/11 19:54:03 @Author : zzubqh @Version : 1.0 @Contact : baiqh@microport.com @License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA @Desc : None ''' # here put the import lib import argparse import os os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3,4,5,6,7' # from fvcore.common.config import CfgNode from configs.config import Config import torch from maskformer_train import MaskFormer from dataset.dataset import ADE200kDataset, NuImagesDataset from Segmentation import Segmentation if torch.cuda.device_count() > 1: torch.distributed.init_process_group(backend='nccl') def user_scattered_collate(batch): data = [item['images'] for item in batch] masks = [item['masks'] for item in batch] out = {'images': torch.cat(data, dim=0), 'masks': torch.cat(masks, dim=0)} return out def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='configs/maskformer_nuimages.yaml') parser.add_argument('--local_rank', type=int, default=0) parser.add_argument("--ngpus", default=1, type=int) parser.add_argument("--project_name", default='NuImages_swin_base_Seg', type=str) args = parser.parse_args() cfg_ake150 = Config.fromfile(args.config) cfg_base = CfgNode.load_yaml_with_base(args.config, allow_unsafe=True) cfg_base.update(cfg_ake150.__dict__.items()) cfg = cfg_base for k, v in args.__dict__.items(): cfg[k] = v cfg = Config(cfg) cfg.ngpus = torch.cuda.device_count() if torch.cuda.device_count() > 1: cfg.local_rank = torch.distributed.get_rank() torch.cuda.set_device(cfg.local_rank) return cfg def train_ade200k(): cfg = get_args() dataset_train = ADE200kDataset(cfg.DATASETS.TRAIN, cfg, dynamic_batchHW=True) if cfg.ngpus > 1: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, rank=cfg.local_rank) else: train_sampler = None loader_train = torch.utils.data.DataLoader( dataset_train, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=False if train_sampler is not None else True, collate_fn=dataset_train.collate_fn, num_workers=cfg.TRAIN.WORKERS, drop_last=True, pin_memory=True, sampler=train_sampler) dataset_eval = ADE200kDataset(cfg.DATASETS.VALID, cfg) loader_eval = torch.utils.data.DataLoader( dataset_eval, batch_size=1, shuffle=False, collate_fn=dataset_eval.collate_fn, num_workers=cfg.TRAIN.WORKERS) seg_model = MaskFormer(cfg) seg_model.train(train_sampler, loader_train, loader_eval, cfg.TRAIN.EPOCH) def train_nuimages(): cfg = get_args() dataset_train = NuImagesDataset(cfg.DATASETS.ROOT_DIR, cfg, version='v1.0-train') # v1.0-mini or v1.0-train dataset_eval = NuImagesDataset(cfg.DATASETS.ROOT_DIR, cfg, version='v1.0-val') # v1.0-mini or v1.0-val if cfg.ngpus > 1: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, rank=cfg.local_rank) eval_sampler = torch.utils.data.distributed.DistributedSampler(dataset_eval, rank=cfg.local_rank) else: train_sampler = None eval_sampler = None loader_train = torch.utils.data.DataLoader( dataset_train, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=False if train_sampler is not None else True, collate_fn=dataset_train.collate_fn, num_workers=cfg.TRAIN.WORKERS, drop_last=True, pin_memory=True, sampler=train_sampler) loader_eval = torch.utils.data.DataLoader( dataset_eval, batch_size=1, shuffle=False if eval_sampler is not None else True, collate_fn=dataset_eval.collate_fn, num_workers=cfg.TRAIN.WORKERS, drop_last=False, pin_memory=True, sampler=eval_sampler) seg_model = MaskFormer(cfg) seg_model.train(train_sampler, loader_train, loader_eval, cfg.TRAIN.EPOCH) def segmentation_test(): cfg = get_args() segmentation_handler = Segmentation(cfg) segmentation_handler.forward() if __name__ == '__main__': train_nuimages() # segmentation_test()