|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import random |
|
|
import logging |
|
|
import subprocess |
|
|
import numpy as np |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, |
|
|
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',) |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.distributed as dist |
|
|
import torch.multiprocessing as mp |
|
|
from torch.nn import DataParallel as DP |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
from torch.utils.data.distributed import DistributedSampler as DS |
|
|
|
|
|
from core.raft_stereo import RAFTStereo |
|
|
from core.raft_stereo_disp import RAFTStereoDisp |
|
|
from core.raft_stereo_mast3r import RAFTStereoMast3r |
|
|
from core.raft_stereo_depthany import RAFTStereoDepthAny |
|
|
from core.raft_stereo_noctx import RAFTStereoNoCTX |
|
|
from core.raft_stereo_depthfusion import RAFTStereoDepthFusion |
|
|
from core.raft_stereo_depthbeta import RAFTStereoDepthBeta |
|
|
from core.raft_stereo_depthbeta_nolbp import RAFTStereoDepthBetaNoLBP |
|
|
from core.raft_stereo_depthmatch import RAFTStereoDepthMatch |
|
|
from core.raft_stereo_depthbeta_refine import RAFTStereoDepthBetaRefine |
|
|
from core.raft_stereo_depth_postfusion import RAFTStereoDepthPostFusion |
|
|
from core.raft_stereo_metric3d import RAFTStereoMetric3D |
|
|
|
|
|
|
|
|
def setup_distributed(args): |
|
|
args.rank = int(os.getenv("RANK")) |
|
|
args.local_rank = int(os.getenv("LOCAL_RANK")) |
|
|
args.world_size = int(os.getenv("WORLD_SIZE")) |
|
|
|
|
|
|
|
|
dist.init_process_group(backend='nccl') |
|
|
torch.cuda.set_device(args.local_rank) |
|
|
torch.set_printoptions(precision=10) |
|
|
|
|
|
def ddp_init(args): |
|
|
if 'WORLD_SIZE' in os.environ: |
|
|
args.distributed = int(os.environ['WORLD_SIZE']) >= 1 |
|
|
|
|
|
if args.distributed: |
|
|
setup_distributed(args) |
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = False |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.manual_seed(args.local_rank) |
|
|
np.random.seed(args.local_rank) |
|
|
random.seed(args.local_rank) |
|
|
|
|
|
print("complete initialization of local_rank:{}".format(args.local_rank)) |
|
|
|
|
|
def ddp_close(): |
|
|
dist.destroy_process_group() |
|
|
|
|
|
def to_python_float(t): |
|
|
if hasattr(t, 'item'): |
|
|
return t.item() |
|
|
else: |
|
|
return t[0] |
|
|
|
|
|
def reduce_tensor(tensor, world_size): |
|
|
rt = tensor.clone() |
|
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
|
|
rt /= world_size |
|
|
return rt |
|
|
|
|
|
def reduce_tensors(*tensors, world_size): |
|
|
return [reduce_tensor(tensor, world_size) for tensor in tensors] |
|
|
|
|
|
|
|
|
def get_loader(dataset, args): |
|
|
""" |
|
|
create dataset from ground-truth |
|
|
return a batch sampler based ont the dataset |
|
|
""" |
|
|
if args.distributed: |
|
|
if args.local_rank == 0: |
|
|
print('use distributed sampler') |
|
|
data_sampler = DS(dataset, shuffle=True, drop_last=True) |
|
|
data_loader = DataLoader(dataset, |
|
|
batch_size=args.batch_size, |
|
|
sampler=data_sampler, |
|
|
num_workers=args.num_workers, |
|
|
pin_memory=True, |
|
|
persistent_workers=True) |
|
|
else: |
|
|
if args.local_rank == 0: |
|
|
print("use default sampler") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_loader = DataLoader(train_dataset, |
|
|
batch_size=args.batch_size, |
|
|
pin_memory=True, shuffle=True, |
|
|
num_workers=args.num_workers, |
|
|
drop_last=True) |
|
|
return data_loader |
|
|
|
|
|
NODE_RANK = os.getenv('NODE_RANK', default=0) |
|
|
LOCAL_RANK = os.getenv("LOCAL_RANK", default=0) |
|
|
|
|
|
def get_model_ddp(args): |
|
|
if args.model_name.lower() == "raftstereo": |
|
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(RAFTStereo(args)) |
|
|
elif args.model_name.lower() == "raftstereodisp": |
|
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(RAFTStereoDisp(args)) |
|
|
elif args.model_name.lower() == "raftstereomast3r": |
|
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(RAFTStereoMast3r(args)) |
|
|
elif args.model_name.lower() == "raftstereodepthany": |
|
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(RAFTStereoDepthAny(args)) |
|
|
elif args.model_name.lower() == "raftstereodepthfusion": |
|
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(RAFTStereoDepthFusion(args)) |
|
|
elif args.model_name.lower() == "RAFTStereoDepthBeta".lower(): |
|
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(RAFTStereoDepthBeta(args)) |
|
|
elif args.model_name.lower() == "RAFTStereoDepthBetaNoLBP".lower(): |
|
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(RAFTStereoDepthBetaNoLBP(args)) |
|
|
elif args.model_name.lower() == "RAFTStereoDepthMatch".lower(): |
|
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(RAFTStereoDepthMatch(args)) |
|
|
elif args.model_name.lower() == "RAFTStereoDepthBetaRefine".lower(): |
|
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(RAFTStereoDepthBetaRefine(args)) |
|
|
elif args.model_name.lower() == "RAFTStereoDepthPostFusion".lower(): |
|
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(RAFTStereoDepthPostFusion(args)) |
|
|
elif args.model_name.lower() == "RAFTStereoMetric3D".lower(): |
|
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(RAFTStereoMetric3D(args)) |
|
|
else : |
|
|
raise Exception("No such model: {}".format(args.model_name)) |
|
|
|
|
|
device = torch.device("cuda", args.local_rank) |
|
|
model = model.to(device) |
|
|
|
|
|
if args.restore_ckpt is not None: |
|
|
assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(".tar") |
|
|
if args.local_rank==0 : |
|
|
logging.info("Loading checkpoint from {} ...".format(args.restore_ckpt)) |
|
|
checkpoint = torch.load(args.restore_ckpt) |
|
|
new_state_dict = {} |
|
|
for key, value in checkpoint.items(): |
|
|
new_key = key.replace('module.', '') |
|
|
|
|
|
|
|
|
new_state_dict[new_key] = value |
|
|
|
|
|
model.load_state_dict(new_state_dict, strict=False) |
|
|
if args.local_rank==0 : |
|
|
logging.info(f"Done loading checkpoint") |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
if args.distributed: |
|
|
model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, |
|
|
find_unused_parameters=True) |
|
|
else: |
|
|
model = DP(RAFTStereo(args)) |
|
|
return model |
|
|
|