Spaces:
Build error
Build error
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import os | |
| from einops import rearrange | |
| import random | |
| from utils.commons.base_task import BaseTask | |
| from utils.commons.dataset_utils import data_loader | |
| from utils.commons.hparams import hparams | |
| from utils.commons.ckpt_utils import load_ckpt | |
| from utils.commons.tensor_utils import tensors_to_scalars, convert_to_np | |
| from utils.nn.model_utils import print_arch, get_device_of_model, not_requires_grad | |
| from utils.nn.schedulers import ExponentialSchedule | |
| from utils.nn.grad import get_grad_norm | |
| from utils.nn.model_utils import print_arch, num_params | |
| from utils.commons.face_alignment_utils import mouth_idx_in_mediapipe_mesh | |
| from modules.audio2motion.vae import VAEModel, PitchContourVAEModel | |
| from tasks.os_avatar.dataset_utils.audio2motion_dataset import Audio2Motion_Dataset | |
| from data_util.face3d_helper import Face3DHelper | |
| from data_gen.utils.mp_feature_extractors.face_landmarker import index_lm68_from_lm478 | |
| from modules.syncnet.models import LandmarkHubertSyncNet | |
| class Audio2MotionTask(BaseTask): | |
| def __init__(self): | |
| super().__init__() | |
| self.dataset_cls = Audio2Motion_Dataset | |
| if hparams["motion_type"] == 'id_exp': | |
| self.in_out_dim = 80 + 64 | |
| elif hparams["motion_type"] == 'exp': | |
| self.in_out_dim = 64 | |
| def build_model(self): | |
| if hparams['audio_type'] == 'hubert': | |
| audio_in_dim = 1024 # hubert | |
| elif hparams['audio_type'] == 'mfcc': | |
| audio_in_dim = 13 # hubert | |
| if hparams.get("use_pitch", False) is True: | |
| self.model = PitchContourVAEModel(hparams, in_out_dim=self.in_out_dim, audio_in_dim=audio_in_dim, use_prior_flow=hparams.get("use_flow", True)) | |
| else: | |
| self.model = VAEModel(in_out_dim=self.in_out_dim, audio_in_dim=audio_in_dim, use_prior_flow=hparams.get("use_flow", True)) | |
| if hparams.get('init_from_ckpt', '') != '': | |
| ckpt_dir = hparams.get('init_from_ckpt', '') | |
| load_ckpt(self.model, ckpt_dir, model_name='model', strict=True) | |
| self.face3d_helper = Face3DHelper(keypoint_mode='mediapipe', use_gpu=False) | |
| lm_dim = 468*3 # lip part in idexp_lm3d | |
| # lm_dim = 20*3 # lip part in idexp_lm3d | |
| hparams['syncnet_num_layers_per_block'] = 3 | |
| hparams['syncnet_base_hid_size'] = 128 | |
| hparams['syncnet_out_hid_size'] = 1024 | |
| self.syncnet = LandmarkHubertSyncNet(lm_dim, audio_in_dim, num_layers_per_block=hparams['syncnet_num_layers_per_block'], base_hid_size=hparams['syncnet_base_hid_size'], out_dim=hparams['syncnet_out_hid_size']) | |
| if hparams['syncnet_ckpt_dir']: | |
| load_ckpt(self.syncnet, hparams['syncnet_ckpt_dir']) | |
| return self.model | |
| def on_train_start(self): | |
| for n, m in self.model.named_children(): | |
| num_params(m, model_name=n) | |
| for n, m in self.model.vae.named_children(): | |
| num_params(m, model_name='vae.'+n) | |
| def build_optimizer(self, model): | |
| self.optimizer = optimizer = torch.optim.Adam( | |
| model.parameters(), | |
| lr=hparams['lr'], | |
| betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2'])) | |
| return optimizer | |
| def build_scheduler(self, optimizer): | |
| return ExponentialSchedule(optimizer, hparams['lr'], hparams['warmup_updates']) | |
| def train_dataloader(self): | |
| if hparams['ds_name'] == 'Concat_voxceleb2_CMLR': | |
| train_dataset1 = self.dataset_cls(prefix='train', data_dir='data/binary/voxceleb2_audio2motion_kv') | |
| train_dataset2 = self.dataset_cls(prefix='train', data_dir='data/binary/CMLR_audio2motion_kv') | |
| train_dataset = BaseConcatDataset([train_dataset1,train_dataset2], prefix='train') | |
| elif hparams['ds_name'] == 'Weighted_Concat_voxceleb2_CMLR': | |
| train_dataset1 = self.dataset_cls(prefix='train', data_dir='data/binary/voxceleb2_audio2motion_kv') | |
| train_dataset2 = self.dataset_cls(prefix='train', data_dir='data/binary/CMLR_audio2motion_kv') | |
| train_dataset = WeightedConcatDataset([train_dataset1,train_dataset2], [0.5, 0.5], prefix='train') | |
| else: | |
| train_dataset = self.dataset_cls(prefix='train') | |
| self.train_dl = train_dataset.get_dataloader() | |
| return self.train_dl | |
| def val_dataloader(self): | |
| val_dataset = self.dataset_cls(prefix='val') | |
| self.val_dl = val_dataset.get_dataloader() | |
| return self.val_dl | |
| def test_dataloader(self): | |
| val_dataset = self.dataset_cls(prefix='val') | |
| self.val_dl = val_dataset.get_dataloader() | |
| return self.val_dl | |
| ########################## | |
| # training and validation | |
| ########################## | |
| def run_model(self, sample, infer=False, temperature=1.0, sync_batch_size=1024): | |
| """ | |
| render or train on a single-frame | |
| :param sample: a batch of data | |
| :param infer: bool, run in infer mode | |
| :return: | |
| if not infer: | |
| return losses, model_out | |
| if infer: | |
| return model_out | |
| """ | |
| model_out = {} | |
| if hparams['audio_type'] == 'hubert': | |
| sample['audio'] = sample['hubert'] | |
| elif hparams['audio_type'] == 'mfcc': | |
| sample['audio'] = sample['mfcc'] / 100 | |
| elif hparams['audio_type'] == 'mel': | |
| sample['audio'] = sample['mel'] # [b, 2*t, 1024] | |
| if hparams.get("blink_mode", 'none') != 'none': # eye_area_percnet or blink_unit | |
| blink = F.interpolate(sample[hparams['blink_mode']].permute(0,2,1).float(), scale_factor=2).permute(0,2,1).long() | |
| sample['blink'] = blink | |
| bs = sample['audio'].shape[0] | |
| if infer: | |
| self.model(sample, model_out, train=False, temperature=temperature) | |
| return model_out | |
| else: | |
| losses_out = {} | |
| if hparams["motion_type"] == 'id_exp': | |
| x_gt = torch.cat([sample['id'], sample['exp']],dim=-1) | |
| sample['y'] = x_gt | |
| self.model(sample, model_out, train=True) | |
| x_pred = model_out['pred'].reshape([bs, -1, 80+64]) | |
| x_mask = model_out['mask'].reshape([bs, -1]) | |
| losses_out['kl'] = model_out['loss_kl'] | |
| id_pred = x_pred[:, :, :80] | |
| exp_pred = x_pred[:, :, 80:] | |
| losses_out['lap_id'] = self.lap_loss(id_pred, x_mask) | |
| losses_out['lap_exp'] = self.lap_loss(exp_pred, x_mask) | |
| pred_idexp_lm3d = self.face3d_helper.reconstruct_idexp_lm3d(id_pred, exp_pred).reshape([bs, x_mask.shape[1], -1]) | |
| gt_idexp_lm3d = self.face3d_helper.reconstruct_idexp_lm3d(sample['id'], sample['exp']).reshape([bs, x_mask.shape[1], -1]) | |
| losses_out['mse_idexp_lm3d'] = self.lm468_mse_loss(gt_idexp_lm3d, pred_idexp_lm3d, x_mask) | |
| losses_out['l2_reg_id'] = self.l2_reg_loss(id_pred, x_mask) | |
| losses_out['l2_reg_exp'] = self.l2_reg_loss(exp_pred, x_mask) | |
| gt_lm2d = self.face3d_helper.reconstruct_lm2d(sample['id'], sample['exp'], sample['euler'], sample['trans']).reshape([bs, x_mask.shape[1], -1]) | |
| pred_lm2d = self.face3d_helper.reconstruct_lm2d(id_pred, exp_pred, sample['euler'], sample['trans']).reshape([bs, x_mask.shape[1], -1]) | |
| losses_out['mse_lm2d'] = self.lm468_mse_loss(gt_lm2d, pred_lm2d, x_mask) | |
| elif hparams["motion_type"] == 'exp': | |
| x_gt = sample['exp'] | |
| sample['y'] = x_gt | |
| self.model(sample, model_out, train=True) | |
| x_pred = model_out['pred'].reshape([bs, -1, 64]) | |
| x_mask = model_out['mask'].reshape([bs, -1]) | |
| losses_out['kl'] = model_out['loss_kl'] | |
| exp_pred = x_pred[:, :, :] | |
| losses_out['lap_exp'] = self.lap_loss(exp_pred, x_mask) | |
| if hparams.get("ref_id_mode",'first_frame') == 'first_frame': | |
| id_pred = sample['id'][:,0:1, :].repeat([1,exp_pred.shape[1],1]) | |
| elif hparams.get("ref_id_mode",'first_frame') == 'random_frame': | |
| max_y = x_mask.sum(dim=1).min().item() | |
| idx = random.randint(0, max_y-1) | |
| id_pred = sample['id'][:,idx:idx+1, :].repeat([1,exp_pred.shape[1],1]) | |
| gt_idexp_lm3d = self.face3d_helper.reconstruct_idexp_lm3d(sample['id'], sample['exp']).reshape([bs, x_mask.shape[1], -1]) | |
| pred_idexp_lm3d = self.face3d_helper.reconstruct_idexp_lm3d(id_pred, exp_pred).reshape([bs, x_mask.shape[1], -1]) | |
| losses_out['mse_exp'] = self.mse_loss(x_gt, x_pred, x_mask) | |
| losses_out['mse_idexp_lm3d'] = self.lm468_mse_loss(gt_idexp_lm3d, pred_idexp_lm3d, x_mask) | |
| losses_out['l2_reg_exp'] = self.l2_reg_loss(exp_pred, x_mask) | |
| gt_lm2d = self.face3d_helper.reconstruct_lm2d(sample['id'], sample['exp'], sample['euler'], sample['trans']).reshape([bs, x_mask.shape[1], -1]) | |
| pred_lm2d = self.face3d_helper.reconstruct_lm2d(id_pred, exp_pred, sample['euler'], sample['trans']).reshape([bs, x_mask.shape[1], -1]) | |
| # losses_out['mse_lm2d'] = self.lm468_mse_loss(gt_lm2d, pred_lm2d, x_mask) | |
| # calculating sync score | |
| mouth_lm3d = pred_idexp_lm3d.reshape([bs, x_pred.shape[1], 468*3]) # [b, t, 60] | |
| # mouth_lm3d = pred_idexp_lm3d.reshape([bs, x_pred.shape[1], 468, 3])[:,:, index_lm68_from_lm478,:][:,:,48:68].reshape([bs, x_pred.shape[1], 20*3]) # [b, t, 60] | |
| if hparams['audio_type'] == 'hubert': | |
| mel = sample['hubert'] # [b, 2*t, 1024] | |
| elif hparams['audio_type'] == 'mfcc': | |
| mel = sample['mfcc'] / 100 # [b, 2*t, 1024] | |
| elif hparams['audio_type'] == 'mel': | |
| mel = sample['mel'] # [b, 2*t, 1024] | |
| num_clips_for_syncnet = 8096 | |
| len_mouth_slice = 5 | |
| len_mel_slice = len_mouth_slice * 2 | |
| num_iters = max(1, num_clips_for_syncnet // len(mouth_lm3d)) | |
| mouth_clip_lst = [] | |
| mel_clip_lst = [] | |
| x_mask_clip_lst = [] | |
| for i in range(num_iters): | |
| t_start = random.randint(0, x_pred.shape[1]-len_mouth_slice-1) | |
| mouth_clip = mouth_lm3d[:, t_start: t_start+len_mouth_slice] | |
| x_mask_clip = x_mask[:, t_start: t_start+len_mouth_slice] | |
| assert mouth_clip.shape[1] == len_mouth_slice | |
| mel_clip = mel[:, t_start*2 : t_start*2+len_mel_slice] | |
| mouth_clip_lst.append(mouth_clip) | |
| mel_clip_lst.append(mel_clip) | |
| x_mask_clip_lst.append(x_mask_clip) | |
| mouth_clips = torch.cat(mouth_clip_lst) # [B=8096, T=5, 60] | |
| mel_clips = torch.cat(mel_clip_lst) # # [B=8096, T=10, 1024] | |
| x_mask_clips = torch.cat(x_mask_clip_lst) # [B=8096, T=5] | |
| x_mask_clips = (x_mask_clips.sum(dim=1) == x_mask_clips.shape[1]).float() # [B=8096,] | |
| audio_embedding, mouth_embedding = self.syncnet.forward(mel_clips, mouth_clips) # get normalized embedding, [B,] | |
| sync_loss, _ = self.syncnet.cal_sync_loss(audio_embedding, mouth_embedding, 1., reduction='none') # | |
| losses_out['sync_lip_lm3d'] = (sync_loss * x_mask_clips).sum() / x_mask_clips.sum() | |
| return losses_out, model_out | |
| def kl_annealing(self, num_updates, max_lambda=0.4, t1=2000, t2=2000): | |
| """ | |
| Cyclical Annealing Schedule: A Simple Approach to Mitigating KL Vanishing | |
| https://aclanthology.org/N19-1021.pdf | |
| """ | |
| T = t1 + t2 | |
| num_updates = num_updates % T | |
| if num_updates < t1: | |
| return num_updates / t1 * max_lambda | |
| else: | |
| return max_lambda | |
| def _training_step(self, sample, batch_idx, optimizer_idx): | |
| loss_output, model_out = self.run_model(sample) | |
| loss_weights = { | |
| 'kl': self.kl_annealing(self.global_step, max_lambda=hparams['lambda_kl'], t1=hparams['lambda_kl_t1'], t2=hparams['lambda_kl_t2']), | |
| 'mse_exp': hparams.get("lambda_mse_exp", 0.1), | |
| 'mse_idexp_lm3d': hparams.get("lambda_mse_lm3d", 1.), | |
| 'lap_id': hparams.get("lambda_lap_id", 1.), | |
| 'lap_exp': hparams.get("lambda_lap_exp", 1.), | |
| 'l2_reg_id': hparams.get("lambda_l2_reg_id", 0.), | |
| 'l2_reg_exp': hparams.get("lambda_l2_reg_exp", 0.0), | |
| 'sync_lip_lm3d': hparams.get("lambda_sync_lm3d", 0.2), | |
| 'mse_lm2d': hparams.get("lambda_mse_lm2d", 0.) | |
| } | |
| total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad]) | |
| return total_loss, loss_output | |
| def validation_start(self): | |
| pass | |
| def validation_step(self, sample, batch_idx): | |
| outputs = {} | |
| outputs['losses'] = {} | |
| outputs['losses'], model_out = self.run_model(sample, infer=False, sync_batch_size=10000) | |
| outputs = tensors_to_scalars(outputs) | |
| return outputs | |
| def validation_end(self, outputs): | |
| return super().validation_end(outputs) | |
| ##################### | |
| # Testing | |
| ##################### | |
| def test_start(self): | |
| self.gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') | |
| os.makedirs(self.gen_dir, exist_ok=True) | |
| def test_step(self, sample, batch_idx): | |
| """ | |
| :param sample: | |
| :param batch_idx: | |
| :return: | |
| """ | |
| outputs = {} | |
| outputs['losses'], model_out = self.run_model(sample, infer=True) | |
| pred_exp = model_out['pred'] | |
| self.save_result(pred_exp, "pred_exp_val" , self.gen_dir) | |
| if hparams['save_gt']: | |
| base_fn = f"gt_exp_val" | |
| self.save_result(sample['exp'], base_fn , self.gen_dir) | |
| return outputs | |
| def test_end(self, outputs): | |
| pass | |
| def save_result(exp_arr, base_fname, gen_dir): | |
| exp_arr = convert_to_np(exp_arr) | |
| np.save(f"{gen_dir}/{base_fname}.npy", exp_arr) | |
| def get_grad(self, opt_idx): | |
| grad_dict = { | |
| 'grad/model': get_grad_norm(self.model), | |
| } | |
| return grad_dict | |
| def lm468_mse_loss(self, proj_lan, gt_lan, x_mask): | |
| b,t,c= proj_lan.shape | |
| # [B, T, 68*3] | |
| loss = ((proj_lan - gt_lan) ** 2) * x_mask[:,:, None] | |
| loss = loss.reshape([b,t,468,-1]) | |
| unmatch_mask = [93, 127, 132, 234, 323, 356, 361, 454] | |
| upper_eye = [161,160,159,158,157] + [388,387,386,385,384] | |
| eye = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249] | |
| inner_lip = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95] | |
| outer_lip = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146] | |
| weights = torch.ones_like(loss) | |
| weights[:, :, eye] = 3 | |
| weights[:, :, upper_eye] = 20 | |
| weights[:, :, inner_lip] = 5 | |
| weights[:, :, outer_lip] = 5 | |
| weights[:, :, unmatch_mask] = 0 | |
| loss = loss.reshape([b,t,c]) | |
| weights = weights.reshape([b,t,c]) | |
| return (loss * weights).sum() / (x_mask.sum()*c) | |
| def lm68_mse_loss(self, proj_lan, gt_lan, x_mask): | |
| b,t,c= proj_lan.shape | |
| # [B, T, 68*3] | |
| loss = ((proj_lan - gt_lan) ** 2) * x_mask[:,:, None] | |
| loss = loss.reshape([b,t,68,3]) | |
| weights = torch.ones_like(loss) | |
| weights[:, :, 36:48, :] = 5 # eye 12 points | |
| weights[:, :, -8:, :] = 5 # inner lip 8 points | |
| weights[:, :, 28:31, :] = 5 # nose 3 points | |
| loss = loss.reshape([b,t,c]) | |
| weights = weights.reshape([b,t,c]) | |
| return (loss * weights).sum() / (x_mask.sum()*c) | |
| def l2_reg_loss(self, x_pred, x_mask): | |
| # mean absolute error, l1 loss | |
| error = (x_pred ** 2) * x_mask[:,:, None] | |
| num_frame = x_mask.sum() | |
| return error.sum() / (num_frame * self.in_out_dim) | |
| def lap_loss(self, in_tensor, x_mask): | |
| # [b, t, c] | |
| b,t,c = in_tensor.shape | |
| in_tensor = F.pad(in_tensor, pad=(0,0,1,1)) | |
| in_tensor = rearrange(in_tensor, "b t c -> (b c) t").unsqueeze(1) # [B*c, 1, t] | |
| lap_kernel = torch.Tensor((-0.5, 1.0, -0.5)).reshape([1,1,3]).float().to(in_tensor.device) # [1, 1, kw] | |
| out_tensor = F.conv1d(in_tensor, lap_kernel) # [B*C, 1, T] | |
| out_tensor = out_tensor.squeeze(1) | |
| out_tensor = rearrange(out_tensor, "(b c) t -> b t c", b=b, t=t) | |
| loss_lap = (out_tensor**2) * x_mask.unsqueeze(-1) | |
| return loss_lap.sum() / (x_mask.sum()*c) | |
| def mse_loss(self, x_gt, x_pred, x_mask): | |
| # mean squared error, l2 loss | |
| error = (x_pred - x_gt) * x_mask[:,:, None] | |
| num_frame = x_mask.sum() | |
| return (error ** 2).sum() / (num_frame * self.in_out_dim) | |
| def mae_loss(self, x_gt, x_pred, x_mask): | |
| # mean absolute error, l1 loss | |
| error = (x_pred - x_gt) * x_mask[:,:, None] | |
| num_frame = x_mask.sum() | |
| return error.abs().sum() / (num_frame * self.in_out_dim) | |
| def vel_loss(self, x_pred, x_mask): | |
| # mean squared error, l2 loss | |
| error = (x_pred[:, 1:] - x_pred[:, :-1]) * x_mask[:,1:, None] | |
| num_frame = x_mask.sum() | |
| return (error).abs().sum() / (num_frame * self.in_out_dim) | |
| def continuity_loss(self, x_gt, x_pred, x_mask): | |
| # continuity loss, borrowed from <FACIAL: Synthesizing Dynamic Talking Face with Implicit Attribute Learning> | |
| diff_x_pred = x_pred[:,1:] - x_pred[:,:-1] | |
| diff_x_gt = x_gt[:,1:] - x_gt[:,:-1] | |
| error = (diff_x_pred[:,:,:] - diff_x_gt[:,:,:]) * x_mask[:,1:,None] | |
| init_error = x_pred[:,0,:] - x_gt[:,0,:] | |
| num_frame = x_mask.sum() | |
| return (error.pow(2).sum() + init_error.pow(2).sum()) / (num_frame * self.in_out_dim) |