import math import torch.nn as nn from torch.nn import functional as F def is_square_of_two(num): if num <= 0: return False return num & (num - 1) == 0 class CnnEncoder(nn.Module): """ Simple cnn encoder that encodes a 64x64 image to embeddings """ def __init__(self, embedding_size, activation_function="relu"): super().__init__() self.act_fn = getattr(F, activation_function) self.embedding_size = embedding_size self.fc = nn.Linear(1024, self.embedding_size) self.conv1 = nn.Conv2d(3, 32, 4, stride=2) self.conv2 = nn.Conv2d(32, 64, 4, stride=2) self.conv3 = nn.Conv2d(64, 128, 4, stride=2) self.conv4 = nn.Conv2d(128, 256, 4, stride=2) self.modules = [self.conv1, self.conv2, self.conv3, self.conv4] def forward(self, observation): batch_size = observation.shape[0] hidden = self.act_fn(self.conv1(observation)) hidden = self.act_fn(self.conv2(hidden)) hidden = self.act_fn(self.conv3(hidden)) hidden = self.act_fn(self.conv4(hidden)) hidden = self.fc(hidden.view(batch_size, 1024)) return hidden class CnnDecoder(nn.Module): """ Simple Cnn decoder that decodes an embedding to 64x64 images """ def __init__(self, embedding_size, activation_function="relu"): super().__init__() self.act_fn = getattr(F, activation_function) self.embedding_size = embedding_size self.fc = nn.Linear(embedding_size, 128) self.conv1 = nn.ConvTranspose2d(128, 128, 5, stride=2) self.conv2 = nn.ConvTranspose2d(128, 64, 5, stride=2) self.conv3 = nn.ConvTranspose2d(64, 32, 6, stride=2) self.conv4 = nn.ConvTranspose2d(32, 3, 6, stride=2) self.modules = [self.conv1, self.conv2, self.conv3, self.conv4] def forward(self, embedding): batch_size = embedding.shape[0] hidden = self.fc(embedding) hidden = hidden.view(batch_size, 128, 1, 1) hidden = self.act_fn(self.conv1(hidden)) hidden = self.act_fn(self.conv2(hidden)) hidden = self.act_fn(self.conv3(hidden)) observation = self.conv4(hidden) return observation class FullyConvEncoder(nn.Module): """ Simple fully convolutional encoder, with 2D input and 2D output """ def __init__( self, input_shape=(3, 64, 64), embedding_shape=(8, 16, 16), activation_function="relu", init_channels=16, ): super().__init__() assert len(input_shape) == 3, "input_shape must be a tuple of length 3" assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3" assert input_shape[1] == input_shape[2] and is_square_of_two( input_shape[1] ), "input_shape must be square" assert ( embedding_shape[1] == embedding_shape[2] ), "embedding_shape must be square" assert ( input_shape[1] % embedding_shape[1] == 0 ), "input_shape must be divisible by embedding_shape" assert is_square_of_two(init_channels), "init_channels must be a square of 2" depth = int(math.sqrt(input_shape[1] / embedding_shape[1])) + 1 channels_per_layer = [init_channels * (2**i) for i in range(depth)] self.act_fn = getattr(F, activation_function) self.downs = nn.ModuleList([]) self.downs.append( nn.Conv2d( input_shape[0], channels_per_layer[0], kernel_size=3, stride=1, padding=1, ) ) for i in range(1, depth): self.downs.append( nn.Conv2d( channels_per_layer[i - 1], channels_per_layer[i], kernel_size=3, stride=2, padding=1, ) ) # Bottleneck layer self.downs.append( nn.Conv2d( channels_per_layer[-1], embedding_shape[0], kernel_size=1, stride=1, padding=0, ) ) def forward(self, observation): hidden = observation for layer in self.downs: hidden = self.act_fn(layer(hidden)) return hidden class FullyConvDecoder(nn.Module): """ Simple fully convolutional decoder, with 2D input and 2D output """ def __init__( self, embedding_shape=(8, 16, 16), output_shape=(3, 64, 64), activation_function="relu", init_channels=16, ): super().__init__() assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3" assert len(output_shape) == 3, "output_shape must be a tuple of length 3" assert output_shape[1] == output_shape[2] and is_square_of_two( output_shape[1] ), "output_shape must be square" assert embedding_shape[1] == embedding_shape[2], "input_shape must be square" assert ( output_shape[1] % embedding_shape[1] == 0 ), "output_shape must be divisible by input_shape" assert is_square_of_two(init_channels), "init_channels must be a square of 2" depth = int(math.sqrt(output_shape[1] / embedding_shape[1])) + 1 channels_per_layer = [init_channels * (2**i) for i in range(depth)] self.act_fn = getattr(F, activation_function) self.ups = nn.ModuleList([]) self.ups.append( nn.ConvTranspose2d( embedding_shape[0], channels_per_layer[-1], kernel_size=1, stride=1, padding=0, ) ) for i in range(1, depth): self.ups.append( nn.ConvTranspose2d( channels_per_layer[-i], channels_per_layer[-i - 1], kernel_size=3, stride=2, padding=1, output_padding=1, ) ) self.output_layer = nn.ConvTranspose2d( channels_per_layer[0], output_shape[0], kernel_size=3, stride=1, padding=1 ) def forward(self, embedding): hidden = embedding for layer in self.ups: hidden = self.act_fn(layer(hidden)) return self.output_layer(hidden)