Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from funasr_detach.register import tables | |
| from funasr_detach.utils.load_utils import extract_fbank, load_audio_text_image_video | |
| class AudioDataset(torch.utils.data.Dataset): | |
| """ | |
| AudioDataset | |
| """ | |
| def __init__( | |
| self, | |
| path, | |
| index_ds: str = None, | |
| frontend=None, | |
| tokenizer=None, | |
| int_pad_value: int = -1, | |
| float_pad_value: float = 0.0, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| index_ds_class = tables.index_ds_classes.get(index_ds) | |
| self.index_ds = index_ds_class(path, **kwargs) | |
| preprocessor_speech = kwargs.get("preprocessor_speech", None) | |
| if preprocessor_speech: | |
| preprocessor_speech_class = tables.preprocessor_classes.get( | |
| preprocessor_speech | |
| ) | |
| preprocessor_speech = preprocessor_speech_class( | |
| **kwargs.get("preprocessor_speech_conf") | |
| ) | |
| self.preprocessor_speech = preprocessor_speech | |
| preprocessor_text = kwargs.get("preprocessor_text", None) | |
| if preprocessor_text: | |
| preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text) | |
| preprocessor_text = preprocessor_text_class( | |
| **kwargs.get("preprocessor_text_conf") | |
| ) | |
| self.preprocessor_text = preprocessor_text | |
| self.frontend = frontend | |
| self.fs = 16000 if frontend is None else frontend.fs | |
| self.data_type = "sound" | |
| self.tokenizer = tokenizer | |
| self.int_pad_value = int_pad_value | |
| self.float_pad_value = float_pad_value | |
| def get_source_len(self, index): | |
| item = self.index_ds[index] | |
| return self.index_ds.get_source_len(item) | |
| def get_target_len(self, index): | |
| item = self.index_ds[index] | |
| return self.index_ds.get_target_len(item) | |
| def __len__(self): | |
| return len(self.index_ds) | |
| def __getitem__(self, index): | |
| item = self.index_ds[index] | |
| # import pdb; | |
| # pdb.set_trace() | |
| source = item["source"] | |
| data_src = load_audio_text_image_video(source, fs=self.fs) | |
| if self.preprocessor_speech: | |
| data_src = self.preprocessor_speech(data_src, fs=self.fs) | |
| speech, speech_lengths = extract_fbank( | |
| data_src, data_type=self.data_type, frontend=self.frontend, is_final=True | |
| ) # speech: [b, T, d] | |
| target = item["target"] | |
| if self.preprocessor_text: | |
| target = self.preprocessor_text(target) | |
| if self.tokenizer: | |
| ids = self.tokenizer.encode(target) | |
| text = torch.tensor(ids, dtype=torch.int64) | |
| else: | |
| ids = target | |
| text = ids | |
| ids_lengths = len(ids) | |
| text_lengths = torch.tensor([ids_lengths], dtype=torch.int32) | |
| return { | |
| "speech": speech[0, :, :], | |
| "speech_lengths": speech_lengths, | |
| "text": text, | |
| "text_lengths": text_lengths, | |
| } | |
| def collator(self, samples: list = None): | |
| outputs = {} | |
| for sample in samples: | |
| for key in sample.keys(): | |
| if key not in outputs: | |
| outputs[key] = [] | |
| outputs[key].append(sample[key]) | |
| for key, data_list in outputs.items(): | |
| if isinstance(data_list[0], torch.Tensor): | |
| if data_list[0].dtype == torch.int64: | |
| pad_value = self.int_pad_value | |
| else: | |
| pad_value = self.float_pad_value | |
| outputs[key] = torch.nn.utils.rnn.pad_sequence( | |
| data_list, batch_first=True, padding_value=pad_value | |
| ) | |
| return outputs | |