import os import sys import torch import torch.nn as nn import torch.nn.functional as F sys.path.append(os.getcwd()) from infer.lib.predictors.DJCM.encoder import ResEncoderBlock from infer.lib.predictors.DJCM.utils import ResConvBlock, BiGRU, init_bn, init_layer class ResDecoderBlock(nn.Module): def __init__(self, in_channels, out_channels, n_blocks, stride): super(ResDecoderBlock, self).__init__() self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, stride, stride, (0, 0), bias=False) self.bn1 = nn.BatchNorm2d(in_channels, momentum=0.01) self.conv = nn.ModuleList([ResConvBlock(out_channels * 2, out_channels)]) for _ in range(n_blocks - 1): self.conv.append(ResConvBlock(out_channels, out_channels)) self.init_weights() def init_weights(self): init_bn(self.bn1) init_layer(self.conv1) def forward(self, x, concat): x = self.conv1(F.relu_(self.bn1(x))) x = torch.cat((x, concat), dim=1) for each_layer in self.conv: x = each_layer(x) return x class Decoder(nn.Module): def __init__(self, n_blocks): super(Decoder, self).__init__() self.de_blocks = nn.ModuleList([ ResDecoderBlock(384, 384, n_blocks, (1, 2)), ResDecoderBlock(384, 384, n_blocks, (1, 2)), ResDecoderBlock(384, 256, n_blocks, (1, 2)), ResDecoderBlock(256, 128, n_blocks, (1, 2)), ResDecoderBlock(128, 64, n_blocks, (1, 2)), ResDecoderBlock(64, 32, n_blocks, (1, 2)) ]) def forward(self, x, concat_tensors): for i, layer in enumerate(self.de_blocks): x = layer(x, concat_tensors[-1 - i]) return x class PE_Decoder(nn.Module): def __init__(self, n_blocks, seq_layers=1, window_length = 1024, n_class = 360): super(PE_Decoder, self).__init__() self.de_blocks = Decoder(n_blocks) self.after_conv1 = ResEncoderBlock(32, 32, n_blocks, None) self.after_conv2 = nn.Conv2d(32, 1, (1, 1)) self.fc = nn.Sequential( BiGRU( (1, window_length // 2), 1, seq_layers ), nn.Linear( window_length // 2, n_class ), nn.Sigmoid() ) init_layer(self.after_conv2) def forward(self, x, concat_tensors): return self.fc(self.after_conv2(self.after_conv1(self.de_blocks(x, concat_tensors)))).squeeze(1) class SVS_Decoder(nn.Module): def __init__(self, in_channels, n_blocks): super(SVS_Decoder, self).__init__() self.de_blocks = Decoder(n_blocks) self.after_conv1 = ResEncoderBlock(32, 32, n_blocks, None) self.after_conv2 = nn.Conv2d(32, in_channels * 4, (1, 1)) self.init_weights() def init_weights(self): init_layer(self.after_conv2) def forward(self, x, concat_tensors): return self.after_conv2(self.after_conv1(self.de_blocks(x, concat_tensors)))