Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import json | |
| import torch | |
| import logging | |
| import concurrent.futures | |
| import librosa | |
| import torch.distributed as dist | |
| from funasr_detach.register import tables | |
| class IndexDSJsonlRankSplit(torch.utils.data.Dataset): | |
| def __init__(self, path): | |
| super().__init__() | |
| contents = [] | |
| with open(path, encoding="utf-8") as fin: | |
| for line in fin: | |
| data = json.loads(line.strip()) | |
| if "text" in data: # for sft | |
| self.contents.append(data["text"]) | |
| if "source" in data: # for speech lab pretrain | |
| prompt = data["prompt"] | |
| source = data["source"] | |
| target = data["target"] | |
| source_len = data["source_len"] | |
| target_len = data["target_len"] | |
| contents.append( | |
| { | |
| "source": source, | |
| "prompt": prompt, | |
| "target": target, | |
| "source_len": source_len, | |
| "target_len": target_len, | |
| } | |
| ) | |
| self.contents = [] | |
| total_num = len(contents) | |
| try: | |
| rank = dist.get_rank() | |
| world_size = dist.get_world_size() | |
| except: | |
| rank = 0 | |
| world_size = 1 | |
| logging.warning("distributed is not initialized, only single shard") | |
| num_per_rank = total_num // world_size | |
| # rank = 0 | |
| # import ipdb; ipdb.set_trace() | |
| self.contents = contents[rank * num_per_rank : (rank + 1) * num_per_rank] | |
| logging.info( | |
| "in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format( | |
| rank, len(self.contents), len(contents) | |
| ) | |
| ) | |
| def __len__(self): | |
| return len(self.contents) | |
| def __getitem__(self, index): | |
| try: | |
| data = self.contents[index] | |
| except: | |
| print(index) | |
| return data | |
| def get_source_len(self, data_dict): | |
| return data_dict["source_len"] | |
| def get_target_len(self, data_dict): | |
| return data_dict["target_len"] if "target_len" in data_dict else 0 | |
| class IndexDSJsonlRankFull(torch.utils.data.Dataset): | |
| def __init__(self, path: str, **kwargs): | |
| super().__init__() | |
| if isinstance(path, (list, tuple)): # wav.scp, text.txt/text.trans | |
| from funasr_detach.datasets.audio_datasets.scp2jsonl import ( | |
| gen_jsonl_from_wav_text_list, | |
| ) | |
| jsonl_outdir = os.path.dirname(path[0]) | |
| jsonl_name = ( | |
| "datalist_train.jsonl" | |
| if kwargs.get("is_training", True) | |
| else "datalist_val.jsonl" | |
| ) | |
| jsonl_file_out = os.path.join(jsonl_outdir, jsonl_name) | |
| if not os.path.exists(jsonl_file_out): | |
| print(f"datalist is: {path}, generate jsonl from it") | |
| gen_jsonl_from_wav_text_list( | |
| path, jsonl_file_out=jsonl_file_out, **kwargs | |
| ) | |
| path = jsonl_file_out | |
| contents = [] | |
| with open(path, encoding="utf-8") as fin: | |
| for line in fin: | |
| data = json.loads(line.strip()) | |
| if "text" in data: # for sft | |
| self.contents.append(data["text"]) | |
| if "source" in data: # for speech lab pretrain | |
| prompt = data.get("prompt", "<ASR>") | |
| source = data["source"] | |
| target = data["target"] | |
| source_len = data.get("source_len", 1) | |
| target_len = data.get("target_len", 0) | |
| contents.append( | |
| { | |
| "source": source, | |
| "prompt": prompt, | |
| "target": target, | |
| "source_len": source_len, | |
| "target_len": target_len, | |
| } | |
| ) | |
| self.contents = contents | |
| logging.info( | |
| "total_num of samplers across ranks: {}".format(len(self.contents)) | |
| ) | |
| def __len__(self): | |
| return len(self.contents) | |
| def __getitem__(self, index): | |
| try: | |
| data = self.contents[index] | |
| except: | |
| print(index) | |
| return data | |
| def get_source_len(self, data_dict): | |
| return data_dict.get("source_len", 1) | |
| def get_target_len(self, data_dict): | |
| return data_dict.get("target_len", 0) | |