| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.checkpoint import checkpoint |
| | import warnings |
| | import numpy as np |
| |
|
| |
|
| | def xavier_init(module: nn.Module, |
| | gain: float = 1, |
| | bias: float = 0, |
| | distribution: str = 'normal') -> None: |
| | assert distribution in ['uniform', 'normal'] |
| | if hasattr(module, 'weight') and module.weight is not None: |
| | if distribution == 'uniform': |
| | nn.init.xavier_uniform_(module.weight, gain=gain) |
| | else: |
| | nn.init.xavier_normal_(module.weight, gain=gain) |
| | if hasattr(module, 'bias') and module.bias is not None: |
| | nn.init.constant_(module.bias, bias) |
| |
|
| | def carafe(x, normed_mask, kernel_size, group=1, up=1): |
| | b, c, h, w = x.shape |
| | _, m_c, m_h, m_w = normed_mask.shape |
| | assert m_h == up * h |
| | assert m_w == up * w |
| | pad = kernel_size // 2 |
| | pad_x = F.pad(x, pad=[pad] * 4, mode='reflect') |
| | unfold_x = F.unfold(pad_x, kernel_size=(kernel_size, kernel_size), stride=1, padding=0) |
| | unfold_x = unfold_x.reshape(b, c * kernel_size * kernel_size, h, w) |
| | unfold_x = F.interpolate(unfold_x, scale_factor=up, mode='nearest') |
| | unfold_x = unfold_x.reshape(b, c, kernel_size * kernel_size, m_h, m_w) |
| | normed_mask = normed_mask.reshape(b, 1, kernel_size * kernel_size, m_h, m_w) |
| | res = unfold_x * normed_mask |
| | res = res.sum(dim=2).reshape(b, c, m_h, m_w) |
| | return res |
| |
|
| | def normal_init(module, mean=0, std=1, bias=0): |
| | if hasattr(module, 'weight') and module.weight is not None: |
| | nn.init.normal_(module.weight, mean, std) |
| | if hasattr(module, 'bias') and module.bias is not None: |
| | nn.init.constant_(module.bias, bias) |
| |
|
| |
|
| | def constant_init(module, val, bias=0): |
| | if hasattr(module, 'weight') and module.weight is not None: |
| | nn.init.constant_(module.weight, val) |
| | if hasattr(module, 'bias') and module.bias is not None: |
| | nn.init.constant_(module.bias, bias) |
| |
|
| | def resize(input, |
| | size=None, |
| | scale_factor=None, |
| | mode='nearest', |
| | align_corners=None, |
| | warning=True): |
| | if warning: |
| | if size is not None and align_corners: |
| | input_h, input_w = tuple(int(x) for x in input.shape[2:]) |
| | output_h, output_w = tuple(int(x) for x in size) |
| | if output_h > input_h or output_w > input_w: |
| | if ((output_h > 1 and output_w > 1 and input_h > 1 |
| | and input_w > 1) and (output_h - 1) % (input_h - 1) |
| | and (output_w - 1) % (input_w - 1)): |
| | warnings.warn( |
| | f'When align_corners={align_corners}, ' |
| | 'the output would more aligned if ' |
| | f'input size {(input_h, input_w)} is `x+1` and ' |
| | f'out size {(output_h, output_w)} is `nx+1`') |
| | return F.interpolate(input, size, scale_factor, mode, align_corners) |
| |
|
| | def hamming2D(M, N): |
| | hamming_x = np.hamming(M) |
| | hamming_y = np.hamming(N) |
| | hamming_2d = np.outer(hamming_x, hamming_y) |
| | return hamming_2d |
| |
|
| | class DesneFusion(nn.Module): |
| | def __init__(self, |
| | hr_channels, |
| | lr_channels, |
| | scale_factor=1, |
| | lowpass_kernel=5, |
| | highpass_kernel=3, |
| | up_group=1, |
| | encoder_kernel=3, |
| | encoder_dilation=1, |
| | compressed_channels=64, |
| | align_corners=False, |
| | upsample_mode='nearest', |
| | feature_resample=False, |
| | feature_resample_group=4, |
| | comp_feat_upsample=True, |
| | use_high_pass=True, |
| | use_low_pass=True, |
| | hr_residual=True, |
| | semi_conv=True, |
| | hamming_window=True, |
| | feature_resample_norm=True, |
| | **kwargs): |
| | super().__init__() |
| | self.scale_factor = scale_factor |
| | self.lowpass_kernel = lowpass_kernel |
| | self.highpass_kernel = highpass_kernel |
| | self.up_group = up_group |
| | self.encoder_kernel = encoder_kernel |
| | self.encoder_dilation = encoder_dilation |
| | self.compressed_channels = compressed_channels |
| | self.hr_channel_compressor = nn.Conv2d(hr_channels, self.compressed_channels,1) |
| | self.lr_channel_compressor = nn.Conv2d(lr_channels, self.compressed_channels,1) |
| | self.content_encoder = nn.Conv2d( |
| | self.compressed_channels, |
| | lowpass_kernel ** 2 * self.up_group * self.scale_factor * self.scale_factor, |
| | self.encoder_kernel, |
| | padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2), |
| | dilation=self.encoder_dilation, |
| | groups=1) |
| | |
| | self.align_corners = align_corners |
| | self.upsample_mode = upsample_mode |
| | self.hr_residual = hr_residual |
| | self.use_high_pass = use_high_pass |
| | self.use_low_pass = use_low_pass |
| | self.semi_conv = semi_conv |
| | self.feature_resample = feature_resample |
| | self.comp_feat_upsample = comp_feat_upsample |
| | if self.feature_resample: |
| | self.dysampler = LocalSimGuidedSampler(in_channels=compressed_channels, scale=2, style='lp', groups=feature_resample_group, use_direct_scale=True, kernel_size=encoder_kernel, norm=feature_resample_norm) |
| | if self.use_high_pass: |
| | self.content_encoder2 = nn.Conv2d( |
| | self.compressed_channels, |
| | highpass_kernel ** 2 * self.up_group * self.scale_factor * self.scale_factor, |
| | self.encoder_kernel, |
| | padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2), |
| | dilation=self.encoder_dilation, |
| | groups=1) |
| | self.hamming_window = hamming_window |
| | lowpass_pad=0 |
| | highpass_pad=0 |
| | if self.hamming_window: |
| | self.register_buffer('hamming_lowpass', torch.FloatTensor(hamming2D(lowpass_kernel + 2 * lowpass_pad, lowpass_kernel + 2 * lowpass_pad))[None, None,]) |
| | self.register_buffer('hamming_highpass', torch.FloatTensor(hamming2D(highpass_kernel + 2 * highpass_pad, highpass_kernel + 2 * highpass_pad))[None, None,]) |
| | else: |
| | self.register_buffer('hamming_lowpass', torch.FloatTensor([1.0])) |
| | self.register_buffer('hamming_highpass', torch.FloatTensor([1.0])) |
| | self.init_weights() |
| | self.intermediate_results = {} |
| |
|
| |
|
| | def init_weights(self): |
| | for m in self.modules(): |
| | |
| | if isinstance(m, nn.Conv2d): |
| | xavier_init(m, distribution='uniform') |
| | normal_init(self.content_encoder, std=0.001) |
| | if self.use_high_pass: |
| | normal_init(self.content_encoder2, std=0.001) |
| |
|
| | def kernel_normalizer(self, mask, kernel, scale_factor=None, hamming=1): |
| | if scale_factor is not None: |
| | mask = F.pixel_shuffle(mask, self.scale_factor) |
| | n, mask_c, h, w = mask.size() |
| | mask_channel = int(mask_c / float(kernel**2)) |
| | |
| | |
| | |
| |
|
| | mask = mask.view(n, mask_channel, -1, h, w) |
| | mask = F.softmax(mask, dim=2, dtype=mask.dtype) |
| | mask = mask.view(n, mask_channel, kernel, kernel, h, w) |
| | mask = mask.permute(0, 1, 4, 5, 2, 3).view(n, -1, kernel, kernel) |
| | |
| | mask = mask * hamming |
| | mask /= mask.sum(dim=(-1, -2), keepdims=True) |
| | |
| | |
| | mask = mask.view(n, mask_channel, h, w, -1) |
| | mask = mask.permute(0, 1, 4, 2, 3).view(n, -1, h, w).contiguous() |
| | return mask |
| |
|
| | def forward(self, hr_feat, lr_feat, use_checkpoint=False): |
| | if use_checkpoint: |
| | return checkpoint(self._forward, hr_feat, lr_feat) |
| | else: |
| | return self._forward(hr_feat, lr_feat) |
| |
|
| | def _forward(self, hr_feat, lr_feat): |
| | |
| | |
| | |
| | self.intermediate_results.clear() |
| | |
| | |
| | self.intermediate_results['hr_feat_before'] = hr_feat.clone() |
| | self.intermediate_results['lr_feat_before'] = lr_feat.clone() |
| |
|
| | compressed_hr_feat = self.hr_channel_compressor(hr_feat) |
| | compressed_lr_feat = self.lr_channel_compressor(lr_feat) |
| | if self.semi_conv: |
| | if self.comp_feat_upsample: |
| | if self.use_high_pass: |
| | mask_hr_hr_feat = self.content_encoder2(compressed_hr_feat) |
| | mask_hr_init = self.kernel_normalizer(mask_hr_hr_feat, self.highpass_kernel, hamming=self.hamming_highpass) |
| | compressed_hr_feat = compressed_hr_feat + compressed_hr_feat - carafe(compressed_hr_feat, mask_hr_init, self.highpass_kernel, self.up_group, 1) |
| | |
| | mask_lr_hr_feat = self.content_encoder(compressed_hr_feat) |
| | mask_lr_init = self.kernel_normalizer(mask_lr_hr_feat, self.lowpass_kernel, hamming=self.hamming_lowpass) |
| | |
| | mask_lr_lr_feat_lr = self.content_encoder(compressed_lr_feat) |
| | mask_lr_lr_feat = F.interpolate( |
| | carafe(mask_lr_lr_feat_lr, mask_lr_init, self.lowpass_kernel, self.up_group, 2), size=compressed_hr_feat.shape[-2:], mode='nearest') |
| | mask_lr = mask_lr_hr_feat + mask_lr_lr_feat |
| |
|
| | mask_lr_init = self.kernel_normalizer(mask_lr, self.lowpass_kernel, hamming=self.hamming_lowpass) |
| | mask_hr_lr_feat = F.interpolate( |
| | carafe(self.content_encoder2(compressed_lr_feat), mask_lr_init, self.lowpass_kernel, self.up_group, 2), size=compressed_hr_feat.shape[-2:], mode='nearest') |
| | mask_hr = mask_hr_hr_feat + mask_hr_lr_feat |
| | else: raise NotImplementedError |
| | else: |
| | mask_lr = self.content_encoder(compressed_hr_feat) + F.interpolate(self.content_encoder(compressed_lr_feat), size=compressed_hr_feat.shape[-2:], mode='nearest') |
| | if self.use_high_pass: |
| | mask_hr = self.content_encoder2(compressed_hr_feat) + F.interpolate(self.content_encoder2(compressed_lr_feat), size=compressed_hr_feat.shape[-2:], mode='nearest') |
| | else: |
| | compressed_x = F.interpolate(compressed_lr_feat, size=compressed_hr_feat.shape[-2:], mode='nearest') + compressed_hr_feat |
| | mask_lr = self.content_encoder(compressed_x) |
| | if self.use_high_pass: |
| | mask_hr = self.content_encoder2(compressed_x) |
| | |
| | mask_lr = self.kernel_normalizer(mask_lr, self.lowpass_kernel, hamming=self.hamming_lowpass) |
| | |
| | |
| | lr_feat_after = carafe(lr_feat, mask_lr, self.lowpass_kernel, self.up_group, 2) |
| | self.intermediate_results['lr_feat_after'] = lr_feat_after.clone() |
| | |
| | if self.semi_conv: |
| | lr_feat = carafe(lr_feat, mask_lr, self.lowpass_kernel, self.up_group, 2) |
| | else: |
| | lr_feat = resize( |
| | input=lr_feat, |
| | size=hr_feat.shape[2:], |
| | mode=self.upsample_mode, |
| | align_corners=None if self.upsample_mode == 'nearest' else self.align_corners) |
| | lr_feat = carafe(lr_feat, mask_lr, self.lowpass_kernel, self.up_group, 1) |
| |
|
| | if self.use_high_pass: |
| | mask_hr = self.kernel_normalizer(mask_hr, self.highpass_kernel, hamming=self.hamming_highpass) |
| | hr_feat_hf = hr_feat - carafe(hr_feat, mask_hr, self.highpass_kernel, self.up_group, 1) |
| | self.intermediate_results['hr_feat_hf_component'] = hr_feat_hf.clone() |
| | if self.hr_residual: |
| | |
| | hr_feat = hr_feat_hf + hr_feat |
| | else: |
| | hr_feat = hr_feat_hf |
| | self.intermediate_results['hr_feat_after'] = hr_feat.clone() |
| | else: |
| | |
| | final_hr_feat = hr_feat |
| | self.intermediate_results['hr_feat_hf_component'] = torch.zeros_like(final_hr_feat) |
| | self.intermediate_results['hr_feat_after'] = final_hr_feat.clone() |
| |
|
| |
|
| | if self.feature_resample: |
| | |
| | lr_feat = self.dysampler(hr_x=compressed_hr_feat, |
| | lr_x=compressed_lr_feat, feat2sample=lr_feat) |
| | self.intermediate_results['lr_feat_after'] = lr_feat.clone() |
| |
|
| | return mask_lr, hr_feat, lr_feat |
| |
|
| |
|
| |
|
| | class LocalSimGuidedSampler(nn.Module): |
| | """ |
| | offset generator in DesneFusion |
| | """ |
| | def __init__(self, in_channels, scale=2, style='lp', groups=4, use_direct_scale=True, kernel_size=1, local_window=3, sim_type='cos', norm=True, direction_feat='sim_concat'): |
| | super().__init__() |
| | assert scale==2 |
| | assert style=='lp' |
| |
|
| | self.scale = scale |
| | self.style = style |
| | self.groups = groups |
| | self.local_window = local_window |
| | self.sim_type = sim_type |
| | self.direction_feat = direction_feat |
| |
|
| | if style == 'pl': |
| | assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0 |
| | assert in_channels >= groups and in_channels % groups == 0 |
| |
|
| | if style == 'pl': |
| | in_channels = in_channels // scale ** 2 |
| | out_channels = 2 * groups |
| | else: |
| | out_channels = 2 * groups * scale ** 2 |
| | if self.direction_feat == 'sim': |
| | self.offset = nn.Conv2d(local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2) |
| | elif self.direction_feat == 'sim_concat': |
| | self.offset = nn.Conv2d(in_channels + local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2) |
| | else: raise NotImplementedError |
| | normal_init(self.offset, std=0.001) |
| | if use_direct_scale: |
| | if self.direction_feat == 'sim': |
| | self.direct_scale = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2) |
| | elif self.direction_feat == 'sim_concat': |
| | self.direct_scale = nn.Conv2d(in_channels + local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2) |
| | else: raise NotImplementedError |
| | constant_init(self.direct_scale, val=0.) |
| |
|
| | out_channels = 2 * groups |
| | if self.direction_feat == 'sim': |
| | self.hr_offset = nn.Conv2d(local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2) |
| | elif self.direction_feat == 'sim_concat': |
| | self.hr_offset = nn.Conv2d(in_channels + local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2) |
| | else: raise NotImplementedError |
| | normal_init(self.hr_offset, std=0.001) |
| | |
| | if use_direct_scale: |
| | if self.direction_feat == 'sim': |
| | self.hr_direct_scale = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2) |
| | elif self.direction_feat == 'sim_concat': |
| | self.hr_direct_scale = nn.Conv2d(in_channels + local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2) |
| | else: raise NotImplementedError |
| | constant_init(self.hr_direct_scale, val=0.) |
| |
|
| | self.norm = norm |
| | if self.norm: |
| | self.norm_hr = nn.GroupNorm(in_channels // 8, in_channels) |
| | self.norm_lr = nn.GroupNorm(in_channels // 8, in_channels) |
| | else: |
| | self.norm_hr = nn.Identity() |
| | self.norm_lr = nn.Identity() |
| | self.register_buffer('init_pos', self._init_pos()) |
| |
|
| | def _init_pos(self): |
| | h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale |
| | return torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1) |
| | |
| | def sample(self, x, offset, scale=None): |
| | if scale is None: scale = self.scale |
| | B, _, H, W = offset.shape |
| | offset = offset.view(B, 2, -1, H, W) |
| | coords_h = torch.arange(H) + 0.5 |
| | coords_w = torch.arange(W) + 0.5 |
| | coords = torch.stack(torch.meshgrid([coords_w, coords_h]) |
| | ).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device) |
| | normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1) |
| | coords = 2 * (coords + offset) / normalizer - 1 |
| | coords = F.pixel_shuffle(coords.view(B, -1, H, W), scale).view( |
| | B, 2, -1, scale * H, scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1) |
| | return F.grid_sample(x.reshape(B * self.groups, -1, x.size(-2), x.size(-1)), coords, mode='bilinear', |
| | align_corners=False, padding_mode="border").view(B, -1, scale * H, scale * W) |
| | |
| | def forward(self, hr_x, lr_x, feat2sample): |
| | hr_x = self.norm_hr(hr_x) |
| | lr_x = self.norm_lr(lr_x) |
| |
|
| | if self.direction_feat == 'sim': |
| | hr_sim = compute_similarity(hr_x, self.local_window, dilation=2, sim='cos') |
| | lr_sim = compute_similarity(lr_x, self.local_window, dilation=2, sim='cos') |
| | elif self.direction_feat == 'sim_concat': |
| | hr_sim = torch.cat([hr_x, compute_similarity(hr_x, self.local_window, dilation=2, sim='cos')], dim=1) |
| | lr_sim = torch.cat([lr_x, compute_similarity(lr_x, self.local_window, dilation=2, sim='cos')], dim=1) |
| | hr_x, lr_x = hr_sim, lr_sim |
| | |
| | offset = self.get_offset_lp(hr_x, lr_x, hr_sim, lr_sim) |
| | return self.sample(feat2sample, offset) |
| | |
| | |
| | def get_offset_lp(self, hr_x, lr_x, hr_sim, lr_sim): |
| | if hasattr(self, 'direct_scale'): |
| | |
| | offset = (self.offset(lr_sim) + F.pixel_unshuffle(self.hr_offset(hr_sim), self.scale)) * (self.direct_scale(lr_x) + F.pixel_unshuffle(self.hr_direct_scale(hr_x), self.scale)).sigmoid() + self.init_pos |
| | |
| | else: |
| | offset = (self.offset(lr_x) + F.pixel_unshuffle(self.hr_offset(hr_x), self.scale)) * 0.25 + self.init_pos |
| | return offset |
| |
|
| | def get_offset(self, hr_x, lr_x): |
| | if self.style == 'pl': |
| | raise NotImplementedError |
| | return self.get_offset_lp(hr_x, lr_x) |
| | |
| |
|
| | def compute_similarity(input_tensor, k=3, dilation=1, sim='cos'): |
| | """ |
| | 计算输入张量中每一点与周围KxK范围内的点的余弦相似度。 |
| | |
| | 参数: |
| | - input_tensor: 输入张量,形状为[B, C, H, W] |
| | - k: 范围大小,表示周围KxK范围内的点 |
| | |
| | 返回: |
| | - 输出张量,形状为[B, KxK-1, H, W] |
| | """ |
| | B, C, H, W = input_tensor.shape |
| | |
| | |
| |
|
| | |
| | unfold_tensor = F.unfold(input_tensor, k, padding=(k // 2) * dilation, dilation=dilation) |
| | |
| | unfold_tensor = unfold_tensor.reshape(B, C, k**2, H, W) |
| |
|
| | |
| | if sim == 'cos': |
| | similarity = F.cosine_similarity(unfold_tensor[:, :, k * k // 2:k * k // 2 + 1], unfold_tensor[:, :, :], dim=1) |
| | elif sim == 'dot': |
| | similarity = unfold_tensor[:, :, k * k // 2:k * k // 2 + 1] * unfold_tensor[:, :, :] |
| | similarity = similarity.sum(dim=1) |
| | else: |
| | raise NotImplementedError |
| |
|
| | |
| | similarity = torch.cat((similarity[:, :k * k // 2], similarity[:, k * k // 2 + 1:]), dim=1) |
| |
|
| | |
| | similarity = similarity.view(B, k * k - 1, H, W) |
| | return similarity |
| |
|
| |
|
| | if __name__ == '__main__': |
| | |
| | |
| | |
| |
|
| | hr_feat = torch.rand(1, 128, 512, 512) |
| | lr_feat = torch.rand(1, 128, 256, 256) |
| | model = DesneFusion(hr_channels=128, lr_channels=128) |
| | mask_lr, hr_feat, lr_feat = model(hr_feat=hr_feat, lr_feat=lr_feat) |
| | print(mask_lr.shape) |