import torch import torch.nn as nn import numpy as np class Layer(nn.Module): def __init__(self, dim_in, dim_out, kernel_size, stride, padding): super(Layer, self).__init__() self.conv = nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=True) self.gate = nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=True) self.bn = nn.BatchNorm2d(dim_out) def forward(self, x): return self.bn(self.conv(x) * torch.sigmoid(self.gate(x))) class Encoder(nn.Module): def __init__(self, out_dim=32, n_layers=3, message_dim=0, message_band_size=None, n_fft=None): super(Encoder, self).__init__() assert message_band_size is not None assert n_fft is not None self.message_band_size = message_band_size main = [Layer(dim_in=1, dim_out=32, kernel_size=3, stride=1, padding=1)] for i in range(n_layers-2): main.append(Layer(dim_in=32, dim_out=32, kernel_size=3, stride=1, padding=1)) main.append(Layer(dim_in=32, dim_out=out_dim, kernel_size=3, stride=1, padding=1)) self.main = nn.Sequential(*main) self.linear = nn.Linear(message_dim, message_band_size) self.n_fft = n_fft def forward(self, x): h = self.main(x) return h def transform_message(self, msg): output = self.linear(msg.transpose(2, 3)).transpose(2, 3) if self.message_band_size != self.n_fft // 2 + 1: output = torch.nn.functional.pad(output, (0, 0, 0, self.n_fft // 2 + 1 - self.message_band_size)) return output class CarrierDecoder(nn.Module): def __init__(self, config, conv_dim, n_layers=4, message_band_size=1024): super(CarrierDecoder, self).__init__() self.config = config self.message_band_size = message_band_size layers = [Layer(dim_in=conv_dim, dim_out=96, kernel_size=3, stride=1, padding=1)] for i in range(n_layers-2): layers.append(Layer(dim_in=96, dim_out=96, kernel_size=3, stride=1, padding=1)) layers.append(Layer(dim_in=96, dim_out=1, kernel_size=1, stride=1, padding=0)) self.main = nn.Sequential(*layers) def forward(self, x, message_sdr): h = self.main(x) if self.config.ensure_negative_message: h = torch.abs(h) h[:, :, self.message_band_size:, :] = 0 if not self.config.no_normalization: h = h / torch.mean(h**2, dim=2, keepdim=True)**0.5 / (10**(message_sdr/20)) return h class MsgDecoder(nn.Module): def __init__(self, message_dim=0, message_band_size=None, channel_dim=128, num_layers=10): super(MsgDecoder, self).__init__() assert message_band_size is not None self.message_band_size = message_band_size main = [ nn.Dropout(0), Layer(dim_in=1, dim_out=channel_dim, kernel_size=3, stride=1, padding=1) ] for l in range(num_layers - 2): main += [ nn.Dropout(0), Layer(dim_in=channel_dim, dim_out=channel_dim, kernel_size=3, stride=1, padding=1), ] main += [ nn.Dropout(0), Layer(dim_in=channel_dim, dim_out=message_dim, kernel_size=3, stride=1, padding=1) ] self.main = nn.Sequential(*main) self.linear = nn.Linear(self.message_band_size, 1) def forward(self, x): h = self.main(x[:, :, :self.message_band_size]) h = self.linear(h.transpose(2, 3)).squeeze(3).unsqueeze(1) return h