File size: 1,710 Bytes
63d4ab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from torch import nn

from .models import VocosBackbone
from .heads import ISTFTHead


class SopranoDecoder(nn.Module):
    def __init__(self,
                 num_input_channels=512,
                 decoder_num_layers=8,
                 decoder_dim=512,
                 decoder_intermediate_dim=None,
                 hop_length=512,
                 n_fft=2048,
                 upscale=4,
                 dw_kernel=3,
                ):
        super().__init__()
        self.decoder_initial_channels = num_input_channels
        self.num_layers = decoder_num_layers
        self.dim = decoder_dim
        self.intermediate_dim = decoder_intermediate_dim if decoder_intermediate_dim else decoder_dim*3
        self.hop_length = hop_length
        self.n_fft = n_fft
        self.upscale = upscale
        self.dw_kernel = dw_kernel
        
        self.decoder = VocosBackbone(input_channels=self.decoder_initial_channels,
                                      dim=self.dim,
                                      intermediate_dim=self.intermediate_dim,
                                      num_layers=self.num_layers,
                                      input_kernel_size=dw_kernel,
                                      dw_kernel_size=dw_kernel,
                                      )
        self.head = ISTFTHead(dim=self.dim,
                            n_fft=self.n_fft,
                            hop_length=self.hop_length)

    def forward(self, x):
        T = x.size(2)
        x = torch.nn.functional.interpolate(x, size=self.upscale*(T-1)+1, mode='linear', align_corners=True)
        x = self.decoder(x)
        reconstructed = self.head(x)
        return reconstructed