Spaces:
Running
Running
| r"""Custom layers used in metrics computations""" | |
| import torch | |
| from typing import Optional | |
| from .filters import hann_filter | |
| class L2Pool2d(torch.nn.Module): | |
| r"""Applies L2 pooling with Hann window of size 3x3 | |
| Args: | |
| x: Tensor with shape (N, C, H, W)""" | |
| EPS = 1e-12 | |
| def __init__(self, kernel_size: int = 3, stride: int = 2, padding=1) -> None: | |
| super().__init__() | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| self.padding = padding | |
| self.kernel: Optional[torch.Tensor] = None | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.kernel is None: | |
| C = x.size(1) | |
| self.kernel = hann_filter(self.kernel_size).repeat((C, 1, 1, 1)).to(x) | |
| out = torch.nn.functional.conv2d( | |
| x ** 2, self.kernel, | |
| stride=self.stride, | |
| padding=self.padding, | |
| groups=x.shape[1] | |
| ) | |
| return (out + self.EPS).sqrt() | |