Spaces:
Running
Running
| """ | |
| Implementation of Content loss, Style loss, LPIPS and DISTS metrics | |
| References: | |
| .. [1] Gatys, Leon and Ecker, Alexander and Bethge, Matthias | |
| (2016). A Neural Algorithm of Artistic Style} | |
| Association for Research in Vision and Ophthalmology (ARVO) | |
| https://arxiv.org/abs/1508.06576 | |
| .. [2] Zhang, Richard and Isola, Phillip and Efros, et al. | |
| (2018) The Unreasonable Effectiveness of Deep Features as a Perceptual Metric | |
| 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition | |
| https://arxiv.org/abs/1801.03924 | |
| """ | |
| from typing import List, Union, Collection | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn.modules.loss import _Loss | |
| from torchvision.models import vgg16, vgg19, VGG16_Weights, VGG19_Weights | |
| from .utils import _validate_input, _reduce | |
| from .functional import similarity_map, L2Pool2d | |
| # Map VGG names to corresponding number in torchvision layer | |
| VGG16_LAYERS = { | |
| "conv1_1": '0', "relu1_1": '1', | |
| "conv1_2": '2', "relu1_2": '3', | |
| "pool1": '4', | |
| "conv2_1": '5', "relu2_1": '6', | |
| "conv2_2": '7', "relu2_2": '8', | |
| "pool2": '9', | |
| "conv3_1": '10', "relu3_1": '11', | |
| "conv3_2": '12', "relu3_2": '13', | |
| "conv3_3": '14', "relu3_3": '15', | |
| "pool3": '16', | |
| "conv4_1": '17', "relu4_1": '18', | |
| "conv4_2": '19', "relu4_2": '20', | |
| "conv4_3": '21', "relu4_3": '22', | |
| "pool4": '23', | |
| "conv5_1": '24', "relu5_1": '25', | |
| "conv5_2": '26', "relu5_2": '27', | |
| "conv5_3": '28', "relu5_3": '29', | |
| "pool5": '30', | |
| } | |
| VGG19_LAYERS = { | |
| "conv1_1": '0', "relu1_1": '1', | |
| "conv1_2": '2', "relu1_2": '3', | |
| "pool1": '4', | |
| "conv2_1": '5', "relu2_1": '6', | |
| "conv2_2": '7', "relu2_2": '8', | |
| "pool2": '9', | |
| "conv3_1": '10', "relu3_1": '11', | |
| "conv3_2": '12', "relu3_2": '13', | |
| "conv3_3": '14', "relu3_3": '15', | |
| "conv3_4": '16', "relu3_4": '17', | |
| "pool3": '18', | |
| "conv4_1": '19', "relu4_1": '20', | |
| "conv4_2": '21', "relu4_2": '22', | |
| "conv4_3": '23', "relu4_3": '24', | |
| "conv4_4": '25', "relu4_4": '26', | |
| "pool4": '27', | |
| "conv5_1": '28', "relu5_1": '29', | |
| "conv5_2": '30', "relu5_2": '31', | |
| "conv5_3": '32', "relu5_3": '33', | |
| "conv5_4": '34', "relu5_4": '35', | |
| "pool5": '36', | |
| } | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| # Constant used in feature normalization to avoid zero division | |
| EPS = 1e-10 | |
| class ContentLoss(_Loss): | |
| r"""Creates Content loss that can be used for image style transfer or as a measure for image to image tasks. | |
| Uses pretrained VGG models from torchvision. | |
| Expects input to be in range [0, 1] or normalized with ImageNet statistics into range [-1, 1] | |
| Args: | |
| feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``. | |
| layers: List of strings with layer names. Default: ``'relu3_3'`` | |
| weights: List of float weight to balance different layers | |
| replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details. | |
| distance: Method to compute distance between features: ``'mse'`` | ``'mae'``. | |
| reduction: Specifies the reduction type: | |
| ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` | |
| mean: List of float values used for data standardization. Default: ImageNet mean. | |
| If there is no need to normalize data, use [0., 0., 0.]. | |
| std: List of float values used for data standardization. Default: ImageNet std. | |
| If there is no need to normalize data, use [1., 1., 1.]. | |
| normalize_features: If true, unit-normalize each feature in channel dimension before scaling | |
| and computing distance. See references for details. | |
| Examples: | |
| >>> loss = ContentLoss() | |
| >>> x = torch.rand(3, 3, 256, 256, requires_grad=True) | |
| >>> y = torch.rand(3, 3, 256, 256) | |
| >>> output = loss(x, y) | |
| >>> output.backward() | |
| References: | |
| Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016). | |
| A Neural Algorithm of Artistic Style | |
| Association for Research in Vision and Ophthalmology (ARVO) | |
| https://arxiv.org/abs/1508.06576 | |
| Zhang, Richard and Isola, Phillip and Efros, et al. (2018) | |
| The Unreasonable Effectiveness of Deep Features as a Perceptual Metric | |
| IEEE/CVF Conference on Computer Vision and Pattern Recognition | |
| https://arxiv.org/abs/1801.03924 | |
| """ | |
| def __init__(self, feature_extractor: Union[str, torch.nn.Module] = "vgg16", layers: Collection[str] = ("relu3_3",), | |
| weights: List[Union[float, torch.Tensor]] = [1.], replace_pooling: bool = False, | |
| distance: str = "mse", reduction: str = "mean", mean: List[float] = IMAGENET_MEAN, | |
| std: List[float] = IMAGENET_STD, normalize_features: bool = False, | |
| allow_layers_weights_mismatch: bool = False) -> None: | |
| assert allow_layers_weights_mismatch or len(layers) == len(weights), \ | |
| f'Lengths of provided layers and weighs mismatch ({len(weights)} weights and {len(layers)} layers), ' \ | |
| f'which will cause incorrect results. Please provide weight for each layer.' | |
| super().__init__() | |
| if callable(feature_extractor): | |
| self.model = feature_extractor | |
| self.layers = layers | |
| else: | |
| if feature_extractor == "vgg16": | |
| # self.model = vgg16(pretrained=True, progress=False).features | |
| self.model = vgg16(weights=VGG16_Weights.DEFAULT, progress=False).features | |
| self.layers = [VGG16_LAYERS[l] for l in layers] | |
| elif feature_extractor == "vgg19": | |
| # self.model = vgg19(pretrained=True, progress=False).features | |
| self.model = vgg19(weights=VGG19_Weights.DEFAULT, progress=False).features | |
| self.layers = [VGG19_LAYERS[l] for l in layers] | |
| else: | |
| raise ValueError("Unknown feature extractor") | |
| if replace_pooling: | |
| self.model = self.replace_pooling(self.model) | |
| # Disable gradients | |
| for param in self.model.parameters(): | |
| param.requires_grad_(False) | |
| self.distance = { | |
| "mse": nn.MSELoss, | |
| "mae": nn.L1Loss, | |
| }[distance](reduction='none') | |
| self.weights = [torch.tensor(w) if not isinstance(w, torch.Tensor) else w for w in weights] | |
| mean = torch.tensor(mean) | |
| std = torch.tensor(std) | |
| self.mean = mean.view(1, -1, 1, 1) | |
| self.std = std.view(1, -1, 1, 1) | |
| self.normalize_features = normalize_features | |
| self.reduction = reduction | |
| def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
| r"""Computation of Content loss between feature representations of prediction :math:`x` and | |
| target :math:`y` tensors. | |
| Args: | |
| x: An input tensor. Shape :math:`(N, C, H, W)`. | |
| y: A target tensor. Shape :math:`(N, C, H, W)`. | |
| Returns: | |
| Content loss between feature representations | |
| """ | |
| _validate_input([x, y], dim_range=(4, 4), data_range=(0, -1)) | |
| self.model.to(x) | |
| x_features = self.get_features(x) | |
| y_features = self.get_features(y) | |
| distances = self.compute_distance(x_features, y_features) | |
| # Scale distances, then average in spatial dimensions, then stack and sum in channels dimension | |
| loss = torch.cat([(d * w.to(d)).mean(dim=[2, 3]) for d, w in zip(distances, self.weights)], dim=1).sum(dim=1) | |
| return _reduce(loss, self.reduction) | |
| def compute_distance(self, x_features: List[torch.Tensor], y_features: List[torch.Tensor]) -> List[torch.Tensor]: | |
| r"""Take L2 or L1 distance between feature maps depending on ``distance``. | |
| Args: | |
| x_features: Features of the input tensor. | |
| y_features: Features of the target tensor. | |
| Returns: | |
| Distance between feature maps | |
| """ | |
| return [self.distance(x, y) for x, y in zip(x_features, y_features)] | |
| def get_features(self, x: torch.Tensor) -> List[torch.Tensor]: | |
| r""" | |
| Args: | |
| x: Tensor. Shape :math:`(N, C, H, W)`. | |
| Returns: | |
| List of features extracted from intermediate layers | |
| """ | |
| # Normalize input | |
| x = (x - self.mean.to(x)) / self.std.to(x) | |
| features = [] | |
| for name, module in self.model._modules.items(): | |
| x = module(x) | |
| if name in self.layers: | |
| features.append(self.normalize(x) if self.normalize_features else x) | |
| return features | |
| def normalize(x: torch.Tensor) -> torch.Tensor: | |
| r"""Normalize feature maps in channel direction to unit length. | |
| Args: | |
| x: Tensor. Shape :math:`(N, C, H, W)`. | |
| Returns: | |
| Normalized input | |
| """ | |
| norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) | |
| return x / (norm_factor + EPS) | |
| def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module: | |
| r"""Turn All MaxPool layers into AveragePool | |
| Args: | |
| module: Module to change MaxPool int AveragePool | |
| Returns: | |
| Module with AveragePool instead MaxPool | |
| """ | |
| module_output = module | |
| if isinstance(module, torch.nn.MaxPool2d): | |
| module_output = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | |
| for name, child in module.named_children(): | |
| module_output.add_module(name, self.replace_pooling(child)) | |
| return module_output | |
| class StyleLoss(ContentLoss): | |
| r"""Creates Style loss that can be used for image style transfer or as a measure in | |
| image to image tasks. Computes distance between Gram matrices of feature maps. | |
| Uses pretrained VGG models from torchvision. | |
| By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1]. | |
| If no normalisation is required, change `mean` and `std` values accordingly. | |
| Args: | |
| feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``. | |
| layers: List of strings with layer names. Default: ``'relu3_3'`` | |
| weights: List of float weight to balance different layers | |
| replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details. | |
| distance: Method to compute distance between features: ``'mse'`` | ``'mae'``. | |
| reduction: Specifies the reduction type: | |
| ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` | |
| mean: List of float values used for data standardization. Default: ImageNet mean. | |
| If there is no need to normalize data, use [0., 0., 0.]. | |
| std: List of float values used for data standardization. Default: ImageNet std. | |
| If there is no need to normalize data, use [1., 1., 1.]. | |
| normalize_features: If true, unit-normalize each feature in channel dimension before scaling | |
| and computing distance. See references for details. | |
| Examples: | |
| >>> loss = StyleLoss() | |
| >>> x = torch.rand(3, 3, 256, 256, requires_grad=True) | |
| >>> y = torch.rand(3, 3, 256, 256) | |
| >>> output = loss(x, y) | |
| >>> output.backward() | |
| References: | |
| Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016). | |
| A Neural Algorithm of Artistic Style | |
| Association for Research in Vision and Ophthalmology (ARVO) | |
| https://arxiv.org/abs/1508.06576 | |
| Zhang, Richard and Isola, Phillip and Efros, et al. (2018) | |
| The Unreasonable Effectiveness of Deep Features as a Perceptual Metric | |
| IEEE/CVF Conference on Computer Vision and Pattern Recognition | |
| https://arxiv.org/abs/1801.03924 | |
| """ | |
| def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor): | |
| r"""Take L2 or L1 distance between Gram matrices of feature maps depending on ``distance``. | |
| Args: | |
| x_features: Features of the input tensor. | |
| y_features: Features of the target tensor. | |
| Returns: | |
| Distance between Gram matrices | |
| """ | |
| x_gram = [self.gram_matrix(x) for x in x_features] | |
| y_gram = [self.gram_matrix(x) for x in y_features] | |
| return [self.distance(x, y) for x, y in zip(x_gram, y_gram)] | |
| def gram_matrix(x: torch.Tensor) -> torch.Tensor: | |
| r"""Compute Gram matrix for batch of features. | |
| Args: | |
| x: Tensor. Shape :math:`(N, C, H, W)`. | |
| Returns: | |
| Gram matrix for given input | |
| """ | |
| B, C, H, W = x.size() | |
| gram = [] | |
| for i in range(B): | |
| features = x[i].view(C, H * W) | |
| # Add fake channel dimension | |
| gram.append(torch.mm(features, features.t()).unsqueeze(0)) | |
| return torch.stack(gram) | |
| class LPIPS(ContentLoss): | |
| r"""Learned Perceptual Image Patch Similarity metric. Only VGG16 learned weights are supported. | |
| By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1]. | |
| If no normalisation is required, change `mean` and `std` values accordingly. | |
| Args: | |
| replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details. | |
| distance: Method to compute distance between features: ``'mse'`` | ``'mae'``. | |
| reduction: Specifies the reduction type: | |
| ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` | |
| mean: List of float values used for data standardization. Default: ImageNet mean. | |
| If there is no need to normalize data, use [0., 0., 0.]. | |
| std: List of float values used for data standardization. Default: ImageNet std. | |
| If there is no need to normalize data, use [1., 1., 1.]. | |
| Examples: | |
| >>> loss = LPIPS() | |
| >>> x = torch.rand(3, 3, 256, 256, requires_grad=True) | |
| >>> y = torch.rand(3, 3, 256, 256) | |
| >>> output = loss(x, y) | |
| >>> output.backward() | |
| References: | |
| Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016). | |
| A Neural Algorithm of Artistic Style | |
| Association for Research in Vision and Ophthalmology (ARVO) | |
| https://arxiv.org/abs/1508.06576 | |
| Zhang, Richard and Isola, Phillip and Efros, et al. (2018) | |
| The Unreasonable Effectiveness of Deep Features as a Perceptual Metric | |
| IEEE/CVF Conference on Computer Vision and Pattern Recognition | |
| https://arxiv.org/abs/1801.03924 | |
| https://github.com/richzhang/PerceptualSimilarity | |
| """ | |
| _weights_url = "https://github.com/photosynthesis-team/" + \ | |
| "photosynthesis.metrics/releases/download/v0.4.0/lpips_weights.pt" | |
| def __init__(self, replace_pooling: bool = False, distance: str = "mse", reduction: str = "mean", | |
| mean: List[float] = IMAGENET_MEAN, std: List[float] = IMAGENET_STD, ) -> None: | |
| lpips_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'] | |
| lpips_weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False) | |
| super().__init__("vgg16", layers=lpips_layers, weights=lpips_weights, | |
| replace_pooling=replace_pooling, distance=distance, | |
| reduction=reduction, mean=mean, std=std, | |
| normalize_features=True) | |
| class DISTS(ContentLoss): | |
| r"""Deep Image Structure and Texture Similarity metric. | |
| By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1]. | |
| If no normalisation is required, change `mean` and `std` values accordingly. | |
| Args: | |
| reduction: Specifies the reduction type: | |
| ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` | |
| mean: List of float values used for data standardization. Default: ImageNet mean. | |
| If there is no need to normalize data, use [0., 0., 0.]. | |
| std: List of float values used for data standardization. Default: ImageNet std. | |
| If there is no need to normalize data, use [1., 1., 1.]. | |
| Examples: | |
| >>> loss = DISTS() | |
| >>> x = torch.rand(3, 3, 256, 256, requires_grad=True) | |
| >>> y = torch.rand(3, 3, 256, 256) | |
| >>> output = loss(x, y) | |
| >>> output.backward() | |
| References: | |
| Keyan Ding, Kede Ma, Shiqi Wang, Eero P. Simoncelli (2020). | |
| Image Quality Assessment: Unifying Structure and Texture Similarity. | |
| https://arxiv.org/abs/2004.07728 | |
| https://github.com/dingkeyan93/DISTS | |
| """ | |
| _weights_url = "https://github.com/photosynthesis-team/piq/releases/download/v0.4.1/dists_weights.pt" | |
| def __init__(self, reduction: str = "mean", mean: List[float] = IMAGENET_MEAN, | |
| std: List[float] = IMAGENET_STD) -> None: | |
| dists_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'] | |
| channels = [3, 64, 128, 256, 512, 512] | |
| weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False) | |
| dists_weights = list(torch.split(weights['alpha'], channels, dim=1)) | |
| dists_weights.extend(torch.split(weights['beta'], channels, dim=1)) | |
| super().__init__("vgg16", layers=dists_layers, weights=dists_weights, | |
| replace_pooling=True, reduction=reduction, mean=mean, std=std, | |
| normalize_features=False, allow_layers_weights_mismatch=True) | |
| def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
| r""" | |
| Args: | |
| x: An input tensor. Shape :math:`(N, C, H, W)`. | |
| y: A target tensor. Shape :math:`(N, C, H, W)`. | |
| Returns: | |
| Deep Image Structure and Texture Similarity loss, i.e. ``1-DISTS`` in range [0, 1]. | |
| """ | |
| _, _, H, W = x.shape | |
| if min(H, W) > 256: | |
| x = torch.nn.functional.interpolate( | |
| x, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear') | |
| y = torch.nn.functional.interpolate( | |
| y, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear') | |
| loss = super().forward(x, y) | |
| return 1 - loss | |
| def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor) -> List[torch.Tensor]: | |
| r"""Compute structure similarity between feature maps | |
| Args: | |
| x_features: Features of the input tensor. | |
| y_features: Features of the target tensor. | |
| Returns: | |
| Structural similarity distance between feature maps | |
| """ | |
| structure_distance, texture_distance = [], [] | |
| # Small constant for numerical stability | |
| EPS = 1e-6 | |
| for x, y in zip(x_features, y_features): | |
| x_mean = x.mean([2, 3], keepdim=True) | |
| y_mean = y.mean([2, 3], keepdim=True) | |
| structure_distance.append(similarity_map(x_mean, y_mean, constant=EPS)) | |
| x_var = ((x - x_mean) ** 2).mean([2, 3], keepdim=True) | |
| y_var = ((y - y_mean) ** 2).mean([2, 3], keepdim=True) | |
| xy_cov = (x * y).mean([2, 3], keepdim=True) - x_mean * y_mean | |
| texture_distance.append((2 * xy_cov + EPS) / (x_var + y_var + EPS)) | |
| return structure_distance + texture_distance | |
| def get_features(self, x: torch.Tensor) -> List[torch.Tensor]: | |
| r""" | |
| Args: | |
| x: Input tensor | |
| Returns: | |
| List of features extracted from input tensor | |
| """ | |
| features = super().get_features(x) | |
| # Add input tensor as an additional feature | |
| features.insert(0, x) | |
| return features | |
| def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module: | |
| r"""Turn All MaxPool layers into L2Pool | |
| Args: | |
| module: Module to change MaxPool into L2Pool | |
| Returns: | |
| Module with L2Pool instead of MaxPool | |
| """ | |
| module_output = module | |
| if isinstance(module, torch.nn.MaxPool2d): | |
| module_output = L2Pool2d(kernel_size=3, stride=2, padding=1) | |
| for name, child in module.named_children(): | |
| module_output.add_module(name, self.replace_pooling(child)) | |
| return module_output | |