ekwek's picture
Upload 10 files
63d4ab6 verified
import torch
from torch import nn
from .models import VocosBackbone
from .heads import ISTFTHead
class SopranoDecoder(nn.Module):
def __init__(self,
num_input_channels=512,
decoder_num_layers=8,
decoder_dim=512,
decoder_intermediate_dim=None,
hop_length=512,
n_fft=2048,
upscale=4,
dw_kernel=3,
):
super().__init__()
self.decoder_initial_channels = num_input_channels
self.num_layers = decoder_num_layers
self.dim = decoder_dim
self.intermediate_dim = decoder_intermediate_dim if decoder_intermediate_dim else decoder_dim*3
self.hop_length = hop_length
self.n_fft = n_fft
self.upscale = upscale
self.dw_kernel = dw_kernel
self.decoder = VocosBackbone(input_channels=self.decoder_initial_channels,
dim=self.dim,
intermediate_dim=self.intermediate_dim,
num_layers=self.num_layers,
input_kernel_size=dw_kernel,
dw_kernel_size=dw_kernel,
)
self.head = ISTFTHead(dim=self.dim,
n_fft=self.n_fft,
hop_length=self.hop_length)
def forward(self, x):
T = x.size(2)
x = torch.nn.functional.interpolate(x, size=self.upscale*(T-1)+1, mode='linear', align_corners=True)
x = self.decoder(x)
reconstructed = self.head(x)
return reconstructed