|
|
from collections import OrderedDict |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
def shape_to_num_params(shapes): |
|
|
return torch.sum(torch.tensor([torch.prod(s) for s in shapes])).int().item() |
|
|
|
|
|
|
|
|
class WeightRegressor(nn.Module): |
|
|
"""Regressing features to convolution weight kernel""" |
|
|
|
|
|
def __init__(self, input_dim, hidden_dim, kernel_size=3, out_channels=16, in_channels=16): |
|
|
super().__init__() |
|
|
self.input_dim = input_dim |
|
|
self.hidden_dim = hidden_dim |
|
|
self.kernel_size = kernel_size |
|
|
self.out_channels = out_channels |
|
|
self.in_channels = in_channels |
|
|
|
|
|
|
|
|
self.fusion = nn.Sequential( |
|
|
nn.Conv2d(2 * self.input_dim, self.input_dim, kernel_size=1, padding=0, stride=1, bias=True), |
|
|
nn.InstanceNorm2d(self.input_dim), |
|
|
nn.ReLU(), |
|
|
) |
|
|
self.feature_extractor = nn.Sequential( |
|
|
nn.Conv2d(self.input_dim, 64, kernel_size=3, padding=1, stride=1, bias=True), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(64, self.hidden_dim, kernel_size=4, stride=2, padding=1, bias=True), |
|
|
nn.ReLU(), |
|
|
) |
|
|
|
|
|
|
|
|
self.w1 = nn.Parameter(torch.randn((self.hidden_dim, self.in_channels * self.hidden_dim))) |
|
|
self.b1 = nn.Parameter(torch.randn((self.in_channels * self.hidden_dim))) |
|
|
self.w2 = nn.Parameter(torch.randn((self.hidden_dim, self.out_channels * self.kernel_size * self.kernel_size))) |
|
|
self.b2 = nn.Parameter(torch.randn((self.out_channels * self.kernel_size * self.kernel_size))) |
|
|
|
|
|
self.weight_init() |
|
|
|
|
|
def weight_init(self): |
|
|
nn.init.kaiming_normal_(self.w1) |
|
|
nn.init.zeros_(self.b1) |
|
|
nn.init.kaiming_normal_(self.w2) |
|
|
nn.init.zeros_(self.b2) |
|
|
|
|
|
def forward(self, w_image_codes, w_bar_codes): |
|
|
bs = w_image_codes.size(0) |
|
|
|
|
|
|
|
|
out = self.fusion(torch.cat((w_image_codes, w_bar_codes), 1)) |
|
|
out = self.feature_extractor(out) |
|
|
out = out.view(bs, -1) |
|
|
|
|
|
|
|
|
out = torch.matmul(out, self.w1) + self.b1 |
|
|
out = out.view(bs, self.in_channels, self.hidden_dim) |
|
|
out = torch.matmul(out, self.w2) + self.b2 |
|
|
kernel = out.view(bs, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size) |
|
|
|
|
|
return kernel |
|
|
|
|
|
|
|
|
class Hypernetwork(nn.Module): |
|
|
def __init__(self, input_dim=512, hidden_dim=64, target_shape=None): |
|
|
super().__init__() |
|
|
self.input_dim = input_dim |
|
|
self.hidden_dim = hidden_dim |
|
|
self.target_shape = target_shape |
|
|
|
|
|
num_predicted_weights = 0 |
|
|
weight_regressors = OrderedDict() |
|
|
for layer_name in target_shape: |
|
|
new_layer_name = "_".join(layer_name.split(".")) |
|
|
shape = target_shape[layer_name]["shape"] |
|
|
|
|
|
|
|
|
if len(shape) == 4: |
|
|
out_channels, in_channels, kernel_size = shape[:3] |
|
|
else: |
|
|
out_channels, in_channels = shape |
|
|
kernel_size = 1 |
|
|
|
|
|
num_predicted_weights += shape_to_num_params([torch.tensor(list(shape))]) |
|
|
weight_regressors[new_layer_name] = WeightRegressor( |
|
|
input_dim=self.input_dim, |
|
|
hidden_dim=self.hidden_dim, |
|
|
kernel_size=kernel_size, |
|
|
out_channels=out_channels, |
|
|
in_channels=in_channels, |
|
|
) |
|
|
self.weight_regressors = nn.ModuleDict(weight_regressors) |
|
|
self.num_predicted_weights = num_predicted_weights |
|
|
|
|
|
def forward(self, w_image_codes, w_bar_codes): |
|
|
bs = w_image_codes.size(0) |
|
|
out_weights = {} |
|
|
|
|
|
for layer_name in self.weight_regressors: |
|
|
ori_layer_name = ".".join(layer_name.split("_")) |
|
|
w_idx = self.target_shape[ori_layer_name]["w_idx"] |
|
|
weights = self.weight_regressors[layer_name]( |
|
|
w_image_codes[:, w_idx, :, :, :], w_bar_codes[:, w_idx, :, :, :] |
|
|
) |
|
|
out_weights[ori_layer_name] = weights.view(bs, *list(self.target_shape[ori_layer_name]["shape"])) |
|
|
|
|
|
return out_weights |