| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class AWingLoss(nn.Module): | |
| def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1, use_weight_map=True): | |
| super(AWingLoss, self).__init__() | |
| self.omega = omega | |
| self.theta = theta | |
| self.epsilon = epsilon | |
| self.alpha = alpha | |
| self.use_weight_map = use_weight_map | |
| def __repr__(self): | |
| return "AWingLoss()" | |
| def generate_weight_map(self, heatmap, k_size=3, w=10): | |
| dilate = F.max_pool2d(heatmap, kernel_size=k_size, stride=1, padding=1) | |
| weight_map = torch.where(dilate < 0.2, torch.zeros_like(heatmap), torch.ones_like(heatmap)) | |
| return w * weight_map + 1 | |
| def forward(self, output, groundtruth): | |
| """ | |
| input: b x n x h x w | |
| output: b x n x h x w => 1 | |
| """ | |
| delta = (output - groundtruth).abs() | |
| A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - groundtruth))) * (self.alpha - groundtruth) * \ | |
| (torch.pow(self.theta / self.epsilon, self.alpha - groundtruth - 1)) * (1 / self.epsilon) | |
| C = self.theta * A - self.omega * \ | |
| torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - groundtruth)) | |
| loss = torch.where(delta < self.theta, | |
| self.omega * torch.log(1 + torch.pow(delta / self.epsilon, self.alpha - groundtruth)), | |
| (A * delta - C)) | |
| if self.use_weight_map: | |
| weight = self.generate_weight_map(groundtruth) | |
| loss = loss * weight | |
| return loss.mean() | |