import os import sys import torch.nn as nn sys.path.append(os.getcwd()) from main.library.predictors.DJCM.utils import ResConvBlock class ResEncoderBlock(nn.Module): def __init__( self, in_channels, out_channels, n_blocks, kernel_size ): super(ResEncoderBlock, self).__init__() self.conv = nn.ModuleList([ ResConvBlock( in_channels, out_channels ) ]) for _ in range(n_blocks - 1): self.conv.append( ResConvBlock( out_channels, out_channels ) ) self.pool = nn.MaxPool2d(kernel_size) if kernel_size is not None else None def forward(self, x): for each_layer in self.conv: x = each_layer(x) if self.pool is not None: return x, self.pool(x) return x class Encoder(nn.Module): def __init__( self, in_channels, n_blocks ): super(Encoder, self).__init__() self.en_blocks = nn.ModuleList([ ResEncoderBlock( in_channels, 32, n_blocks, (1, 2) ), ResEncoderBlock( 32, 64, n_blocks, (1, 2) ), ResEncoderBlock( 64, 128, n_blocks, (1, 2) ), ResEncoderBlock( 128, 256, n_blocks, (1, 2) ), ResEncoderBlock( 256, 384, n_blocks, (1, 2) ), ResEncoderBlock( 384, 384, n_blocks, (1, 2) ) ]) def forward(self, x): concat_tensors = [] for layer in self.en_blocks: _, x = layer(x) concat_tensors.append(_) return x, concat_tensors