NeoPy's picture
Update infer/lib/predictors/DJCM/decoder.py
bb6b19f verified
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.append(os.getcwd())
from infer.lib.predictors.DJCM.encoder import ResEncoderBlock
from infer.lib.predictors.DJCM.utils import ResConvBlock, BiGRU, init_bn, init_layer
class ResDecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, n_blocks, stride):
super(ResDecoderBlock, self).__init__()
self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, stride, stride, (0, 0), bias=False)
self.bn1 = nn.BatchNorm2d(in_channels, momentum=0.01)
self.conv = nn.ModuleList([ResConvBlock(out_channels * 2, out_channels)])
for _ in range(n_blocks - 1):
self.conv.append(ResConvBlock(out_channels, out_channels))
self.init_weights()
def init_weights(self):
init_bn(self.bn1)
init_layer(self.conv1)
def forward(self, x, concat):
x = self.conv1(F.relu_(self.bn1(x)))
x = torch.cat((x, concat), dim=1)
for each_layer in self.conv:
x = each_layer(x)
return x
class Decoder(nn.Module):
def __init__(self, n_blocks):
super(Decoder, self).__init__()
self.de_blocks = nn.ModuleList([
ResDecoderBlock(384, 384, n_blocks, (1, 2)),
ResDecoderBlock(384, 384, n_blocks, (1, 2)),
ResDecoderBlock(384, 256, n_blocks, (1, 2)),
ResDecoderBlock(256, 128, n_blocks, (1, 2)),
ResDecoderBlock(128, 64, n_blocks, (1, 2)),
ResDecoderBlock(64, 32, n_blocks, (1, 2))
])
def forward(self, x, concat_tensors):
for i, layer in enumerate(self.de_blocks):
x = layer(x, concat_tensors[-1 - i])
return x
class PE_Decoder(nn.Module):
def __init__(self, n_blocks, seq_layers=1, window_length = 1024, n_class = 360):
super(PE_Decoder, self).__init__()
self.de_blocks = Decoder(n_blocks)
self.after_conv1 = ResEncoderBlock(32, 32, n_blocks, None)
self.after_conv2 = nn.Conv2d(32, 1, (1, 1))
self.fc = nn.Sequential(
BiGRU(
(1, window_length // 2),
1,
seq_layers
),
nn.Linear(
window_length // 2,
n_class
),
nn.Sigmoid()
)
init_layer(self.after_conv2)
def forward(self, x, concat_tensors):
return self.fc(self.after_conv2(self.after_conv1(self.de_blocks(x, concat_tensors)))).squeeze(1)
class SVS_Decoder(nn.Module):
def __init__(self, in_channels, n_blocks):
super(SVS_Decoder, self).__init__()
self.de_blocks = Decoder(n_blocks)
self.after_conv1 = ResEncoderBlock(32, 32, n_blocks, None)
self.after_conv2 = nn.Conv2d(32, in_channels * 4, (1, 1))
self.init_weights()
def init_weights(self):
init_layer(self.after_conv2)
def forward(self, x, concat_tensors):
return self.after_conv2(self.after_conv1(self.de_blocks(x, concat_tensors)))