Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 SparkAudio | |
| # 2025 Xinsheng Wang (w.xinshawn@gmail.com) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| import torch.nn as nn | |
| from pathlib import Path | |
| from typing import Dict, Any | |
| from omegaconf import DictConfig | |
| from safetensors.torch import load_file | |
| from sparktts.utils.file import load_config | |
| from sparktts.modules.speaker.speaker_encoder import SpeakerEncoder | |
| from sparktts.modules.encoder_decoder.feat_encoder import Encoder | |
| from sparktts.modules.encoder_decoder.feat_decoder import Decoder | |
| from sparktts.modules.encoder_decoder.wave_generator import WaveGenerator | |
| from sparktts.modules.vq.factorized_vector_quantize import FactorizedVectorQuantize | |
| class BiCodec(nn.Module): | |
| """ | |
| BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder, | |
| quantizer, and wave generator. | |
| """ | |
| def __init__( | |
| self, | |
| mel_params: Dict[str, Any], | |
| encoder: nn.Module, | |
| decoder: nn.Module, | |
| quantizer: nn.Module, | |
| speaker_encoder: nn.Module, | |
| prenet: nn.Module, | |
| postnet: nn.Module, | |
| **kwargs | |
| ) -> None: | |
| """ | |
| Initializes the BiCodec model with the required components. | |
| Args: | |
| mel_params (dict): Parameters for the mel-spectrogram transformer. | |
| encoder (nn.Module): Encoder module. | |
| decoder (nn.Module): Decoder module. | |
| quantizer (nn.Module): Quantizer module. | |
| speaker_encoder (nn.Module): Speaker encoder module. | |
| prenet (nn.Module): Prenet network. | |
| postnet (nn.Module): Postnet network. | |
| """ | |
| super().__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| self.quantizer = quantizer | |
| self.speaker_encoder = speaker_encoder | |
| self.prenet = prenet | |
| self.postnet = postnet | |
| self.init_mel_transformer(mel_params) | |
| def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec": | |
| """ | |
| Loads the model from a checkpoint. | |
| Args: | |
| model_dir (Path): Path to the model directory containing checkpoint and config. | |
| Returns: | |
| BiCodec: The initialized BiCodec model. | |
| """ | |
| ckpt_path = f'{model_dir}/model.safetensors' | |
| config = load_config(f'{model_dir}/config.yaml')['audio_tokenizer'] | |
| mel_params = config["mel_params"] | |
| encoder = Encoder(**config["encoder"]) | |
| quantizer = FactorizedVectorQuantize(**config["quantizer"]) | |
| prenet = Decoder(**config["prenet"]) | |
| postnet = Decoder(**config["postnet"]) | |
| decoder = WaveGenerator(**config["decoder"]) | |
| speaker_encoder = SpeakerEncoder(**config["speaker_encoder"]) | |
| model = cls( | |
| mel_params=mel_params, | |
| encoder=encoder, | |
| decoder=decoder, | |
| quantizer=quantizer, | |
| speaker_encoder=speaker_encoder, | |
| prenet=prenet, | |
| postnet=postnet, | |
| ) | |
| state_dict = load_file(ckpt_path) | |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
| for key in missing_keys: | |
| print(f"Missing tensor: {key}") | |
| for key in unexpected_keys: | |
| print(f"Unexpected tensor: {key}") | |
| model.eval() | |
| model.remove_weight_norm() | |
| return model | |
| def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Performs a forward pass through the model. | |
| Args: | |
| batch (dict): A dictionary containing features, reference waveform, and target waveform. | |
| Returns: | |
| dict: A dictionary containing the reconstruction, features, and other metrics. | |
| """ | |
| feat = batch["feat"] | |
| mel = self.mel_transformer(batch["ref_wav"]).squeeze(1) | |
| z = self.encoder(feat.transpose(1, 2)) | |
| vq_outputs = self.quantizer(z) | |
| x_vector, d_vector = self.speaker_encoder(mel.transpose(1, 2)) | |
| conditions = d_vector | |
| with_speaker_loss = False | |
| x = self.prenet(vq_outputs["z_q"], conditions) | |
| pred_feat = self.postnet(x) | |
| x = x + conditions.unsqueeze(-1) | |
| wav_recon = self.decoder(x) | |
| return { | |
| "vq_loss": vq_outputs["vq_loss"], | |
| "perplexity": vq_outputs["perplexity"], | |
| "cluster_size": vq_outputs["active_num"], | |
| "recons": wav_recon, | |
| "pred_feat": pred_feat, | |
| "x_vector": x_vector, | |
| "d_vector": d_vector, | |
| "audios": batch["wav"].unsqueeze(1), | |
| "with_speaker_loss": with_speaker_loss, | |
| } | |
| def tokenize(self, batch: Dict[str, Any]): | |
| """ | |
| Tokenizes the input audio into semantic and global tokens. | |
| Args: | |
| batch (dict): The input audio features and reference waveform. | |
| Returns: | |
| tuple: Semantic tokens and global tokens. | |
| """ | |
| feat = batch["feat"] | |
| mel = self.mel_transformer(batch["ref_wav"]).squeeze(1) | |
| z = self.encoder(feat.transpose(1, 2)) | |
| semantic_tokens = self.quantizer.tokenize(z) | |
| global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2)) | |
| return semantic_tokens, global_tokens | |
| def detokenize(self, semantic_tokens, global_tokens): | |
| """ | |
| Detokenizes the semantic and global tokens into a waveform. | |
| Args: | |
| semantic_tokens (tensor): Semantic tokens. | |
| global_tokens (tensor): Global tokens. | |
| Returns: | |
| tensor: Reconstructed waveform. | |
| """ | |
| z_q = self.quantizer.detokenize(semantic_tokens) | |
| d_vector = self.speaker_encoder.detokenize(global_tokens) | |
| x = self.prenet(z_q, d_vector) | |
| x = x + d_vector.unsqueeze(-1) | |
| wav_recon = self.decoder(x) | |
| return wav_recon | |
| def init_mel_transformer(self, config: Dict[str, Any]): | |
| """ | |
| Initializes the MelSpectrogram transformer based on the provided configuration. | |
| Args: | |
| config (dict): Configuration parameters for MelSpectrogram. | |
| """ | |
| import torchaudio.transforms as TT | |
| self.mel_transformer = TT.MelSpectrogram( | |
| config["sample_rate"], | |
| config["n_fft"], | |
| config["win_length"], | |
| config["hop_length"], | |
| config["mel_fmin"], | |
| config["mel_fmax"], | |
| n_mels=config["num_mels"], | |
| power=1, | |
| norm="slaney", | |
| mel_scale="slaney", | |
| ) | |
| def remove_weight_norm(self): | |
| """Removes weight normalization from all layers.""" | |
| def _remove_weight_norm(m): | |
| try: | |
| torch.nn.utils.remove_weight_norm(m) | |
| except ValueError: | |
| pass # The module didn't have weight norm | |
| self.apply(_remove_weight_norm) | |
| # Test the model | |
| if __name__ == "__main__": | |
| config = load_config("pretrained_models/SparkTTS-0.5B/BiCodec/config.yaml") | |
| model = BiCodec.load_from_checkpoint( | |
| model_dir="pretrained_models/SparkTTS-0.5B/BiCodec", | |
| ) | |
| # Generate random inputs for testing | |
| duration = 0.96 | |
| x = torch.randn(20, 1, int(duration * 16000)) | |
| feat = torch.randn(20, int(duration * 50), 1024) | |
| inputs = {"feat": feat, "wav": x, "ref_wav": x} | |
| # Forward pass | |
| outputs = model(inputs) | |
| semantic_tokens, global_tokens = model.tokenize(inputs) | |
| wav_recon = model.detokenize(semantic_tokens, global_tokens) | |
| # Verify if the reconstruction matches | |
| if torch.allclose(outputs["recons"].detach(), wav_recon): | |
| print("Test successful") | |
| else: | |
| print("Test failed") | |