Delete configuration_spark_tts.py
Browse files- configuration_spark_tts.py +0 -230
configuration_spark_tts.py
DELETED
|
@@ -1,230 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
""" SparkTTS model configuration"""
|
| 16 |
-
|
| 17 |
-
from transformers.configuration_utils import PretrainedConfig
|
| 18 |
-
from transformers.utils import logging
|
| 19 |
-
from typing import List, Optional
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
logger = logging.get_logger(__name__)
|
| 23 |
-
|
| 24 |
-
# --- Define Individual Sub-Component Config Classes ---
|
| 25 |
-
|
| 26 |
-
class SparkTTSMelParamsConfig(PretrainedConfig):
|
| 27 |
-
"""Configuration for Mel Spectrogram parameters."""
|
| 28 |
-
model_type = "spark-tts-mel-params"
|
| 29 |
-
def __init__(self, sample_rate=16000, n_fft=1024, win_length=640, hop_length=320,
|
| 30 |
-
mel_fmin=10, mel_fmax=None, num_mels=128, **kwargs):
|
| 31 |
-
super().__init__(**kwargs)
|
| 32 |
-
self.sample_rate = sample_rate
|
| 33 |
-
self.n_fft = n_fft
|
| 34 |
-
self.win_length = win_length
|
| 35 |
-
self.hop_length = hop_length
|
| 36 |
-
self.mel_fmin = mel_fmin
|
| 37 |
-
self.mel_fmax = mel_fmax
|
| 38 |
-
self.num_mels = num_mels
|
| 39 |
-
|
| 40 |
-
class SparkTTSEncoderConfig(PretrainedConfig):
|
| 41 |
-
"""Configuration for the BiCodec Feature Encoder."""
|
| 42 |
-
model_type = "spark-tts-encoder"
|
| 43 |
-
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
|
| 44 |
-
vocos_num_layers=12, out_channels=1024, sample_ratios=[1, 1], **kwargs):
|
| 45 |
-
super().__init__(**kwargs)
|
| 46 |
-
self.input_channels = input_channels
|
| 47 |
-
self.vocos_dim = vocos_dim
|
| 48 |
-
self.vocos_intermediate_dim = vocos_intermediate_dim
|
| 49 |
-
self.vocos_num_layers = vocos_num_layers
|
| 50 |
-
self.out_channels = out_channels
|
| 51 |
-
self.sample_ratios = sample_ratios
|
| 52 |
-
|
| 53 |
-
class SparkTTSDecoderConfig(PretrainedConfig):
|
| 54 |
-
"""Configuration for the BiCodec Wave Generator (Decoder)."""
|
| 55 |
-
model_type = "spark-tts-decoder"
|
| 56 |
-
def __init__(self, input_channel=1024, channels=1536, rates=[8, 5, 4, 2],
|
| 57 |
-
kernel_sizes=[16, 11, 8, 4], **kwargs):
|
| 58 |
-
super().__init__(**kwargs)
|
| 59 |
-
self.input_channel = input_channel
|
| 60 |
-
self.channels = channels
|
| 61 |
-
self.rates = rates
|
| 62 |
-
self.kernel_sizes = kernel_sizes
|
| 63 |
-
|
| 64 |
-
class SparkTTSQuantizerConfig(PretrainedConfig):
|
| 65 |
-
"""Configuration for the BiCodec Factorized Vector Quantizer."""
|
| 66 |
-
model_type = "spark-tts-quantizer"
|
| 67 |
-
def __init__(self, input_dim=1024, codebook_size=8192, codebook_dim=8,
|
| 68 |
-
commitment=0.25, codebook_loss_weight=2.0, decay=0.99,
|
| 69 |
-
threshold_ema_dead_code=0.2, **kwargs):
|
| 70 |
-
super().__init__(**kwargs)
|
| 71 |
-
self.input_dim = input_dim
|
| 72 |
-
self.codebook_size = codebook_size
|
| 73 |
-
self.codebook_dim = codebook_dim
|
| 74 |
-
self.commitment = commitment
|
| 75 |
-
self.codebook_loss_weight = codebook_loss_weight
|
| 76 |
-
self.decay = decay
|
| 77 |
-
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 78 |
-
|
| 79 |
-
class SparkTTSSpeakerEncoderConfig(PretrainedConfig):
|
| 80 |
-
"""Configuration for the BiCodec Speaker Encoder."""
|
| 81 |
-
model_type = "spark-tts-speaker-encoder"
|
| 82 |
-
def __init__(self, input_dim=128, out_dim=1024, latent_dim=128, token_num=32,
|
| 83 |
-
fsq_levels=[4, 4, 4, 4, 4, 4], fsq_num_quantizers=1, **kwargs):
|
| 84 |
-
super().__init__(**kwargs)
|
| 85 |
-
self.input_dim = input_dim
|
| 86 |
-
self.out_dim = out_dim
|
| 87 |
-
self.latent_dim = latent_dim
|
| 88 |
-
self.token_num = token_num
|
| 89 |
-
self.fsq_levels = fsq_levels
|
| 90 |
-
self.fsq_num_quantizers = fsq_num_quantizers
|
| 91 |
-
|
| 92 |
-
class SparkTTSPrenetConfig(PretrainedConfig):
|
| 93 |
-
"""Configuration for the BiCodec Prenet."""
|
| 94 |
-
model_type = "spark-tts-prenet"
|
| 95 |
-
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
|
| 96 |
-
vocos_num_layers=12, out_channels=1024, condition_dim=1024,
|
| 97 |
-
sample_ratios=[1, 1], use_tanh_at_final=False, **kwargs):
|
| 98 |
-
super().__init__(**kwargs)
|
| 99 |
-
self.input_channels = input_channels
|
| 100 |
-
self.vocos_dim = vocos_dim
|
| 101 |
-
self.vocos_intermediate_dim = vocos_intermediate_dim
|
| 102 |
-
self.vocos_num_layers = vocos_num_layers
|
| 103 |
-
self.out_channels = out_channels
|
| 104 |
-
self.condition_dim = condition_dim
|
| 105 |
-
self.sample_ratios = sample_ratios
|
| 106 |
-
self.use_tanh_at_final = use_tanh_at_final
|
| 107 |
-
|
| 108 |
-
class SparkTTSPostnetConfig(PretrainedConfig):
|
| 109 |
-
"""Configuration for the BiCodec Postnet."""
|
| 110 |
-
model_type = "spark-tts-postnet"
|
| 111 |
-
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
|
| 112 |
-
vocos_num_layers=6, out_channels=1024, use_tanh_at_final=False, **kwargs):
|
| 113 |
-
super().__init__(**kwargs)
|
| 114 |
-
self.input_channels = input_channels
|
| 115 |
-
self.vocos_dim = vocos_dim
|
| 116 |
-
self.vocos_intermediate_dim = vocos_intermediate_dim
|
| 117 |
-
self.vocos_num_layers = vocos_num_layers
|
| 118 |
-
self.out_channels = out_channels
|
| 119 |
-
self.use_tanh_at_final = use_tanh_at_final
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
# --- Define the Intermediate BiCodec Config Class ---
|
| 123 |
-
|
| 124 |
-
class SparkTTSBiCodecConfig(PretrainedConfig):
|
| 125 |
-
"""
|
| 126 |
-
Intermediate configuration class for the BiCodec component within SparkTTS.
|
| 127 |
-
It holds instances of the individual sub-component configurations.
|
| 128 |
-
"""
|
| 129 |
-
model_type = "spark-tts-bicodec"
|
| 130 |
-
sub_configs = {
|
| 131 |
-
"mel_params": SparkTTSMelParamsConfig,
|
| 132 |
-
"encoder_config": SparkTTSEncoderConfig,
|
| 133 |
-
"decoder_config": SparkTTSDecoderConfig,
|
| 134 |
-
"quantizer_config": SparkTTSQuantizerConfig,
|
| 135 |
-
"speaker_encoder_config": SparkTTSSpeakerEncoderConfig,
|
| 136 |
-
"prenet_config": SparkTTSPrenetConfig,
|
| 137 |
-
"postnet_config": SparkTTSPostnetConfig,
|
| 138 |
-
}
|
| 139 |
-
|
| 140 |
-
def __init__(
|
| 141 |
-
self,
|
| 142 |
-
mel_params=None,
|
| 143 |
-
encoder_config=None,
|
| 144 |
-
decoder_config=None,
|
| 145 |
-
quantizer_config=None,
|
| 146 |
-
speaker_encoder_config=None,
|
| 147 |
-
prenet_config=None,
|
| 148 |
-
postnet_config=None,
|
| 149 |
-
**kwargs,
|
| 150 |
-
):
|
| 151 |
-
super().__init__(**kwargs)
|
| 152 |
-
self.mel_params = self._init_sub_config(mel_params, "mel_params")
|
| 153 |
-
self.encoder_config = self._init_sub_config(encoder_config, "encoder_config")
|
| 154 |
-
self.decoder_config = self._init_sub_config(decoder_config, "decoder_config")
|
| 155 |
-
self.quantizer_config = self._init_sub_config(quantizer_config, "quantizer_config")
|
| 156 |
-
self.speaker_encoder_config = self._init_sub_config(speaker_encoder_config, "speaker_encoder_config")
|
| 157 |
-
self.prenet_config = self._init_sub_config(prenet_config, "prenet_config")
|
| 158 |
-
self.postnet_config = self._init_sub_config(postnet_config, "postnet_config")
|
| 159 |
-
|
| 160 |
-
def _init_sub_config(self, config_input, config_key):
|
| 161 |
-
"""Helper to initialize sub-configs."""
|
| 162 |
-
config_cls = self.sub_configs[config_key]
|
| 163 |
-
if isinstance(config_input, dict):
|
| 164 |
-
return config_cls(**config_input)
|
| 165 |
-
elif config_input is None:
|
| 166 |
-
return config_cls()
|
| 167 |
-
elif isinstance(config_input, config_cls):
|
| 168 |
-
return config_input
|
| 169 |
-
else:
|
| 170 |
-
raise TypeError(f"Invalid type for {config_key}: {type(config_input)}. Expected dict, None, or {config_cls.__name__}.")
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
# --- Define the Main SparkTTS Config Class ---
|
| 174 |
-
|
| 175 |
-
class SparkTTSConfig(PretrainedConfig):
|
| 176 |
-
r"""
|
| 177 |
-
Main configuration class for SparkTTSModel, including nested BiCodec configuration.
|
| 178 |
-
"""
|
| 179 |
-
model_type = "spark-tts"
|
| 180 |
-
sub_configs = {"bicodec_config": SparkTTSBiCodecConfig}
|
| 181 |
-
attribute_map = {"hidden_size": "d_model"}
|
| 182 |
-
|
| 183 |
-
def __init__(
|
| 184 |
-
self,
|
| 185 |
-
llm_model_name_or_path="./LLM",
|
| 186 |
-
bicodec_model_name_or_path="./BiCodec",
|
| 187 |
-
wav2vec2_model_name_or_path="./wav2vec2-large-xlsr-53",
|
| 188 |
-
sample_rate=16000,
|
| 189 |
-
highpass_cutoff_freq=40,
|
| 190 |
-
latent_hop_length=320,
|
| 191 |
-
ref_segment_duration=6.0,
|
| 192 |
-
volume_normalize=True,
|
| 193 |
-
bicodec_config=None,
|
| 194 |
-
torch_dtype="auto",
|
| 195 |
-
**kwargs,
|
| 196 |
-
):
|
| 197 |
-
self.llm_model_name_or_path = llm_model_name_or_path
|
| 198 |
-
self.bicodec_model_name_or_path = bicodec_model_name_or_path
|
| 199 |
-
self.wav2vec2_model_name_or_path = wav2vec2_model_name_or_path
|
| 200 |
-
self.sample_rate = sample_rate
|
| 201 |
-
self.highpass_cutoff_freq = highpass_cutoff_freq
|
| 202 |
-
self.latent_hop_length = latent_hop_length
|
| 203 |
-
self.ref_segment_duration = ref_segment_duration
|
| 204 |
-
self.volume_normalize = volume_normalize
|
| 205 |
-
self.torch_dtype = torch_dtype
|
| 206 |
-
|
| 207 |
-
if isinstance(bicodec_config, dict):
|
| 208 |
-
self.bicodec_config = self.sub_configs["bicodec_config"](**bicodec_config)
|
| 209 |
-
elif bicodec_config is None:
|
| 210 |
-
logger.info("`bicodec_config` not provided. Initializing with defaults.")
|
| 211 |
-
self.bicodec_config = self.sub_configs["bicodec_config"]()
|
| 212 |
-
elif isinstance(bicodec_config, self.sub_configs["bicodec_config"]):
|
| 213 |
-
self.bicodec_config = bicodec_config
|
| 214 |
-
else:
|
| 215 |
-
raise TypeError(f"Invalid type for bicodec_config: {type(bicodec_config)}")
|
| 216 |
-
|
| 217 |
-
kwargs["processor_class"] = kwargs.get("processor_class", "SparkTTSProcessor")
|
| 218 |
-
kwargs["auto_map"] = kwargs.get("auto_map", {
|
| 219 |
-
"AutoConfig": "configuration_spark_tts.SparkTTSConfig",
|
| 220 |
-
"AutoModel": "modeling_spark_tts.SparkTTSModel",
|
| 221 |
-
"AutoProcessor": "processing_spark_tts.SparkTTSProcessor"
|
| 222 |
-
})
|
| 223 |
-
super().__init__(**kwargs)
|
| 224 |
-
|
| 225 |
-
def to_dict(self):
|
| 226 |
-
"""Serializes this instance to a Python dictionary."""
|
| 227 |
-
output = super().to_dict()
|
| 228 |
-
if hasattr(self, 'bicodec_config') and hasattr(self.bicodec_config, 'to_dict'):
|
| 229 |
-
output['bicodec_config'] = self.bicodec_config.to_dict()
|
| 230 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|