Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # Copyright 2025 Xiaomi Corp. (authors: Han Zhu) | |
| # | |
| # See ../../../../LICENSE for clarification regarding multiple authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| This script generates speech with our pre-trained ZipVoice-Dialog or | |
| ZipVoice-Dialog-Stereo models. If no local model is specified, | |
| Required files will be automatically downloaded from HuggingFace. | |
| Usage: | |
| Note: If you having trouble connecting to HuggingFace, | |
| try switching endpoint to mirror site: | |
| export HF_ENDPOINT=https://hf-mirror.com | |
| python3 -m zipvoice.bin.infer_zipvoice_dialog \ | |
| --model-name "zipvoice_dialog" \ | |
| --test-list test.tsv \ | |
| --res-dir results | |
| `--model-name` can be `zipvoice_dialog` or `zipvoice_dialog_stereo`, | |
| which generate mono and stereo dialogues, respectively. | |
| Each line of `test.tsv` is in the format of merged conversation: | |
| '{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}' | |
| or splited conversation: | |
| '{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription} | |
| \t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}' | |
| """ | |
| import argparse | |
| import datetime as dt | |
| import json | |
| import os | |
| from typing import List, Optional, Union | |
| import numpy as np | |
| import safetensors.torch | |
| import torch | |
| import torchaudio | |
| from huggingface_hub import hf_hub_download | |
| from lhotse.utils import fix_random_seed | |
| from vocos import Vocos | |
| from zipvoice.models.zipvoice_dialog import ZipVoiceDialog, ZipVoiceDialogStereo | |
| from zipvoice.tokenizer.tokenizer import DialogTokenizer | |
| from zipvoice.utils.checkpoint import load_checkpoint | |
| from zipvoice.utils.common import AttributeDict | |
| from zipvoice.utils.feature import VocosFbank | |
| HUGGINGFACE_REPO = "k2-fsa/ZipVoice" | |
| PRETRAINED_MODEL = { | |
| "zipvoice_dialog": "zipvoice_dialog/model.pt", | |
| "zipvoice_dialog_stereo": "zipvoice_dialog_stereo/model.pt", | |
| } | |
| TOKEN_FILE = { | |
| "zipvoice_dialog": "zipvoice_dialog/tokens.txt", | |
| "zipvoice_dialog_stereo": "zipvoice_dialog_stereo/tokens.txt", | |
| } | |
| MODEL_CONFIG = { | |
| "zipvoice_dialog": "zipvoice_dialog/model.json", | |
| "zipvoice_dialog_stereo": "zipvoice_dialog_stereo/model.json", | |
| } | |
| def get_parser(): | |
| parser = argparse.ArgumentParser( | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter | |
| ) | |
| parser.add_argument( | |
| "--model-name", | |
| type=str, | |
| default="zipvoice_dialog", | |
| choices=["zipvoice_dialog", "zipvoice_dialog_stereo"], | |
| help="The model used for inference", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint", | |
| type=str, | |
| default=None, | |
| help="The model checkpoint. " | |
| "Will download pre-trained checkpoint from huggingface if not specified.", | |
| ) | |
| parser.add_argument( | |
| "--model-config", | |
| type=str, | |
| default=None, | |
| help="The model configuration file. " | |
| "Will download model.json from huggingface if not specified.", | |
| ) | |
| parser.add_argument( | |
| "--vocoder-path", | |
| type=str, | |
| default=None, | |
| help="The vocoder checkpoint. " | |
| "Will download pre-trained vocoder from huggingface if not specified.", | |
| ) | |
| parser.add_argument( | |
| "--token-file", | |
| type=str, | |
| default=None, | |
| help="The file that contains information that maps tokens to ids," | |
| "which is a text file with '{token}\t{token_id}' per line. " | |
| "Will download tokens_emilia.txt from huggingface if not specified.", | |
| ) | |
| parser.add_argument( | |
| "--test-list", | |
| type=str, | |
| default=None, | |
| help="The list of prompt speech, prompt_transcription, " | |
| "and text to synthesizein the format of merged conversation: " | |
| "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}' " | |
| "or splited conversation: " | |
| "'{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription}" | |
| "\t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}'.", | |
| ) | |
| parser.add_argument( | |
| "--res-dir", | |
| type=str, | |
| default="results", | |
| help=""" | |
| Path name of the generated wavs dir, | |
| used when test-list is not None | |
| """, | |
| ) | |
| parser.add_argument( | |
| "--guidance-scale", | |
| type=float, | |
| default=1.5, | |
| help="The scale of classifier-free guidance during inference.", | |
| ) | |
| parser.add_argument( | |
| "--num-step", | |
| type=int, | |
| default=16, | |
| help="The number of sampling steps.", | |
| ) | |
| parser.add_argument( | |
| "--feat-scale", | |
| type=float, | |
| default=0.1, | |
| help="The scale factor of fbank feature", | |
| ) | |
| parser.add_argument( | |
| "--speed", | |
| type=float, | |
| default=1.0, | |
| help="Control speech speed, 1.0 means normal, >1.0 means speed up", | |
| ) | |
| parser.add_argument( | |
| "--t-shift", | |
| type=float, | |
| default=0.5, | |
| help="Shift t to smaller ones if t_shift < 1.0", | |
| ) | |
| parser.add_argument( | |
| "--target-rms", | |
| type=float, | |
| default=0.1, | |
| help="Target speech normalization rms value, set to 0 to disable normalization", | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=666, | |
| help="Random seed", | |
| ) | |
| parser.add_argument( | |
| "--silence-wav", | |
| type=str, | |
| default="assets/silence.wav", | |
| help="Path of the silence wav file, used in two-channel generation " | |
| "with single-channel prompts", | |
| ) | |
| return parser | |
| def get_vocoder(vocos_local_path: Optional[str] = None): | |
| if vocos_local_path: | |
| vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") | |
| state_dict = torch.load( | |
| f"{vocos_local_path}/pytorch_model.bin", | |
| weights_only=True, | |
| map_location="cpu", | |
| ) | |
| vocoder.load_state_dict(state_dict) | |
| else: | |
| vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz") | |
| return vocoder | |
| def generate_sentence( | |
| save_path: str, | |
| prompt_text: str, | |
| prompt_wav: Union[str, List[str]], | |
| text: str, | |
| model: torch.nn.Module, | |
| vocoder: torch.nn.Module, | |
| tokenizer: DialogTokenizer, | |
| feature_extractor: VocosFbank, | |
| device: torch.device, | |
| num_step: int = 16, | |
| guidance_scale: float = 1.0, | |
| speed: float = 1.0, | |
| t_shift: float = 0.5, | |
| target_rms: float = 0.1, | |
| feat_scale: float = 0.1, | |
| sampling_rate: int = 24000, | |
| ): | |
| """ | |
| Generate waveform of a text based on a given prompt | |
| waveform and its transcription. | |
| Args: | |
| save_path (str): Path to save the generated wav. | |
| prompt_text (str): Transcription of the prompt wav. | |
| prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be | |
| one or two wav files, which corresponding to a merged conversational | |
| speech or two seperate speaker's speech. | |
| text (str): Text to be synthesized into a waveform. | |
| model (torch.nn.Module): The model used for generation. | |
| vocoder (torch.nn.Module): The vocoder used to convert features to waveforms. | |
| tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens. | |
| feature_extractor (VocosFbank): The feature extractor used to | |
| extract acoustic features. | |
| device (torch.device): The device on which computations are performed. | |
| num_step (int, optional): Number of steps for decoding. Defaults to 16. | |
| guidance_scale (float, optional): Scale for classifier-free guidance. | |
| Defaults to 1.0. | |
| speed (float, optional): Speed control. Defaults to 1.0. | |
| t_shift (float, optional): Time shift. Defaults to 0.5. | |
| target_rms (float, optional): Target RMS for waveform normalization. | |
| Defaults to 0.1. | |
| feat_scale (float, optional): Scale for features. | |
| Defaults to 0.1. | |
| sampling_rate (int, optional): Sampling rate for the waveform. | |
| Defaults to 24000. | |
| Returns: | |
| metrics (dict): Dictionary containing time and real-time | |
| factor metrics for processing. | |
| """ | |
| # Convert text to tokens | |
| tokens = tokenizer.texts_to_token_ids([text]) | |
| prompt_tokens = tokenizer.texts_to_token_ids([prompt_text]) | |
| # Load and preprocess prompt wav | |
| if isinstance(prompt_wav, str): | |
| prompt_wav = [ | |
| prompt_wav, | |
| ] | |
| else: | |
| assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str) | |
| loaded_prompt_wavs = prompt_wav | |
| for i in range(len(prompt_wav)): | |
| loaded_prompt_wavs[i], prompt_sampling_rate = torchaudio.load(prompt_wav[i]) | |
| if prompt_sampling_rate != sampling_rate: | |
| resampler = torchaudio.transforms.Resample( | |
| orig_freq=prompt_sampling_rate, new_freq=sampling_rate | |
| ) | |
| loaded_prompt_wavs[i] = resampler(loaded_prompt_wavs[i]) | |
| if len(loaded_prompt_wavs) == 1: | |
| prompt_wav = loaded_prompt_wavs[0] | |
| else: | |
| prompt_wav = torch.cat(loaded_prompt_wavs, dim=1) | |
| prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav))) | |
| if prompt_rms < target_rms: | |
| prompt_wav = prompt_wav * target_rms / prompt_rms | |
| # Extract features from prompt wav | |
| prompt_features = feature_extractor.extract( | |
| prompt_wav, sampling_rate=sampling_rate | |
| ).to(device) | |
| prompt_features = prompt_features.unsqueeze(0) * feat_scale | |
| prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device) | |
| # Start timing | |
| start_t = dt.datetime.now() | |
| # Generate features | |
| ( | |
| pred_features, | |
| pred_features_lens, | |
| pred_prompt_features, | |
| pred_prompt_features_lens, | |
| ) = model.sample( | |
| tokens=tokens, | |
| prompt_tokens=prompt_tokens, | |
| prompt_features=prompt_features, | |
| prompt_features_lens=prompt_features_lens, | |
| speed=speed, | |
| t_shift=t_shift, | |
| duration="predict", | |
| num_step=num_step, | |
| guidance_scale=guidance_scale, | |
| ) | |
| # Postprocess predicted features | |
| pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T) | |
| # Start vocoder processing | |
| start_vocoder_t = dt.datetime.now() | |
| wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1) | |
| # Calculate processing times and real-time factors | |
| t = (dt.datetime.now() - start_t).total_seconds() | |
| t_no_vocoder = (start_vocoder_t - start_t).total_seconds() | |
| t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds() | |
| wav_seconds = wav.shape[-1] / sampling_rate | |
| rtf = t / wav_seconds | |
| rtf_no_vocoder = t_no_vocoder / wav_seconds | |
| rtf_vocoder = t_vocoder / wav_seconds | |
| metrics = { | |
| "t": t, | |
| "t_no_vocoder": t_no_vocoder, | |
| "t_vocoder": t_vocoder, | |
| "wav_seconds": wav_seconds, | |
| "rtf": rtf, | |
| "rtf_no_vocoder": rtf_no_vocoder, | |
| "rtf_vocoder": rtf_vocoder, | |
| } | |
| # Adjust wav volume if necessary | |
| if prompt_rms < target_rms: | |
| wav = wav * prompt_rms / target_rms | |
| torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate) | |
| return metrics | |
| def generate_sentence_stereo( | |
| save_path: str, | |
| prompt_text: str, | |
| prompt_wav: Union[str, List[str]], | |
| text: str, | |
| model: torch.nn.Module, | |
| vocoder: torch.nn.Module, | |
| tokenizer: DialogTokenizer, | |
| feature_extractor: VocosFbank, | |
| device: torch.device, | |
| num_step: int = 16, | |
| guidance_scale: float = 1.0, | |
| speed: float = 1.0, | |
| t_shift: float = 0.5, | |
| target_rms: float = 0.1, | |
| feat_scale: float = 0.1, | |
| sampling_rate: int = 24000, | |
| silence_wav: Optional[str] = None, | |
| ): | |
| """ | |
| Generate waveform of a text based on a given prompt | |
| waveform and its transcription. | |
| Args: | |
| save_path (str): Path to save the generated wav. | |
| prompt_text (str): Transcription of the prompt wav. | |
| prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be | |
| one or two wav files, which corresponding to a merged conversational | |
| speech or two seperate speaker's speech. | |
| text (str): Text to be synthesized into a waveform. | |
| model (torch.nn.Module): The model used for generation. | |
| vocoder (torch.nn.Module): The vocoder used to convert features to waveforms. | |
| tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens. | |
| feature_extractor (VocosFbank): The feature extractor used to | |
| extract acoustic features. | |
| device (torch.device): The device on which computations are performed. | |
| num_step (int, optional): Number of steps for decoding. Defaults to 16. | |
| guidance_scale (float, optional): Scale for classifier-free guidance. | |
| Defaults to 1.0. | |
| speed (float, optional): Speed control. Defaults to 1.0. | |
| t_shift (float, optional): Time shift. Defaults to 0.5. | |
| target_rms (float, optional): Target RMS for waveform normalization. | |
| Defaults to 0.1. | |
| feat_scale (float, optional): Scale for features. | |
| Defaults to 0.1. | |
| sampling_rate (int, optional): Sampling rate for the waveform. | |
| Defaults to 24000. | |
| silence_wav (str): Path of the silence wav file, used in two-channel | |
| generation with single-channel prompts | |
| Returns: | |
| metrics (dict): Dictionary containing time and real-time | |
| factor metrics for processing. | |
| """ | |
| # Convert text to tokens | |
| tokens = tokenizer.texts_to_token_ids([text]) | |
| prompt_tokens = tokenizer.texts_to_token_ids([prompt_text]) | |
| # Load and preprocess prompt wav | |
| if isinstance(prompt_wav, str): | |
| prompt_wav = [ | |
| prompt_wav, | |
| ] | |
| else: | |
| assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str) | |
| loaded_prompt_wavs = prompt_wav | |
| for i in range(len(prompt_wav)): | |
| loaded_prompt_wavs[i], prompt_sampling_rate = torchaudio.load(prompt_wav[i]) | |
| if prompt_sampling_rate != sampling_rate: | |
| resampler = torchaudio.transforms.Resample( | |
| orig_freq=prompt_sampling_rate, new_freq=sampling_rate | |
| ) | |
| loaded_prompt_wavs[i] = resampler(loaded_prompt_wavs[i]) | |
| if len(loaded_prompt_wavs) == 1: | |
| assert ( | |
| loaded_prompt_wavs[0].size(0) == 2 | |
| ), "Merged prompt wav must be stereo for stereo dialogue generation" | |
| prompt_wav = loaded_prompt_wavs[0] | |
| else: | |
| assert len(loaded_prompt_wavs) == 2 | |
| if loaded_prompt_wavs[0].size(0) == 2: | |
| prompt_wav = torch.cat(loaded_prompt_wavs, dim=1) | |
| else: | |
| assert loaded_prompt_wavs[0].size(0) == 1 | |
| silence_wav, silence_sampling_rate = torchaudio.load(silence_wav) | |
| assert silence_sampling_rate == sampling_rate | |
| prompt_wav = silence_wav[ | |
| :, : loaded_prompt_wavs[0].size(1) + loaded_prompt_wavs[1].size(1) | |
| ] | |
| prompt_wav[0, : loaded_prompt_wavs[0].size(1)] = loaded_prompt_wavs[0] | |
| prompt_wav[1, loaded_prompt_wavs[0].size(1) :] = loaded_prompt_wavs[1] | |
| prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav))) | |
| if prompt_rms < target_rms: | |
| prompt_wav = prompt_wav * target_rms / prompt_rms | |
| # Extract features from prompt wav | |
| prompt_features = feature_extractor.extract( | |
| prompt_wav, sampling_rate=sampling_rate | |
| ).to(device) | |
| prompt_features = prompt_features.unsqueeze(0) * feat_scale | |
| prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device) | |
| # Start timing | |
| start_t = dt.datetime.now() | |
| # Generate features | |
| ( | |
| pred_features, | |
| pred_features_lens, | |
| pred_prompt_features, | |
| pred_prompt_features_lens, | |
| ) = model.sample( | |
| tokens=tokens, | |
| prompt_tokens=prompt_tokens, | |
| prompt_features=prompt_features, | |
| prompt_features_lens=prompt_features_lens, | |
| speed=speed, | |
| t_shift=t_shift, | |
| duration="predict", | |
| num_step=num_step, | |
| guidance_scale=guidance_scale, | |
| ) | |
| # Postprocess predicted features | |
| pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T) | |
| # Start vocoder processing | |
| start_vocoder_t = dt.datetime.now() | |
| feat_dim = pred_features.size(1) // 2 | |
| wav_left = vocoder.decode(pred_features[:, :feat_dim]).squeeze(1).clamp(-1, 1) | |
| wav_right = ( | |
| vocoder.decode(pred_features[:, feat_dim : feat_dim * 2]) | |
| .squeeze(1) | |
| .clamp(-1, 1) | |
| ) | |
| wav = torch.cat([wav_left, wav_right], dim=0) | |
| # Calculate processing times and real-time factors | |
| t = (dt.datetime.now() - start_t).total_seconds() | |
| t_no_vocoder = (start_vocoder_t - start_t).total_seconds() | |
| t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds() | |
| wav_seconds = wav.shape[-1] / sampling_rate | |
| rtf = t / wav_seconds | |
| rtf_no_vocoder = t_no_vocoder / wav_seconds | |
| rtf_vocoder = t_vocoder / wav_seconds | |
| metrics = { | |
| "t": t, | |
| "t_no_vocoder": t_no_vocoder, | |
| "t_vocoder": t_vocoder, | |
| "wav_seconds": wav_seconds, | |
| "rtf": rtf, | |
| "rtf_no_vocoder": rtf_no_vocoder, | |
| "rtf_vocoder": rtf_vocoder, | |
| } | |
| # Adjust wav volume if necessary | |
| if prompt_rms < target_rms: | |
| wav = wav * prompt_rms / target_rms | |
| torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate) | |
| return metrics | |
| def generate_list( | |
| model_name: str, | |
| res_dir: str, | |
| test_list: str, | |
| model: torch.nn.Module, | |
| vocoder: torch.nn.Module, | |
| tokenizer: DialogTokenizer, | |
| feature_extractor: VocosFbank, | |
| device: torch.device, | |
| num_step: int = 16, | |
| guidance_scale: float = 1.5, | |
| speed: float = 1.0, | |
| t_shift: float = 0.5, | |
| target_rms: float = 0.1, | |
| feat_scale: float = 0.1, | |
| sampling_rate: int = 24000, | |
| silence_wav: Optional[str] = None, | |
| ): | |
| total_t = [] | |
| total_t_no_vocoder = [] | |
| total_t_vocoder = [] | |
| total_wav_seconds = [] | |
| with open(test_list, "r") as fr: | |
| lines = fr.readlines() | |
| for i, line in enumerate(lines): | |
| items = line.strip().split("\t") | |
| if len(items) == 6: | |
| ( | |
| wav_name, | |
| prompt_text_1, | |
| prompt_text_2, | |
| prompt_wav_1, | |
| prompt_wav_2, | |
| text, | |
| ) = items | |
| prompt_text = f"[S1]{prompt_text_1}[S2]{prompt_text_2}" | |
| prompt_wav = [prompt_wav_1, prompt_wav_2] | |
| elif len(items) == 4: | |
| wav_name, prompt_text, prompt_wav, text = items | |
| else: | |
| raise ValueError(f"Invalid line: {line}") | |
| assert text.startswith("[S1]") | |
| save_path = f"{res_dir}/{wav_name}.wav" | |
| if model_name == "zipvoice_dialog": | |
| metrics = generate_sentence( | |
| save_path=save_path, | |
| prompt_text=prompt_text, | |
| prompt_wav=prompt_wav, | |
| text=text, | |
| model=model, | |
| vocoder=vocoder, | |
| tokenizer=tokenizer, | |
| feature_extractor=feature_extractor, | |
| device=device, | |
| num_step=num_step, | |
| guidance_scale=guidance_scale, | |
| speed=speed, | |
| t_shift=t_shift, | |
| target_rms=target_rms, | |
| feat_scale=feat_scale, | |
| sampling_rate=sampling_rate, | |
| ) | |
| else: | |
| assert model_name == "zipvoice_dialog_stereo" | |
| metrics = generate_sentence_stereo( | |
| save_path=save_path, | |
| prompt_text=prompt_text, | |
| prompt_wav=prompt_wav, | |
| text=text, | |
| model=model, | |
| vocoder=vocoder, | |
| tokenizer=tokenizer, | |
| feature_extractor=feature_extractor, | |
| device=device, | |
| num_step=num_step, | |
| guidance_scale=guidance_scale, | |
| speed=speed, | |
| t_shift=t_shift, | |
| target_rms=target_rms, | |
| feat_scale=feat_scale, | |
| sampling_rate=sampling_rate, | |
| silence_wav=silence_wav, | |
| ) | |
| print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}") | |
| total_t.append(metrics["t"]) | |
| total_t_no_vocoder.append(metrics["t_no_vocoder"]) | |
| total_t_vocoder.append(metrics["t_vocoder"]) | |
| total_wav_seconds.append(metrics["wav_seconds"]) | |
| print(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}") | |
| print( | |
| f"Average RTF w/o vocoder: " | |
| f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}" | |
| ) | |
| print( | |
| f"Average RTF vocoder: " | |
| f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}" | |
| ) | |
| def main(): | |
| parser = get_parser() | |
| args = parser.parse_args() | |
| params = AttributeDict() | |
| params.update(vars(args)) | |
| fix_random_seed(params.seed) | |
| assert ( | |
| params.test_list is not None | |
| ), "For inference, please provide prompts and text with '--test-list'" | |
| if torch.cuda.is_available(): | |
| params.device = torch.device("cuda", 0) | |
| elif torch.backends.mps.is_available(): | |
| params.device = torch.device("mps") | |
| else: | |
| params.device = torch.device("cpu") | |
| print("Loading model...") | |
| if params.model_config is None: | |
| model_config = hf_hub_download( | |
| HUGGINGFACE_REPO, filename=MODEL_CONFIG[params.model_name] | |
| ) | |
| else: | |
| model_config = params.model_config | |
| with open(model_config, "r") as f: | |
| model_config = json.load(f) | |
| if params.token_file is None: | |
| token_file = hf_hub_download( | |
| HUGGINGFACE_REPO, filename=TOKEN_FILE[params.model_name] | |
| ) | |
| else: | |
| token_file = params.token_file | |
| tokenizer = DialogTokenizer(token_file=token_file) | |
| tokenizer_config = { | |
| "vocab_size": tokenizer.vocab_size, | |
| "pad_id": tokenizer.pad_id, | |
| "spk_a_id": tokenizer.spk_a_id, | |
| "spk_b_id": tokenizer.spk_b_id, | |
| } | |
| if params.checkpoint is None: | |
| model_ckpt = hf_hub_download( | |
| HUGGINGFACE_REPO, | |
| filename=PRETRAINED_MODEL[params.model_name], | |
| ) | |
| else: | |
| model_ckpt = params.checkpoint | |
| if params.model_name == "zipvoice_dialog": | |
| model = ZipVoiceDialog( | |
| **model_config["model"], | |
| **tokenizer_config, | |
| ) | |
| else: | |
| assert params.model_name == "zipvoice_dialog_stereo" | |
| model = ZipVoiceDialogStereo( | |
| **model_config["model"], | |
| **tokenizer_config, | |
| ) | |
| if model_ckpt.endswith(".safetensors"): | |
| safetensors.torch.load_model(model, model_ckpt) | |
| elif model_ckpt.endswith(".pt"): | |
| load_checkpoint(filename=model_ckpt, model=model, strict=True) | |
| else: | |
| raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}") | |
| model = model.to(params.device) | |
| model.eval() | |
| vocoder = get_vocoder(params.vocoder_path) | |
| vocoder = vocoder.to(params.device) | |
| vocoder.eval() | |
| if model_config["feature"]["type"] == "vocos": | |
| if params.model_name == "zipvoice_dialog": | |
| num_channels = 1 | |
| else: | |
| assert params.model_name == "zipvoice_dialog_stereo" | |
| num_channels = 2 | |
| feature_extractor = VocosFbank(num_channels=num_channels) | |
| else: | |
| raise NotImplementedError( | |
| f"Unsupported feature type: {model_config['feature']['type']}" | |
| ) | |
| params.sampling_rate = model_config["feature"]["sampling_rate"] | |
| print("Start generating...") | |
| os.makedirs(params.res_dir, exist_ok=True) | |
| generate_list( | |
| model_name=params.model_name, | |
| res_dir=params.res_dir, | |
| test_list=params.test_list, | |
| model=model, | |
| vocoder=vocoder, | |
| tokenizer=tokenizer, | |
| feature_extractor=feature_extractor, | |
| device=params.device, | |
| num_step=params.num_step, | |
| guidance_scale=params.guidance_scale, | |
| speed=params.speed, | |
| t_shift=params.t_shift, | |
| target_rms=params.target_rms, | |
| feat_scale=params.feat_scale, | |
| sampling_rate=params.sampling_rate, | |
| silence_wav=params.silence_wav, | |
| ) | |
| print("Done") | |
| if __name__ == "__main__": | |
| torch.set_num_threads(1) | |
| torch.set_num_interop_threads(1) | |
| main() | |