# coding=utf-8 # Copyright 2025 OpenMOSS and HuggingFace Inc. teams. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os import random import uuid as uuid_module from collections import OrderedDict, defaultdict from pathlib import Path from typing import List, Optional, Sequence, Tuple, Union import numpy as np import onnxruntime from hyperpyyaml import load_hyperpyyaml import torch import torchaudio import torchaudio.compliance.kaldi as kaldi from safetensors.torch import load_file from torch import nn from transformers import PreTrainedModel, WhisperFeatureExtractor from .configuration_moss_speech_codec import MossSpeechCodecConfig from .modeling_whisper import WhisperVQEncoder, WhisperVQConfig from .utils import extract_speech_token logger = logging.getLogger(__name__) def set_seed(seed: int) -> None: if not isinstance(seed, int): raise TypeError("Seed must be an integer.") logger.info("Setting random seed to %s", seed) random.seed(seed) np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False else: torch.manual_seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) os.environ["TF_CUDNN_DETERMINISTIC"] = "1" def fade_in_out(fade_in_mel, fade_out_mel, window): device = fade_in_mel.device fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu() mel_overlap_len = int(window.shape[0] / 2) fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \ fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] return fade_in_mel.to(device) tts_speech_prev = None tts_mel_prev = None class AudioDecoder(nn.Module): def __init__( self, config_path: Union[str, os.PathLike], flow_ckpt_path: Union[str, os.PathLike], hift_ckpt_path: Union[str, os.PathLike], campplus_model: Union[str, os.PathLike], device: Union[str, torch.device] = "cuda", ) -> None: super().__init__() self.device = torch.device(device) if isinstance(device, str) else device with open(config_path, "r", encoding="utf-8") as config_file: logger.info("Loading decoder configurations from %s", config_path) self.scratch_configs = load_hyperpyyaml(config_file) # Load models self.flow = self.scratch_configs["flow"] self.flow.load_state_dict(torch.load(flow_ckpt_path, map_location=self.device), strict=False) self.hift = self.scratch_configs["hift"] self.hift.load_state_dict(torch.load(hift_ckpt_path, map_location=self.device)) self.hift = self.hift.eval() self.sample_rate = self.scratch_configs["sample_rate"] self.feat_extractor = self.scratch_configs["feat_extractor"] # Move models to the appropriate device self.flow.to(self.device) self.hift.to(self.device) self.mel_overlap_dict = defaultdict(lambda: None) self.hift_cache_dict = defaultdict(lambda: None) self.token_min_hop_len = 2 * self.flow.input_frame_rate self.token_max_hop_len = 4 * self.flow.input_frame_rate self.token_overlap_len = 3.5 self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 24000 / (480 * 2)) self.mel_window = np.hamming(2 * self.mel_overlap_len) # hift cache self.mel_cache_len = 1 self.source_cache_len = int(self.mel_cache_len * 480) # speech fade in out session_options = onnxruntime.SessionOptions() session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL session_options.intra_op_num_threads = 1 self.campplus_session = onnxruntime.InferenceSession( str(campplus_model), sess_options=session_options, providers=["CPUExecutionProvider"], ) self.speech_window = np.hamming(2 * self.source_cache_len) def token2wav( self, token: torch.Tensor, uuid: str, prompt_token: Optional[torch.Tensor] = None, prompt_feat: Optional[torch.Tensor] = None, embedding: Optional[torch.Tensor] = None, finalize: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: prompt_token = prompt_token if prompt_token is not None else torch.zeros(1, 0, dtype=torch.int32) prompt_feat = prompt_feat if prompt_feat is not None else torch.zeros(1, 0, 80) embedding = embedding if embedding is not None else torch.zeros(1, 192) tts_mel = self.flow.inference( token=token.to(self.device), token_len=torch.tensor([token.shape[1]], dtype=torch.int32, device=self.device), prompt_token=prompt_token.to(self.device), prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32, device=self.device), prompt_feat=prompt_feat.to(self.device), prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32, device=self.device), embedding=embedding.to(self.device), streaming=False, finalize=finalize, ) tts_mel = tts_mel[0] if self.mel_overlap_dict[uuid] is not None: tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window) # append hift cache if self.hift_cache_dict[uuid] is not None: hift_cache_mel, hift_cache_source = ( self.hift_cache_dict[uuid]["mel"], self.hift_cache_dict[uuid]["source"], ) tts_mel = torch.cat([hift_cache_mel, tts_mel], dim=2) else: hift_cache_source = torch.zeros(1, 1, 0) # keep overlap mel and hift cache if not finalize: self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:] tts_mel = tts_mel[:, :, :-self.mel_overlap_len] tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) self.hift_cache_dict[uuid] = { "mel": tts_mel[:, :, -self.mel_cache_len:], "source": tts_source[:, :, -self.source_cache_len:], "speech": tts_speech[:, -self.source_cache_len:], } tts_speech = tts_speech[:, :-self.source_cache_len] else: tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) del self.hift_cache_dict[uuid] del self.mel_overlap_dict[uuid] return tts_speech, tts_mel def offline_inference(self, token: torch.Tensor) -> torch.Tensor: this_uuid = str(uuid_module.uuid1()) tts_speech, tts_mel = self.token2wav(token, uuid=this_uuid, finalize=True) return tts_speech.cpu() def stream_inference( self, token: torch.Tensor, prompt_token: Optional[torch.Tensor] = None, prompt_feat: Optional[torch.Tensor] = None, embedding: Optional[torch.Tensor] = None, block_size: int = 8, ) -> torch.Tensor: token = token.to(self.device) this_uuid = str(uuid_module.uuid1()) prompt_tensor = ( prompt_token.to(self.device) if prompt_token is not None else torch.zeros(1, 0, dtype=torch.int32, device=self.device) ) prompt_speech_feat = ( prompt_feat.to(self.device) if prompt_feat is not None else torch.zeros(1, 0, 80, device=self.device) ) embedding = embedding.to(self.device) if embedding is not None else torch.zeros(1, 192, device=self.device) base_prompt_tensor = prompt_tensor base_prompt_feat = prompt_speech_feat tts_speechs: List[torch.Tensor] = [] tts_mels: List[torch.Tensor] = [] prev_mel: Optional[torch.Tensor] = None for idx in range(0, token.size(1), block_size): tts_token = token[:, idx : idx + block_size] prompt_tensor_current = base_prompt_tensor prompt_feat_current = base_prompt_feat if prev_mel is not None: prompt_feat_current = torch.cat( [base_prompt_feat.transpose(1, 2)] + tts_mels, dim=-1, ).transpose(1, 2) prompt_tensor_current = torch.cat([base_prompt_tensor, token[:, :idx]], dim=-1) is_finalize = idx + block_size >= token.size(-1) tts_speech, tts_mel = self.token2wav( tts_token, uuid=this_uuid, prompt_token=prompt_tensor_current, prompt_feat=prompt_feat_current, embedding=embedding, finalize=is_finalize, ) prev_mel = tts_mel tts_speechs.append(tts_speech) tts_mels.append(tts_mel) tts_speech = torch.cat(tts_speechs, dim=-1).cpu() return tts_speech def streaming_inference( self, token: torch.Tensor, prompt_token: Optional[torch.Tensor] = None, prompt_feat: Optional[torch.Tensor] = None, embedding: Optional[torch.Tensor] = None, uuid: Optional[str] = None, prev_mel: Optional[torch.Tensor] = None, prev_token: Optional[torch.Tensor] = None, is_finalize: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: token = token.to(self.device) this_uuid = uuid or str(uuid_module.uuid1()) prompt_speech_feat = ( prompt_feat.to(self.device) if prompt_feat is not None else torch.zeros(1, 0, 80, device=self.device) ) flow_prompt_speech_token = ( prompt_token.to(self.device) if prompt_token is not None else torch.zeros(1, 0, dtype=torch.int32, device=self.device) ) embedding_tensor = ( embedding.to(self.device) if embedding is not None else torch.zeros(1, 192, device=self.device) ) if prev_mel is not None: prompt_speech_feat = prev_mel if prev_token is not None: flow_prompt_speech_token = prev_token tts_speech, tts_mel = self.token2wav( token, uuid=this_uuid, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, embedding=embedding_tensor, finalize=is_finalize, ) if prev_mel is not None: prev_mel = torch.cat([prev_mel, tts_mel], dim=1) else: prev_mel = tts_mel if prev_token is not None: prev_token = torch.cat([prev_token, token], dim=-1) else: prev_token = token return tts_speech.cpu(), prev_mel, prev_token class MossSpeechCodec(PreTrainedModel): """MossSpeech codec model (Whisper-VQ encoder + Flow/HiFT decoder). Notes - API is designed to be compatible with the existing `MossSpeechProcessor` usages, while adopting a Transformers-style layout similar to HF codec models (`xcodec`, `encodec`). - `encode` accepts raw audio tensors or file paths. It returns a Python list of codec token ids per input sample for backward-compatibility. - `decode` accepts either a 3D LongTensor `(B, 1, T)` or a nested list of token ids, and returns a dict with a list of waveforms under `"syn_wav_list"` (matching current processor expectations). """ config_class = MossSpeechCodecConfig def __init__( self, encoder_weight_path: Union[str, os.PathLike], encoder_config_path: Union[str, os.PathLike], encoder_feature_extractor_path: Union[str, os.PathLike], flow_path: Union[str, os.PathLike], ) -> None: super().__init__(config=MossSpeechCodecConfig()) # Whisper-VQ encoder self.sample_rate = 16000 config = WhisperVQConfig.from_pretrained(str(encoder_config_path)) self.whisper_vqmodel = WhisperVQEncoder(config) state_dict = load_file(str(encoder_weight_path)) new_state_dict: OrderedDict[str, torch.Tensor] = OrderedDict() for k, v in state_dict.items(): if k.startswith("encoder."): new_state_dict[k[len("encoder."):]] = v self.whisper_vqmodel.load_state_dict(new_state_dict, strict=False) self.feature_extractor = WhisperFeatureExtractor.from_pretrained( str(encoder_feature_extractor_path) ) # Flow / HiFT decoder stack self.flow_path = str(flow_path) self.audio_decoder = AudioDecoder( config_path=os.path.join(self.flow_path, "config.yaml"), flow_ckpt_path=os.path.join(self.flow_path, "flow.pt"), hift_ckpt_path=os.path.join(self.flow_path, "hift.pt"), campplus_model=os.path.join(self.flow_path, "campplus.onnx"), ).eval() @torch.no_grad() def encode( self, inputs: Union[ Sequence[Union[str, os.PathLike, Tuple[torch.Tensor, int], torch.Tensor]], torch.Tensor, ], *, sampling_rate: Optional[int] = None, batch_size: int = 128, ) -> List[List[int]]: """Encode audio into codec token ids. Accepts one of: - a list of file paths - a list of `(waveform, sr)` tuples - a list of 1D/2D waveforms (sr assumed 16k) - a batched tensor with shape `(B, C, T)` or `(B, T)` """ # Normalize to a list the helper can consume if isinstance(inputs, torch.Tensor): if inputs.dim() == 2: inputs = inputs.unsqueeze(1) # (B, 1, T) if inputs.dim() != 3: raise ValueError("`inputs` must be (B, C, T) when passing a tensor.") sr = sampling_rate or self.sample_rate items: List[Tuple[torch.Tensor, int]] = [ (inputs[i].squeeze(0).cpu(), sr) for i in range(inputs.size(0)) ] else: items = list(inputs) # type: ignore[assignment] # Use the existing utility (supports file paths, tuples, tensors) audio_tokens: List[List[int]] = extract_speech_token( self.whisper_vqmodel, self.feature_extractor, items, batch_size=batch_size ) return audio_tokens def _extract_speech_feat(self, speech: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: speech_feat = self.audio_decoder.feat_extractor(speech).squeeze(dim=0).transpose(0, 1) speech_feat = speech_feat.unsqueeze(dim=0) speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32) return speech_feat, speech_feat_len def _extract_spk_embedding(self, speech_16k: torch.Tensor) -> torch.Tensor: feat = kaldi.fbank(speech_16k, num_mel_bins=80, dither=0, sample_frequency=16000) feat = feat - feat.mean(dim=0, keepdim=True) embedding = self.audio_decoder.campplus_session.run( None, {self.audio_decoder.campplus_session.get_inputs()[0].name: feat.unsqueeze(0).cpu().numpy()}, )[0].flatten().tolist() return torch.tensor([embedding]) @torch.no_grad() def decode( self, audio_codes: Union[Sequence[Sequence[int]], torch.LongTensor], *, prompt_speech: Optional[Union[str, os.PathLike]] = None, prompt_speech_sample_rate: Optional[int] = None, use_spk_embedding: bool = True, use_prompt_speech: bool = True, finalize: bool = True, device: torch.device = torch.device("cuda"), ) -> dict: """Decode codec token ids back to waveform(s). Args - audio_codes: `(B, 1, T)` or Python nested lists per sample. - prompt_speech: path to the enrollment audio used for conditioning. Returns - {"syn_wav_list": List[Tensor(T)]} """ if isinstance(audio_codes, torch.Tensor): if audio_codes.dim() == 3 and audio_codes.size(1) == 1: codes_list: List[List[int]] = [ audio_codes[i, 0].detach().cpu().tolist() for i in range(audio_codes.size(0)) ] elif audio_codes.dim() == 2: codes_list = [row.detach().cpu().tolist() for row in audio_codes] else: raise ValueError("`audio_codes` must be (B, 1, T) or (B, T) when passing a tensor.") else: codes_list = [list(c) for c in audio_codes] if prompt_speech is None or not os.path.exists(str(prompt_speech)): raise ValueError("`prompt_speech` path is required for decoding and must exist.") prompt_wav, orig_sr = torchaudio.load(str(prompt_speech)) target_sr = self.audio_decoder.sample_rate if orig_sr != target_sr: prompt_wav = torchaudio.transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)(prompt_wav) device = device if torch.cuda.is_available() or device.type == "cpu" else torch.device("cpu") speech_token = torch.tensor(self.encode([str(prompt_speech)])[0], device=device).unsqueeze(0) speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav) if target_sr == 24000: token_len = min(int(speech_feat.shape[1] / 4), speech_token.shape[1]) speech_feat, speech_feat_len[:] = speech_feat[:, : 4 * token_len], 4 * token_len speech_token, _ = speech_token[:, :token_len], token_len prompt_16k = torchaudio.transforms.Resample(orig_freq=target_sr, new_freq=16000)(prompt_wav) embedding = self._extract_spk_embedding(prompt_16k).to(device) speech_feat = speech_feat.to(device) speech_feat_len = speech_feat_len.to(device) syn_wav_list: List[torch.Tensor] = [] for codes in codes_list: codes_t = torch.tensor(codes, device=device).unsqueeze(0) uuid = os.urandom(16).hex() kwargs = {"uuid": uuid, "finalize": finalize} if use_prompt_speech: kwargs.update({"prompt_token": speech_token, "prompt_feat": speech_feat}) if use_spk_embedding: kwargs.update({"embedding": embedding}) tts_speech, _ = self.audio_decoder.token2wav(codes_t, **kwargs) syn_wav_list.append(tts_speech.squeeze()) return {"syn_wav_list": syn_wav_list} @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], *, revision: Optional[str] = None, cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, use_auth_token: Optional[Union[str, bool]] = None, # back-compat with HF Transformers kwarg subfolder: Optional[str] = None, **kwargs, ): """Instantiate codec from a local directory or a Hugging Face Hub repo. This mirrors the typical Hugging Face ``from_pretrained`` behavior: - If ``pretrained_model_name_or_path`` is a local folder, files are loaded from it. - Otherwise, it is treated as a Hub repo ID and downloaded with ``snapshot_download``. Expected layout inside the resolved base folder: - ``model.safetensors`` (Whisper VQ encoder weights) - ``config.json`` (Whisper VQ config) - ``preprocessor_config.json`` (WhisperFeatureExtractor params) - ``flow/{config.yaml, flow.pt, hift.pt, campplus.onnx}`` """ # Resolve local directory vs HF Hub repo. base: Path path_str = str(pretrained_model_name_or_path) if os.path.isdir(path_str): base = Path(path_str) else: try: from huggingface_hub import snapshot_download # lazy import to avoid hard dependency at import time except Exception as exc: # pragma: no cover raise RuntimeError( "huggingface_hub is required to load from a repo id; please `pip install huggingface_hub`." ) from exc # HF Transformers historically supports both `token` and deprecated `use_auth_token`. if token is None and use_auth_token is not None: token = use_auth_token snapshot_path = snapshot_download( repo_id=path_str, revision=revision, cache_dir=str(cache_dir) if cache_dir is not None else None, force_download=force_download, local_files_only=local_files_only, token=token, ) base = Path(snapshot_path) if subfolder: base = base / subfolder tokenizer_dir = base flow_dir = base / "flow" # Validate expected files and provide actionable error messages, similar to HF patterns. missing: List[str] = [] if not (tokenizer_dir / "model.safetensors").exists(): missing.append(str(tokenizer_dir / "model.safetensors")) if not (tokenizer_dir / "config.json").exists(): missing.append(str(tokenizer_dir / "config.json")) if not (tokenizer_dir / "preprocessor_config.json").exists(): missing.append(str(tokenizer_dir / "preprocessor_config.json")) for fname in ("config.yaml", "flow.pt", "hift.pt"): if not (flow_dir / fname).exists(): missing.append(str(flow_dir / fname)) # `campplus.onnx` may be named differently in some drops; only warn if absent. has_campplus = (flow_dir / "campplus.onnx").exists() if missing: raise FileNotFoundError( "Missing required codec assets under resolved path. The following files were not found: " + ", ".join(missing) ) if not has_campplus: logger.warning("campplus.onnx not found under %s; decoding speaker embedding may fail.", flow_dir) encoder_weight_path = str(tokenizer_dir / "model.safetensors") encoder_config_path = str(tokenizer_dir / "config.json") encoder_feature_extractor_path = str(tokenizer_dir) flow_path = str(flow_dir) return cls( encoder_weight_path=encoder_weight_path, encoder_config_path=encoder_config_path, encoder_feature_extractor_path=encoder_feature_extractor_path, flow_path=flow_path, )