import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from core.update_disp import DispBasicMultiUpdateBlock from core.extractor import BasicEncoder, ResidualBlock from core.extractor_metric3d import Metric3DExtractor from core.corr import CorrBlock1D, PytorchAlternateCorrBlock1D, CorrBlockFast1D, AlternateCorrBlock from core.utils.utils import hor_coords_grid, rescale_modulation from core.geometry import LBPEncoder from core.fusion import BetaModulator, RefinementMonStereo try: autocast = torch.cuda.amp.autocast except: # dummy autocast for PyTorch < 1.6 class autocast: def __init__(self, enabled): pass def __enter__(self): pass def __exit__(self, *args): pass class RAFTStereoMetric3D(nn.Module): def __init__(self, args): super(RAFTStereoMetric3D, self).__init__() self.args = args context_dims = args.hidden_dims self.cnet = Metric3DExtractor(args) self.update_block = DispBasicMultiUpdateBlock(self.args, hidden_dims=args.hidden_dims) self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)]) self.refinement = RefinementMonStereo(args, hidden_dim=args.hidden_dims[-1]) if args.shared_backbone: self.conv2 = nn.Sequential( ResidualBlock(128, 128, 'instance', stride=1), nn.Conv2d(128, 256, 3, padding=1)) else: self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', downsample=args.n_downsample) # # 冻结 除refinement以外 模块的所有参数 # for module in [self.cnet, self.update_block, self.context_zqr_convs, self.fnet]: # for param in module.parameters(): # param.requires_grad = False def freeze_bn(self): for m in self.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() def initialize_disp(self, img): """ Disparity is represented as difference between two horizontal coordinate grids disp = hor_coords1 - hor_coords0""" N, _, H, W = img.shape hor_coords0 = hor_coords_grid(N, H, W).to(img.device) hor_coords1 = hor_coords_grid(N, H, W).to(img.device) return hor_coords0, hor_coords1 def upsample_disp(self, disp, mask): """ Upsample disp field [H/8, W/8, 1] -> [H, W, 1] using convex combination """ N, D, H, W = disp.shape factor = 2 ** self.args.n_downsample mask = mask.view(N, 1, 9, factor, factor, H, W) mask = torch.softmax(mask, dim=2) up_disp = F.unfold(factor * disp, [3,3], padding=1) up_disp = up_disp.view(N, D, 9, 1, 1, H, W) up_disp = torch.sum(mask * up_disp, dim=2) up_disp = up_disp.permute(0, 1, 4, 2, 5, 3) return up_disp.reshape(N, D, factor*H, factor*W) def forward(self, image1, image2, iters=12, disp_init=None, test_mode=False, vis_mode=False, intrinsic=None): """ Estimate optical flow between pair of frames """ image1 = (2 * (image1 / 255.0) - 1.0).contiguous() image2 = (2 * (image2 / 255.0) - 1.0).contiguous() # run the context network with autocast(enabled=self.args.mixed_precision): if self.args.shared_backbone: *cnet_list, x = self.cnet(torch.cat((image1, image2), dim=0), dual_inp=True, num_layers=self.args.n_gru_layers) fmap1, fmap2 = self.conv2(x).split(dim=0, split_size=x.shape[0]//2) else: # cnet_list: [[(128,248,360), (128,248,360)], [(128,124,180),(128,124,180)], [(128,62,90),(128,62,90)]] net_list, inp_list, depth = self.cnet(image1, intrinsic) # fmap1: (128,248,360), fmap2: (128,248,360) fmap1, fmap2 = self.fnet([image1, image2]) # net_list = [torch.tanh(x[0]) for x in cnet_list] # inp_list = [torch.relu(x[1]) for x in cnet_list] # # Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning # inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)] if self.args.corr_implementation == "reg": # Default corr_block = CorrBlock1D fmap1, fmap2 = fmap1.float(), fmap2.float() elif self.args.corr_implementation == "alt": # More memory efficient than reg corr_block = PytorchAlternateCorrBlock1D fmap1, fmap2 = fmap1.float(), fmap2.float() elif self.args.corr_implementation == "reg_cuda": # Faster version of reg corr_block = CorrBlockFast1D elif self.args.corr_implementation == "alt_cuda": # Faster version of alt corr_block = AlternateCorrBlock corr_fn = corr_block(fmap1, fmap2, radius=self.args.corr_radius, num_levels=self.args.corr_levels) hor_coords0, hor_coords1 = self.initialize_disp(net_list[0]) if disp_init is not None: hor_coords1 = hor_coords1 + disp_init disp_predictions = [] for itr in range(iters): hor_coords1 = hor_coords1.detach() corr = corr_fn(hor_coords1) # index correlation volume disp = hor_coords1 - hor_coords0 with autocast(enabled=self.args.mixed_precision): if self.args.n_gru_layers == 3 and self.args.slow_fast_gru: # Update low-res GRU net_list = self.update_block(net_list, inp_list, iter32=True, iter16=False, iter08=False, update=False) if self.args.n_gru_layers >= 2 and self.args.slow_fast_gru:# Update low-res GRU and mid-res GRU net_list = self.update_block(net_list, inp_list, iter32=self.args.n_gru_layers==3, iter16=True, iter08=False, update=False) net_list, up_mask, delta_disp = self.update_block(net_list, inp_list, corr, disp, iter32=self.args.n_gru_layers==3, iter16=self.args.n_gru_layers>=2) # F(t+1) = F(t) + \Delta(t) hor_coords1 = hor_coords1 + delta_disp # We do not need to upsample or output intermediate results in test_mode if test_mode and itr < iters-1: continue # upsample predictions disp_up = self.upsample_disp(hor_coords1 - hor_coords0, up_mask) disp_predictions.append(disp_up) # refinement corr = corr_fn(hor_coords1) disp = -hor_coords1 + hor_coords0 disp_refine, up_mask, depth_registered, conf = self.refinement(disp, depth, net_list[0], corr) disp_up = self.upsample_disp(-disp_refine, up_mask) depth_registered_up = self.upsample_disp(-depth_registered, up_mask) disp_predictions.append(depth_registered_up) disp_predictions.append(disp_up) if test_mode: return hor_coords1 - hor_coords0, disp_up # if test_mode: # return hor_coords1 - hor_coords0, depth_registered_up if vis_mode: return {"disp_predictions": disp_predictions, "depth": depth, } return {"disp_predictions": disp_predictions, "conf": conf}