import os import sys import logging import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from core.utils.utils import coords_grid, disparity_computation from core.utils.utils import LoggerCommon logger = LoggerCommon("LOSS") try: autocast = torch.cuda.amp.autocast except: # dummy autocast for PyTorch < 1.6 class autocast: def __init__(self, enabled): pass def __enter__(self): pass def __exit__(self, *args): pass def sequence_loss(flow_preds, flow_gt, valid, loss_gamma=0.9, max_flow=700): """ Loss function defined over sequence of flow predictions """ n_predictions = len(flow_preds) assert n_predictions >= 1 flow_loss = 0.0 # exlude invalid pixels and extremely large diplacements mag = torch.sum(flow_gt**2, dim=1).sqrt() # exclude extremly large displacements valid = ((valid >= 0.5) & (mag < max_flow)).unsqueeze(1) assert valid.shape == flow_gt.shape, [valid.shape, flow_gt.shape] assert not torch.isinf(flow_gt[valid.bool()]).any() for i in range(n_predictions): if not torch.isnan(flow_preds[i]).any() and not torch.isinf(flow_preds[i]).any(): # We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1)) i_weight = adjusted_loss_gamma**(n_predictions - i - 1) i_loss = (flow_preds[i] - flow_gt).abs() assert i_loss.shape == valid.shape, [i_loss.shape, valid.shape, flow_gt.shape, flow_preds[i].shape] flow_loss += i_weight * i_loss[valid.bool()].mean() epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt() epe = epe.view(-1)[valid.view(-1)] metrics = { 'epe': epe.mean().item(), '1px': (epe < 1).float().mean().item(), '3px': (epe < 3).float().mean().item(), '5px': (epe < 5).float().mean().item(), } return flow_loss, metrics def my_loss(res, flow_gt, valid, loss_gamma=0.9, max_flow=700): pass class Loss(nn.Module): def __init__(self, loss_gamma=0.9, max_flow=700, loss_zeta=0.3, smoothness=None, slant=None, slant_norm=False, ner_kernel_size=3, ner_weight_reduce=False, local_rank=None, mixed_precision=True, args=None): super(Loss, self).__init__() self.loss_gamma = loss_gamma self.loss_zeta = loss_zeta self.max_flow = max_flow self.smoothness = smoothness self.mixed_precision = mixed_precision self.conf_disp = args.conf_disp self.args = args if self.smoothness is not None and len(self.smoothness)>0: self.smooth_loss_computer = SmoothLoss(self.smoothness, slant=slant, slant_norm=slant_norm, kernel_size=ner_kernel_size, ner_weight_reduce=ner_weight_reduce) logger.info(f"smoothness: {smoothness}, " +\ f"slant: {slant}, slant_norm: {slant_norm}, " +\ f"ner_kernel_size: {ner_kernel_size}, " +\ f"ner_weight_reduce: {ner_weight_reduce}, " +\ f"conf_disp: {self.conf_disp}. " ) def forward(self, flow_preds, flow_gt, valid, disp_preds=None, disp_preds_refine=None, confidence_list=None, params_list=None, params_list_refine=None, plane_abc=None, imgL=None, imgR=None, global_batch_num=None,): """ Loss function defined over sequence of flow predictions """ n_predictions = len(flow_preds) assert n_predictions >= 1 flow_loss = 0.0 disp_loss = 0.0 disp_refine_loss = 0.0 smooth_loss = 0.0 confidence_loss = 0.0 params_loss = 0.0 params_refine_loss = 0.0 # exlude invalid pixels and extremely large diplacements mag = torch.sum(flow_gt**2, dim=1).sqrt() # exclude extremly large displacements valid = ((valid >= 0.5) & (mag < self.max_flow)).unsqueeze(1) assert valid.shape == flow_gt.shape, [valid.shape, flow_gt.shape] assert not torch.isinf(flow_gt[valid.bool()]).any() for i in range(n_predictions): assert not torch.isnan(flow_preds[i]).any() and not torch.isinf(flow_preds[i]).any() # We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations adjusted_loss_gamma = self.loss_gamma**(15/(n_predictions - 1)) i_weight = adjusted_loss_gamma**(n_predictions - i - 1) # confidence loss if confidence_list[i] is not None and \ (self.args.offset_memory_last_iter<0 or \ (self.args.offset_memory_last_iter>0 and i<=self.args.offset_memory_last_iter)): with autocast(enabled=self.mixed_precision): gt_error = (flow_preds[i].detach() - flow_gt).abs().detach() gt_error = F.interpolate(gt_error,scale_factor=1/4,mode='bilinear') # confidence_loss += i_weight * F.smooth_l1_loss(confidence_list[i], gt_error) # confidence_loss += i_weight * F.binary_cross_entropy_with_logits(confidence_list[i], # torch.sigmoid(gt_error-4)) gt_conf = (gt_error>4).float() weight = torch.pow(F.sigmoid(confidence_list[i])-gt_conf, 2) tmp_confidence_loss = (1+gt_conf*0.5) * weight *\ F.binary_cross_entropy_with_logits(confidence_list[i], gt_conf, reduction='none') confidence_loss += i_weight * tmp_confidence_loss.mean() # flow loss i_loss = (flow_preds[i] - flow_gt).abs() if self.conf_disp and global_batch_num>3 and confidence_list[i] is not None: weight = F.interpolate(confidence_list[i],scale_factor=4,mode='bilinear') i_loss = i_loss * (F.sigmoid(weight.detach()/3)*1.5 + 1) assert i_loss.shape == valid.shape, [i_loss.shape, valid.shape, flow_gt.shape, flow_preds[i].shape] flow_loss += i_weight * i_loss[valid.bool()].mean() # disparity loss if disp_preds is not None and len(disp_preds)>0 and disp_preds[i] is not None: i_loss = (-disp_preds[i] - flow_gt).abs() disp_loss += i_weight * i_loss[valid.bool()].mean() # plane loss if params_list is not None and len(params_list)>0 and plane_abc is not None and plane_abc.shape[1]==3: # print("~"*30, params_list[-1].shape, plane_abc.shape) i_loss = (params_list[i] - plane_abc).abs() params_loss += i_weight * 0.5 * i_loss.mean() # refinement loss if disp_preds_refine is not None and len(disp_preds_refine)>0 and disp_preds_refine[i] is not None: i_loss = (-disp_preds_refine[i] - flow_gt).abs() disp_refine_loss += i_weight * i_loss[valid.bool()].mean() # plane loss if params_list_refine is not None and len(params_list_refine)>0 and plane_abc is not None and plane_abc.shape[1]==3: # print("~"*30, params_list_refine[-1].shape, plane_abc.shape) i_loss = (params_list_refine[i] - plane_abc).abs() params_refine_loss += i_weight * 0.5 * i_loss.mean() if i>n_predictions//2: with autocast(enabled=self.mixed_precision): if self.smoothness=="gradient": smooth_loss += i_weight * self.smooth_loss_computer(flow_preds[i], imgL).mean() elif self.smoothness=="curvature": smooth_loss += i_weight * self.smooth_loss_computer(params_list[i], imgL).mean() epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt() epe = epe.view(-1)[valid.view(-1)] metrics = { 'epe': epe.mean().item(), '1px': (epe < 1).float().mean().item(), '3px': (epe < 3).float().mean().item(), '5px': (epe < 5).float().mean().item(), } if disp_preds is not None and len(disp_preds)>0 and disp_preds[-1] is not None: epe = torch.sum((-disp_preds[-1] - flow_gt)**2, dim=1).sqrt() epe = epe.view(-1)[valid.view(-1)] metrics.update({'epe_disp': epe.mean().item(), '3px_disp': (epe < 3).float().mean().item(),}) if disp_preds_refine is not None and len(disp_preds_refine)>0 and disp_preds_refine[-1] is not None: epe = torch.sum((-disp_preds_refine[-1] - flow_gt)**2, dim=1).sqrt() epe = epe.view(-1)[valid.view(-1)] metrics.update({'epe_disp_refine': epe.mean().item(), '3px_disp_refine': (epe < 3).float().mean().item(),}) if self.smoothness is not None and len(self.smoothness)>0: loss = flow_loss + disp_loss + params_loss + disp_refine_loss + params_refine_loss + confidence_loss + self.loss_zeta * smooth_loss else: loss = flow_loss + disp_loss + params_loss + disp_refine_loss + params_refine_loss + confidence_loss smooth_loss = torch.Tensor([0.0]).to(flow_loss.device) return loss, metrics, flow_loss, disp_loss, disp_refine_loss, confidence_loss, smooth_loss, params_loss, params_refine_loss class SmoothLoss(nn.Module): """Smooth constaint for prediction. - gradient-based smooth regularization: \psi_{pq} = max(w_{pq},\epsilon) min(\hat{\psi}_{pq}(f_p,f_q), \tau_{dis}) \\ w_{pq} = e^{-||I_L(p)-I_L(q)||_1 / \eta} \\ \hat{\psi}_{pq} = |d_p(f_p) - d_q(f_q)| \\ d_p(f_p) = a_p p_u + b_p p_v + c_p \\ d_q(f_q) = a_q q_u + b_q q_v + c_q - curvature-based smooth regularization: \psi_{pq} = max(w_{pq},\epsilon) min(\hat{\psi}_{pq}(f_p,f_q), \tau_{dis}) \\ w_{pq} = e^{-||I_L(p)-I_L(q)||_1 / \eta} \\ \hat{\psi}_{pq} = |d_p(f_p) - d_p(f_q)| + |d_q(f_q) - d_q(f_p)| \\ d_p(f_p) = a_p p_u + b_p p_v + c_p \\ d_p(f_q) = a_p q_u + b_p q_v + c_p """ def __init__(self, smoothness, slant=None, slant_norm=False, kernel_size=3, ner_weight_reduce=False, epsilon=0.01, tau=3, eta=10): super(SmoothLoss, self).__init__() self.smoothness = smoothness self.slant = slant self.slant_norm = slant_norm self.eta = eta self.tau = tau self.epsilon = epsilon self.reduce = ner_weight_reduce self.img_ner_extractor = NerghborExtractor(3, kernel_size, reduce=self.reduce) self.coord_ner_extractor = NerghborExtractor(2, kernel_size) self.params_ner_extractor = NerghborExtractor(3, kernel_size) def forward(self, params, imgL): """Function: compute smoothe loss args: params: (B,3,H,W) imgL: (B,3,H,W) coordL: (B,2,H,W) corrdR: (B,2,H,W) """ img_ner = self.img_ner_extractor(imgL) # B,3,N,H,W B, _, H, W = imgL.shape coord = coords_grid(B, H, W).to(imgL.device) # B,2,H,W coord_ner = self.coord_ner_extractor(coord) # B,2,N,H,W coord = coord.unsqueeze(2) # B,2,1,H,W params_ner = self.params_ner_extractor(params) # B,3,N,H,W params = params.unsqueeze(2) # B,3,1,H,W # w_{pq} = e^{-||I_L(p)-I_L(q)||_1 / \eta} if not self.reduce: weight = torch.exp(-torch.abs(img_ner-imgL.unsqueeze(2)).mean(dim=1) / self.eta) # B,N,H,W else: weight = torch.exp(-torch.abs(img_ner).mean(dim=1) / self.eta) # B,N,H,W if self.smoothness=="gradient": # \hat{\psi}_{pq} = |d_p(f_p) - d_q(f_q)| psi_p = disparity_computation(params, coords0=coord, slant=self.slant, slant_norm=self.slant_norm) - \ disparity_computation(params_ner, coords0=coord_ner, slant=self.slant, slant_norm=self.slant_norm) psi = torch.abs(psi_p) # B,N,H,W elif self.smoothness=="curvature": # |d_p(f_p) - d_p(f_q)| psi_p = disparity_computation(params, coords0=coord, slant=self.slant, slant_norm=self.slant_norm) - \ disparity_computation(params, coords0=coord_ner, slant=self.slant, slant_norm=self.slant_norm) # d_q(f_q) - d_q(f_p) psi_q = disparity_computation(params_ner, coords0=coord_ner, slant=self.slant, slant_norm=self.slant_norm) - \ disparity_computation(params_ner, coords0=coord, slant=self.slant, slant_norm=self.slant_norm) # \hat{\psi} = |d_p(f_p) - d_p(f_q)| + |d_q(f_q) - d_q(f_p)| psi = torch.abs(psi_p) + torch.abs(psi_q) # B,N,H,W # \psi_{pq} = max(w_{pq},\epsilon) min(\hat{\psi_{pq}(f_p,f_q)}, \tau_{dis}) smooth_loss = torch.clip(weight, min=self.epsilon,) * \ F.sigmoid(psi/self.tau*8-4) * self.tau smooth_loss = smooth_loss.mean() return smooth_loss def diamond(n): a = np.arange(n) b = np.minimum(a,a[::-1]) return (b[:,None]+b)>=(n-1)//2 def diamond_edge(n): arr = np.diagflat(np.ones(n//2+1), n//2) arr = np.maximum(arr,np.flip(arr,1)) return np.maximum(arr,np.flip(arr,0)) kernel_dict = {} kernel_dict["diamond"] = diamond kernel_dict["diamond_edge"] = diamond_edge class NerghborExtractor(nn.Module): """Extarct the neighbors of each pixel using depthwise convolution. Input: (B,C,H,W), Output: (B,C,N,H,W). """ def __init__(self, input_channel, kernel_size=3, reduce=False): super(NerghborExtractor, self).__init__() self.reduce = reduce self.input_channel = input_channel # build kernel matrix if isinstance(kernel_size, int): H, W = kernel_size, kernel_size self.neighbors_num = kernel_size*kernel_size neighbor_kernel = np.zeros((self.neighbors_num, H, W), dtype=np.float16) for idx in range(self.neighbors_num): neighbor_kernel[idx, idx//H, idx%W] = 1 elif isinstance(kernel_size, str): ## obatin the compressed kernel kernel_type, size = kernel_size.split("-") kernel_size = int(size) compressed_kernel = kernel_dict[kernel_type](kernel_size) ## decode the compressed kernel into a series of kernels H, W = compressed_kernel.shape self.neighbors_num = np.count_nonzero(compressed_kernel) neighbors_pos = np.nonzero(compressed_kernel) neighbor_kernel = np.zeros((self.neighbors_num, H, W), dtype=np.float16) for idx_k, (idx_h, idx_w) in enumerate(zip(neighbors_pos[0],neighbors_pos[1])): neighbor_kernel[idx_k, idx_h, idx_w] = compressed_kernel[idx_h, idx_w] else: raise Exception("kernel_size currently only supports integer") if self.reduce: neighbor_kernel[:, H//2, W//2] = -1 if not self.reduce: neighbor_kernel = np.tile(neighbor_kernel, (input_channel,1,1)) neighbor_kernel = neighbor_kernel[:,np.newaxis] # in*neighbors_num, 1, k, k output_channel = input_channel*self.neighbors_num groups = input_channel else: neighbor_kernel = np.tile(neighbor_kernel[:, np.newaxis], (1,input_channel,1,1)) # neighbors_num, in, k, k output_channel = self.neighbors_num groups = 1 # extract neighbors through depthwise conv self.conv = nn.Conv2d(input_channel, output_channel, kernel_size=kernel_size, padding=kernel_size//2, bias=False, groups=groups, padding_mode="reflect") neighbor_kernel = torch.Tensor(neighbor_kernel) self.conv.weight = nn.Parameter(neighbor_kernel, requires_grad=False) def forward(self, x): B,C,H,W = x.shape neighbors = self.conv(x) neighbors = neighbors.reshape((B,-1,self.neighbors_num,H,W)) if self.reduce: neighbors = neighbors / self.input_channel return neighbors