Spaces:
Runtime error
Runtime error
File size: 4,830 Bytes
a22eb82 a86a2b8 a22eb82 defda6e a22eb82 9ab094a a22eb82 9ab094a 1dce2dd 416263d 1dce2dd 416263d 1dce2dd 416263d a22eb82 0ce42bd 9ab094a 0ce42bd 9ab094a 0ce42bd a22eb82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
from tqdm import tqdm
import torch
import numpy as np
import random
import scipy.io as scio
import src.utils.audio as audio
def crop_pad_audio(wav, audio_length):
if len(wav) > audio_length:
wav = wav[:audio_length]
elif len(wav) < audio_length:
wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0)
return wav
def parse_audio_length(audio_length, sr, fps):
bit_per_frames = sr / fps
num_frames = int(audio_length / bit_per_frames)
audio_length = int(num_frames * bit_per_frames)
return audio_length, num_frames
def generate_blink_seq(num_frames):
ratio = np.zeros((num_frames,1))
frame_id = 0
while frame_id in range(num_frames):
start = 80
if frame_id+start+9<=num_frames - 1:
ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5]
frame_id = frame_id+start+9
else:
break
return ratio
def generate_blink_seq_randomly(num_frames):
ratio = np.zeros((num_frames,1))
if num_frames<=20:
return ratio
frame_id = 0
while frame_id in range(num_frames):
start = random.choice(range(min(10,num_frames), min(int(num_frames/2), 70)))
if frame_id+start+5<=num_frames - 1:
ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5]
frame_id = frame_id+start+5
else:
break
return ratio
def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True):
syncnet_mel_step_size = 16
fps = 25
pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0]
audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
if idlemode:
num_frames = int(length_of_audio * 25)
indiv_mels = np.zeros((num_frames, 80, 16))
else:
wav = audio.load_wav(audio_path, 16000)
wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
wav = crop_pad_audio(wav, wav_length)
orig_mel = audio.melspectrogram(wav).T
spec = orig_mel.copy() # nframes 80
indiv_mels = []
for i in tqdm(range(num_frames), 'mel:'):
start_frame_num = i-2
start_idx = int(80. * (start_frame_num / float(fps)))
end_idx = start_idx + syncnet_mel_step_size
seq = list(range(start_idx, end_idx))
seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ]
m = spec[seq, :]
indiv_mels.append(m.T)
indiv_mels = np.asarray(indiv_mels) # T 80 16
if num_frames < 20:
print(f"[WARN] num_frames={num_frames} too small, enable still_mode / skip blink.")
still = True
use_blink = False
# Blink ratio
if use_blink and not still:
ratio = generate_blink_seq_randomly(num_frames) # T × 1
else:
ratio = np.zeros((num_frames, 1)) # không blink
# ratio = generate_blink_seq_randomly(num_frames) # T
source_semantics_path = first_coeff_path
source_semantics_dict = scio.loadmat(source_semantics_path)
ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70
ref_coeff = np.repeat(ref_coeff, num_frames, axis=0)
if ref_eyeblink_coeff_path is not None:
ratio[:num_frames] = 0
refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path)
refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64]
refeyeblink_num_frames = refeyeblink_coeff.shape[0]
if refeyeblink_num_frames<num_frames:
div = num_frames//refeyeblink_num_frames
re = num_frames%refeyeblink_num_frames
refeyeblink_coeff_list = [refeyeblink_coeff for i in range(div)]
if re > 0:
refeyeblink_coeff_list.append(refeyeblink_coeff[:re, :64])
refeyeblink_coeff = np.concatenate(refeyeblink_coeff_list, axis=0)
print(refeyeblink_coeff.shape[0])
ref_coeff[:, :64] = refeyeblink_coeff[:num_frames, :64]
indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0) # bs T 1 80 16
if use_blink:
ratio = torch.FloatTensor(ratio).unsqueeze(0) # bs T
else:
ratio = torch.FloatTensor(ratio).unsqueeze(0).fill_(0.)
# bs T
ref_coeff = torch.FloatTensor(ref_coeff).unsqueeze(0) # bs 1 70
indiv_mels = indiv_mels.to(device)
ratio = ratio.to(device)
ref_coeff = ref_coeff.to(device)
return {'indiv_mels': indiv_mels,
'ref': ref_coeff,
'num_frames': num_frames,
'ratio_gt': ratio,
'audio_name': audio_name, 'pic_name': pic_name}
|