Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.autograd import Variable | |
| from utils.box_utils import match, log_sum_exp | |
| from data import cfg_mnet | |
| GPU = cfg_mnet['gpu_train'] | |
| class MultiBoxLoss(nn.Module): | |
| """SSD Weighted Loss Function | |
| Compute Targets: | |
| 1) Produce Confidence Target Indices by matching ground truth boxes | |
| with (default) 'priorboxes' that have jaccard index > threshold parameter | |
| (default threshold: 0.5). | |
| 2) Produce localization target by 'encoding' variance into offsets of ground | |
| truth boxes and their matched 'priorboxes'. | |
| 3) Hard negative mining to filter the excessive number of negative examples | |
| that comes with using a large number of default bounding boxes. | |
| (default negative:positive ratio 3:1) | |
| Objective Loss: | |
| L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N | |
| Where, Lconf==the CrossEntropy Loss and Lloc==the SmoothL1 Loss | |
| weighted by α which==set to 1 by cross val. | |
| Args: | |
| c: class confidences, | |
| l: predicted boxes, | |
| g: ground truth boxes | |
| N: number of matched default boxes | |
| See: https://arxiv.org/pdf/1512.02325.pdf for more details. | |
| """ | |
| def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target): | |
| super(MultiBoxLoss, self).__init__() | |
| self.num_classes = num_classes | |
| self.threshold = overlap_thresh | |
| self.background_label = bkg_label | |
| self.encode_target = encode_target | |
| self.use_prior_for_matching = prior_for_matching | |
| self.do_neg_mining = neg_mining | |
| self.negpos_ratio = neg_pos | |
| self.neg_overlap = neg_overlap | |
| self.variance = [0.1, 0.2] | |
| def forward(self, predictions, priors, targets): | |
| """Multibox Loss | |
| Args: | |
| predictions (tuple): A tuple containing loc preds, conf preds, | |
| and prior boxes from SSD net. | |
| conf shape: torch.size(batch_size,num_priors,num_classes) | |
| loc shape: torch.size(batch_size,num_priors,4) | |
| priors shape: torch.size(num_priors,4) | |
| ground_truth (tensor): Ground truth boxes and labels for a batch, | |
| shape: [batch_size,num_objs,5] (last idx==the label). | |
| """ | |
| loc_data, conf_data, landm_data = predictions | |
| priors = priors | |
| num = loc_data.size(0) | |
| num_priors = (priors.size(0)) | |
| # match priors (default boxes) and ground truth boxes | |
| loc_t = torch.Tensor(num, num_priors, 4) | |
| landm_t = torch.Tensor(num, num_priors, 10) | |
| conf_t = torch.LongTensor(num, num_priors) | |
| for idx in range(num): | |
| truths = targets[idx][:, :4].data | |
| labels = targets[idx][:, -1].data | |
| landms = targets[idx][:, 4:14].data | |
| defaults = priors.data | |
| match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx) | |
| if GPU: | |
| loc_t = loc_t.cuda() | |
| conf_t = conf_t.cuda() | |
| landm_t = landm_t.cuda() | |
| zeros = torch.tensor(0).cuda() | |
| # landm Loss (Smooth L1) | |
| # Shape: [batch,num_priors,10] | |
| pos1 = conf_t > zeros | |
| num_pos_landm = pos1.long().sum(1, keepdim=True) | |
| N1 = max(num_pos_landm.data.sum().float(), 1) | |
| pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data) | |
| landm_p = landm_data[pos_idx1].view(-1, 10) | |
| landm_t = landm_t[pos_idx1].view(-1, 10) | |
| loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum') | |
| pos = conf_t != zeros | |
| conf_t[pos] = 1 | |
| # Localization Loss (Smooth L1) | |
| # Shape: [batch,num_priors,4] | |
| pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) | |
| loc_p = loc_data[pos_idx].view(-1, 4) | |
| loc_t = loc_t[pos_idx].view(-1, 4) | |
| loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') | |
| # Compute max conf across batch for hard negative mining | |
| batch_conf = conf_data.view(-1, self.num_classes) | |
| loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) | |
| # Hard Negative Mining | |
| loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now | |
| loss_c = loss_c.view(num, -1) | |
| _, loss_idx = loss_c.sort(1, descending=True) | |
| _, idx_rank = loss_idx.sort(1) | |
| num_pos = pos.long().sum(1, keepdim=True) | |
| num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) | |
| neg = idx_rank < num_neg.expand_as(idx_rank) | |
| # Confidence Loss Including Positive and Negative Examples | |
| pos_idx = pos.unsqueeze(2).expand_as(conf_data) | |
| neg_idx = neg.unsqueeze(2).expand_as(conf_data) | |
| conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes) | |
| targets_weighted = conf_t[(pos+neg).gt(0)] | |
| loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum') | |
| # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N | |
| N = max(num_pos.data.sum().float(), 1) | |
| loss_l /= N | |
| loss_c /= N | |
| loss_landm /= N1 | |
| return loss_l, loss_c, loss_landm | |