Spaces:
Build error
Build error
| import json | |
| import os | |
| from pathlib import Path | |
| from datetime import datetime | |
| from matplotlib import pyplot as plt | |
| from ttts.unet1d.embeddings import TextTimeEmbedding | |
| from ttts.unet1d.unet_1d_condition import UNet1DConditionModel | |
| from vocos import Vocos | |
| from torch import expm1, nn | |
| import ttts.diffusion.commons as commons | |
| from accelerate import Accelerator | |
| from ttts.diffusion.operations import OPERATIONS_ENCODER | |
| from accelerate import DistributedDataParallelKwargs | |
| import math | |
| from multiprocessing import cpu_count | |
| from pathlib import Path | |
| from random import random | |
| from functools import partial | |
| from collections import namedtuple | |
| from torch.utils.tensorboard import SummaryWriter | |
| import logging | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn, einsum | |
| from torch.optim import AdamW | |
| from torch.utils.data import Dataset, DataLoader | |
| from einops import rearrange, reduce, repeat | |
| from einops.layers.torch import Rearrange | |
| from tqdm.auto import tqdm | |
| TACOTRON_MEL_MAX = 5.5451774444795624753378569716654 | |
| TACOTRON_MEL_MIN = -16.118095650958319788125940182791 | |
| # TACOTRON_MEL_MIN = -11.512925464970228420089957273422 | |
| # -16.118095650958319788125940182791 | |
| def denormalize_tacotron_mel(norm_mel): | |
| return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN | |
| def normalize_tacotron_mel(mel): | |
| return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1 | |
| def exists(x): | |
| return x is not None | |
| def cycle(dl): | |
| while True: | |
| for data in dl: | |
| yield data | |
| def count_parameters(model): | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| class TransformerEncoderLayer(nn.Module): | |
| def __init__(self, layer, hidden_size, dropout): | |
| super().__init__() | |
| self.layer = layer | |
| self.hidden_size = hidden_size | |
| self.dropout = dropout | |
| self.op = OPERATIONS_ENCODER[layer](hidden_size, dropout) | |
| def forward(self, x, **kwargs): | |
| return self.op(x, **kwargs) | |
| def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): | |
| return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) | |
| class ConvTBC(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, padding=0): | |
| super(ConvTBC, self).__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.kernel_size = kernel_size | |
| self.padding = padding | |
| self.weight = torch.nn.Parameter(torch.Tensor( | |
| self.kernel_size, in_channels, out_channels)) | |
| self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) | |
| def forward(self, input): | |
| return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding) | |
| class ConvLayer(nn.Module): | |
| def __init__(self, c_in, c_out, kernel_size, dropout=0): | |
| super().__init__() | |
| self.layer_norm = LayerNorm(c_in) | |
| conv = ConvTBC(c_in, c_out, kernel_size, padding=kernel_size // 2) | |
| std = math.sqrt((4 * (1.0 - dropout)) / (kernel_size * c_in)) | |
| nn.init.normal_(conv.weight, mean=0, std=std) | |
| nn.init.constant_(conv.bias, 0) | |
| self.conv = conv | |
| def forward(self, x, encoder_padding_mask=None, **kwargs): | |
| layer_norm_training = kwargs.get('layer_norm_training', None) | |
| if layer_norm_training is not None: | |
| self.layer_norm.training = layer_norm_training | |
| if encoder_padding_mask is not None: | |
| x = x.masked_fill(encoder_padding_mask.t().unsqueeze(-1), 0) | |
| x = self.layer_norm(x) | |
| x = self.conv(x) | |
| return x | |
| class PhoneEncoder(nn.Module): | |
| def __init__(self, | |
| in_channels=128, | |
| hidden_channels=512, | |
| out_channels=512, | |
| n_layers=6, | |
| p_dropout=0.2, | |
| last_ln = True): | |
| super().__init__() | |
| self.arch = [8 for _ in range(n_layers)] | |
| self.num_layers = n_layers | |
| self.hidden_size = hidden_channels | |
| self.padding_idx = 0 | |
| self.dropout = p_dropout | |
| self.layers = nn.ModuleList([]) | |
| self.layers.extend([ | |
| TransformerEncoderLayer(self.arch[i], self.hidden_size, self.dropout) | |
| for i in range(self.num_layers) | |
| ]) | |
| self.last_ln = last_ln | |
| self.pre = ConvLayer(in_channels, hidden_channels, 1, p_dropout) | |
| # self.prompt_proj = ConvLayer(in_channels, hidden_channels, 1, p_dropout) | |
| self.out_proj = ConvLayer(hidden_channels, out_channels, 1, p_dropout) | |
| if last_ln: | |
| self.layer_norm = LayerNorm(out_channels) | |
| self.spk_proj = nn.Conv1d(100,hidden_channels,1) | |
| def forward(self, src_tokens, lengths, g=None): | |
| # B x C x T -> T x B x C | |
| src_tokens = self.spk_proj(src_tokens+g) | |
| src_tokens = rearrange(src_tokens, 'b c t -> t b c') | |
| # compute padding mask | |
| encoder_padding_mask = ~commons.sequence_mask(lengths, src_tokens.size(0)).to(torch.bool) | |
| # prompt_mask = ~commons.sequence_mask(prompt_lengths, prompt.size(0)).to(torch.bool) | |
| x = src_tokens | |
| x = self.pre(x, encoder_padding_mask=encoder_padding_mask) | |
| x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None] | |
| # prompt = self.prompt_proj(prompt, encoder_padding_mask=prompt_mask) | |
| # encoder layers | |
| for i in range(self.num_layers): | |
| x = self.layers[i](x, encoder_padding_mask=encoder_padding_mask) | |
| # x = x+self.attn_blocks[i](x, prompt, prompt, key_padding_mask=prompt_mask)[0] | |
| x = self.out_proj(x, encoder_padding_mask=encoder_padding_mask) | |
| if self.last_ln: | |
| x = self.layer_norm(x) | |
| x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None] | |
| x = rearrange(x, 't b c-> b c t') | |
| return x | |
| class PromptEncoder(nn.Module): | |
| def __init__(self, | |
| in_channels=128, | |
| hidden_channels=256, | |
| out_channels=512, | |
| n_layers=6, | |
| p_dropout=0.2, | |
| last_ln = True): | |
| super().__init__() | |
| self.arch = [8 for _ in range(n_layers)] | |
| self.num_layers = n_layers | |
| self.hidden_size = hidden_channels | |
| self.padding_idx = 0 | |
| self.dropout = p_dropout | |
| self.layers = nn.ModuleList([]) | |
| self.layers.extend([ | |
| TransformerEncoderLayer(self.arch[i], self.hidden_size, self.dropout) | |
| for i in range(self.num_layers) | |
| ]) | |
| self.last_ln = last_ln | |
| if last_ln: | |
| self.layer_norm = LayerNorm(out_channels) | |
| self.pre = ConvLayer(in_channels, hidden_channels, 1, p_dropout) | |
| self.out_proj = ConvLayer(hidden_channels, out_channels, 1, p_dropout) | |
| def forward(self, src_tokens, lengths=None): | |
| # B x C x T -> T x B x C | |
| src_tokens = rearrange(src_tokens, 'b c t -> t b c') | |
| # compute padding mask | |
| encoder_padding_mask = ~commons.sequence_mask(lengths, src_tokens.size(0)).to(torch.bool) | |
| x = src_tokens | |
| x = self.pre(x, encoder_padding_mask=encoder_padding_mask) | |
| x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None] | |
| # encoder layers | |
| for layer in self.layers: | |
| x = layer(x, encoder_padding_mask=encoder_padding_mask) | |
| x = self.out_proj(x, encoder_padding_mask=encoder_padding_mask) | |
| if self.last_ln: | |
| x = self.layer_norm(x) | |
| x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None] | |
| x = rearrange(x, 't b c-> b c t') | |
| return x | |
| class SinusoidalPosEmb(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x): | |
| device = x.device | |
| half_dim = self.dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
| emb = x[:, None] * emb[None, :] | |
| emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
| return emb | |
| def silu(x): | |
| return x * torch.sigmoid(x) | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, n_mels, residual_channels, dilation, kernel_size, dropout): | |
| ''' | |
| :param n_mels: inplanes of conv1x1 for spectrogram conditional | |
| :param residual_channels: audio conv | |
| :param dilation: audio conv dilation | |
| :param uncond: disable spectrogram conditional | |
| ''' | |
| super().__init__() | |
| if dilation==1: | |
| padding = kernel_size//2 | |
| else: | |
| padding = dilation | |
| self.dilated_conv = ConvLayer(residual_channels, 2 * residual_channels, kernel_size) | |
| self.conditioner_projection = ConvLayer(n_mels, 2 * residual_channels, 1) | |
| # self.output_projection = ConvLayer(residual_channels, 2 * residual_channels, 1) | |
| self.output_projection = ConvLayer(residual_channels, residual_channels, 1) | |
| self.t_proj = ConvLayer(residual_channels, residual_channels, 1) | |
| self.drop = nn.Dropout(dropout) | |
| def forward(self, x, diffusion_step, conditioner,x_mask): | |
| assert (conditioner is None and self.conditioner_projection is None) or \ | |
| (conditioner is not None and self.conditioner_projection is not None) | |
| #T B C | |
| y = x + self.t_proj(diffusion_step.unsqueeze(0)) | |
| y = y.masked_fill(x_mask.t().unsqueeze(-1), 0) | |
| conditioner = self.conditioner_projection(conditioner) | |
| conditioner = self.drop(conditioner) | |
| y = self.dilated_conv(y) + conditioner | |
| y = y.masked_fill(x_mask.t().unsqueeze(-1), 0) | |
| gate, filter_ = torch.chunk(y, 2, dim=-1) | |
| y = torch.sigmoid(gate) * torch.tanh(filter_) | |
| y = y.masked_fill(x_mask.t().unsqueeze(-1), 0) | |
| y = self.output_projection(y) | |
| return y | |
| # y = y.masked_fill(x_mask.t().unsqueeze(-1), 0) | |
| # residual, skip = torch.chunk(y, 2, dim=-1) | |
| # return (x + residual) / math.sqrt(2.0), skip | |
| class Pre_model(nn.Module): | |
| def __init__(self, cfg) -> None: | |
| super().__init__() | |
| self.cfg = cfg | |
| self.phoneme_encoder = PhoneEncoder(**self.cfg['phoneme_encoder']) | |
| print("phoneme params:", count_parameters(self.phoneme_encoder)) | |
| self.prompt_encoder = PromptEncoder(**self.cfg['prompt_encoder']) | |
| print("prompt params:", count_parameters(self.prompt_encoder)) | |
| dim = self.cfg['phoneme_encoder']['out_channels'] | |
| self.ref_enc = TextTimeEmbedding(100, 100, 1) | |
| def forward(self,data, g=None): | |
| mel_recon_padded, mel_padded, mel_lengths, refer_padded, refer_lengths = data | |
| mel_recon_padded, refer_padded = normalize_tacotron_mel(mel_recon_padded), normalize_tacotron_mel(refer_padded) | |
| g = self.ref_enc(refer_padded.transpose(1,2)).unsqueeze(-1) | |
| audio_prompt = self.prompt_encoder(refer_padded,refer_lengths) | |
| content = self.phoneme_encoder(mel_recon_padded, mel_lengths, g) | |
| return content, audio_prompt | |
| def infer(self, data): | |
| mel_recon_padded, refer_padded, mel_lengths, refer_lengths = data | |
| mel_recon_padded, refer_padded = normalize_tacotron_mel(mel_recon_padded), normalize_tacotron_mel(refer_padded) | |
| g = self.ref_enc(refer_padded.transpose(1,2)).unsqueeze(-1) | |
| audio_prompt = self.prompt_encoder(refer_padded,refer_lengths) | |
| content = self.phoneme_encoder(mel_recon_padded, mel_lengths, g) | |
| return content, audio_prompt | |
| class Diffusion_Encoder(nn.Module): | |
| def __init__(self, | |
| in_channels=128, | |
| out_channels=128, | |
| hidden_channels=256, | |
| block_out_channels = [128,256,384,512], | |
| n_heads=8, | |
| p_dropout=0.2, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.hidden_channels = hidden_channels | |
| self.n_heads=n_heads | |
| self.unet = UNet1DConditionModel( | |
| in_channels=in_channels+hidden_channels, | |
| out_channels=out_channels, | |
| block_out_channels=block_out_channels, | |
| norm_num_groups=8, | |
| cross_attention_dim=hidden_channels, | |
| attention_head_dim=n_heads, | |
| addition_embed_type='text', | |
| resnet_time_scale_shift='scale_shift', | |
| ) | |
| def forward(self, x, data, t): | |
| assert torch.isnan(x).any() == False | |
| contentvec, prompt, contentvec_lengths, prompt_lengths = data | |
| prompt = rearrange(prompt,' b c t-> b t c') | |
| x = torch.cat([x, contentvec], dim=1) | |
| prompt_mask = commons.sequence_mask(prompt_lengths, prompt.size(1)).to(torch.bool) | |
| x = self.unet(x, t, prompt, encoder_attention_mask=prompt_mask) | |
| return x.sample | |
| # tensor helper functions | |
| def log(t, eps = 1e-20): | |
| return torch.log(t.clamp(min = eps)) | |
| def extract(a, t, x_shape): | |
| b, *_ = t.shape | |
| out = a.gather(-1, t) | |
| return out.reshape(b, *((1,) * (len(x_shape) - 1))) | |
| def linear_beta_schedule(timesteps): | |
| """ | |
| linear schedule, proposed in original ddpm paper | |
| """ | |
| scale = 1000 / timesteps | |
| beta_start = scale * 0.0001 | |
| beta_end = scale * 0.02 | |
| return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) | |
| def default(val, d): | |
| if exists(val): | |
| return val | |
| return d() if callable(d) else d | |
| ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) | |
| class Diffuser(nn.Module): | |
| def __init__(self, | |
| cfg, | |
| ddim_sampling_eta = 0, | |
| min_snr_loss_weight = False, | |
| min_snr_gamma = 5, | |
| conditioning_free = True, | |
| conditioning_free_k = 1.0 | |
| ): | |
| super().__init__() | |
| self.pre_model = Pre_model(cfg) | |
| print("pre params: ", count_parameters(self.pre_model)) | |
| self.diff_model = Diffusion_Encoder(**cfg['diffusion']) | |
| print("diff params: ", count_parameters(self.diff_model)) | |
| self.dim = self.diff_model.in_channels | |
| timesteps = cfg['train']['timesteps'] | |
| beta_schedule_fn = linear_beta_schedule | |
| betas = beta_schedule_fn(timesteps) | |
| alphas = 1. - betas | |
| alphas_cumprod = torch.cumprod(alphas, dim = 0) | |
| alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) | |
| timesteps, = betas.shape | |
| self.num_timesteps = timesteps | |
| self.unconditioned_content = nn.Parameter(torch.randn(1,cfg['phoneme_encoder']['out_channels'],1)) | |
| # self.sampling_timesteps = cfg['train']['sampling_timesteps'] | |
| self.ddim_sampling_eta = ddim_sampling_eta | |
| register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) | |
| register_buffer('betas', betas) | |
| register_buffer('alphas_cumprod', alphas_cumprod) | |
| register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) | |
| register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) | |
| register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) | |
| register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) | |
| register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) | |
| register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) | |
| posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) | |
| register_buffer('posterior_variance', posterior_variance) | |
| register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) | |
| register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) | |
| register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) | |
| snr = alphas_cumprod / (1 - alphas_cumprod) | |
| maybe_clipped_snr = snr.clone() | |
| if min_snr_loss_weight: | |
| maybe_clipped_snr.clamp_(max = min_snr_gamma) | |
| register_buffer('loss_weight', maybe_clipped_snr) | |
| self.conditioning_free = conditioning_free | |
| self.conditioning_free_k = conditioning_free_k | |
| def predict_noise_from_start(self, x_t, t, x0): | |
| return ( | |
| (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ | |
| extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) | |
| ) | |
| def q_posterior(self, x_start, x_t, t): | |
| posterior_mean = ( | |
| extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + | |
| extract(self.posterior_mean_coef2, t, x_t.shape) * x_t | |
| ) | |
| posterior_variance = extract(self.posterior_variance, t, x_t.shape) | |
| posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) | |
| return posterior_mean, posterior_variance, posterior_log_variance_clipped | |
| def model_predictions(self, x, t, data = None): | |
| model_output = self.diff_model(x,data, t) | |
| t = t.type(torch.int64) | |
| x_start = model_output | |
| pred_noise = self.predict_noise_from_start(x, t, x_start) | |
| return ModelPrediction(pred_noise, x_start) | |
| def sample_fun(self, x, t, data = None): | |
| if self.conditioning_free: | |
| # data[1] = self.unconditioned_refer[] | |
| model_output_no_conditioning = self.diff_model(x, data, t) | |
| model_output = self.diff_model(x,data, t) | |
| t = t.type(torch.int64) | |
| pred_noise = model_output | |
| if self.conditioning_free: | |
| cfk = self.conditioning_free_k | |
| model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning | |
| return pred_noise | |
| def p_mean_variance(self, x, t, data): | |
| preds = self.model_predictions(x, t, data) | |
| x_start = preds.pred_x_start | |
| model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t) | |
| return model_mean, posterior_variance, posterior_log_variance, x_start | |
| def p_sample(self, x, t: int, data): | |
| b, *_, device = *x.shape, x.device | |
| batched_times = torch.full((b,), t, device = device, dtype = torch.long) | |
| model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, data=data) | |
| noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0 | |
| pred_img = model_mean + (0.5 * model_log_variance).exp() * noise | |
| return pred_img, x_start | |
| def p_sample_loop(self, content, refer, lengths, refer_lengths, f0, uv, auto_predict_f0 = True): | |
| data = (content, refer, f0, 0, 0, lengths, refer_lengths, uv) | |
| content, refer = self.pre_model.infer(data) | |
| shape = (content.shape[1], self.dim, content.shape[0]) | |
| batch, device = shape[0], refer.device | |
| img = torch.randn(shape, device = device) | |
| imgs = [img] | |
| x_start = None | |
| for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): | |
| img, x_start = self.p_sample(img, t, (content,refer,lengths,refer_lengths)) | |
| imgs.append(img) | |
| ret = img | |
| return ret | |
| def ddim_sample(self, content, refer, lengths, refer_lengths, f0, uv, auto_predict_f0 = True): | |
| data = (content, refer, f0, 0, 0, lengths, refer_lengths, uv) | |
| content, refer = self.pre_model.infer(data,auto_predict_f0=auto_predict_f0) | |
| shape = (content.shape[1], self.dim, content.shape[0]) | |
| batch, device, total_timesteps, sampling_timesteps, eta = shape[0], refer.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta | |
| times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps | |
| times = list(reversed(times.int().tolist())) | |
| time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] | |
| img = torch.randn(shape, device = device) | |
| imgs = [img] | |
| x_start = None | |
| for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): | |
| time_cond = torch.full((batch,), time, device = device, dtype = torch.long) | |
| pred_noise, x_start, *_ = self.model_predictions(img, time_cond, (content,refer,lengths,refer_lengths)) | |
| if time_next < 0: | |
| img = x_start | |
| imgs.append(img) | |
| continue | |
| alpha = self.alphas_cumprod[time] | |
| alpha_next = self.alphas_cumprod[time_next] | |
| sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() | |
| c = (1 - alpha_next - sigma ** 2).sqrt() | |
| noise = torch.randn_like(img) | |
| img = x_start * alpha_next.sqrt() + \ | |
| c * pred_noise + \ | |
| sigma * noise | |
| imgs.append(img) | |
| ret = img | |
| return ret | |
| def sample(self, | |
| mel_recon, refer, lengths, refer_lengths, | |
| # c, refer, f0, uv, lengths, refer_lengths, vocos, | |
| sampling_timesteps=100, sample_method='unipc' | |
| ): | |
| mel_recon, refer = normalize_tacotron_mel(mel_recon), normalize_tacotron_mel(refer) | |
| if refer.shape[0]==2: | |
| refer = refer[0].unsqueeze(0) | |
| self.sampling_timesteps = sampling_timesteps | |
| if sample_method == 'ddpm': | |
| sample_fn = self.p_sample_loop | |
| # audio = sample_fn(c, refer, lengths, refer_lengths, f0, uv, auto_predict_f0) | |
| elif sample_method == 'ddim': | |
| sample_fn = self.ddim_sample | |
| # audio = sample_fn(c, refer, lengths, refer_lengths, f0, uv, auto_predict_f0) | |
| elif sample_method == 'dpmsolver': | |
| from sampler.dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver | |
| noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas) | |
| def my_wrapper(fn): | |
| def wrapped(x, t, **kwargs): | |
| ret = fn(x, t, **kwargs) | |
| self.bar.update(1) | |
| return ret | |
| return wrapped | |
| # data = (c, refer, f0, 0, 0, lengths, refer_lengths, uv) | |
| # content, refer = self.pre_model.infer(data,auto_predict_f0=auto_predict_f0) | |
| shape = (content.shape[1], self.dim, content.shape[0]) | |
| batch, device, total_timesteps, sampling_timesteps, eta = shape[0], refer.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta | |
| audio = torch.randn(shape, device = device) | |
| model_fn = model_wrapper( | |
| my_wrapper(self.sample_fun), | |
| noise_schedule, | |
| model_type="x_start", #"noise" or "x_start" or "v" or "score" | |
| model_kwargs={"data":(content,refer,lengths,refer_lengths)} | |
| ) | |
| dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") | |
| steps = 40 | |
| self.bar = tqdm(desc="sample time step", total=steps) | |
| audio = dpm_solver.sample( | |
| audio, | |
| steps=steps, | |
| order=2, | |
| skip_type="time_uniform", | |
| method="multistep", | |
| ) | |
| self.bar.close() | |
| elif sample_method =='unipc': | |
| from ttts.sampler.uni_pc import NoiseScheduleVP, model_wrapper, UniPC | |
| noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas) | |
| def my_wrapper(fn): | |
| def wrapped(x, t, **kwargs): | |
| ret = fn(x, t, **kwargs) | |
| self.bar.update(1) | |
| return ret | |
| return wrapped | |
| data = (mel_recon, refer, lengths, refer_lengths) | |
| content, refer = self.pre_model.infer(data) | |
| shape = (content.shape[0], self.dim, content.shape[2]) | |
| batch, device, total_timesteps, sampling_timesteps, eta = shape[0], refer.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta | |
| audio = torch.randn(shape, device = device) | |
| model_fn = model_wrapper( | |
| my_wrapper(self.sample_fun), | |
| noise_schedule, | |
| model_type="noise", #"noise" or "x_start" or "v" or "score" | |
| model_kwargs={"data":(content,refer,lengths,refer_lengths)} | |
| ) | |
| uni_pc = UniPC(model_fn, noise_schedule, variant='bh2') | |
| steps = 30 | |
| self.bar = tqdm(desc="sample time step", total=steps) | |
| mel = uni_pc.sample( | |
| audio, | |
| steps=steps, | |
| order=2, | |
| skip_type="time_uniform", | |
| method="multistep", | |
| ) | |
| self.bar.close() | |
| # mel = audio | |
| # vocos.to(audio.device) | |
| # audio = vocos.decode(audio) | |
| # if audio.ndim == 3: | |
| # audio = rearrange(audio, 'b 1 n -> b n') | |
| # return denormalize(mel) | |
| return denormalize_tacotron_mel(mel) | |
| def q_sample(self, x_start, t, noise = None): | |
| noise = default(noise, lambda: torch.randn_like(x_start)) | |
| return ( | |
| extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + | |
| extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise | |
| ) | |
| def forward(self, data, conditioning_free=False): | |
| unused_params = [] | |
| mel_recon_padded, mel_padded, mel_lengths, refer_padded, refer_lengths = data | |
| mel_recon_padded, mel_padded = normalize_tacotron_mel(mel_recon_padded), normalize_tacotron_mel(mel_recon_padded) | |
| assert mel_recon_padded.shape[2] == mel_padded.shape[2] | |
| b, d, n, device = *mel_padded.shape, mel_padded.device | |
| x_mask = torch.unsqueeze(commons.sequence_mask(mel_lengths, mel_padded.size(2)), 1).to(mel_padded.dtype) | |
| x_start = mel_padded*x_mask | |
| # get pre model outputs | |
| content, refer = self.pre_model(data) | |
| if conditioning_free==True: | |
| refer = self.unconditioned_refer.repeat(data[0].shape[0], 1 ,1) + refer.mean()*0 | |
| else: | |
| unused_params.append(self.unconditioned_refer) | |
| t = torch.randint(0, self.num_timesteps, (b,), device=device).long() | |
| noise = torch.randn_like(x_start)*x_mask | |
| # noise sample | |
| x = self.q_sample(x_start = x_start, t = t, noise = noise) | |
| # predict and take gradient step | |
| model_out = self.diff_model(x,(content,refer,mel_lengths,refer_lengths), t) | |
| target = noise | |
| loss = F.mse_loss(model_out, target, reduction = 'none') | |
| loss_diff = reduce(loss, 'b ... -> b (...)', 'mean') | |
| loss_diff = loss_diff * extract(self.loss_weight, t, loss.shape) | |
| loss_diff = loss_diff.mean() | |
| loss = loss_diff | |
| extraneous_addition = 0 | |
| for p in unused_params: | |
| extraneous_addition = extraneous_addition + p.mean() | |
| loss = loss + extraneous_addition * 0 | |
| return loss | |
| def get_grad_norm(model): | |
| total_norm = 0 | |
| for name,p in model.named_parameters(): | |
| try: | |
| param_norm = p.grad.data.norm(2) | |
| total_norm += param_norm.item() ** 2 | |
| except: | |
| print(name) | |
| total_norm = total_norm ** (1. / 2) | |
| return total_norm | |
| logging.getLogger('matplotlib').setLevel(logging.WARNING) | |
| logging.getLogger('numba').setLevel(logging.WARNING) |