Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import torch | |
| import numpy as np | |
| import torchaudio | |
| from typing import List | |
| from soulxsinger.utils.audio_utils import load_wav | |
| class DataProcessor: | |
| """Data processor for SoulX-Singer | |
| """ | |
| def __init__( | |
| self, | |
| hop_size: int, | |
| sample_rate: int, | |
| phoneset_path: str = 'soulxsinger/utils/phoneme/phone_set.json', | |
| device: str = 'cuda', | |
| prompt_append_duration: float = 0.5): | |
| """Initialize data processor. | |
| Args: | |
| hop_size (int): Hop size in samples. | |
| sample_rate (int): Sample rate in Hz. | |
| phoneset_path (str): Path to phoneme set JSON file. | |
| device (str): Device to use for tensor operations. | |
| prompt_append_duration (float): Duration to append to prompt in seconds. | |
| """ | |
| self.hop_size = hop_size | |
| self.sample_rate = sample_rate | |
| self.device = device | |
| self.prompt_append_duration = prompt_append_duration | |
| self.prompt_append_length = int(prompt_append_duration * sample_rate / hop_size) | |
| self.load_phoneme_id_map(phoneset_path) | |
| def load_phoneme_id_map(self, phoneset_path: str): | |
| with open(phoneset_path, "r", encoding='utf-8') as f: | |
| phoneset = json.load(f) | |
| self.phone2idx = {ph: idx for idx, ph in enumerate(phoneset)} | |
| def merge_phoneme(self, meta): | |
| merged_items = [] | |
| duration = [float(x) for x in meta["duration"].split()] | |
| phoneme = [str(x).replace("<AP>", "<SP>") for i, x in enumerate(meta["phoneme"].split())] | |
| note_pitch = [int(x) for x in meta["note_pitch"].split()] | |
| note_type = [int(x) if phoneme[i] != "<SP>" else 1 for i, x in enumerate(meta["note_type"].split())] | |
| for i in range(len(phoneme)): | |
| if i > 0 and phoneme[i] == phoneme[i - 1] == "<SP>" and note_type[i] == note_type[i - 1] and note_pitch[i] == note_pitch[i - 1]: | |
| merged_items[-1][1] += duration[i] | |
| else: | |
| merged_items.append([phoneme[i], duration[i], note_pitch[i], note_type[i]]) | |
| single_frame_duration = self.hop_size / self.sample_rate | |
| meta['phoneme'] = [x[0] for x in merged_items] | |
| meta['duration'] = [x[1] for x in merged_items] | |
| meta['note_pitch'] = [x[2] for x in merged_items] | |
| meta['note_type'] = [x[3] for x in merged_items] | |
| return meta | |
| def preprocess( | |
| self, | |
| note_duration: List[float], | |
| phonemes: List[str], | |
| note_pitch: List[int], | |
| note_type: List[int], | |
| ): | |
| """ | |
| Insert <BOW> and <EOW> for each note. | |
| Get aligned indices for each frame. | |
| Args: | |
| note_duration: Duration of each note in seconds | |
| phonemes: Phoneme sequence for each note | |
| note_pitch: Pitch value for each note | |
| note_type: Type value for each note | |
| """ | |
| sample_rate = self.sample_rate | |
| hop_size = self.hop_size | |
| duration = sum(note_duration) * sample_rate / hop_size | |
| mel2note = torch.zeros(int(duration), dtype=torch.long) | |
| ph_locations = [] # idx at mel scale and length | |
| new_phonemes = [] | |
| dur_sum = 0 | |
| note2origin = [] | |
| for ph_idx in range(len(phonemes)): | |
| dur = int(np.round(dur_sum * sample_rate / hop_size)) | |
| dur = min(dur, len(mel2note) - 1) | |
| new_phonemes.append("<BOW>") | |
| note2origin.append(ph_idx) | |
| if phonemes[ph_idx][:3] == "en_": | |
| en_phs = ['en_' + x for x in phonemes[ph_idx][3:].split('-')] + ['<SEP>'] # <sep> between en words in one note | |
| ph_locations.append([dur, max(1, len(en_phs))]) | |
| new_phonemes.extend(en_phs) | |
| note2origin.extend([ph_idx] * len(en_phs)) | |
| else: | |
| ph_locations.append([dur, 1]) | |
| new_phonemes.append(phonemes[ph_idx]) | |
| note2origin.append(ph_idx) | |
| new_phonemes.append("<EOW>") | |
| note2origin.append(ph_idx) | |
| dur_sum += note_duration[ph_idx] | |
| ph_idx = 1 | |
| for idx, (i, j) in enumerate(ph_locations): | |
| next_phoneme_start = ph_locations[idx + 1][0] if idx < len(ph_locations) - 1 else len(mel2note) | |
| if i >= len(mel2note) or i + j > len(mel2note): | |
| break | |
| if i < len(mel2note) and mel2note[i] > 0: | |
| # print(f"warning: overlap of {idx}: {mel2note[i]}") | |
| while i < len(mel2note) and mel2note[i] > 0: | |
| i += 1 | |
| mel2note[i] = ph_idx | |
| k = i + 1 | |
| while k + j < next_phoneme_start: | |
| mel2note[k : k + j] = torch.arange(ph_idx, ph_idx + j) + 1 | |
| k += j | |
| mel2note[next_phoneme_start - 1] = ph_idx + j + 1 | |
| ph_idx += j + 2 # <BOW> + ph repeats + <EOW> | |
| new_phonemes = ["<PAD>"] + new_phonemes | |
| new_note_pitch = [0] + [note_pitch[k] for k in note2origin] | |
| new_note_type = [1] + [note_type[k] for k in note2origin] | |
| return { | |
| "phoneme": torch.tensor([self.phone2idx[x] for x in new_phonemes], device=self.device).unsqueeze(0), | |
| "note_pitch": torch.tensor(new_note_pitch, device=self.device).unsqueeze(0), | |
| "note_type": torch.tensor(new_note_type, device=self.device).unsqueeze(0), | |
| "mel2note": mel2note.clone().detach().to(self.device).unsqueeze(0), | |
| } | |
| def process( | |
| self, | |
| meta: dict, | |
| wav_path: str = None | |
| ): | |
| meta = self.merge_phoneme(meta) | |
| item = self.preprocess( | |
| meta["duration"], | |
| meta["phoneme"], | |
| meta["note_pitch"], | |
| meta["note_type"], | |
| ) | |
| f0 = torch.tensor([float(x) for x in meta["f0"].split()]) | |
| min_frame = min(item["mel2note"].shape[1], f0.shape[0]) | |
| item['f0'] = f0[:min_frame].unsqueeze(0).float().to(self.device) | |
| item["mel2note"] = item["mel2note"][:, :min_frame] | |
| if wav_path is not None: | |
| waveform = load_wav(wav_path, self.sample_rate) | |
| item["waveform"] = waveform.to(self.device)[:, :min_frame * self.hop_size] | |
| return item | |
| # test | |
| if __name__ == "__main__": | |
| import json | |
| with open("example/metadata/zh_prompt.json", "r", encoding="utf-8") as f: | |
| meta = json.load(f) | |
| if isinstance(meta, list): | |
| meta = meta[0] | |
| processor = DataProcessor(hop_size=480, sample_rate=24000) | |
| item = processor.process(meta, "example/audio/zh_prompt.wav") | |
| print(item.keys()) |