|
|
|
|
|
|
|
|
''' |
|
|
@File : main.py |
|
|
@Time : 2022/10/11 19:54:03 |
|
|
@Author : zzubqh |
|
|
@Version : 1.0 |
|
|
@Contact : baiqh@microport.com |
|
|
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA |
|
|
@Desc : None |
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import os |
|
|
os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3,4,5,6,7' |
|
|
|
|
|
from fvcore.common.config import CfgNode |
|
|
from configs.config import Config |
|
|
import torch |
|
|
from maskformer_train import MaskFormer |
|
|
from dataset.dataset import ADE200kDataset, NuImagesDataset |
|
|
from Segmentation import Segmentation |
|
|
|
|
|
if torch.cuda.device_count() > 1: |
|
|
torch.distributed.init_process_group(backend='nccl') |
|
|
|
|
|
def user_scattered_collate(batch): |
|
|
data = [item['images'] for item in batch] |
|
|
masks = [item['masks'] for item in batch] |
|
|
out = {'images': torch.cat(data, dim=0), 'masks': torch.cat(masks, dim=0)} |
|
|
return out |
|
|
|
|
|
def get_args(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--config', type=str, default='configs/maskformer_nuimages.yaml') |
|
|
parser.add_argument('--local_rank', type=int, default=0) |
|
|
parser.add_argument("--ngpus", default=1, type=int) |
|
|
parser.add_argument("--project_name", default='NuImages_swin_base_Seg', type=str) |
|
|
|
|
|
args = parser.parse_args() |
|
|
cfg_ake150 = Config.fromfile(args.config) |
|
|
|
|
|
cfg_base = CfgNode.load_yaml_with_base(args.config, allow_unsafe=True) |
|
|
cfg_base.update(cfg_ake150.__dict__.items()) |
|
|
|
|
|
cfg = cfg_base |
|
|
for k, v in args.__dict__.items(): |
|
|
cfg[k] = v |
|
|
|
|
|
cfg = Config(cfg) |
|
|
|
|
|
cfg.ngpus = torch.cuda.device_count() |
|
|
if torch.cuda.device_count() > 1: |
|
|
cfg.local_rank = torch.distributed.get_rank() |
|
|
torch.cuda.set_device(cfg.local_rank) |
|
|
return cfg |
|
|
|
|
|
|
|
|
def train_ade200k(): |
|
|
cfg = get_args() |
|
|
dataset_train = ADE200kDataset(cfg.DATASETS.TRAIN, cfg, dynamic_batchHW=True) |
|
|
if cfg.ngpus > 1: |
|
|
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, rank=cfg.local_rank) |
|
|
else: |
|
|
train_sampler = None |
|
|
loader_train = torch.utils.data.DataLoader( |
|
|
dataset_train, |
|
|
batch_size=cfg.TRAIN.BATCH_SIZE, |
|
|
shuffle=False if train_sampler is not None else True, |
|
|
collate_fn=dataset_train.collate_fn, |
|
|
num_workers=cfg.TRAIN.WORKERS, |
|
|
drop_last=True, |
|
|
pin_memory=True, |
|
|
sampler=train_sampler) |
|
|
|
|
|
dataset_eval = ADE200kDataset(cfg.DATASETS.VALID, cfg) |
|
|
loader_eval = torch.utils.data.DataLoader( |
|
|
dataset_eval, |
|
|
batch_size=1, |
|
|
shuffle=False, |
|
|
collate_fn=dataset_eval.collate_fn, |
|
|
num_workers=cfg.TRAIN.WORKERS) |
|
|
|
|
|
seg_model = MaskFormer(cfg) |
|
|
seg_model.train(train_sampler, loader_train, loader_eval, cfg.TRAIN.EPOCH) |
|
|
|
|
|
def train_nuimages(): |
|
|
cfg = get_args() |
|
|
dataset_train = NuImagesDataset(cfg.DATASETS.ROOT_DIR, cfg, version='v1.0-train') |
|
|
dataset_eval = NuImagesDataset(cfg.DATASETS.ROOT_DIR, cfg, version='v1.0-val') |
|
|
|
|
|
if cfg.ngpus > 1: |
|
|
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, rank=cfg.local_rank) |
|
|
eval_sampler = torch.utils.data.distributed.DistributedSampler(dataset_eval, rank=cfg.local_rank) |
|
|
else: |
|
|
train_sampler = None |
|
|
eval_sampler = None |
|
|
|
|
|
loader_train = torch.utils.data.DataLoader( |
|
|
dataset_train, |
|
|
batch_size=cfg.TRAIN.BATCH_SIZE, |
|
|
shuffle=False if train_sampler is not None else True, |
|
|
collate_fn=dataset_train.collate_fn, |
|
|
num_workers=cfg.TRAIN.WORKERS, |
|
|
drop_last=True, |
|
|
pin_memory=True, |
|
|
sampler=train_sampler) |
|
|
|
|
|
loader_eval = torch.utils.data.DataLoader( |
|
|
dataset_eval, |
|
|
batch_size=1, |
|
|
shuffle=False if eval_sampler is not None else True, |
|
|
collate_fn=dataset_eval.collate_fn, |
|
|
num_workers=cfg.TRAIN.WORKERS, |
|
|
drop_last=False, |
|
|
pin_memory=True, |
|
|
sampler=eval_sampler) |
|
|
|
|
|
seg_model = MaskFormer(cfg) |
|
|
seg_model.train(train_sampler, loader_train, loader_eval, cfg.TRAIN.EPOCH) |
|
|
|
|
|
def segmentation_test(): |
|
|
cfg = get_args() |
|
|
segmentation_handler = Segmentation(cfg) |
|
|
segmentation_handler.forward() |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
train_nuimages() |
|
|
|