IbrahimSalah commited on
Commit
2732b2e
·
verified ·
1 Parent(s): ee77fc6

Delete configuration_spark_tts.py

Browse files
Files changed (1) hide show
  1. configuration_spark_tts.py +0 -230
configuration_spark_tts.py DELETED
@@ -1,230 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ SparkTTS model configuration"""
16
-
17
- from transformers.configuration_utils import PretrainedConfig
18
- from transformers.utils import logging
19
- from typing import List, Optional
20
-
21
-
22
- logger = logging.get_logger(__name__)
23
-
24
- # --- Define Individual Sub-Component Config Classes ---
25
-
26
- class SparkTTSMelParamsConfig(PretrainedConfig):
27
- """Configuration for Mel Spectrogram parameters."""
28
- model_type = "spark-tts-mel-params"
29
- def __init__(self, sample_rate=16000, n_fft=1024, win_length=640, hop_length=320,
30
- mel_fmin=10, mel_fmax=None, num_mels=128, **kwargs):
31
- super().__init__(**kwargs)
32
- self.sample_rate = sample_rate
33
- self.n_fft = n_fft
34
- self.win_length = win_length
35
- self.hop_length = hop_length
36
- self.mel_fmin = mel_fmin
37
- self.mel_fmax = mel_fmax
38
- self.num_mels = num_mels
39
-
40
- class SparkTTSEncoderConfig(PretrainedConfig):
41
- """Configuration for the BiCodec Feature Encoder."""
42
- model_type = "spark-tts-encoder"
43
- def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
44
- vocos_num_layers=12, out_channels=1024, sample_ratios=[1, 1], **kwargs):
45
- super().__init__(**kwargs)
46
- self.input_channels = input_channels
47
- self.vocos_dim = vocos_dim
48
- self.vocos_intermediate_dim = vocos_intermediate_dim
49
- self.vocos_num_layers = vocos_num_layers
50
- self.out_channels = out_channels
51
- self.sample_ratios = sample_ratios
52
-
53
- class SparkTTSDecoderConfig(PretrainedConfig):
54
- """Configuration for the BiCodec Wave Generator (Decoder)."""
55
- model_type = "spark-tts-decoder"
56
- def __init__(self, input_channel=1024, channels=1536, rates=[8, 5, 4, 2],
57
- kernel_sizes=[16, 11, 8, 4], **kwargs):
58
- super().__init__(**kwargs)
59
- self.input_channel = input_channel
60
- self.channels = channels
61
- self.rates = rates
62
- self.kernel_sizes = kernel_sizes
63
-
64
- class SparkTTSQuantizerConfig(PretrainedConfig):
65
- """Configuration for the BiCodec Factorized Vector Quantizer."""
66
- model_type = "spark-tts-quantizer"
67
- def __init__(self, input_dim=1024, codebook_size=8192, codebook_dim=8,
68
- commitment=0.25, codebook_loss_weight=2.0, decay=0.99,
69
- threshold_ema_dead_code=0.2, **kwargs):
70
- super().__init__(**kwargs)
71
- self.input_dim = input_dim
72
- self.codebook_size = codebook_size
73
- self.codebook_dim = codebook_dim
74
- self.commitment = commitment
75
- self.codebook_loss_weight = codebook_loss_weight
76
- self.decay = decay
77
- self.threshold_ema_dead_code = threshold_ema_dead_code
78
-
79
- class SparkTTSSpeakerEncoderConfig(PretrainedConfig):
80
- """Configuration for the BiCodec Speaker Encoder."""
81
- model_type = "spark-tts-speaker-encoder"
82
- def __init__(self, input_dim=128, out_dim=1024, latent_dim=128, token_num=32,
83
- fsq_levels=[4, 4, 4, 4, 4, 4], fsq_num_quantizers=1, **kwargs):
84
- super().__init__(**kwargs)
85
- self.input_dim = input_dim
86
- self.out_dim = out_dim
87
- self.latent_dim = latent_dim
88
- self.token_num = token_num
89
- self.fsq_levels = fsq_levels
90
- self.fsq_num_quantizers = fsq_num_quantizers
91
-
92
- class SparkTTSPrenetConfig(PretrainedConfig):
93
- """Configuration for the BiCodec Prenet."""
94
- model_type = "spark-tts-prenet"
95
- def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
96
- vocos_num_layers=12, out_channels=1024, condition_dim=1024,
97
- sample_ratios=[1, 1], use_tanh_at_final=False, **kwargs):
98
- super().__init__(**kwargs)
99
- self.input_channels = input_channels
100
- self.vocos_dim = vocos_dim
101
- self.vocos_intermediate_dim = vocos_intermediate_dim
102
- self.vocos_num_layers = vocos_num_layers
103
- self.out_channels = out_channels
104
- self.condition_dim = condition_dim
105
- self.sample_ratios = sample_ratios
106
- self.use_tanh_at_final = use_tanh_at_final
107
-
108
- class SparkTTSPostnetConfig(PretrainedConfig):
109
- """Configuration for the BiCodec Postnet."""
110
- model_type = "spark-tts-postnet"
111
- def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
112
- vocos_num_layers=6, out_channels=1024, use_tanh_at_final=False, **kwargs):
113
- super().__init__(**kwargs)
114
- self.input_channels = input_channels
115
- self.vocos_dim = vocos_dim
116
- self.vocos_intermediate_dim = vocos_intermediate_dim
117
- self.vocos_num_layers = vocos_num_layers
118
- self.out_channels = out_channels
119
- self.use_tanh_at_final = use_tanh_at_final
120
-
121
-
122
- # --- Define the Intermediate BiCodec Config Class ---
123
-
124
- class SparkTTSBiCodecConfig(PretrainedConfig):
125
- """
126
- Intermediate configuration class for the BiCodec component within SparkTTS.
127
- It holds instances of the individual sub-component configurations.
128
- """
129
- model_type = "spark-tts-bicodec"
130
- sub_configs = {
131
- "mel_params": SparkTTSMelParamsConfig,
132
- "encoder_config": SparkTTSEncoderConfig,
133
- "decoder_config": SparkTTSDecoderConfig,
134
- "quantizer_config": SparkTTSQuantizerConfig,
135
- "speaker_encoder_config": SparkTTSSpeakerEncoderConfig,
136
- "prenet_config": SparkTTSPrenetConfig,
137
- "postnet_config": SparkTTSPostnetConfig,
138
- }
139
-
140
- def __init__(
141
- self,
142
- mel_params=None,
143
- encoder_config=None,
144
- decoder_config=None,
145
- quantizer_config=None,
146
- speaker_encoder_config=None,
147
- prenet_config=None,
148
- postnet_config=None,
149
- **kwargs,
150
- ):
151
- super().__init__(**kwargs)
152
- self.mel_params = self._init_sub_config(mel_params, "mel_params")
153
- self.encoder_config = self._init_sub_config(encoder_config, "encoder_config")
154
- self.decoder_config = self._init_sub_config(decoder_config, "decoder_config")
155
- self.quantizer_config = self._init_sub_config(quantizer_config, "quantizer_config")
156
- self.speaker_encoder_config = self._init_sub_config(speaker_encoder_config, "speaker_encoder_config")
157
- self.prenet_config = self._init_sub_config(prenet_config, "prenet_config")
158
- self.postnet_config = self._init_sub_config(postnet_config, "postnet_config")
159
-
160
- def _init_sub_config(self, config_input, config_key):
161
- """Helper to initialize sub-configs."""
162
- config_cls = self.sub_configs[config_key]
163
- if isinstance(config_input, dict):
164
- return config_cls(**config_input)
165
- elif config_input is None:
166
- return config_cls()
167
- elif isinstance(config_input, config_cls):
168
- return config_input
169
- else:
170
- raise TypeError(f"Invalid type for {config_key}: {type(config_input)}. Expected dict, None, or {config_cls.__name__}.")
171
-
172
-
173
- # --- Define the Main SparkTTS Config Class ---
174
-
175
- class SparkTTSConfig(PretrainedConfig):
176
- r"""
177
- Main configuration class for SparkTTSModel, including nested BiCodec configuration.
178
- """
179
- model_type = "spark-tts"
180
- sub_configs = {"bicodec_config": SparkTTSBiCodecConfig}
181
- attribute_map = {"hidden_size": "d_model"}
182
-
183
- def __init__(
184
- self,
185
- llm_model_name_or_path="./LLM",
186
- bicodec_model_name_or_path="./BiCodec",
187
- wav2vec2_model_name_or_path="./wav2vec2-large-xlsr-53",
188
- sample_rate=16000,
189
- highpass_cutoff_freq=40,
190
- latent_hop_length=320,
191
- ref_segment_duration=6.0,
192
- volume_normalize=True,
193
- bicodec_config=None,
194
- torch_dtype="auto",
195
- **kwargs,
196
- ):
197
- self.llm_model_name_or_path = llm_model_name_or_path
198
- self.bicodec_model_name_or_path = bicodec_model_name_or_path
199
- self.wav2vec2_model_name_or_path = wav2vec2_model_name_or_path
200
- self.sample_rate = sample_rate
201
- self.highpass_cutoff_freq = highpass_cutoff_freq
202
- self.latent_hop_length = latent_hop_length
203
- self.ref_segment_duration = ref_segment_duration
204
- self.volume_normalize = volume_normalize
205
- self.torch_dtype = torch_dtype
206
-
207
- if isinstance(bicodec_config, dict):
208
- self.bicodec_config = self.sub_configs["bicodec_config"](**bicodec_config)
209
- elif bicodec_config is None:
210
- logger.info("`bicodec_config` not provided. Initializing with defaults.")
211
- self.bicodec_config = self.sub_configs["bicodec_config"]()
212
- elif isinstance(bicodec_config, self.sub_configs["bicodec_config"]):
213
- self.bicodec_config = bicodec_config
214
- else:
215
- raise TypeError(f"Invalid type for bicodec_config: {type(bicodec_config)}")
216
-
217
- kwargs["processor_class"] = kwargs.get("processor_class", "SparkTTSProcessor")
218
- kwargs["auto_map"] = kwargs.get("auto_map", {
219
- "AutoConfig": "configuration_spark_tts.SparkTTSConfig",
220
- "AutoModel": "modeling_spark_tts.SparkTTSModel",
221
- "AutoProcessor": "processing_spark_tts.SparkTTSProcessor"
222
- })
223
- super().__init__(**kwargs)
224
-
225
- def to_dict(self):
226
- """Serializes this instance to a Python dictionary."""
227
- output = super().to_dict()
228
- if hasattr(self, 'bicodec_config') and hasattr(self.bicodec_config, 'to_dict'):
229
- output['bicodec_config'] = self.bicodec_config.to_dict()
230
- return output