Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from .common.batch_norm_conv import BatchNormConv | |
| from .common.highway_network import HighwayNetwork | |
| class CBHG(nn.Module): | |
| def __init__(self, K, in_channels, channels, proj_channels, num_highways): | |
| super().__init__() | |
| # List of all rnns to call `flatten_parameters()` on | |
| self._to_flatten = [] | |
| self.bank_kernels = [i for i in range(1, K + 1)] | |
| self.conv1d_bank = nn.ModuleList() | |
| for k in self.bank_kernels: | |
| conv = BatchNormConv(in_channels, channels, k) | |
| self.conv1d_bank.append(conv) | |
| self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) | |
| self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3) | |
| self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False) | |
| # Fix the highway input if necessary | |
| if proj_channels[-1] != channels: | |
| self.highway_mismatch = True | |
| self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False) | |
| else: | |
| self.highway_mismatch = False | |
| self.highways = nn.ModuleList() | |
| for i in range(num_highways): | |
| hn = HighwayNetwork(channels) | |
| self.highways.append(hn) | |
| self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True) | |
| self._to_flatten.append(self.rnn) | |
| # Avoid fragmentation of RNN parameters and associated warning | |
| self._flatten_parameters() | |
| def forward(self, x): | |
| # Although we `_flatten_parameters()` on init, when using DataParallel | |
| # the model gets replicated, making it no longer guaranteed that the | |
| # weights are contiguous in GPU memory. Hence, we must call it again | |
| self.rnn.flatten_parameters() | |
| # Save these for later | |
| residual = x | |
| seq_len = x.size(-1) | |
| conv_bank = [] | |
| # Convolution Bank | |
| for conv in self.conv1d_bank: | |
| c = conv(x) # Convolution | |
| conv_bank.append(c[:, :, :seq_len]) | |
| # Stack along the channel axis | |
| conv_bank = torch.cat(conv_bank, dim=1) | |
| # dump the last padding to fit residual | |
| x = self.maxpool(conv_bank)[:, :, :seq_len] | |
| # Conv1d projections | |
| x = self.conv_project1(x) | |
| x = self.conv_project2(x) | |
| # Residual Connect | |
| x = x + residual | |
| # Through the highways | |
| x = x.transpose(1, 2) | |
| if self.highway_mismatch is True: | |
| x = self.pre_highway(x) | |
| for h in self.highways: x = h(x) | |
| # And then the RNN | |
| x, _ = self.rnn(x) | |
| return x | |
| def _flatten_parameters(self): | |
| """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used | |
| to improve efficiency and avoid PyTorch yelling at us.""" | |
| [m.flatten_parameters() for m in self._to_flatten] | |