Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import kaldiio | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data import Dataset | |
| def custom_collate(batch): | |
| keys, speech, speaker_labels, orders = zip(*batch) | |
| speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech] | |
| speaker_labels = [ | |
| torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels | |
| ] | |
| orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders] | |
| batch = dict(speech=speech, speaker_labels=speaker_labels, orders=orders) | |
| return keys, batch | |
| class EENDOLADataset(Dataset): | |
| def __init__( | |
| self, | |
| data_file, | |
| ): | |
| self.data_file = data_file | |
| with open(data_file) as f: | |
| lines = f.readlines() | |
| self.samples = [line.strip().split() for line in lines] | |
| logging.info("total samples: {}".format(len(self.samples))) | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| key, speech_path, speaker_label_path = self.samples[idx] | |
| speech = kaldiio.load_mat(speech_path) | |
| speaker_label = kaldiio.load_mat(speaker_label_path).reshape( | |
| speech.shape[0], -1 | |
| ) | |
| order = np.arange(speech.shape[0]) | |
| np.random.shuffle(order) | |
| return key, speech, speaker_label, order | |
| class EENDOLADataLoader: | |
| def __init__(self, data_file, batch_size, shuffle=True, num_workers=8): | |
| dataset = EENDOLADataset(data_file) | |
| self.data_loader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| collate_fn=custom_collate, | |
| shuffle=shuffle, | |
| num_workers=num_workers, | |
| ) | |
| def build_iter(self, epoch): | |
| return self.data_loader | |