# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Xiao Chen) # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import librosa import logging import soundfile as sf import sys from pathlib import Path sub_modules = ["", "semantic_tokenizer/f40ms", "semantic_detokenizer"] for sub in sub_modules: sys.path.append(str((Path(__file__).parent / sub).absolute())) from semantic_tokenizer.f40ms.simple_tokenizer_infer import SpeechTokenizer, TOKENIZER_CFG_NAME from semantic_detokenizer.chunk_infer import SpeechDetokenizer class ReconstructionPipeline: def __init__( self, detok_vocoder: str, tokenizer_cfg_name: str = TOKENIZER_CFG_NAME, tokenizer_cfg_path: str = str( (Path(__file__).parent / "semantic_tokenizer/f40ms/config").absolute() ), tokenizer_ckpt: str = str( ( Path(__file__).parent / "semantic_tokenizer/f40ms/ckpt/model.pt" ).absolute() ), detok_model_cfg: str = str( (Path(__file__).parent / "semantic_detokenizer/ckpt/model.yaml").absolute() ), detok_ckpt: str = str( (Path(__file__).parent / "semantic_detokenizer/ckpt/model.pt").absolute() ), detok_vocab: str = str( ( Path(__file__).parent / "semantic_detokenizer/ckpt/vocab_4096.txt" ).absolute() ), ): self.tokenizer_cfg_name = tokenizer_cfg_name self.tokenizer = SpeechTokenizer( ckpt_path=tokenizer_ckpt, cfg_path=tokenizer_cfg_path, cfg_name=self.tokenizer_cfg_name, ) self.device = "cuda:0" self.detoker = SpeechDetokenizer( vocoder_path=detok_vocoder, model_cfg=detok_model_cfg, ckpt_file=detok_ckpt, vocab_file=detok_vocab, device=self.device, ) self.token_chunk_len = 75 self.chunk_cond_proportion = 0.3 self.chunk_look_ahead = 10 self.max_ref_duration = 4.5 self.ref_audio_cut_from_head = False def reconstruct(self, ref_wav, input_wav): ref_wavs_list = [] raw_ref_wav, sr = librosa.load(ref_wav, sr=16000) ref_wavs_list.append(raw_ref_wav) raw_input_wav, sr = librosa.load(input_wav, sr=16000) ref_wavs_list.append(raw_input_wav) token_list, token_info_list = self.tokenizer.extract( ref_wavs_list ) ref_tokens = token_info_list[0]["reduced_unit_sequence"] input_tokens = token_info_list[1]["reduced_unit_sequence"] logging.info("tokens for ref wav: %s are [%s]" % (ref_wav, ref_tokens)) logging.info("tokens for input wav: %s are [%s]" % (input_wav, input_tokens)) generated_wave, target_sample_rate = self.detoker.chunk_generate( ref_wav, ref_tokens.split(), input_tokens.split(), self.token_chunk_len, self.chunk_cond_proportion, self.chunk_look_ahead, self.max_ref_duration, self.ref_audio_cut_from_head, ) if generated_wave is None: logging.info("generation FAILED") return None, None return generated_wave, target_sample_rate def main(args): # initialize reconsturctor = ReconstructionPipeline( detok_vocoder=args.detok_vocoder, ) generated_wave, target_sample_rate = reconsturctor.reconstruct(args.ref_wav, args.input_wav) with open(args.output_wav, "wb") as f: sf.write(f.name, generated_wave, target_sample_rate) logging.info(f"write output to: {f.name}") logging.info("Finished") return if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--tokenizer-ckpt", required=False, help="path to ckpt", ) parser.add_argument( "--tokenizer-cfg-path", required=False, default="semantic_tokenizer/f40ms/config", help="path to config", ) parser.add_argument( "--detok-ckpt", required=False, help="path to ckpt", ) parser.add_argument( "--detok-model-cfg", required=False, help="path to model_cfg", ) parser.add_argument( "--detok-vocab", required=False, help="path to vocab", ) parser.add_argument( "--detok-vocoder", required=True, help="path to vocoder", ) parser.add_argument( "--ref-wav", required=True, help="path to ref wav", ) parser.add_argument( "--output-wav", required=True, help="path to output reconstructed wav", ) parser.add_argument( "--input-wav", required=True, help="input wav to reconstruction", ) args = parser.parse_args() main(args)