Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from .models import VocosBackbone | |
| from .heads import ISTFTHead | |
| class SopranoDecoder(nn.Module): | |
| def __init__(self, | |
| num_input_channels=512, | |
| decoder_num_layers=8, | |
| decoder_dim=512, | |
| decoder_intermediate_dim=None, | |
| hop_length=512, | |
| n_fft=2048, | |
| upscale=4, | |
| dw_kernel=3, | |
| ): | |
| super().__init__() | |
| self.decoder_initial_channels = num_input_channels | |
| self.num_layers = decoder_num_layers | |
| self.dim = decoder_dim | |
| self.intermediate_dim = decoder_intermediate_dim if decoder_intermediate_dim else decoder_dim*3 | |
| self.hop_length = hop_length | |
| self.n_fft = n_fft | |
| self.upscale = upscale | |
| self.dw_kernel = dw_kernel | |
| self.decoder = VocosBackbone(input_channels=self.decoder_initial_channels, | |
| dim=self.dim, | |
| intermediate_dim=self.intermediate_dim, | |
| num_layers=self.num_layers, | |
| input_kernel_size=dw_kernel, | |
| dw_kernel_size=dw_kernel, | |
| ) | |
| self.head = ISTFTHead(dim=self.dim, | |
| n_fft=self.n_fft, | |
| hop_length=self.hop_length) | |
| def forward(self, x): | |
| T = x.size(2) | |
| x = torch.nn.functional.interpolate(x, size=self.upscale*(T-1)+1, mode='linear', align_corners=True) | |
| x = self.decoder(x) | |
| reconstructed = self.head(x) | |
| return reconstructed | |