Spaces:
Build error
Build error
| import os | |
| import random | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.utils.data | |
| from torch import LongTensor | |
| from tqdm import tqdm | |
| import torchaudio | |
| from pypinyin import Style, lazy_pinyin | |
| from ttts.gpt.voice_tokenizer import VoiceBpeTokenizer | |
| import json | |
| import os | |
| def read_jsonl(path): | |
| with open(path, 'r') as f: | |
| json_str = f.read() | |
| data_list = [] | |
| for line in json_str.splitlines(): | |
| data = json.loads(line) | |
| data_list.append(data) | |
| return data_list | |
| def write_jsonl(path, all_paths): | |
| with open(path,'w', encoding='utf-8') as file: | |
| for item in all_paths: | |
| json.dump(item, file, ensure_ascii=False) | |
| file.write('\n') | |
| class GptTtsDataset(torch.utils.data.Dataset): | |
| def __init__(self, opt): | |
| self.tok = VoiceBpeTokenizer('ttts/gpt/gpt_tts_tokenizer.json') | |
| self.jsonl_path = opt['dataset']['path'] | |
| self.audiopaths_and_text = read_jsonl(self.jsonl_path) | |
| def __getitem__(self, index): | |
| try: | |
| # Fetch text and add start/stop tokens. | |
| audiopath_and_text = self.audiopaths_and_text[index] | |
| audiopath, text = audiopath_and_text['path'], audiopath_and_text['text'] | |
| text = ' '.join(lazy_pinyin(text, style=Style.TONE3, neutral_tone_with_five=True)) | |
| text = self.tok.encode(text) | |
| text = LongTensor(text) | |
| # Fetch quantized MELs | |
| quant_path = audiopath + '.melvq.pth' | |
| qmel = LongTensor(torch.load(quant_path)[0]) | |
| mel_path = audiopath + '.mel.pth' | |
| mel = torch.load(mel_path)[0] | |
| wav_length = mel.shape[1]*256 | |
| split = random.randint(int(mel.shape[1]//3), int(mel.shape[1]//3*2)) | |
| if random.random()>0.5: | |
| mel = mel[:,:split] | |
| else: | |
| mel = mel[:,split:] | |
| except: | |
| return None | |
| #load wav | |
| # wav,sr = torchaudio.load(audiopath) | |
| # wav = torchaudio.transforms.Resample(sr,24000)(wav) | |
| if text.shape[0]>400 or qmel.shape[0]>600: | |
| return None | |
| return text, qmel, mel, wav_length | |
| def __len__(self): | |
| return len(self.audiopaths_and_text) | |
| class GptTtsCollater(): | |
| def __init__(self,cfg): | |
| self.cfg=cfg | |
| def __call__(self, batch): | |
| batch = [x for x in batch if x is not None] | |
| if len(batch)==0: | |
| return None | |
| text_lens = [len(x[0]) for x in batch] | |
| max_text_len = max(text_lens) | |
| # max_text_len = self.cfg['gpt']['max_text_tokens'] | |
| qmel_lens = [len(x[1]) for x in batch] | |
| max_qmel_len = max(qmel_lens) | |
| # max_qmel_len = self.cfg['gpt']['max_mel_tokens'] | |
| raw_mel_lens = [x[2].shape[1] for x in batch] | |
| max_raw_mel_len = max(raw_mel_lens) | |
| wav_lens = [x[3] for x in batch] | |
| max_wav_len = max(wav_lens) | |
| texts = [] | |
| qmels = [] | |
| raw_mels = [] | |
| wavs = [] | |
| # This is the sequential "background" tokens that are used as padding for text tokens, as specified in the DALLE paper. | |
| for b in batch: | |
| text, qmel, raw_mel, wav = b | |
| text = F.pad(text, (0, max_text_len-len(text)), value=0) | |
| texts.append(text) | |
| qmels.append(F.pad(qmel, (0, max_qmel_len-len(qmel)), value=0)) | |
| raw_mels.append(F.pad(raw_mel,(0, max_raw_mel_len-raw_mel.shape[1]), value=0)) | |
| padded_qmel = torch.stack(qmels) | |
| padded_raw_mel = torch.stack(raw_mels) | |
| padded_texts = torch.stack(texts) | |
| return { | |
| 'padded_text': padded_texts, | |
| 'text_lengths': LongTensor(text_lens), | |
| 'padded_qmel': padded_qmel, | |
| 'qmel_lengths': LongTensor(qmel_lens), | |
| 'padded_raw_mel': padded_raw_mel, | |
| 'raw_mel_lengths': LongTensor(raw_mel_lens), | |
| 'wav_lens': LongTensor(wav_lens) | |
| } | |
| if __name__ == '__main__': | |
| params = { | |
| 'mode': 'gpt_tts', | |
| 'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt', | |
| 'phase': 'train', | |
| 'n_workers': 0, | |
| 'batch_size': 16, | |
| 'mel_vocab_size': 512, | |
| } | |
| cfg = json.load(open('ttts/gpt/config.json')) | |
| ds = GptTtsDataset(cfg) | |
| dl = torch.utils.data.DataLoader(ds, **cfg['dataloader'], collate_fn=GptTtsCollater(cfg)) | |
| i = 0 | |
| m = [] | |
| max_text = 0 | |
| max_mel = 0 | |
| for b in tqdm(dl): | |
| break | |