Upload 3 files
Browse files- configuration_spark_tts.py +233 -0
- modeling_spark_tts.py +0 -0
- processing_spark_tts.py +889 -0
configuration_spark_tts.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
# ... (License headers remain the same) ...
|
| 4 |
+
""" SparkTTS model configuration"""
|
| 5 |
+
|
| 6 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 7 |
+
from transformers.utils import logging
|
| 8 |
+
from typing import List, Optional # Added typing
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = logging.get_logger(__name__)
|
| 12 |
+
|
| 13 |
+
# --- Define Individual Sub-Component Config Classes ---
|
| 14 |
+
|
| 15 |
+
class SparkTTSMelParamsConfig(PretrainedConfig):
|
| 16 |
+
"""Configuration for Mel Spectrogram parameters."""
|
| 17 |
+
model_type = "spark-tts-mel-params"
|
| 18 |
+
def __init__(self, sample_rate=16000, n_fft=1024, win_length=640, hop_length=320,
|
| 19 |
+
mel_fmin=10, mel_fmax=None, num_mels=128, **kwargs):
|
| 20 |
+
super().__init__(**kwargs)
|
| 21 |
+
self.sample_rate = sample_rate
|
| 22 |
+
self.n_fft = n_fft
|
| 23 |
+
self.win_length = win_length
|
| 24 |
+
self.hop_length = hop_length
|
| 25 |
+
self.mel_fmin = mel_fmin
|
| 26 |
+
self.mel_fmax = mel_fmax
|
| 27 |
+
self.num_mels = num_mels
|
| 28 |
+
|
| 29 |
+
class SparkTTSEncoderConfig(PretrainedConfig):
|
| 30 |
+
"""Configuration for the BiCodec Feature Encoder."""
|
| 31 |
+
model_type = "spark-tts-encoder"
|
| 32 |
+
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
|
| 33 |
+
vocos_num_layers=12, out_channels=1024, sample_ratios=[1, 1], **kwargs):
|
| 34 |
+
super().__init__(**kwargs)
|
| 35 |
+
self.input_channels = input_channels
|
| 36 |
+
self.vocos_dim = vocos_dim
|
| 37 |
+
self.vocos_intermediate_dim = vocos_intermediate_dim
|
| 38 |
+
self.vocos_num_layers = vocos_num_layers
|
| 39 |
+
self.out_channels = out_channels
|
| 40 |
+
self.sample_ratios = sample_ratios
|
| 41 |
+
|
| 42 |
+
class SparkTTSDecoderConfig(PretrainedConfig):
|
| 43 |
+
"""Configuration for the BiCodec Wave Generator (Decoder)."""
|
| 44 |
+
model_type = "spark-tts-decoder"
|
| 45 |
+
def __init__(self, input_channel=1024, channels=1536, rates=[8, 5, 4, 2],
|
| 46 |
+
kernel_sizes=[16, 11, 8, 4], **kwargs):
|
| 47 |
+
super().__init__(**kwargs)
|
| 48 |
+
self.input_channel = input_channel
|
| 49 |
+
self.channels = channels
|
| 50 |
+
self.rates = rates
|
| 51 |
+
self.kernel_sizes = kernel_sizes
|
| 52 |
+
|
| 53 |
+
class SparkTTSQuantizerConfig(PretrainedConfig):
|
| 54 |
+
"""Configuration for the BiCodec Factorized Vector Quantizer."""
|
| 55 |
+
model_type = "spark-tts-quantizer"
|
| 56 |
+
def __init__(self, input_dim=1024, codebook_size=8192, codebook_dim=8,
|
| 57 |
+
commitment=0.25, codebook_loss_weight=2.0, decay=0.99,
|
| 58 |
+
threshold_ema_dead_code=0.2, **kwargs):
|
| 59 |
+
# Note: Removed use_l2_normlize as it wasn't in the original class __init__ args
|
| 60 |
+
# Add it back if it's actually used by the FactorizedVectorQuantize class init
|
| 61 |
+
super().__init__(**kwargs)
|
| 62 |
+
self.input_dim = input_dim
|
| 63 |
+
self.codebook_size = codebook_size
|
| 64 |
+
self.codebook_dim = codebook_dim
|
| 65 |
+
self.commitment = commitment
|
| 66 |
+
self.codebook_loss_weight = codebook_loss_weight
|
| 67 |
+
self.decay = decay
|
| 68 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 69 |
+
|
| 70 |
+
class SparkTTSSpeakerEncoderConfig(PretrainedConfig):
|
| 71 |
+
"""Configuration for the BiCodec Speaker Encoder."""
|
| 72 |
+
model_type = "spark-tts-speaker-encoder"
|
| 73 |
+
def __init__(self, input_dim=128, out_dim=1024, latent_dim=128, token_num=32,
|
| 74 |
+
fsq_levels=[4, 4, 4, 4, 4, 4], fsq_num_quantizers=1, **kwargs):
|
| 75 |
+
super().__init__(**kwargs)
|
| 76 |
+
self.input_dim = input_dim
|
| 77 |
+
self.out_dim = out_dim
|
| 78 |
+
self.latent_dim = latent_dim
|
| 79 |
+
self.token_num = token_num
|
| 80 |
+
self.fsq_levels = fsq_levels
|
| 81 |
+
self.fsq_num_quantizers = fsq_num_quantizers
|
| 82 |
+
|
| 83 |
+
class SparkTTSPrenetConfig(PretrainedConfig):
|
| 84 |
+
"""Configuration for the BiCodec Prenet."""
|
| 85 |
+
model_type = "spark-tts-prenet"
|
| 86 |
+
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
|
| 87 |
+
vocos_num_layers=12, out_channels=1024, condition_dim=1024,
|
| 88 |
+
sample_ratios=[1, 1], use_tanh_at_final=False, **kwargs):
|
| 89 |
+
super().__init__(**kwargs)
|
| 90 |
+
self.input_channels = input_channels
|
| 91 |
+
self.vocos_dim = vocos_dim
|
| 92 |
+
self.vocos_intermediate_dim = vocos_intermediate_dim
|
| 93 |
+
self.vocos_num_layers = vocos_num_layers
|
| 94 |
+
self.out_channels = out_channels
|
| 95 |
+
self.condition_dim = condition_dim
|
| 96 |
+
self.sample_ratios = sample_ratios
|
| 97 |
+
self.use_tanh_at_final = use_tanh_at_final
|
| 98 |
+
|
| 99 |
+
class SparkTTSPostnetConfig(PretrainedConfig):
|
| 100 |
+
"""Configuration for the BiCodec Postnet."""
|
| 101 |
+
model_type = "spark-tts-postnet"
|
| 102 |
+
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
|
| 103 |
+
vocos_num_layers=6, out_channels=1024, use_tanh_at_final=False, **kwargs):
|
| 104 |
+
# Note: Removed condition_dim as it wasn't in the original config example for postnet
|
| 105 |
+
super().__init__(**kwargs)
|
| 106 |
+
self.input_channels = input_channels
|
| 107 |
+
self.vocos_dim = vocos_dim
|
| 108 |
+
self.vocos_intermediate_dim = vocos_intermediate_dim
|
| 109 |
+
self.vocos_num_layers = vocos_num_layers
|
| 110 |
+
self.out_channels = out_channels
|
| 111 |
+
self.use_tanh_at_final = use_tanh_at_final
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# --- Define the Intermediate BiCodec Config Class ---
|
| 115 |
+
|
| 116 |
+
class SparkTTSBiCodecConfig(PretrainedConfig):
|
| 117 |
+
"""
|
| 118 |
+
Intermediate configuration class for the BiCodec component within SparkTTS.
|
| 119 |
+
It holds instances of the individual sub-component configurations.
|
| 120 |
+
"""
|
| 121 |
+
model_type = "spark-tts-bicodec"
|
| 122 |
+
# Map keys in the 'bicodec_config' dict to their respective classes
|
| 123 |
+
sub_configs = {
|
| 124 |
+
"mel_params": SparkTTSMelParamsConfig,
|
| 125 |
+
"encoder_config": SparkTTSEncoderConfig,
|
| 126 |
+
"decoder_config": SparkTTSDecoderConfig,
|
| 127 |
+
"quantizer_config": SparkTTSQuantizerConfig,
|
| 128 |
+
"speaker_encoder_config": SparkTTSSpeakerEncoderConfig,
|
| 129 |
+
"prenet_config": SparkTTSPrenetConfig,
|
| 130 |
+
"postnet_config": SparkTTSPostnetConfig,
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
def __init__(
|
| 134 |
+
self,
|
| 135 |
+
mel_params=None,
|
| 136 |
+
encoder_config=None,
|
| 137 |
+
decoder_config=None,
|
| 138 |
+
quantizer_config=None,
|
| 139 |
+
speaker_encoder_config=None,
|
| 140 |
+
prenet_config=None,
|
| 141 |
+
postnet_config=None,
|
| 142 |
+
**kwargs,
|
| 143 |
+
):
|
| 144 |
+
super().__init__(**kwargs)
|
| 145 |
+
|
| 146 |
+
# Instantiate sub-configs from dictionaries or use defaults/provided instances
|
| 147 |
+
self.mel_params = self._init_sub_config(mel_params, "mel_params")
|
| 148 |
+
self.encoder_config = self._init_sub_config(encoder_config, "encoder_config")
|
| 149 |
+
self.decoder_config = self._init_sub_config(decoder_config, "decoder_config")
|
| 150 |
+
self.quantizer_config = self._init_sub_config(quantizer_config, "quantizer_config")
|
| 151 |
+
self.speaker_encoder_config = self._init_sub_config(speaker_encoder_config, "speaker_encoder_config")
|
| 152 |
+
self.prenet_config = self._init_sub_config(prenet_config, "prenet_config")
|
| 153 |
+
self.postnet_config = self._init_sub_config(postnet_config, "postnet_config")
|
| 154 |
+
|
| 155 |
+
def _init_sub_config(self, config_input, config_key):
|
| 156 |
+
"""Helper to initialize sub-configs."""
|
| 157 |
+
config_cls = self.sub_configs[config_key]
|
| 158 |
+
if isinstance(config_input, dict):
|
| 159 |
+
return config_cls(**config_input)
|
| 160 |
+
elif config_input is None:
|
| 161 |
+
return config_cls() # Initialize with defaults
|
| 162 |
+
elif isinstance(config_input, config_cls):
|
| 163 |
+
return config_input # Already an instance
|
| 164 |
+
else:
|
| 165 |
+
raise TypeError(f"Invalid type for {config_key}: {type(config_input)}. Expected dict, None, or {config_cls.__name__}.")
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# --- Define the Main SparkTTS Config Class ---
|
| 169 |
+
|
| 170 |
+
class SparkTTSConfig(PretrainedConfig):
|
| 171 |
+
r"""
|
| 172 |
+
Main configuration class for SparkTTSModel, including nested BiCodec configuration.
|
| 173 |
+
Args:
|
| 174 |
+
llm_model_name_or_path (`str`, *optional*, defaults to `"./LLM"`): Path/ID for LLM.
|
| 175 |
+
bicodec_model_name_or_path (`str`, *optional*, defaults to `"./BiCodec"`): Path/ID for BiCodec checkpoint.
|
| 176 |
+
wav2vec2_model_name_or_path (`str`, *optional*, defaults to `"./wav2vec2-large-xlsr-53"`): Path/ID for Wav2Vec2.
|
| 177 |
+
sample_rate (`int`, *optional*, defaults to 16000): Audio sample rate.
|
| 178 |
+
# ... (other top-level args: highpass_cutoff_freq, latent_hop_length, ref_segment_duration, volume_normalize) ...
|
| 179 |
+
bicodec_config (`dict`, *optional*): Dictionary to initialize `SparkTTSBiCodecConfig`.
|
| 180 |
+
torch_dtype (`str`, *optional*, defaults to `"auto"`): Torch dtype.
|
| 181 |
+
kwargs (*optional*): Dictionary of keyword arguments.
|
| 182 |
+
"""
|
| 183 |
+
model_type = "spark-tts"
|
| 184 |
+
# Map the key in config.json to the intermediate BiCodec config class
|
| 185 |
+
sub_configs = {"bicodec_config": SparkTTSBiCodecConfig}
|
| 186 |
+
attribute_map = {"hidden_size": "d_model"} # Example
|
| 187 |
+
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
llm_model_name_or_path="./LLM",
|
| 191 |
+
bicodec_model_name_or_path="./BiCodec",
|
| 192 |
+
wav2vec2_model_name_or_path="./wav2vec2-large-xlsr-53",
|
| 193 |
+
sample_rate=16000,
|
| 194 |
+
highpass_cutoff_freq=40,
|
| 195 |
+
latent_hop_length=320,
|
| 196 |
+
ref_segment_duration=6.0,
|
| 197 |
+
volume_normalize=True,
|
| 198 |
+
bicodec_config=None, # Expects a dictionary or None
|
| 199 |
+
torch_dtype="auto",
|
| 200 |
+
**kwargs,
|
| 201 |
+
):
|
| 202 |
+
# --- Top-level parameters ---
|
| 203 |
+
self.llm_model_name_or_path = llm_model_name_or_path
|
| 204 |
+
self.bicodec_model_name_or_path = bicodec_model_name_or_path
|
| 205 |
+
self.wav2vec2_model_name_or_path = wav2vec2_model_name_or_path
|
| 206 |
+
self.sample_rate = sample_rate
|
| 207 |
+
self.highpass_cutoff_freq = highpass_cutoff_freq
|
| 208 |
+
self.latent_hop_length = latent_hop_length
|
| 209 |
+
self.ref_segment_duration = ref_segment_duration
|
| 210 |
+
self.volume_normalize = volume_normalize
|
| 211 |
+
self.torch_dtype = torch_dtype
|
| 212 |
+
|
| 213 |
+
# --- Nested BiCodec Configuration ---
|
| 214 |
+
# Instantiate the intermediate BiCodec config class, which will handle its own sub-configs
|
| 215 |
+
if isinstance(bicodec_config, dict):
|
| 216 |
+
self.bicodec_config = self.sub_configs["bicodec_config"](**bicodec_config)
|
| 217 |
+
elif bicodec_config is None:
|
| 218 |
+
logger.info("`bicodec_config` not provided. Initializing `SparkTTSBiCodecConfig` with its defaults.")
|
| 219 |
+
self.bicodec_config = self.sub_configs["bicodec_config"]()
|
| 220 |
+
elif isinstance(bicodec_config, self.sub_configs["bicodec_config"]):
|
| 221 |
+
self.bicodec_config = bicodec_config # Use existing instance
|
| 222 |
+
else:
|
| 223 |
+
raise TypeError(f"Invalid type for bicodec_config: {type(bicodec_config)}. Expected dict, None, or SparkTTSBiCodecConfig.")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# Set processor class and auto_map
|
| 227 |
+
kwargs["processor_class"] = kwargs.get("processor_class", "SparkTTSProcessor")
|
| 228 |
+
kwargs["auto_map"] = kwargs.get("auto_map", {
|
| 229 |
+
"AutoConfig": "configuration_spark_tts.SparkTTSConfig",
|
| 230 |
+
"AutoModel": "modeling_spark_tts.SparkTTSModel",
|
| 231 |
+
"AutoProcessor": "processing_spark_tts.SparkTTSProcessor"
|
| 232 |
+
})
|
| 233 |
+
super().__init__(**kwargs)
|
modeling_spark_tts.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
processing_spark_tts.py
ADDED
|
@@ -0,0 +1,889 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 SparkAudio & The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Processor class for SparkTTS. Combines text tokenization and audio feature extraction/processing.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os # Needed for save_pretrained
|
| 20 |
+
import re # For decoding
|
| 21 |
+
import torch
|
| 22 |
+
import numpy as np
|
| 23 |
+
import soundfile as sf # For audio loading
|
| 24 |
+
import soxr # For resampling
|
| 25 |
+
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Optional, Union, List, Dict, Tuple, Any
|
| 28 |
+
|
| 29 |
+
from transformers.processing_utils import ProcessorMixin
|
| 30 |
+
from transformers.tokenization_utils_base import BatchEncoding # Return type hint
|
| 31 |
+
from transformers.feature_extraction_utils import BatchFeature # Return type hint
|
| 32 |
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
| 33 |
+
from transformers.models.wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
|
| 34 |
+
from transformers.utils import logging, PushToHubMixin # Added PushToHubMixin
|
| 35 |
+
from numpy.lib.stride_tricks import sliding_window_view
|
| 36 |
+
import soxr
|
| 37 |
+
import soundfile
|
| 38 |
+
import random
|
| 39 |
+
|
| 40 |
+
# Import custom config if needed for defaults
|
| 41 |
+
from .configuration_spark_tts import SparkTTSConfig
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# =============================================================================
|
| 47 |
+
# >> START: PASTE CODE FROM sparktts/utils/* HERE <<
|
| 48 |
+
# =============================================================================
|
| 49 |
+
# IMPORTANT: Utility functions needed for processing (audio loading, token parsing)
|
| 50 |
+
# must be defined or imported here.
|
| 51 |
+
|
| 52 |
+
# --- Paste sparktts/utils/audio.py content here ---
|
| 53 |
+
|
| 54 |
+
def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
|
| 55 |
+
"""
|
| 56 |
+
Normalize the volume of an audio signal.
|
| 57 |
+
|
| 58 |
+
Parameters:
|
| 59 |
+
audio (numpy array): Input audio signal array.
|
| 60 |
+
coeff (float): Target coefficient for normalization, default is 0.2.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
numpy array: The volume-normalized audio signal.
|
| 64 |
+
"""
|
| 65 |
+
# Sort the absolute values of the audio signal
|
| 66 |
+
temp = np.sort(np.abs(audio))
|
| 67 |
+
|
| 68 |
+
# If the maximum value is less than 0.1, scale the array to have a maximum of 0.1
|
| 69 |
+
if temp[-1] < 0.1:
|
| 70 |
+
scaling_factor = max(
|
| 71 |
+
temp[-1], 1e-3
|
| 72 |
+
) # Prevent division by zero with a small constant
|
| 73 |
+
audio = audio / scaling_factor * 0.1
|
| 74 |
+
|
| 75 |
+
# Filter out values less than 0.01 from temp
|
| 76 |
+
temp = temp[temp > 0.01]
|
| 77 |
+
L = temp.shape[0] # Length of the filtered array
|
| 78 |
+
|
| 79 |
+
# If there are fewer than or equal to 10 significant values, return the audio without further processing
|
| 80 |
+
if L <= 10:
|
| 81 |
+
return audio
|
| 82 |
+
|
| 83 |
+
# Compute the average of the top 10% to 1% of values in temp
|
| 84 |
+
volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
|
| 85 |
+
|
| 86 |
+
# Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10
|
| 87 |
+
audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
|
| 88 |
+
|
| 89 |
+
# Ensure the maximum absolute value in the audio does not exceed 1
|
| 90 |
+
max_value = np.max(np.abs(audio))
|
| 91 |
+
if max_value > 1:
|
| 92 |
+
audio = audio / max_value
|
| 93 |
+
|
| 94 |
+
return audio
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def load_audio(
|
| 98 |
+
adfile: Path,
|
| 99 |
+
sampling_rate: int = None,
|
| 100 |
+
length: int = None,
|
| 101 |
+
volume_normalize: bool = False,
|
| 102 |
+
segment_duration: int = None,
|
| 103 |
+
) -> np.ndarray:
|
| 104 |
+
r"""Load audio file with target sampling rate and lsength
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
adfile (Path): path to audio file.
|
| 108 |
+
sampling_rate (int, optional): target sampling rate. Defaults to None.
|
| 109 |
+
length (int, optional): target audio length. Defaults to None.
|
| 110 |
+
volume_normalize (bool, optional): whether perform volume normalization. Defaults to False.
|
| 111 |
+
segment_duration (int): random select a segment with duration of {segment_duration}s.
|
| 112 |
+
Defualt to None which means the whole audio will be used.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
audio (np.ndarray): audio
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
audio, sr = soundfile.read(adfile)
|
| 119 |
+
if len(audio.shape) > 1:
|
| 120 |
+
audio = audio[:, 0]
|
| 121 |
+
|
| 122 |
+
if sampling_rate is not None and sr != sampling_rate:
|
| 123 |
+
audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
|
| 124 |
+
sr = sampling_rate
|
| 125 |
+
|
| 126 |
+
if segment_duration is not None:
|
| 127 |
+
seg_length = int(sr * segment_duration)
|
| 128 |
+
audio = random_select_audio_segment(audio, seg_length)
|
| 129 |
+
|
| 130 |
+
# Audio volume normalize
|
| 131 |
+
if volume_normalize:
|
| 132 |
+
audio = audio_volume_normalize(audio)
|
| 133 |
+
# check the audio length
|
| 134 |
+
if length is not None:
|
| 135 |
+
assert abs(audio.shape[0] - length) < 1000
|
| 136 |
+
if audio.shape[0] > length:
|
| 137 |
+
audio = audio[:length]
|
| 138 |
+
else:
|
| 139 |
+
audio = np.pad(audio, (0, int(length - audio.shape[0])))
|
| 140 |
+
return audio
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
|
| 144 |
+
"""get an audio segment given the length
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
audio (np.ndarray):
|
| 148 |
+
length (int): audio length = sampling_rate * duration
|
| 149 |
+
"""
|
| 150 |
+
if audio.shape[0] < length:
|
| 151 |
+
audio = np.pad(audio, (0, int(length - audio.shape[0])))
|
| 152 |
+
start_index = random.randint(0, audio.shape[0] - length)
|
| 153 |
+
end_index = int(start_index + length)
|
| 154 |
+
|
| 155 |
+
return audio[start_index:end_index]
|
| 156 |
+
|
| 157 |
+
def get_ref_clip(wav: np.ndarray, config) -> np.ndarray: # Needs access to config attributes
|
| 158 |
+
"""Get reference audio clip for speaker embedding."""
|
| 159 |
+
# Make sure config has sample_rate, ref_segment_duration, latent_hop_length
|
| 160 |
+
if not all(hasattr(config, attr) for attr in ['sample_rate', 'ref_segment_duration', 'latent_hop_length']):
|
| 161 |
+
raise AttributeError("Config object missing required attributes for get_ref_clip")
|
| 162 |
+
ref_segment_length = (
|
| 163 |
+
int(config.sample_rate * config.ref_segment_duration)
|
| 164 |
+
// config.latent_hop_length
|
| 165 |
+
* config.latent_hop_length
|
| 166 |
+
)
|
| 167 |
+
wav_length = len(wav)
|
| 168 |
+
if ref_segment_length > wav_length:
|
| 169 |
+
wav = np.tile(wav, ref_segment_length // wav_length + 1)
|
| 170 |
+
return wav[:ref_segment_length]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# --- Paste sparktts/utils/token_parser.py content here ---
|
| 174 |
+
|
| 175 |
+
TASK_TOKEN_MAP = {
|
| 176 |
+
"vc": "<|task_vc|>",
|
| 177 |
+
"tts": "<|task_tts|>",
|
| 178 |
+
"asr": "<|task_asr|>",
|
| 179 |
+
"s2s": "<|task_s2s|>",
|
| 180 |
+
"t2s": "<|task_t2s|>",
|
| 181 |
+
"understand": "<|task_understand|>",
|
| 182 |
+
"caption": "<|task_cap|>",
|
| 183 |
+
"controllable_tts": "<|task_controllable_tts|>",
|
| 184 |
+
"prompt_tts": "<|task_prompt_tts|>",
|
| 185 |
+
"speech_edit": "<|task_edit|>",
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
LEVELS_MAP = {
|
| 189 |
+
"very_low": 0,
|
| 190 |
+
"low": 1,
|
| 191 |
+
"moderate": 2,
|
| 192 |
+
"high": 3,
|
| 193 |
+
"very_high": 4,
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
LEVELS_MAP_UI = {
|
| 197 |
+
1: 'very_low',
|
| 198 |
+
2: 'low',
|
| 199 |
+
3: 'moderate',
|
| 200 |
+
4: 'high',
|
| 201 |
+
5: 'very_high'
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
GENDER_MAP = {
|
| 205 |
+
"female": 0,
|
| 206 |
+
"male": 1,
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4}
|
| 210 |
+
|
| 211 |
+
EMO_MAP = {
|
| 212 |
+
"UNKNOWN": 0,
|
| 213 |
+
"NEUTRAL": 1,
|
| 214 |
+
"ANGRY": 2,
|
| 215 |
+
"HAPPY": 3,
|
| 216 |
+
"SAD": 4,
|
| 217 |
+
"FEARFUL": 5,
|
| 218 |
+
"DISGUSTED": 6,
|
| 219 |
+
"SURPRISED": 7,
|
| 220 |
+
"SARCASTIC": 8,
|
| 221 |
+
"EXCITED": 9,
|
| 222 |
+
"SLEEPY": 10,
|
| 223 |
+
"CONFUSED": 11,
|
| 224 |
+
"EMPHASIS": 12,
|
| 225 |
+
"LAUGHING": 13,
|
| 226 |
+
"SINGING": 14,
|
| 227 |
+
"WORRIED": 15,
|
| 228 |
+
"WHISPER": 16,
|
| 229 |
+
"ANXIOUS": 17,
|
| 230 |
+
"NO-AGREEMENT": 18,
|
| 231 |
+
"APOLOGETIC": 19,
|
| 232 |
+
"CONCERNED": 20,
|
| 233 |
+
"ENUNCIATED": 21,
|
| 234 |
+
"ASSERTIVE": 22,
|
| 235 |
+
"ENCOURAGING": 23,
|
| 236 |
+
"CONTEMPT": 24,
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class TokenParser:
|
| 241 |
+
"""Turn label to special token"""
|
| 242 |
+
|
| 243 |
+
def __init__(self):
|
| 244 |
+
pass
|
| 245 |
+
|
| 246 |
+
"""Parse the attributes of a person."""
|
| 247 |
+
|
| 248 |
+
def __init__(self):
|
| 249 |
+
pass
|
| 250 |
+
|
| 251 |
+
@staticmethod
|
| 252 |
+
def age(age: str) -> str:
|
| 253 |
+
"""Turn age token."""
|
| 254 |
+
age_id = AGE_MAP[age]
|
| 255 |
+
return f"<|age_{age_id}|>"
|
| 256 |
+
|
| 257 |
+
@staticmethod
|
| 258 |
+
def gender(gender: str) -> str:
|
| 259 |
+
"""Turn gender token."""
|
| 260 |
+
gender_id = GENDER_MAP[gender]
|
| 261 |
+
return f"<|gender_{gender_id}|>"
|
| 262 |
+
|
| 263 |
+
@staticmethod
|
| 264 |
+
def mel_value(mel: int):
|
| 265 |
+
"""Turn special token of mel scale pitch."""
|
| 266 |
+
mel = max(0, int(mel))
|
| 267 |
+
mel = min(1000, int(mel))
|
| 268 |
+
return f"<|pitch_value_{mel}|>"
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def mel_level(level: str):
|
| 272 |
+
"""Turn special token of mel level."""
|
| 273 |
+
level_tag = LEVELS_MAP[level]
|
| 274 |
+
return f"<|pitch_label_{level_tag}|>"
|
| 275 |
+
|
| 276 |
+
@staticmethod
|
| 277 |
+
def pitch_var_value(pitch_std: int):
|
| 278 |
+
"""Turn special token of pitch_std value."""
|
| 279 |
+
assert isinstance(pitch_std, int)
|
| 280 |
+
pitch_std = max(0, int(pitch_std))
|
| 281 |
+
pitch_std = min(10, int(pitch_std))
|
| 282 |
+
return f"<|pitch_var_value_{pitch_std}|>"
|
| 283 |
+
|
| 284 |
+
@staticmethod
|
| 285 |
+
def pitch_var_level(level: str):
|
| 286 |
+
"""Turn special token of pitch std level."""
|
| 287 |
+
level_tag = LEVELS_MAP[level]
|
| 288 |
+
return f"<|pitch_var_label_{level_tag}|>"
|
| 289 |
+
|
| 290 |
+
@staticmethod
|
| 291 |
+
def loudness_value(loudness: int):
|
| 292 |
+
"""Turn special toak of loudness value [0, 30]"""
|
| 293 |
+
assert loudness >= 0
|
| 294 |
+
loudness = max(0, int(loudness))
|
| 295 |
+
loudness = min(30, int(loudness))
|
| 296 |
+
return f"<|loudness_value_{loudness}|>"
|
| 297 |
+
|
| 298 |
+
@staticmethod
|
| 299 |
+
def loudness_level(level: str):
|
| 300 |
+
"""Turn special token of loudness level."""
|
| 301 |
+
level_tag = LEVELS_MAP[level]
|
| 302 |
+
return f"<|loudness_label_{level_tag}|>"
|
| 303 |
+
|
| 304 |
+
@staticmethod
|
| 305 |
+
def speed_value(speed: int):
|
| 306 |
+
"""Turn special token of speed value."""
|
| 307 |
+
speed = max(0, int(speed))
|
| 308 |
+
speed = min(10, int(speed))
|
| 309 |
+
return f"<|speed_value_{speed}|>"
|
| 310 |
+
|
| 311 |
+
@staticmethod
|
| 312 |
+
def speed_level(level: str):
|
| 313 |
+
"""Turn special token of speed level."""
|
| 314 |
+
level_tag = LEVELS_MAP[level]
|
| 315 |
+
return f"<|speed_label_{level_tag}|>"
|
| 316 |
+
|
| 317 |
+
@staticmethod
|
| 318 |
+
def task(task: str) -> str:
|
| 319 |
+
"""Turn special token of task."""
|
| 320 |
+
assert task in TASK_TOKEN_MAP.keys()
|
| 321 |
+
|
| 322 |
+
return TASK_TOKEN_MAP[task]
|
| 323 |
+
|
| 324 |
+
@staticmethod
|
| 325 |
+
def emotion(emotion: str):
|
| 326 |
+
emo_id = EMO_MAP[emotion]
|
| 327 |
+
|
| 328 |
+
return f"<|emotion_{emo_id}|>"
|
| 329 |
+
|
| 330 |
+
# =============================================================================
|
| 331 |
+
# >> END: PASTE CODE FROM sparktts/utils/* HERE <<
|
| 332 |
+
# =============================================================================
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class SparkTTSProcessor(ProcessorMixin, PushToHubMixin): # Added PushToHubMixin
|
| 336 |
+
r"""
|
| 337 |
+
Constructs a SparkTTS processor which wraps a text tokenizer and relevant audio processing logic.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
tokenizer ([`PreTrainedTokenizer`]):
|
| 341 |
+
An instance of [`PreTrainedTokenizer`]. This handles the text tokenization for the LLM.
|
| 342 |
+
feature_extractor ([`Wav2Vec2FeatureExtractor`]):
|
| 343 |
+
An instance of [`Wav2Vec2FeatureExtractor`]. Although Wav2Vec2 features are extracted
|
| 344 |
+
within the model's `tokenize_audio`, the extractor's configuration (like sampling rate)
|
| 345 |
+
is useful, and it aligns with the ProcessorMixin pattern.
|
| 346 |
+
config ([`SparkTTSConfig`], *optional*):
|
| 347 |
+
An instance of [`SparkTTSConfig`] to access configuration parameters like sample rate.
|
| 348 |
+
"""
|
| 349 |
+
attributes = ["tokenizer", "feature_extractor"]
|
| 350 |
+
tokenizer_class = "AutoTokenizer"
|
| 351 |
+
feature_extractor_class = "Wav2Vec2FeatureExtractor" # Keep for consistency
|
| 352 |
+
|
| 353 |
+
def __init__(self, tokenizer, feature_extractor, config: Optional[SparkTTSConfig] = None, **kwargs):
|
| 354 |
+
super().__init__(tokenizer=tokenizer, feature_extractor=feature_extractor, **kwargs)
|
| 355 |
+
self.model = None
|
| 356 |
+
self.config = config
|
| 357 |
+
# Set sampling rate
|
| 358 |
+
if config and hasattr(config, 'sample_rate'):
|
| 359 |
+
self.sampling_rate = config.sample_rate
|
| 360 |
+
elif feature_extractor and hasattr(feature_extractor, 'sampling_rate'):
|
| 361 |
+
self.sampling_rate = feature_extractor.sampling_rate
|
| 362 |
+
else:
|
| 363 |
+
self.sampling_rate = 16000
|
| 364 |
+
logger.warning(f"Could not determine sampling rate. Defaulting to {self.sampling_rate} Hz.")
|
| 365 |
+
|
| 366 |
+
# # Ensure tokenizer pad token
|
| 367 |
+
# if self.tokenizer.pad_token is None:
|
| 368 |
+
# if self.tokenizer.eos_token is not None:
|
| 369 |
+
# logger.warning("Tokenizer does not have a pad token. Setting pad_token to eos_token.")
|
| 370 |
+
# self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 371 |
+
# else:
|
| 372 |
+
# logger.warning("Tokenizer lacks pad and eos token. Adding default pad token '<|pad|>'.")
|
| 373 |
+
# self.tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
|
| 374 |
+
|
| 375 |
+
def link_model(self, model):
|
| 376 |
+
"""Links the processor to a SparkTTSModel instance for audio processing calls."""
|
| 377 |
+
if not hasattr(model, 'tokenize_audio') or not hasattr(model, 'detokenize_audio'):
|
| 378 |
+
raise TypeError("The provided model instance does not have the required 'tokenize_audio' and 'detokenize_audio' methods.")
|
| 379 |
+
if not hasattr(model, 'config'):
|
| 380 |
+
logger.warning("Linked model does not have a 'config' attribute. Some processor functionalities might rely on it.")
|
| 381 |
+
|
| 382 |
+
self.model = model
|
| 383 |
+
logger.info("SparkTTSModel successfully linked to the processor.")
|
| 384 |
+
# Update sampling rate based on linked model's config if available
|
| 385 |
+
if hasattr(model, 'config') and hasattr(model.config, 'sample_rate'):
|
| 386 |
+
if self.sampling_rate != model.config.sample_rate:
|
| 387 |
+
logger.info(f"Updating processor sampling rate from {self.sampling_rate} to {model.config.sample_rate} based on linked model config.")
|
| 388 |
+
self.sampling_rate = model.config.sample_rate
|
| 389 |
+
# Also update feature extractor sampling rate if it differs
|
| 390 |
+
if hasattr(self, 'feature_extractor') and self.feature_extractor.sampling_rate != model.config.sample_rate:
|
| 391 |
+
logger.info(f"Updating feature_extractor sampling rate from {self.feature_extractor.sampling_rate} to {model.config.sample_rate}.")
|
| 392 |
+
self.feature_extractor.sampling_rate = model.config.sample_rate
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def __call__(
|
| 396 |
+
self,
|
| 397 |
+
text: str,
|
| 398 |
+
prompt_speech_path: Optional[Union[str, Path]] = None,
|
| 399 |
+
prompt_text: Optional[str] = None,
|
| 400 |
+
gender: Optional[str] = None,
|
| 401 |
+
pitch: Optional[str] = None,
|
| 402 |
+
speed: Optional[str] = None,
|
| 403 |
+
return_tensors: Optional[str] = "pt",
|
| 404 |
+
**kwargs, # Allow passing other args like padding, truncation to tokenizer
|
| 405 |
+
) -> BatchEncoding:
|
| 406 |
+
"""
|
| 407 |
+
Processes the input text and optional prompt audio/control parameters into a format suitable for [`SparkTTSModel`].
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
text (`str`):
|
| 411 |
+
The main text to be synthesized.
|
| 412 |
+
prompt_speech_path (`str` or `Path`, *optional*):
|
| 413 |
+
Path to the prompt audio file for voice cloning. Required if `gender` is not set.
|
| 414 |
+
prompt_text (`str`, *optional*):
|
| 415 |
+
Transcript of the prompt audio. Used only in voice cloning mode.
|
| 416 |
+
gender (`str`, *optional*):
|
| 417 |
+
Target gender ("male" or "female") for controllable synthesis. If set, enables control mode.
|
| 418 |
+
pitch (`str`, *optional*):
|
| 419 |
+
Target pitch level ("very_low", "low", "moderate", "high", "very_high") for control mode. Required if `gender` is set.
|
| 420 |
+
speed (`str`, *optional*):
|
| 421 |
+
Target speed level ("very_low", "low", "moderate", "high", "very_high") for control mode. Required if `gender` is set.
|
| 422 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
| 423 |
+
If set, will return tensors instead of list of python integers. Only "pt" (PyTorch) is supported currently.
|
| 424 |
+
**kwargs:
|
| 425 |
+
Additional arguments passed to the underlying tokenizer's `__call__` method.
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
[`BatchEncoding`]: A dictionary containing the `input_ids` and `attention_mask` for the LLM.
|
| 429 |
+
In voice cloning mode, it also includes `global_token_ids_prompt` (torch.Tensor) representing the
|
| 430 |
+
global tokens extracted from the prompt audio.
|
| 431 |
+
"""
|
| 432 |
+
|
| 433 |
+
global_token_ids_prompt = None # Initialize
|
| 434 |
+
|
| 435 |
+
# Determine mode: Control TTS or Voice Cloning (Prompt TTS)
|
| 436 |
+
is_control_mode = gender is not None
|
| 437 |
+
is_cloning_mode = prompt_speech_path is not None and not is_control_mode
|
| 438 |
+
|
| 439 |
+
if is_control_mode:
|
| 440 |
+
# --- Controllable TTS Prompt Construction ---
|
| 441 |
+
if not all([pitch, speed]):
|
| 442 |
+
raise ValueError("For controllable TTS, 'gender', 'pitch', and 'speed' must all be provided.")
|
| 443 |
+
if prompt_speech_path is not None:
|
| 444 |
+
logger.warning("`prompt_speech_path` provided but ignored because `gender` is set (controllable TTS mode).")
|
| 445 |
+
|
| 446 |
+
if not all(k in GENDER_MAP for k in [gender]): # Basic check
|
| 447 |
+
raise ValueError(f"Invalid gender provided: {gender}. Must be one of {list(GENDER_MAP.keys())}")
|
| 448 |
+
if not all(k in LEVELS_MAP for k in [pitch, speed]): # Basic check
|
| 449 |
+
raise ValueError(f"Invalid pitch or speed level provided. Must be one of {list(LEVELS_MAP.keys())}")
|
| 450 |
+
|
| 451 |
+
gender_id = GENDER_MAP[gender]
|
| 452 |
+
pitch_level_id = LEVELS_MAP[pitch]
|
| 453 |
+
speed_level_id = LEVELS_MAP[speed]
|
| 454 |
+
|
| 455 |
+
pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
|
| 456 |
+
speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
|
| 457 |
+
gender_tokens = f"<|gender_{gender_id}|>"
|
| 458 |
+
|
| 459 |
+
attribute_tokens = "".join([gender_tokens, pitch_label_tokens, speed_label_tokens])
|
| 460 |
+
|
| 461 |
+
prompt_list = [
|
| 462 |
+
TASK_TOKEN_MAP["controllable_tts"],
|
| 463 |
+
"<|start_content|>",
|
| 464 |
+
text,
|
| 465 |
+
"<|end_content|>",
|
| 466 |
+
"<|start_style_label|>",
|
| 467 |
+
attribute_tokens,
|
| 468 |
+
"<|end_style_label|>",
|
| 469 |
+
]
|
| 470 |
+
prompt_string = "".join(prompt_list)
|
| 471 |
+
|
| 472 |
+
elif is_cloning_mode:
|
| 473 |
+
# --- Voice Cloning Prompt Construction ---
|
| 474 |
+
if self.model is None:
|
| 475 |
+
raise RuntimeError("Processor must be linked to a SparkTTSModel instance via `processor.link_model(model)` before performing voice cloning.")
|
| 476 |
+
prompt_speech_path = Path(prompt_speech_path) # Ensure it's a Path object
|
| 477 |
+
if not prompt_speech_path.exists():
|
| 478 |
+
raise FileNotFoundError(f"Prompt audio file not found: {prompt_speech_path}")
|
| 479 |
+
|
| 480 |
+
# Load and process prompt audio
|
| 481 |
+
try:
|
| 482 |
+
model_config = self.model.config if self.model and hasattr(self.model, 'config') else self.config
|
| 483 |
+
if model_config is None:
|
| 484 |
+
raise ValueError("Configuration not available in processor or linked model.")
|
| 485 |
+
|
| 486 |
+
# Load main wav
|
| 487 |
+
wav = load_audio(
|
| 488 |
+
prompt_speech_path,
|
| 489 |
+
sampling_rate=self.sampling_rate,
|
| 490 |
+
volume_normalize=getattr(model_config, 'volume_normalize', True), # Use getattr for safety
|
| 491 |
+
)
|
| 492 |
+
# Get reference clip
|
| 493 |
+
wav_ref_np = get_ref_clip(wav, model_config) # Pass config object
|
| 494 |
+
wav_ref = torch.from_numpy(wav_ref_np).unsqueeze(0).float()
|
| 495 |
+
wav_tensor = torch.from_numpy(wav).unsqueeze(0).float()
|
| 496 |
+
|
| 497 |
+
# Tokenize using the linked model's method
|
| 498 |
+
# Assuming tokenize_audio returns tensors with batch dim 1: [1, N_global], [1, N_semantic]
|
| 499 |
+
global_tokens_tensor, semantic_tokens_tensor = self.model.tokenize_audio(wav_tensor, wav_ref)
|
| 500 |
+
|
| 501 |
+
# Store the global tokens tensor (with batch dim) for the output dict
|
| 502 |
+
global_token_ids_prompt = global_tokens_tensor # Keep batch dim [1, N_global]
|
| 503 |
+
|
| 504 |
+
# Convert tensors to lists of ints for string formatting
|
| 505 |
+
global_token_list = global_tokens_tensor.squeeze().tolist() # Remove batch dim -> list
|
| 506 |
+
semantic_token_list = semantic_tokens_tensor.squeeze().tolist() # Remove batch dim -> list
|
| 507 |
+
|
| 508 |
+
except Exception as e:
|
| 509 |
+
logger.error(f"Error processing prompt audio {prompt_speech_path}: {e}")
|
| 510 |
+
import traceback
|
| 511 |
+
traceback.print_exc()
|
| 512 |
+
raise
|
| 513 |
+
|
| 514 |
+
# ==============================================================
|
| 515 |
+
# CORRECTED TOKEN STRING FORMATTING
|
| 516 |
+
# ==============================================================
|
| 517 |
+
# Create individual token strings for each ID
|
| 518 |
+
global_tokens_str = "".join([f"<|bicodec_global_{gid}|>" for gid in global_token_list])
|
| 519 |
+
semantic_tokens_str = "".join([f"<|bicodec_semantic_{sid}|>" for sid in semantic_token_list])
|
| 520 |
+
# ==============================================================
|
| 521 |
+
|
| 522 |
+
# Construct prompt list based on presence of prompt_text
|
| 523 |
+
if prompt_text is not None and prompt_text.strip(): # Check if prompt_text is meaningful
|
| 524 |
+
logger.info("Using prompt text in voice cloning prompt.")
|
| 525 |
+
prompt_list = [
|
| 526 |
+
TASK_TOKEN_MAP["tts"], # Or maybe TASK_TOKEN_MAP["prompt_tts"]? Check original logic. Assuming "tts".
|
| 527 |
+
"<|start_content|>",
|
| 528 |
+
prompt_text, # Transcript first
|
| 529 |
+
text, # Then target text
|
| 530 |
+
"<|end_content|>",
|
| 531 |
+
"<|start_global_token|>",
|
| 532 |
+
global_tokens_str,
|
| 533 |
+
"<|end_global_token|>",
|
| 534 |
+
"<|start_semantic_token|>",
|
| 535 |
+
semantic_tokens_str,
|
| 536 |
+
# "<|end_semantic_token|>", # Original code didn't have this marker here
|
| 537 |
+
]
|
| 538 |
+
else:
|
| 539 |
+
# Simpler prompt without semantic tokens if no transcript provided
|
| 540 |
+
logger.info("No prompt text provided, using text-only voice cloning prompt.")
|
| 541 |
+
prompt_list = [
|
| 542 |
+
TASK_TOKEN_MAP["tts"], # Or maybe TASK_TOKEN_MAP["prompt_tts"]?
|
| 543 |
+
"<|start_content|>",
|
| 544 |
+
text, # Only target text
|
| 545 |
+
"<|end_content|>",
|
| 546 |
+
"<|start_global_token|>",
|
| 547 |
+
global_tokens_str,
|
| 548 |
+
"<|end_global_token|>",
|
| 549 |
+
]
|
| 550 |
+
prompt_string = "".join(prompt_list)
|
| 551 |
+
logger.debug(f"Generated prompt string (cloning): {prompt_string[:200]}...") # Log start of prompt
|
| 552 |
+
|
| 553 |
+
else:
|
| 554 |
+
raise ValueError("Invalid input combination. Either provide `prompt_speech_path` for cloning or (`gender`, `pitch`, `speed`) for control.")
|
| 555 |
+
|
| 556 |
+
# --- Tokenize the final prompt string ---
|
| 557 |
+
# print(f"Tokenizing prompt: {prompt_string}")
|
| 558 |
+
inputs = self.tokenizer(
|
| 559 |
+
prompt_string,
|
| 560 |
+
return_tensors=return_tensors,
|
| 561 |
+
padding=kwargs.get("padding", False), # Often False for generation prompts unless batching > 1
|
| 562 |
+
truncation=kwargs.get("truncation", True),
|
| 563 |
+
max_length=kwargs.get("max_length", self.tokenizer.model_max_length),
|
| 564 |
+
add_special_tokens=kwargs.get("add_special_tokens", True), # Usually True unless handled manually
|
| 565 |
+
return_attention_mask=kwargs.get("return_attention_mask", True), # Need attention mask
|
| 566 |
+
**{k: v for k, v in kwargs.items() if k not in ["padding", "truncation", "max_length", "add_special_tokens", "return_attention_mask"]}
|
| 567 |
+
)
|
| 568 |
+
logger.debug(f"Tokenized input_ids shape: {inputs['input_ids'].shape}")
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
# Add the prompt's global tokens (as tensor with batch dim) to the output if in cloning mode
|
| 572 |
+
if is_cloning_mode and global_token_ids_prompt is not None:
|
| 573 |
+
if return_tensors == "pt":
|
| 574 |
+
inputs["global_token_ids_prompt"] = global_token_ids_prompt # Already has batch dim [1, N_global]
|
| 575 |
+
else:
|
| 576 |
+
# Handle non-tensor return if necessary
|
| 577 |
+
inputs["global_token_ids_prompt"] = global_token_ids_prompt.tolist()
|
| 578 |
+
|
| 579 |
+
return inputs
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def decode(
|
| 583 |
+
self,
|
| 584 |
+
generated_ids: torch.Tensor,
|
| 585 |
+
global_token_ids_prompt: Optional[torch.Tensor] = None,
|
| 586 |
+
input_ids_len: Optional[int] = None,
|
| 587 |
+
skip_special_tokens: bool = True,
|
| 588 |
+
) -> Dict[str, Any]:
|
| 589 |
+
"""
|
| 590 |
+
Decodes the generated token IDs from [`SparkTTSModel`] into an audio waveform.
|
| 591 |
+
|
| 592 |
+
Args:
|
| 593 |
+
generated_ids (`torch.Tensor`):
|
| 594 |
+
Tensor of token IDs generated by `model.generate()`, including the input prompt part. Shape [B, seq_len].
|
| 595 |
+
global_token_ids_prompt (`torch.Tensor`, *optional*):
|
| 596 |
+
The global tokens extracted from the prompt audio during the `__call__` step (for voice cloning).
|
| 597 |
+
Shape [B, N_global]. Required if the generation was for voice cloning.
|
| 598 |
+
input_ids_len (`int`, *optional*):
|
| 599 |
+
The length of the original input prompt `input_ids` fed to `model.generate()`. Required to
|
| 600 |
+
correctly isolate the newly generated tokens.
|
| 601 |
+
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
| 602 |
+
Whether to skip special tokens during the text decoding step (used to extract audio tokens).
|
| 603 |
+
|
| 604 |
+
Returns:
|
| 605 |
+
Dict[str, Any]: A dictionary containing:
|
| 606 |
+
- "audio": The decoded audio waveform as a NumPy array. Shape [T_audio] (if B=1) or [B, T_audio].
|
| 607 |
+
- "sampling_rate": The sampling rate of the audio.
|
| 608 |
+
"""
|
| 609 |
+
if self.model is None:
|
| 610 |
+
raise RuntimeError("Processor must be linked to a SparkTTSModel instance via `processor.link_model(model)` before decoding.")
|
| 611 |
+
if input_ids_len is None:
|
| 612 |
+
raise ValueError("`input_ids_len` (length of the prompt input_ids) must be provided for decoding.")
|
| 613 |
+
|
| 614 |
+
# --- Isolate generated part and decode text ---
|
| 615 |
+
# Assumes generated_ids has shape [B, full_seq_len]
|
| 616 |
+
# Handle case where generated sequence is shorter than prompt (shouldn't happen with max_new_tokens > 0)
|
| 617 |
+
if generated_ids.shape[1] < input_ids_len:
|
| 618 |
+
logger.warning(f"Generated sequence length ({generated_ids.shape[1]}) is shorter than input prompt length ({input_ids_len}). Decoding might be incorrect.")
|
| 619 |
+
output_only_ids = generated_ids[:, input_ids_len:] # Will be empty if equal
|
| 620 |
+
else:
|
| 621 |
+
output_only_ids = generated_ids[:, input_ids_len:]
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
# Decode the generated part to find audio tokens
|
| 625 |
+
# Need to handle batch decoding if B > 1
|
| 626 |
+
# print("decode token", self.tokenizer.batch_decode(output_only_ids, skip_special_tokens=False))
|
| 627 |
+
decoded_texts = self.tokenizer.batch_decode(output_only_ids, skip_special_tokens=skip_special_tokens)
|
| 628 |
+
|
| 629 |
+
# --- Extract Audio Tokens ---
|
| 630 |
+
# Handle batch processing correctly
|
| 631 |
+
batch_size = generated_ids.shape[0]
|
| 632 |
+
all_semantic_ids = []
|
| 633 |
+
all_global_tokens = []
|
| 634 |
+
successful_indices = [] # Keep track of which batch items were successful
|
| 635 |
+
|
| 636 |
+
for i in range(batch_size):
|
| 637 |
+
decoded_text = decoded_texts[i]
|
| 638 |
+
current_semantic_ids = None
|
| 639 |
+
current_global_tokens = None
|
| 640 |
+
|
| 641 |
+
# Extract semantic tokens
|
| 642 |
+
try:
|
| 643 |
+
pred_semantic_indices = [int(token) for token in re.findall(r"bicodec_semantic_(\d+)", decoded_text)]
|
| 644 |
+
if not pred_semantic_indices:
|
| 645 |
+
logger.warning(f"Batch item {i}: No semantic tokens found in decoded text: '{decoded_text[:200]}...'")
|
| 646 |
+
continue # Skip this item
|
| 647 |
+
|
| 648 |
+
current_semantic_ids = torch.tensor(pred_semantic_indices).long() # Shape [N_semantic]
|
| 649 |
+
except Exception as e:
|
| 650 |
+
logger.error(f"Batch item {i}: Error parsing semantic tokens from: '{decoded_text[:200]}...'. Error: {e}")
|
| 651 |
+
continue # Skip this item
|
| 652 |
+
|
| 653 |
+
# Determine global tokens
|
| 654 |
+
if global_token_ids_prompt is not None:
|
| 655 |
+
# Cloning mode: Use the provided prompt global tokens for this batch item
|
| 656 |
+
if global_token_ids_prompt.shape[0] != batch_size:
|
| 657 |
+
raise ValueError(f"Batch size mismatch: generated_ids has {batch_size}, but global_token_ids_prompt has {global_token_ids_prompt.shape[0]}.")
|
| 658 |
+
current_global_tokens = global_token_ids_prompt[i] # Shape [N_global]
|
| 659 |
+
else:
|
| 660 |
+
# Control mode: Extract global tokens from the generated text
|
| 661 |
+
try:
|
| 662 |
+
pred_global_indices = [int(token) for token in re.findall(r"bicodec_global_(\d+)", decoded_text)]
|
| 663 |
+
if not pred_global_indices:
|
| 664 |
+
logger.warning(f"Batch item {i}: No global tokens found in decoded text for control mode: '{decoded_text[:200]}...'")
|
| 665 |
+
continue # Skip this item
|
| 666 |
+
|
| 667 |
+
current_global_tokens = torch.tensor(pred_global_indices).long() # Shape [N_global]
|
| 668 |
+
|
| 669 |
+
except Exception as e:
|
| 670 |
+
logger.error(f"Batch item {i}: Error parsing global tokens from: '{decoded_text[:200]}...'. Error: {e}")
|
| 671 |
+
continue # Skip this item
|
| 672 |
+
|
| 673 |
+
# If both tokens extracted successfully
|
| 674 |
+
all_semantic_ids.append(current_semantic_ids)
|
| 675 |
+
all_global_tokens.append(current_global_tokens)
|
| 676 |
+
successful_indices.append(i)
|
| 677 |
+
|
| 678 |
+
if not successful_indices:
|
| 679 |
+
logger.error("Failed to extract audio tokens for any item in the batch.")
|
| 680 |
+
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
|
| 681 |
+
|
| 682 |
+
# Pad sequences to the max length within the successful batch items for batch detokenization
|
| 683 |
+
# Note: BiCodec might not support batching if sequences have different lengths. Check its implementation.
|
| 684 |
+
# Assuming BiCodec *can* handle batches if padded (or if lengths are naturally equal).
|
| 685 |
+
# This padding might be unnecessary if BiCodec handles variable lengths or if B=1 anyway.
|
| 686 |
+
# For now, let's assume B=1 was handled correctly and skip complex padding.
|
| 687 |
+
if batch_size > 1 and len(successful_indices) < batch_size:
|
| 688 |
+
logger.warning(f"Only successfully decoded {len(successful_indices)} out of {batch_size} batch items.")
|
| 689 |
+
# Further processing might need to handle only the successful items.
|
| 690 |
+
|
| 691 |
+
# Let's proceed assuming B=1 or BiCodec handles batches appropriately.
|
| 692 |
+
# Stack the successful tokens.
|
| 693 |
+
try:
|
| 694 |
+
# Need to ensure tensors have the same length before stacking if BiCodec requires it.
|
| 695 |
+
# If BiCodec handles variable length, stacking might not be needed, just loop and call detokenize.
|
| 696 |
+
# Let's assume B=1 for simplicity of the example, matching original code's likely behavior.
|
| 697 |
+
if len(successful_indices) != 1:
|
| 698 |
+
raise NotImplementedError("Batch decoding (B > 1) requires verification of BiCodec's batch handling and potentially padding.")
|
| 699 |
+
|
| 700 |
+
final_semantic_ids = all_semantic_ids[0].unsqueeze(0) # Add batch dim [1, N_semantic]
|
| 701 |
+
final_global_tokens = all_global_tokens[0].unsqueeze(0) # Add batch dim [1, N_global]
|
| 702 |
+
|
| 703 |
+
except IndexError: # Should not happen if successful_indices is not empty
|
| 704 |
+
logger.error("Internal error during token batch preparation.")
|
| 705 |
+
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
# --- Detokenize Audio ---
|
| 709 |
+
try:
|
| 710 |
+
# Call the linked model's detokenize method
|
| 711 |
+
# print(f"DEBUG: Detokenizing audio with global tokens {final_global_tokens.shape}, semantic tokens {final_semantic_ids.shape}")
|
| 712 |
+
output_wav = self.model.detokenize_audio(final_global_tokens, final_semantic_ids)
|
| 713 |
+
# detokenize_audio now returns numpy array float32 in [-1, 1]
|
| 714 |
+
|
| 715 |
+
# Optional: Double-check dtype here if needed, but should be handled by detokenize_audio now
|
| 716 |
+
# if output_wav.dtype != np.float32:
|
| 717 |
+
# logger.warning(f"Audio dtype after detokenize is {output_wav.dtype}. Converting to float32.")
|
| 718 |
+
# output_wav = output_wav.astype(np.float32)
|
| 719 |
+
# output_wav = np.clip(output_wav, -1.0, 1.0) # Clipping done in detokenize_audio
|
| 720 |
+
|
| 721 |
+
except Exception as e:
|
| 722 |
+
logger.error(f"Error during audio detokenization: {e}")
|
| 723 |
+
import traceback
|
| 724 |
+
traceback.print_exc()
|
| 725 |
+
raise RuntimeError("Audio detokenization failed.") from e
|
| 726 |
+
|
| 727 |
+
return {"audio": output_wav, "sampling_rate": self.sampling_rate}
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
@classmethod
|
| 731 |
+
def from_pretrained(
|
| 732 |
+
cls,
|
| 733 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 734 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 735 |
+
force_download: bool = False,
|
| 736 |
+
local_files_only: bool = False,
|
| 737 |
+
token: Optional[Union[str, bool]] = None,
|
| 738 |
+
revision: str = "main",
|
| 739 |
+
trust_remote_code: bool = False, # Allow passing this, needed for config potentially
|
| 740 |
+
**kwargs,
|
| 741 |
+
):
|
| 742 |
+
r"""
|
| 743 |
+
Instantiate a SparkTTSProcessor from pretrained components.
|
| 744 |
+
"""
|
| 745 |
+
# Pop specific kwargs for this method
|
| 746 |
+
config = kwargs.pop("config", None) # Allow passing config explicitly
|
| 747 |
+
|
| 748 |
+
# --- 1. Load Config (to find component paths) ---
|
| 749 |
+
# We need the config even if the processor doesn't store it permanently,
|
| 750 |
+
# just to find where the tokenizer/feature_extractor live.
|
| 751 |
+
loaded_config = None
|
| 752 |
+
if not isinstance(config, SparkTTSConfig):
|
| 753 |
+
try:
|
| 754 |
+
# Load the specific config class
|
| 755 |
+
loaded_config = SparkTTSConfig.from_pretrained(
|
| 756 |
+
pretrained_model_name_or_path,
|
| 757 |
+
cache_dir=cache_dir,
|
| 758 |
+
force_download=force_download,
|
| 759 |
+
local_files_only=local_files_only,
|
| 760 |
+
token=token,
|
| 761 |
+
revision=revision,
|
| 762 |
+
trust_remote_code=trust_remote_code, # Config might be custom
|
| 763 |
+
**kwargs, # Pass relevant kwargs
|
| 764 |
+
)
|
| 765 |
+
except Exception as e:
|
| 766 |
+
logger.warning(
|
| 767 |
+
f"Could not load SparkTTSConfig from {pretrained_model_name_or_path}. "
|
| 768 |
+
f"Attempting to load components from default relative paths ('LLM', 'wav2vec2-large-xlsr-53'). Error: {e}"
|
| 769 |
+
)
|
| 770 |
+
loaded_config = None # Fallback
|
| 771 |
+
else:
|
| 772 |
+
# Config object was passed directly
|
| 773 |
+
loaded_config = config
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
# --- 2. Determine Component Paths ---
|
| 777 |
+
llm_tokenizer_path_or_id = "./LLM" # Default relative path
|
| 778 |
+
w2v_processor_path_or_id = "./wav2vec2-large-xlsr-53" # Default relative path
|
| 779 |
+
|
| 780 |
+
if loaded_config:
|
| 781 |
+
llm_tokenizer_path_or_id = getattr(loaded_config, 'llm_model_name_or_path', llm_tokenizer_path_or_id)
|
| 782 |
+
w2v_processor_path_or_id = getattr(loaded_config, 'wav2vec2_model_name_or_path', w2v_processor_path_or_id)
|
| 783 |
+
|
| 784 |
+
# The component `from_pretrained` methods handle resolving these paths/IDs
|
| 785 |
+
# whether they are relative subfolders of `pretrained_model_name_or_path`
|
| 786 |
+
# or separate Hub IDs.
|
| 787 |
+
|
| 788 |
+
# --- 3. Load Components ---
|
| 789 |
+
# Pass down relevant kwargs for loading components
|
| 790 |
+
component_loading_kwargs = {
|
| 791 |
+
"cache_dir": cache_dir,
|
| 792 |
+
"force_download": force_download,
|
| 793 |
+
"local_files_only": local_files_only,
|
| 794 |
+
"token": token,
|
| 795 |
+
"revision": revision,
|
| 796 |
+
**kwargs # Pass other user kwargs
|
| 797 |
+
}
|
| 798 |
+
try:
|
| 799 |
+
# Tokenizer might require trust_remote_code if its class is custom
|
| 800 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 801 |
+
pretrained_model_name_or_path, # Main path
|
| 802 |
+
subfolder=llm_tokenizer_path_or_id.lstrip('./'), # Specify subfolder relative to main path
|
| 803 |
+
trust_remote_code=trust_remote_code,
|
| 804 |
+
**component_loading_kwargs
|
| 805 |
+
)
|
| 806 |
+
except Exception as e:
|
| 807 |
+
# Fallback: try loading directly using the path/id from config if different
|
| 808 |
+
if llm_tokenizer_path_or_id != "./LLM":
|
| 809 |
+
try:
|
| 810 |
+
logger.info(f"Retrying tokenizer load directly from: {llm_tokenizer_path_or_id}")
|
| 811 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 812 |
+
llm_tokenizer_path_or_id,
|
| 813 |
+
trust_remote_code=trust_remote_code,
|
| 814 |
+
**component_loading_kwargs
|
| 815 |
+
)
|
| 816 |
+
except Exception as e2:
|
| 817 |
+
raise OSError(f"Could not load tokenizer using main path + subfolder or directly from '{llm_tokenizer_path_or_id}'. Error: {e2}") from e
|
| 818 |
+
else:
|
| 819 |
+
raise OSError(f"Could not load tokenizer from subfolder '{llm_tokenizer_path_or_id}' within '{pretrained_model_name_or_path}'. Error: {e}")
|
| 820 |
+
|
| 821 |
+
|
| 822 |
+
try:
|
| 823 |
+
# Feature extractor usually doesn't need trust_remote_code
|
| 824 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
| 825 |
+
pretrained_model_name_or_path, # Main path
|
| 826 |
+
subfolder=w2v_processor_path_or_id.lstrip('./'), # Specify subfolder relative to main path
|
| 827 |
+
**component_loading_kwargs
|
| 828 |
+
)
|
| 829 |
+
except Exception as e:
|
| 830 |
+
# Fallback: try loading directly using the path/id from config if different
|
| 831 |
+
if w2v_processor_path_or_id != "./wav2vec2-large-xlsr-53":
|
| 832 |
+
try:
|
| 833 |
+
logger.info(f"Retrying feature extractor load directly from: {w2v_processor_path_or_id}")
|
| 834 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
| 835 |
+
w2v_processor_path_or_id,
|
| 836 |
+
**component_loading_kwargs
|
| 837 |
+
)
|
| 838 |
+
except Exception as e2:
|
| 839 |
+
raise OSError(f"Could not load feature extractor using main path + subfolder or directly from '{w2v_processor_path_or_id}'. Error: {e2}") from e
|
| 840 |
+
else:
|
| 841 |
+
raise OSError(f"Could not load feature extractor from subfolder '{w2v_processor_path_or_id}' within '{pretrained_model_name_or_path}'. Error: {e}")
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
# --- 4. Instantiate processor ---
|
| 845 |
+
# Pass the potentially loaded config object (or None)
|
| 846 |
+
return cls(tokenizer=tokenizer, feature_extractor=feature_extractor, config=loaded_config)
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
def save_pretrained(
|
| 850 |
+
self,
|
| 851 |
+
save_directory: Union[str, os.PathLike],
|
| 852 |
+
push_to_hub: bool = False,
|
| 853 |
+
**kwargs,
|
| 854 |
+
):
|
| 855 |
+
"""
|
| 856 |
+
Save the processor's state (tokenizer and feature extractor files) to a directory.
|
| 857 |
+
|
| 858 |
+
Args:
|
| 859 |
+
save_directory (`str` or `os.PathLike`):
|
| 860 |
+
Directory where the processor files will be saved.
|
| 861 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 862 |
+
Whether or not to push your model to the Hugging Face Hub after saving it.
|
| 863 |
+
**kwargs:
|
| 864 |
+
Additional key word arguments passed along to the `push_to_hub` method.
|
| 865 |
+
"""
|
| 866 |
+
save_directory = Path(save_directory)
|
| 867 |
+
save_directory.mkdir(parents=True, exist_ok=True)
|
| 868 |
+
|
| 869 |
+
# Save tokenizer
|
| 870 |
+
self.tokenizer.save_pretrained(str(save_directory), **kwargs)
|
| 871 |
+
|
| 872 |
+
# Save feature extractor
|
| 873 |
+
self.feature_extractor.save_pretrained(str(save_directory), **kwargs)
|
| 874 |
+
|
| 875 |
+
# Save the main processor config (if it exists and has relevant info)
|
| 876 |
+
# Note: The SparkTTSConfig is usually saved with the *model*, not the processor.
|
| 877 |
+
# However, if the processor holds specific config needed for reloading *itself*,
|
| 878 |
+
# it could be saved here. Usually, relying on the model's config is sufficient.
|
| 879 |
+
# if self.config:
|
| 880 |
+
# self.config.save_pretrained(str(save_directory)) # Example if needed
|
| 881 |
+
|
| 882 |
+
logger.info(f"Processor components saved in {save_directory}")
|
| 883 |
+
|
| 884 |
+
if push_to_hub:
|
| 885 |
+
# Commit message and other hub kwargs can be passed via **kwargs
|
| 886 |
+
commit_message = kwargs.pop("commit_message", "Save processor")
|
| 887 |
+
return self.push_to_hub(save_directory, commit_message=commit_message, **kwargs)
|
| 888 |
+
|
| 889 |
+
return str(save_directory) # Return path consistent with Mixin
|