File size: 4,348 Bytes
36c1e62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File : main.py
@Time : 2022/10/11 19:54:03
@Author : zzubqh
@Version : 1.0
@Contact : baiqh@microport.com
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
@Desc : None
'''
# here put the import lib
import argparse
import os
os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3,4,5,6,7' #
from fvcore.common.config import CfgNode
from configs.config import Config
import torch
from maskformer_train import MaskFormer
from dataset.dataset import ADE200kDataset, NuImagesDataset
from Segmentation import Segmentation
if torch.cuda.device_count() > 1:
torch.distributed.init_process_group(backend='nccl')
def user_scattered_collate(batch):
data = [item['images'] for item in batch]
masks = [item['masks'] for item in batch]
out = {'images': torch.cat(data, dim=0), 'masks': torch.cat(masks, dim=0)}
return out
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='configs/maskformer_nuimages.yaml')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument("--ngpus", default=1, type=int)
parser.add_argument("--project_name", default='NuImages_swin_base_Seg', type=str)
args = parser.parse_args()
cfg_ake150 = Config.fromfile(args.config)
cfg_base = CfgNode.load_yaml_with_base(args.config, allow_unsafe=True)
cfg_base.update(cfg_ake150.__dict__.items())
cfg = cfg_base
for k, v in args.__dict__.items():
cfg[k] = v
cfg = Config(cfg)
cfg.ngpus = torch.cuda.device_count()
if torch.cuda.device_count() > 1:
cfg.local_rank = torch.distributed.get_rank()
torch.cuda.set_device(cfg.local_rank)
return cfg
def train_ade200k():
cfg = get_args()
dataset_train = ADE200kDataset(cfg.DATASETS.TRAIN, cfg, dynamic_batchHW=True)
if cfg.ngpus > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, rank=cfg.local_rank)
else:
train_sampler = None
loader_train = torch.utils.data.DataLoader(
dataset_train,
batch_size=cfg.TRAIN.BATCH_SIZE,
shuffle=False if train_sampler is not None else True,
collate_fn=dataset_train.collate_fn,
num_workers=cfg.TRAIN.WORKERS,
drop_last=True,
pin_memory=True,
sampler=train_sampler)
dataset_eval = ADE200kDataset(cfg.DATASETS.VALID, cfg)
loader_eval = torch.utils.data.DataLoader(
dataset_eval,
batch_size=1,
shuffle=False,
collate_fn=dataset_eval.collate_fn,
num_workers=cfg.TRAIN.WORKERS)
seg_model = MaskFormer(cfg)
seg_model.train(train_sampler, loader_train, loader_eval, cfg.TRAIN.EPOCH)
def train_nuimages():
cfg = get_args()
dataset_train = NuImagesDataset(cfg.DATASETS.ROOT_DIR, cfg, version='v1.0-train') # v1.0-mini or v1.0-train
dataset_eval = NuImagesDataset(cfg.DATASETS.ROOT_DIR, cfg, version='v1.0-val') # v1.0-mini or v1.0-val
if cfg.ngpus > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, rank=cfg.local_rank)
eval_sampler = torch.utils.data.distributed.DistributedSampler(dataset_eval, rank=cfg.local_rank)
else:
train_sampler = None
eval_sampler = None
loader_train = torch.utils.data.DataLoader(
dataset_train,
batch_size=cfg.TRAIN.BATCH_SIZE,
shuffle=False if train_sampler is not None else True,
collate_fn=dataset_train.collate_fn,
num_workers=cfg.TRAIN.WORKERS,
drop_last=True,
pin_memory=True,
sampler=train_sampler)
loader_eval = torch.utils.data.DataLoader(
dataset_eval,
batch_size=1,
shuffle=False if eval_sampler is not None else True,
collate_fn=dataset_eval.collate_fn,
num_workers=cfg.TRAIN.WORKERS,
drop_last=False,
pin_memory=True,
sampler=eval_sampler)
seg_model = MaskFormer(cfg)
seg_model.train(train_sampler, loader_train, loader_eval, cfg.TRAIN.EPOCH)
def segmentation_test():
cfg = get_args()
segmentation_handler = Segmentation(cfg)
segmentation_handler.forward()
if __name__ == '__main__':
train_nuimages()
# segmentation_test() |