Spaces:
Build error
Build error
| import numpy as np | |
| import random | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from modules.audio2motion.cnn_models import LambdaLayer | |
| class Discriminator1DFactory(nn.Module): | |
| def __init__(self, time_length, kernel_size=3, in_dim=1, hidden_size=128, norm_type='bn'): | |
| super(Discriminator1DFactory, self).__init__() | |
| padding = kernel_size // 2 | |
| def discriminator_block(in_filters, out_filters, first=False): | |
| """ | |
| Input: (B, c, T) | |
| Output:(B, c, T//2) | |
| """ | |
| conv = nn.Conv1d(in_filters, out_filters, kernel_size, 2, padding) | |
| block = [ | |
| conv, # padding = kernel//2 | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Dropout2d(0.25) | |
| ] | |
| if norm_type == 'bn' and not first: | |
| block.append(nn.BatchNorm1d(out_filters, 0.8)) | |
| if norm_type == 'in' and not first: | |
| block.append(nn.InstanceNorm1d(out_filters, affine=True)) | |
| block = nn.Sequential(*block) | |
| return block | |
| if time_length >= 8: | |
| self.model = nn.ModuleList([ | |
| discriminator_block(in_dim, hidden_size, first=True), | |
| discriminator_block(hidden_size, hidden_size), | |
| discriminator_block(hidden_size, hidden_size), | |
| ]) | |
| ds_size = time_length // (2 ** 3) | |
| elif time_length == 3: | |
| self.model = nn.ModuleList([ | |
| nn.Sequential(*[ | |
| nn.Conv1d(in_dim, hidden_size, 3, 1, 0), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Dropout2d(0.25), | |
| nn.Conv1d(hidden_size, hidden_size, 1, 1, 0), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Dropout2d(0.25), | |
| nn.BatchNorm1d(hidden_size, 0.8), | |
| nn.Conv1d(hidden_size, hidden_size, 1, 1, 0), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Dropout2d(0.25), | |
| nn.BatchNorm1d(hidden_size, 0.8) | |
| ]) | |
| ]) | |
| ds_size = 1 | |
| elif time_length == 1: | |
| self.model = nn.ModuleList([ | |
| nn.Sequential(*[ | |
| nn.Linear(in_dim, hidden_size), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Dropout2d(0.25), | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Dropout2d(0.25), | |
| ]) | |
| ]) | |
| ds_size = 1 | |
| self.adv_layer = nn.Linear(hidden_size * ds_size, 1) | |
| def forward(self, x): | |
| """ | |
| :param x: [B, C, T] | |
| :return: validity: [B, 1], h: List of hiddens | |
| """ | |
| h = [] | |
| if x.shape[-1] == 1: | |
| x = x.squeeze(-1) | |
| for l in self.model: | |
| x = l(x) | |
| h.append(x) | |
| if x.ndim == 2: | |
| b, ct = x.shape | |
| use_sigmoid = True | |
| else: | |
| b, c, t = x.shape | |
| ct = c * t | |
| use_sigmoid = False | |
| x = x.view(b, ct) | |
| validity = self.adv_layer(x) # [B, 1] | |
| if use_sigmoid: | |
| validity = torch.sigmoid(validity) | |
| return validity, h | |
| class CosineDiscriminator1DFactory(nn.Module): | |
| def __init__(self, time_length, kernel_size=3, in_dim=1, hidden_size=128, norm_type='bn'): | |
| super().__init__() | |
| padding = kernel_size // 2 | |
| def discriminator_block(in_filters, out_filters, first=False): | |
| """ | |
| Input: (B, c, T) | |
| Output:(B, c, T//2) | |
| """ | |
| conv = nn.Conv1d(in_filters, out_filters, kernel_size, 2, padding) | |
| block = [ | |
| conv, # padding = kernel//2 | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Dropout2d(0.25) | |
| ] | |
| if norm_type == 'bn' and not first: | |
| block.append(nn.BatchNorm1d(out_filters, 0.8)) | |
| if norm_type == 'in' and not first: | |
| block.append(nn.InstanceNorm1d(out_filters, affine=True)) | |
| block = nn.Sequential(*block) | |
| return block | |
| self.model1 = nn.ModuleList([ | |
| discriminator_block(in_dim, hidden_size, first=True), | |
| discriminator_block(hidden_size, hidden_size), | |
| discriminator_block(hidden_size, hidden_size), | |
| ]) | |
| self.model2 = nn.ModuleList([ | |
| discriminator_block(in_dim, hidden_size, first=True), | |
| discriminator_block(hidden_size, hidden_size), | |
| discriminator_block(hidden_size, hidden_size), | |
| ]) | |
| self.relu = nn.ReLU() | |
| def forward(self, x1, x2): | |
| """ | |
| :param x1: [B, C, T] | |
| :param x2: [B, C, T] | |
| :return: validity: [B, 1], h: List of hiddens | |
| """ | |
| h1, h2 = [], [] | |
| for l in self.model1: | |
| x1 = l(x1) | |
| h1.append(x1) | |
| for l in self.model2: | |
| x2 = l(x2) | |
| h2.append(x1) | |
| b,c,t = x1.shape | |
| x1 = x1.view(b, c*t) | |
| x2 = x2.view(b, c*t) | |
| x1 = self.relu(x1) | |
| x2 = self.relu(x2) | |
| # x1 = F.normalize(x1, p=2, dim=1) | |
| # x2 = F.normalize(x2, p=2, dim=1) | |
| validity = F.cosine_similarity(x1, x2) | |
| return validity, [h1,h2] | |
| class MultiWindowDiscriminator(nn.Module): | |
| def __init__(self, time_lengths, cond_dim=80, in_dim=64, kernel_size=3, hidden_size=128, disc_type='standard', norm_type='bn', reduction='sum'): | |
| super(MultiWindowDiscriminator, self).__init__() | |
| self.win_lengths = time_lengths | |
| self.reduction = reduction | |
| self.disc_type = disc_type | |
| if cond_dim > 0: | |
| self.use_cond = True | |
| self.cond_proj_layers = nn.ModuleList() | |
| self.in_proj_layers = nn.ModuleList() | |
| else: | |
| self.use_cond = False | |
| self.conv_layers = nn.ModuleList() | |
| for time_length in time_lengths: | |
| conv_layer = [ | |
| Discriminator1DFactory( | |
| time_length, kernel_size, in_dim=64, hidden_size=hidden_size, | |
| norm_type=norm_type) if self.disc_type == 'standard' | |
| else CosineDiscriminator1DFactory(time_length, kernel_size, in_dim=64, | |
| hidden_size=hidden_size,norm_type=norm_type) | |
| ] | |
| self.conv_layers += conv_layer | |
| if self.use_cond: | |
| self.cond_proj_layers.append(nn.Linear(cond_dim, 64)) | |
| self.in_proj_layers.append(nn.Linear(in_dim, 64)) | |
| def clip(self, x, cond, x_len, win_length, start_frames=None): | |
| '''Ramdom clip x to win_length. | |
| Args: | |
| x (tensor) : (B, T, C). | |
| cond (tensor) : (B, T, H). | |
| x_len (tensor) : (B,). | |
| win_length (int): target clip length | |
| Returns: | |
| (tensor) : (B, c_in, win_length, n_bins). | |
| ''' | |
| clip_from_same_frame = start_frames is None | |
| T_start = 0 | |
| # T_end = x_len.max() - win_length | |
| T_end = x_len.min() - win_length | |
| if T_end < 0: | |
| return None, None, start_frames | |
| T_end = T_end.item() | |
| if start_frames is None: | |
| start_frame = np.random.randint(low=T_start, high=T_end + 1) | |
| start_frames = [start_frame] * x.size(0) | |
| else: | |
| start_frame = start_frames[0] | |
| if clip_from_same_frame: | |
| x_batch = x[:, start_frame: start_frame + win_length, :] | |
| c_batch = cond[:, start_frame: start_frame + win_length, :] if cond is not None else None | |
| else: | |
| x_lst = [] | |
| c_lst = [] | |
| for i, start_frame in enumerate(start_frames): | |
| x_lst.append(x[i, start_frame: start_frame + win_length, :]) | |
| if cond is not None: | |
| c_lst.append(cond[i, start_frame: start_frame + win_length, :]) | |
| x_batch = torch.stack(x_lst, dim=0) | |
| if cond is None: | |
| c_batch = None | |
| else: | |
| c_batch = torch.stack(c_lst, dim=0) | |
| return x_batch, c_batch, start_frames | |
| def forward(self, x, x_len, cond=None, start_frames_wins=None): | |
| ''' | |
| Args: | |
| x (tensor): input mel, (B, T, C). | |
| x_length (tensor): len of per mel. (B,). | |
| Returns: | |
| tensor : (B). | |
| ''' | |
| validity = [] | |
| if start_frames_wins is None: | |
| start_frames_wins = [None] * len(self.conv_layers) | |
| h = [] | |
| for i, start_frames in zip(range(len(self.conv_layers)), start_frames_wins): | |
| x_clip, c_clip, start_frames = self.clip( | |
| x, cond, x_len, self.win_lengths[i], start_frames) # (B, win_length, C) | |
| start_frames_wins[i] = start_frames | |
| if x_clip is None: | |
| continue | |
| if self.disc_type == 'standard': | |
| if self.use_cond: | |
| x_clip = self.in_proj_layers[i](x_clip) # (B, T, C) | |
| c_clip = self.cond_proj_layers[i](c_clip) | |
| x_clip = x_clip + c_clip | |
| validity_pred, h_ = self.conv_layers[i](x_clip.transpose(1,2)) | |
| elif self.disc_type == 'cosine': | |
| assert self.use_cond is True | |
| x_clip = self.in_proj_layers[i](x_clip) # (B, T, C) | |
| c_clip = self.cond_proj_layers[i](c_clip) | |
| validity_pred, h_ = self.conv_layers[i](x_clip.transpose(1,2), c_clip.transpose(1,2)) | |
| else: | |
| raise NotImplementedError | |
| h += h_ | |
| validity.append(validity_pred) | |
| if len(validity) != len(self.conv_layers): | |
| return None, start_frames_wins, h | |
| if self.reduction == 'sum': | |
| validity = sum(validity) # [B] | |
| elif self.reduction == 'stack': | |
| validity = torch.stack(validity, -1) # [B, W_L] | |
| return validity, start_frames_wins, h | |
| class Discriminator(nn.Module): | |
| def __init__(self, x_dim=80, y_dim=64, disc_type='standard', | |
| uncond_disc=False, kernel_size=3, hidden_size=128, norm_type='bn', reduction='sum', time_lengths=(8,16,32)): | |
| """_summary_ | |
| Args: | |
| time_lengths (list, optional): the list of window size. Defaults to [32, 64, 128]. | |
| x_dim (int, optional): the dim of audio features. Defaults to 80, corresponding to mel-spec. | |
| y_dim (int, optional): the dim of facial coeff. Defaults to 64, correspond to exp; other options can be 7(pose) or 71(exp+pose). | |
| kernel (tuple, optional): _description_. Defaults to (3, 3). | |
| c_in (int, optional): _description_. Defaults to 1. | |
| hidden_size (int, optional): _description_. Defaults to 128. | |
| norm_type (str, optional): _description_. Defaults to 'bn'. | |
| reduction (str, optional): _description_. Defaults to 'sum'. | |
| uncond_disc (bool, optional): _description_. Defaults to False. | |
| """ | |
| super(Discriminator, self).__init__() | |
| self.time_lengths = time_lengths | |
| self.x_dim, self.y_dim = x_dim, y_dim | |
| self.disc_type = disc_type | |
| self.reduction = reduction | |
| self.uncond_disc = uncond_disc | |
| if uncond_disc: | |
| self.x_dim = 0 | |
| cond_dim = 0 | |
| else: | |
| cond_dim = 64 | |
| self.mel_encoder = nn.Sequential(*[ | |
| nn.Conv1d(self.x_dim, 64, 3, 1, 1, bias=False), | |
| nn.BatchNorm1d(64), | |
| nn.GELU(), | |
| nn.Conv1d(64, cond_dim, 3, 1, 1, bias=False) | |
| ]) | |
| self.disc = MultiWindowDiscriminator( | |
| time_lengths=self.time_lengths, | |
| in_dim=self.y_dim, | |
| cond_dim=cond_dim, | |
| kernel_size=kernel_size, | |
| hidden_size=hidden_size, norm_type=norm_type, | |
| reduction=reduction, | |
| disc_type=disc_type | |
| ) | |
| self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2)) | |
| def device(self): | |
| return self.disc.parameters().__next__().device | |
| def forward(self,x, batch, start_frames_wins=None): | |
| """ | |
| :param x: [B, T, C] | |
| :param cond: [B, T, cond_size] | |
| :return: | |
| """ | |
| x = x.to(self.device) | |
| if not self.uncond_disc: | |
| mel = self.downsampler(batch['mel'].to(self.device)) | |
| mel_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2) | |
| else: | |
| mel_feat = None | |
| x_len = x.sum(-1).ne(0).int().sum([1]) | |
| disc_confidence, start_frames_wins, h = self.disc(x, x_len, mel_feat, start_frames_wins=start_frames_wins) | |
| return disc_confidence | |