|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from .common import filter2D |
|
|
|
|
|
|
|
|
class USMSharp(torch.nn.Module): |
|
|
|
|
|
def __init__(self, radius=50, sigma=0): |
|
|
super(USMSharp, self).__init__() |
|
|
if radius % 2 == 0: |
|
|
radius += 1 |
|
|
self.radius = radius |
|
|
kernel = cv2.getGaussianKernel(radius, sigma) |
|
|
kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) |
|
|
self.register_buffer('kernel', kernel) |
|
|
|
|
|
def forward(self, img, weight=0.5, threshold=10): |
|
|
blur = filter2D(img, self.kernel) |
|
|
residual = img - blur |
|
|
|
|
|
mask = torch.abs(residual) * 255 > threshold |
|
|
mask = mask.float() |
|
|
soft_mask = filter2D(mask, self.kernel) |
|
|
sharp = img + weight * residual |
|
|
sharp = torch.clip(sharp, 0, 1) |
|
|
return soft_mask * sharp + (1 - soft_mask) * img |
|
|
|