Spaces:
Build error
Build error
| from tqdm import tqdm | |
| from ttts.utils.infer_utils import load_model | |
| import json | |
| import torch.nn.functional as F | |
| import torch | |
| import os | |
| def read_jsonl(path): | |
| path = os.path.expanduser(path) | |
| with open(path, 'r') as f: | |
| json_str = f.read() | |
| data_list = [] | |
| for line in json_str.splitlines(): | |
| data = json.loads(line) | |
| data_list.append(data) | |
| return data_list | |
| def classify_audio_clip(clip, classifier): | |
| """ | |
| Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise. | |
| :param clip: torch tensor containing audio waveform data (get it from load_audio) | |
| :return: True if the clip was classified as coming from Tortoise and false if it was classified as real. | |
| """ | |
| with torch.no_grad(): | |
| results = F.softmax(classifier(clip), dim=-1) | |
| return results | |
| class MelDataset(torch.utils.data.Dataset): | |
| def __init__(self,paths): | |
| super().__init__() | |
| self.paths = paths | |
| self.pad_to=700 | |
| def __getitem__(self,index): | |
| path = self.paths[index] | |
| try: | |
| mel = torch.load(path+'.mel.pth') | |
| except: | |
| mel = torch.zeros((1,100,self.pad_to)) | |
| if mel.shape[-1] >= self.pad_to: | |
| start = torch.randint(0, mel.shape[-1] - self.pad_to+1, (1,)) | |
| mel = mel[:, :, start:start+self.pad_to] | |
| else: | |
| padding_needed = self.pad_to - mel.shape[-1] | |
| mel = F.pad(mel, (0,padding_needed)) | |
| mel = mel.squeeze(0) | |
| return mel,path | |
| def __len__(self): | |
| return len(self.paths) | |
| if __name__=='__main__': | |
| model_path = '/home/hyc/tortoise_plus_zh/ttts/classifier/logs/2023-11-23-17-34-45/model-9.pt' | |
| config_path = '~/tortoise_plus_zh/ttts/classifier/config.json' | |
| device = 'cuda' | |
| classifier = load_model('classifier', model_path, config_path, device) | |
| jsonl_path = '~/tortoise_plus_zh/ttts/datasets/all_data.jsonl' | |
| audiopaths_and_text = read_jsonl(jsonl_path) | |
| audio_paths = [x['path'] for x in audiopaths_and_text] | |
| ds = MelDataset(audio_paths) | |
| dl = torch.utils.data.DataLoader(ds,batch_size=1024,num_workers=16) | |
| for _,batch in tqdm(enumerate(dl),total=len(dl)): | |
| mels, paths = batch | |
| mels = mels.to(device) | |
| label = classify_audio_clip(mels,classifier) | |
| for i in range(label.shape[0]): | |
| if label[i][0]<0.1: | |
| with open('ttts/classifier/noise_files.txt','a') as f: | |
| # print(os.path.join(os.getcwd(),paths[i])) | |
| f.write(os.path.join(os.getcwd(),paths[i])+'\n') | |