IbrahimSalah commited on
Commit
4668bad
·
verified ·
1 Parent(s): d099eee

Delete processing_spark_tts.py

Browse files
Files changed (1) hide show
  1. processing_spark_tts.py +0 -889
processing_spark_tts.py DELETED
@@ -1,889 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2025 SparkAudio & The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """
16
- Processor class for SparkTTS. Combines text tokenization and audio feature extraction/processing.
17
- """
18
-
19
- import os # Needed for save_pretrained
20
- import re # For decoding
21
- import torch
22
- import numpy as np
23
- import soundfile as sf # For audio loading
24
- import soxr # For resampling
25
-
26
- from pathlib import Path
27
- from typing import Optional, Union, List, Dict, Tuple, Any
28
-
29
- from transformers.processing_utils import ProcessorMixin
30
- from transformers.tokenization_utils_base import BatchEncoding # Return type hint
31
- from transformers.feature_extraction_utils import BatchFeature # Return type hint
32
- from transformers.models.auto.tokenization_auto import AutoTokenizer
33
- from transformers.models.wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
34
- from transformers.utils import logging, PushToHubMixin # Added PushToHubMixin
35
- from numpy.lib.stride_tricks import sliding_window_view
36
- import soxr
37
- import soundfile
38
- import random
39
-
40
- # Import custom config if needed for defaults
41
- from .configuration_spark_tts import SparkTTSConfig
42
-
43
- logger = logging.get_logger(__name__)
44
-
45
-
46
- # =============================================================================
47
- # >> START: PASTE CODE FROM sparktts/utils/* HERE <<
48
- # =============================================================================
49
- # IMPORTANT: Utility functions needed for processing (audio loading, token parsing)
50
- # must be defined or imported here.
51
-
52
- # --- Paste sparktts/utils/audio.py content here ---
53
-
54
- def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
55
- """
56
- Normalize the volume of an audio signal.
57
-
58
- Parameters:
59
- audio (numpy array): Input audio signal array.
60
- coeff (float): Target coefficient for normalization, default is 0.2.
61
-
62
- Returns:
63
- numpy array: The volume-normalized audio signal.
64
- """
65
- # Sort the absolute values of the audio signal
66
- temp = np.sort(np.abs(audio))
67
-
68
- # If the maximum value is less than 0.1, scale the array to have a maximum of 0.1
69
- if temp[-1] < 0.1:
70
- scaling_factor = max(
71
- temp[-1], 1e-3
72
- ) # Prevent division by zero with a small constant
73
- audio = audio / scaling_factor * 0.1
74
-
75
- # Filter out values less than 0.01 from temp
76
- temp = temp[temp > 0.01]
77
- L = temp.shape[0] # Length of the filtered array
78
-
79
- # If there are fewer than or equal to 10 significant values, return the audio without further processing
80
- if L <= 10:
81
- return audio
82
-
83
- # Compute the average of the top 10% to 1% of values in temp
84
- volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
85
-
86
- # Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10
87
- audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
88
-
89
- # Ensure the maximum absolute value in the audio does not exceed 1
90
- max_value = np.max(np.abs(audio))
91
- if max_value > 1:
92
- audio = audio / max_value
93
-
94
- return audio
95
-
96
-
97
- def load_audio(
98
- adfile: Path,
99
- sampling_rate: int = None,
100
- length: int = None,
101
- volume_normalize: bool = False,
102
- segment_duration: int = None,
103
- ) -> np.ndarray:
104
- r"""Load audio file with target sampling rate and lsength
105
-
106
- Args:
107
- adfile (Path): path to audio file.
108
- sampling_rate (int, optional): target sampling rate. Defaults to None.
109
- length (int, optional): target audio length. Defaults to None.
110
- volume_normalize (bool, optional): whether perform volume normalization. Defaults to False.
111
- segment_duration (int): random select a segment with duration of {segment_duration}s.
112
- Defualt to None which means the whole audio will be used.
113
-
114
- Returns:
115
- audio (np.ndarray): audio
116
- """
117
-
118
- audio, sr = soundfile.read(adfile)
119
- if len(audio.shape) > 1:
120
- audio = audio[:, 0]
121
-
122
- if sampling_rate is not None and sr != sampling_rate:
123
- audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
124
- sr = sampling_rate
125
-
126
- if segment_duration is not None:
127
- seg_length = int(sr * segment_duration)
128
- audio = random_select_audio_segment(audio, seg_length)
129
-
130
- # Audio volume normalize
131
- if volume_normalize:
132
- audio = audio_volume_normalize(audio)
133
- # check the audio length
134
- if length is not None:
135
- assert abs(audio.shape[0] - length) < 1000
136
- if audio.shape[0] > length:
137
- audio = audio[:length]
138
- else:
139
- audio = np.pad(audio, (0, int(length - audio.shape[0])))
140
- return audio
141
-
142
-
143
- def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
144
- """get an audio segment given the length
145
-
146
- Args:
147
- audio (np.ndarray):
148
- length (int): audio length = sampling_rate * duration
149
- """
150
- if audio.shape[0] < length:
151
- audio = np.pad(audio, (0, int(length - audio.shape[0])))
152
- start_index = random.randint(0, audio.shape[0] - length)
153
- end_index = int(start_index + length)
154
-
155
- return audio[start_index:end_index]
156
-
157
- def get_ref_clip(wav: np.ndarray, config) -> np.ndarray: # Needs access to config attributes
158
- """Get reference audio clip for speaker embedding."""
159
- # Make sure config has sample_rate, ref_segment_duration, latent_hop_length
160
- if not all(hasattr(config, attr) for attr in ['sample_rate', 'ref_segment_duration', 'latent_hop_length']):
161
- raise AttributeError("Config object missing required attributes for get_ref_clip")
162
- ref_segment_length = (
163
- int(config.sample_rate * config.ref_segment_duration)
164
- // config.latent_hop_length
165
- * config.latent_hop_length
166
- )
167
- wav_length = len(wav)
168
- if ref_segment_length > wav_length:
169
- wav = np.tile(wav, ref_segment_length // wav_length + 1)
170
- return wav[:ref_segment_length]
171
-
172
-
173
- # --- Paste sparktts/utils/token_parser.py content here ---
174
-
175
- TASK_TOKEN_MAP = {
176
- "vc": "<|task_vc|>",
177
- "tts": "<|task_tts|>",
178
- "asr": "<|task_asr|>",
179
- "s2s": "<|task_s2s|>",
180
- "t2s": "<|task_t2s|>",
181
- "understand": "<|task_understand|>",
182
- "caption": "<|task_cap|>",
183
- "controllable_tts": "<|task_controllable_tts|>",
184
- "prompt_tts": "<|task_prompt_tts|>",
185
- "speech_edit": "<|task_edit|>",
186
- }
187
-
188
- LEVELS_MAP = {
189
- "very_low": 0,
190
- "low": 1,
191
- "moderate": 2,
192
- "high": 3,
193
- "very_high": 4,
194
- }
195
-
196
- LEVELS_MAP_UI = {
197
- 1: 'very_low',
198
- 2: 'low',
199
- 3: 'moderate',
200
- 4: 'high',
201
- 5: 'very_high'
202
- }
203
-
204
- GENDER_MAP = {
205
- "female": 0,
206
- "male": 1,
207
- }
208
-
209
- AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4}
210
-
211
- EMO_MAP = {
212
- "UNKNOWN": 0,
213
- "NEUTRAL": 1,
214
- "ANGRY": 2,
215
- "HAPPY": 3,
216
- "SAD": 4,
217
- "FEARFUL": 5,
218
- "DISGUSTED": 6,
219
- "SURPRISED": 7,
220
- "SARCASTIC": 8,
221
- "EXCITED": 9,
222
- "SLEEPY": 10,
223
- "CONFUSED": 11,
224
- "EMPHASIS": 12,
225
- "LAUGHING": 13,
226
- "SINGING": 14,
227
- "WORRIED": 15,
228
- "WHISPER": 16,
229
- "ANXIOUS": 17,
230
- "NO-AGREEMENT": 18,
231
- "APOLOGETIC": 19,
232
- "CONCERNED": 20,
233
- "ENUNCIATED": 21,
234
- "ASSERTIVE": 22,
235
- "ENCOURAGING": 23,
236
- "CONTEMPT": 24,
237
- }
238
-
239
-
240
- class TokenParser:
241
- """Turn label to special token"""
242
-
243
- def __init__(self):
244
- pass
245
-
246
- """Parse the attributes of a person."""
247
-
248
- def __init__(self):
249
- pass
250
-
251
- @staticmethod
252
- def age(age: str) -> str:
253
- """Turn age token."""
254
- age_id = AGE_MAP[age]
255
- return f"<|age_{age_id}|>"
256
-
257
- @staticmethod
258
- def gender(gender: str) -> str:
259
- """Turn gender token."""
260
- gender_id = GENDER_MAP[gender]
261
- return f"<|gender_{gender_id}|>"
262
-
263
- @staticmethod
264
- def mel_value(mel: int):
265
- """Turn special token of mel scale pitch."""
266
- mel = max(0, int(mel))
267
- mel = min(1000, int(mel))
268
- return f"<|pitch_value_{mel}|>"
269
-
270
- @staticmethod
271
- def mel_level(level: str):
272
- """Turn special token of mel level."""
273
- level_tag = LEVELS_MAP[level]
274
- return f"<|pitch_label_{level_tag}|>"
275
-
276
- @staticmethod
277
- def pitch_var_value(pitch_std: int):
278
- """Turn special token of pitch_std value."""
279
- assert isinstance(pitch_std, int)
280
- pitch_std = max(0, int(pitch_std))
281
- pitch_std = min(10, int(pitch_std))
282
- return f"<|pitch_var_value_{pitch_std}|>"
283
-
284
- @staticmethod
285
- def pitch_var_level(level: str):
286
- """Turn special token of pitch std level."""
287
- level_tag = LEVELS_MAP[level]
288
- return f"<|pitch_var_label_{level_tag}|>"
289
-
290
- @staticmethod
291
- def loudness_value(loudness: int):
292
- """Turn special toak of loudness value [0, 30]"""
293
- assert loudness >= 0
294
- loudness = max(0, int(loudness))
295
- loudness = min(30, int(loudness))
296
- return f"<|loudness_value_{loudness}|>"
297
-
298
- @staticmethod
299
- def loudness_level(level: str):
300
- """Turn special token of loudness level."""
301
- level_tag = LEVELS_MAP[level]
302
- return f"<|loudness_label_{level_tag}|>"
303
-
304
- @staticmethod
305
- def speed_value(speed: int):
306
- """Turn special token of speed value."""
307
- speed = max(0, int(speed))
308
- speed = min(10, int(speed))
309
- return f"<|speed_value_{speed}|>"
310
-
311
- @staticmethod
312
- def speed_level(level: str):
313
- """Turn special token of speed level."""
314
- level_tag = LEVELS_MAP[level]
315
- return f"<|speed_label_{level_tag}|>"
316
-
317
- @staticmethod
318
- def task(task: str) -> str:
319
- """Turn special token of task."""
320
- assert task in TASK_TOKEN_MAP.keys()
321
-
322
- return TASK_TOKEN_MAP[task]
323
-
324
- @staticmethod
325
- def emotion(emotion: str):
326
- emo_id = EMO_MAP[emotion]
327
-
328
- return f"<|emotion_{emo_id}|>"
329
-
330
- # =============================================================================
331
- # >> END: PASTE CODE FROM sparktts/utils/* HERE <<
332
- # =============================================================================
333
-
334
-
335
- class SparkTTSProcessor(ProcessorMixin, PushToHubMixin): # Added PushToHubMixin
336
- r"""
337
- Constructs a SparkTTS processor which wraps a text tokenizer and relevant audio processing logic.
338
-
339
- Args:
340
- tokenizer ([`PreTrainedTokenizer`]):
341
- An instance of [`PreTrainedTokenizer`]. This handles the text tokenization for the LLM.
342
- feature_extractor ([`Wav2Vec2FeatureExtractor`]):
343
- An instance of [`Wav2Vec2FeatureExtractor`]. Although Wav2Vec2 features are extracted
344
- within the model's `tokenize_audio`, the extractor's configuration (like sampling rate)
345
- is useful, and it aligns with the ProcessorMixin pattern.
346
- config ([`SparkTTSConfig`], *optional*):
347
- An instance of [`SparkTTSConfig`] to access configuration parameters like sample rate.
348
- """
349
- attributes = ["tokenizer", "feature_extractor"]
350
- tokenizer_class = "AutoTokenizer"
351
- feature_extractor_class = "Wav2Vec2FeatureExtractor" # Keep for consistency
352
-
353
- def __init__(self, tokenizer, feature_extractor, config: Optional[SparkTTSConfig] = None, **kwargs):
354
- super().__init__(tokenizer=tokenizer, feature_extractor=feature_extractor, **kwargs)
355
- self.model = None
356
- self.config = config
357
- # Set sampling rate
358
- if config and hasattr(config, 'sample_rate'):
359
- self.sampling_rate = config.sample_rate
360
- elif feature_extractor and hasattr(feature_extractor, 'sampling_rate'):
361
- self.sampling_rate = feature_extractor.sampling_rate
362
- else:
363
- self.sampling_rate = 16000
364
- logger.warning(f"Could not determine sampling rate. Defaulting to {self.sampling_rate} Hz.")
365
-
366
- # # Ensure tokenizer pad token
367
- # if self.tokenizer.pad_token is None:
368
- # if self.tokenizer.eos_token is not None:
369
- # logger.warning("Tokenizer does not have a pad token. Setting pad_token to eos_token.")
370
- # self.tokenizer.pad_token = self.tokenizer.eos_token
371
- # else:
372
- # logger.warning("Tokenizer lacks pad and eos token. Adding default pad token '<|pad|>'.")
373
- # self.tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
374
-
375
- def link_model(self, model):
376
- """Links the processor to a SparkTTSModel instance for audio processing calls."""
377
- if not hasattr(model, 'tokenize_audio') or not hasattr(model, 'detokenize_audio'):
378
- raise TypeError("The provided model instance does not have the required 'tokenize_audio' and 'detokenize_audio' methods.")
379
- if not hasattr(model, 'config'):
380
- logger.warning("Linked model does not have a 'config' attribute. Some processor functionalities might rely on it.")
381
-
382
- self.model = model
383
- logger.info("SparkTTSModel successfully linked to the processor.")
384
- # Update sampling rate based on linked model's config if available
385
- if hasattr(model, 'config') and hasattr(model.config, 'sample_rate'):
386
- if self.sampling_rate != model.config.sample_rate:
387
- logger.info(f"Updating processor sampling rate from {self.sampling_rate} to {model.config.sample_rate} based on linked model config.")
388
- self.sampling_rate = model.config.sample_rate
389
- # Also update feature extractor sampling rate if it differs
390
- if hasattr(self, 'feature_extractor') and self.feature_extractor.sampling_rate != model.config.sample_rate:
391
- logger.info(f"Updating feature_extractor sampling rate from {self.feature_extractor.sampling_rate} to {model.config.sample_rate}.")
392
- self.feature_extractor.sampling_rate = model.config.sample_rate
393
-
394
-
395
- def __call__(
396
- self,
397
- text: str,
398
- prompt_speech_path: Optional[Union[str, Path]] = None,
399
- prompt_text: Optional[str] = None,
400
- gender: Optional[str] = None,
401
- pitch: Optional[str] = None,
402
- speed: Optional[str] = None,
403
- return_tensors: Optional[str] = "pt",
404
- **kwargs, # Allow passing other args like padding, truncation to tokenizer
405
- ) -> BatchEncoding:
406
- """
407
- Processes the input text and optional prompt audio/control parameters into a format suitable for [`SparkTTSModel`].
408
-
409
- Args:
410
- text (`str`):
411
- The main text to be synthesized.
412
- prompt_speech_path (`str` or `Path`, *optional*):
413
- Path to the prompt audio file for voice cloning. Required if `gender` is not set.
414
- prompt_text (`str`, *optional*):
415
- Transcript of the prompt audio. Used only in voice cloning mode.
416
- gender (`str`, *optional*):
417
- Target gender ("male" or "female") for controllable synthesis. If set, enables control mode.
418
- pitch (`str`, *optional*):
419
- Target pitch level ("very_low", "low", "moderate", "high", "very_high") for control mode. Required if `gender` is set.
420
- speed (`str`, *optional*):
421
- Target speed level ("very_low", "low", "moderate", "high", "very_high") for control mode. Required if `gender` is set.
422
- return_tensors (`str`, *optional*, defaults to `"pt"`):
423
- If set, will return tensors instead of list of python integers. Only "pt" (PyTorch) is supported currently.
424
- **kwargs:
425
- Additional arguments passed to the underlying tokenizer's `__call__` method.
426
-
427
- Returns:
428
- [`BatchEncoding`]: A dictionary containing the `input_ids` and `attention_mask` for the LLM.
429
- In voice cloning mode, it also includes `global_token_ids_prompt` (torch.Tensor) representing the
430
- global tokens extracted from the prompt audio.
431
- """
432
-
433
- global_token_ids_prompt = None # Initialize
434
-
435
- # Determine mode: Control TTS or Voice Cloning (Prompt TTS)
436
- is_control_mode = gender is not None
437
- is_cloning_mode = prompt_speech_path is not None and not is_control_mode
438
-
439
- if is_control_mode:
440
- # --- Controllable TTS Prompt Construction ---
441
- if not all([pitch, speed]):
442
- raise ValueError("For controllable TTS, 'gender', 'pitch', and 'speed' must all be provided.")
443
- if prompt_speech_path is not None:
444
- logger.warning("`prompt_speech_path` provided but ignored because `gender` is set (controllable TTS mode).")
445
-
446
- if not all(k in GENDER_MAP for k in [gender]): # Basic check
447
- raise ValueError(f"Invalid gender provided: {gender}. Must be one of {list(GENDER_MAP.keys())}")
448
- if not all(k in LEVELS_MAP for k in [pitch, speed]): # Basic check
449
- raise ValueError(f"Invalid pitch or speed level provided. Must be one of {list(LEVELS_MAP.keys())}")
450
-
451
- gender_id = GENDER_MAP[gender]
452
- pitch_level_id = LEVELS_MAP[pitch]
453
- speed_level_id = LEVELS_MAP[speed]
454
-
455
- pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
456
- speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
457
- gender_tokens = f"<|gender_{gender_id}|>"
458
-
459
- attribute_tokens = "".join([gender_tokens, pitch_label_tokens, speed_label_tokens])
460
-
461
- prompt_list = [
462
- TASK_TOKEN_MAP["controllable_tts"],
463
- "<|start_content|>",
464
- text,
465
- "<|end_content|>",
466
- "<|start_style_label|>",
467
- attribute_tokens,
468
- "<|end_style_label|>",
469
- ]
470
- prompt_string = "".join(prompt_list)
471
-
472
- elif is_cloning_mode:
473
- # --- Voice Cloning Prompt Construction ---
474
- if self.model is None:
475
- raise RuntimeError("Processor must be linked to a SparkTTSModel instance via `processor.link_model(model)` before performing voice cloning.")
476
- prompt_speech_path = Path(prompt_speech_path) # Ensure it's a Path object
477
- if not prompt_speech_path.exists():
478
- raise FileNotFoundError(f"Prompt audio file not found: {prompt_speech_path}")
479
-
480
- # Load and process prompt audio
481
- try:
482
- model_config = self.model.config if self.model and hasattr(self.model, 'config') else self.config
483
- if model_config is None:
484
- raise ValueError("Configuration not available in processor or linked model.")
485
-
486
- # Load main wav
487
- wav = load_audio(
488
- prompt_speech_path,
489
- sampling_rate=self.sampling_rate,
490
- volume_normalize=getattr(model_config, 'volume_normalize', True), # Use getattr for safety
491
- )
492
- # Get reference clip
493
- wav_ref_np = get_ref_clip(wav, model_config) # Pass config object
494
- wav_ref = torch.from_numpy(wav_ref_np).unsqueeze(0).float()
495
- wav_tensor = torch.from_numpy(wav).unsqueeze(0).float()
496
-
497
- # Tokenize using the linked model's method
498
- # Assuming tokenize_audio returns tensors with batch dim 1: [1, N_global], [1, N_semantic]
499
- global_tokens_tensor, semantic_tokens_tensor = self.model.tokenize_audio(wav_tensor, wav_ref)
500
-
501
- # Store the global tokens tensor (with batch dim) for the output dict
502
- global_token_ids_prompt = global_tokens_tensor # Keep batch dim [1, N_global]
503
-
504
- # Convert tensors to lists of ints for string formatting
505
- global_token_list = global_tokens_tensor.squeeze().tolist() # Remove batch dim -> list
506
- semantic_token_list = semantic_tokens_tensor.squeeze().tolist() # Remove batch dim -> list
507
-
508
- except Exception as e:
509
- logger.error(f"Error processing prompt audio {prompt_speech_path}: {e}")
510
- import traceback
511
- traceback.print_exc()
512
- raise
513
-
514
- # ==============================================================
515
- # CORRECTED TOKEN STRING FORMATTING
516
- # ==============================================================
517
- # Create individual token strings for each ID
518
- global_tokens_str = "".join([f"<|bicodec_global_{gid}|>" for gid in global_token_list])
519
- semantic_tokens_str = "".join([f"<|bicodec_semantic_{sid}|>" for sid in semantic_token_list])
520
- # ==============================================================
521
-
522
- # Construct prompt list based on presence of prompt_text
523
- if prompt_text is not None and prompt_text.strip(): # Check if prompt_text is meaningful
524
- logger.info("Using prompt text in voice cloning prompt.")
525
- prompt_list = [
526
- TASK_TOKEN_MAP["tts"], # Or maybe TASK_TOKEN_MAP["prompt_tts"]? Check original logic. Assuming "tts".
527
- "<|start_content|>",
528
- prompt_text, # Transcript first
529
- text, # Then target text
530
- "<|end_content|>",
531
- "<|start_global_token|>",
532
- global_tokens_str,
533
- "<|end_global_token|>",
534
- "<|start_semantic_token|>",
535
- semantic_tokens_str,
536
- # "<|end_semantic_token|>", # Original code didn't have this marker here
537
- ]
538
- else:
539
- # Simpler prompt without semantic tokens if no transcript provided
540
- logger.info("No prompt text provided, using text-only voice cloning prompt.")
541
- prompt_list = [
542
- TASK_TOKEN_MAP["tts"], # Or maybe TASK_TOKEN_MAP["prompt_tts"]?
543
- "<|start_content|>",
544
- text, # Only target text
545
- "<|end_content|>",
546
- "<|start_global_token|>",
547
- global_tokens_str,
548
- "<|end_global_token|>",
549
- ]
550
- prompt_string = "".join(prompt_list)
551
- logger.debug(f"Generated prompt string (cloning): {prompt_string[:200]}...") # Log start of prompt
552
-
553
- else:
554
- raise ValueError("Invalid input combination. Either provide `prompt_speech_path` for cloning or (`gender`, `pitch`, `speed`) for control.")
555
-
556
- # --- Tokenize the final prompt string ---
557
- # print(f"Tokenizing prompt: {prompt_string}")
558
- inputs = self.tokenizer(
559
- prompt_string,
560
- return_tensors=return_tensors,
561
- padding=kwargs.get("padding", False), # Often False for generation prompts unless batching > 1
562
- truncation=kwargs.get("truncation", True),
563
- max_length=kwargs.get("max_length", self.tokenizer.model_max_length),
564
- add_special_tokens=kwargs.get("add_special_tokens", True), # Usually True unless handled manually
565
- return_attention_mask=kwargs.get("return_attention_mask", True), # Need attention mask
566
- **{k: v for k, v in kwargs.items() if k not in ["padding", "truncation", "max_length", "add_special_tokens", "return_attention_mask"]}
567
- )
568
- logger.debug(f"Tokenized input_ids shape: {inputs['input_ids'].shape}")
569
-
570
-
571
- # Add the prompt's global tokens (as tensor with batch dim) to the output if in cloning mode
572
- if is_cloning_mode and global_token_ids_prompt is not None:
573
- if return_tensors == "pt":
574
- inputs["global_token_ids_prompt"] = global_token_ids_prompt # Already has batch dim [1, N_global]
575
- else:
576
- # Handle non-tensor return if necessary
577
- inputs["global_token_ids_prompt"] = global_token_ids_prompt.tolist()
578
-
579
- return inputs
580
-
581
-
582
- def decode(
583
- self,
584
- generated_ids: torch.Tensor,
585
- global_token_ids_prompt: Optional[torch.Tensor] = None,
586
- input_ids_len: Optional[int] = None,
587
- skip_special_tokens: bool = True,
588
- ) -> Dict[str, Any]:
589
- """
590
- Decodes the generated token IDs from [`SparkTTSModel`] into an audio waveform.
591
-
592
- Args:
593
- generated_ids (`torch.Tensor`):
594
- Tensor of token IDs generated by `model.generate()`, including the input prompt part. Shape [B, seq_len].
595
- global_token_ids_prompt (`torch.Tensor`, *optional*):
596
- The global tokens extracted from the prompt audio during the `__call__` step (for voice cloning).
597
- Shape [B, N_global]. Required if the generation was for voice cloning.
598
- input_ids_len (`int`, *optional*):
599
- The length of the original input prompt `input_ids` fed to `model.generate()`. Required to
600
- correctly isolate the newly generated tokens.
601
- skip_special_tokens (`bool`, *optional*, defaults to `True`):
602
- Whether to skip special tokens during the text decoding step (used to extract audio tokens).
603
-
604
- Returns:
605
- Dict[str, Any]: A dictionary containing:
606
- - "audio": The decoded audio waveform as a NumPy array. Shape [T_audio] (if B=1) or [B, T_audio].
607
- - "sampling_rate": The sampling rate of the audio.
608
- """
609
- if self.model is None:
610
- raise RuntimeError("Processor must be linked to a SparkTTSModel instance via `processor.link_model(model)` before decoding.")
611
- if input_ids_len is None:
612
- raise ValueError("`input_ids_len` (length of the prompt input_ids) must be provided for decoding.")
613
-
614
- # --- Isolate generated part and decode text ---
615
- # Assumes generated_ids has shape [B, full_seq_len]
616
- # Handle case where generated sequence is shorter than prompt (shouldn't happen with max_new_tokens > 0)
617
- if generated_ids.shape[1] < input_ids_len:
618
- logger.warning(f"Generated sequence length ({generated_ids.shape[1]}) is shorter than input prompt length ({input_ids_len}). Decoding might be incorrect.")
619
- output_only_ids = generated_ids[:, input_ids_len:] # Will be empty if equal
620
- else:
621
- output_only_ids = generated_ids[:, input_ids_len:]
622
-
623
-
624
- # Decode the generated part to find audio tokens
625
- # Need to handle batch decoding if B > 1
626
- # print("decode token", self.tokenizer.batch_decode(output_only_ids, skip_special_tokens=False))
627
- decoded_texts = self.tokenizer.batch_decode(output_only_ids, skip_special_tokens=skip_special_tokens)
628
-
629
- # --- Extract Audio Tokens ---
630
- # Handle batch processing correctly
631
- batch_size = generated_ids.shape[0]
632
- all_semantic_ids = []
633
- all_global_tokens = []
634
- successful_indices = [] # Keep track of which batch items were successful
635
-
636
- for i in range(batch_size):
637
- decoded_text = decoded_texts[i]
638
- current_semantic_ids = None
639
- current_global_tokens = None
640
-
641
- # Extract semantic tokens
642
- try:
643
- pred_semantic_indices = [int(token) for token in re.findall(r"bicodec_semantic_(\d+)", decoded_text)]
644
- if not pred_semantic_indices:
645
- logger.warning(f"Batch item {i}: No semantic tokens found in decoded text: '{decoded_text[:200]}...'")
646
- continue # Skip this item
647
-
648
- current_semantic_ids = torch.tensor(pred_semantic_indices).long() # Shape [N_semantic]
649
- except Exception as e:
650
- logger.error(f"Batch item {i}: Error parsing semantic tokens from: '{decoded_text[:200]}...'. Error: {e}")
651
- continue # Skip this item
652
-
653
- # Determine global tokens
654
- if global_token_ids_prompt is not None:
655
- # Cloning mode: Use the provided prompt global tokens for this batch item
656
- if global_token_ids_prompt.shape[0] != batch_size:
657
- raise ValueError(f"Batch size mismatch: generated_ids has {batch_size}, but global_token_ids_prompt has {global_token_ids_prompt.shape[0]}.")
658
- current_global_tokens = global_token_ids_prompt[i] # Shape [N_global]
659
- else:
660
- # Control mode: Extract global tokens from the generated text
661
- try:
662
- pred_global_indices = [int(token) for token in re.findall(r"bicodec_global_(\d+)", decoded_text)]
663
- if not pred_global_indices:
664
- logger.warning(f"Batch item {i}: No global tokens found in decoded text for control mode: '{decoded_text[:200]}...'")
665
- continue # Skip this item
666
-
667
- current_global_tokens = torch.tensor(pred_global_indices).long() # Shape [N_global]
668
-
669
- except Exception as e:
670
- logger.error(f"Batch item {i}: Error parsing global tokens from: '{decoded_text[:200]}...'. Error: {e}")
671
- continue # Skip this item
672
-
673
- # If both tokens extracted successfully
674
- all_semantic_ids.append(current_semantic_ids)
675
- all_global_tokens.append(current_global_tokens)
676
- successful_indices.append(i)
677
-
678
- if not successful_indices:
679
- logger.error("Failed to extract audio tokens for any item in the batch.")
680
- return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
681
-
682
- # Pad sequences to the max length within the successful batch items for batch detokenization
683
- # Note: BiCodec might not support batching if sequences have different lengths. Check its implementation.
684
- # Assuming BiCodec *can* handle batches if padded (or if lengths are naturally equal).
685
- # This padding might be unnecessary if BiCodec handles variable lengths or if B=1 anyway.
686
- # For now, let's assume B=1 was handled correctly and skip complex padding.
687
- if batch_size > 1 and len(successful_indices) < batch_size:
688
- logger.warning(f"Only successfully decoded {len(successful_indices)} out of {batch_size} batch items.")
689
- # Further processing might need to handle only the successful items.
690
-
691
- # Let's proceed assuming B=1 or BiCodec handles batches appropriately.
692
- # Stack the successful tokens.
693
- try:
694
- # Need to ensure tensors have the same length before stacking if BiCodec requires it.
695
- # If BiCodec handles variable length, stacking might not be needed, just loop and call detokenize.
696
- # Let's assume B=1 for simplicity of the example, matching original code's likely behavior.
697
- if len(successful_indices) != 1:
698
- raise NotImplementedError("Batch decoding (B > 1) requires verification of BiCodec's batch handling and potentially padding.")
699
-
700
- final_semantic_ids = all_semantic_ids[0].unsqueeze(0) # Add batch dim [1, N_semantic]
701
- final_global_tokens = all_global_tokens[0].unsqueeze(0) # Add batch dim [1, N_global]
702
-
703
- except IndexError: # Should not happen if successful_indices is not empty
704
- logger.error("Internal error during token batch preparation.")
705
- return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
706
-
707
-
708
- # --- Detokenize Audio ---
709
- try:
710
- # Call the linked model's detokenize method
711
- # print(f"DEBUG: Detokenizing audio with global tokens {final_global_tokens.shape}, semantic tokens {final_semantic_ids.shape}")
712
- output_wav = self.model.detokenize_audio(final_global_tokens, final_semantic_ids)
713
- # detokenize_audio now returns numpy array float32 in [-1, 1]
714
-
715
- # Optional: Double-check dtype here if needed, but should be handled by detokenize_audio now
716
- # if output_wav.dtype != np.float32:
717
- # logger.warning(f"Audio dtype after detokenize is {output_wav.dtype}. Converting to float32.")
718
- # output_wav = output_wav.astype(np.float32)
719
- # output_wav = np.clip(output_wav, -1.0, 1.0) # Clipping done in detokenize_audio
720
-
721
- except Exception as e:
722
- logger.error(f"Error during audio detokenization: {e}")
723
- import traceback
724
- traceback.print_exc()
725
- raise RuntimeError("Audio detokenization failed.") from e
726
-
727
- return {"audio": output_wav, "sampling_rate": self.sampling_rate}
728
-
729
-
730
- @classmethod
731
- def from_pretrained(
732
- cls,
733
- pretrained_model_name_or_path: Union[str, os.PathLike],
734
- cache_dir: Optional[Union[str, os.PathLike]] = None,
735
- force_download: bool = False,
736
- local_files_only: bool = False,
737
- token: Optional[Union[str, bool]] = None,
738
- revision: str = "main",
739
- trust_remote_code: bool = False, # Allow passing this, needed for config potentially
740
- **kwargs,
741
- ):
742
- r"""
743
- Instantiate a SparkTTSProcessor from pretrained components.
744
- """
745
- # Pop specific kwargs for this method
746
- config = kwargs.pop("config", None) # Allow passing config explicitly
747
-
748
- # --- 1. Load Config (to find component paths) ---
749
- # We need the config even if the processor doesn't store it permanently,
750
- # just to find where the tokenizer/feature_extractor live.
751
- loaded_config = None
752
- if not isinstance(config, SparkTTSConfig):
753
- try:
754
- # Load the specific config class
755
- loaded_config = SparkTTSConfig.from_pretrained(
756
- pretrained_model_name_or_path,
757
- cache_dir=cache_dir,
758
- force_download=force_download,
759
- local_files_only=local_files_only,
760
- token=token,
761
- revision=revision,
762
- trust_remote_code=trust_remote_code, # Config might be custom
763
- **kwargs, # Pass relevant kwargs
764
- )
765
- except Exception as e:
766
- logger.warning(
767
- f"Could not load SparkTTSConfig from {pretrained_model_name_or_path}. "
768
- f"Attempting to load components from default relative paths ('LLM', 'wav2vec2-large-xlsr-53'). Error: {e}"
769
- )
770
- loaded_config = None # Fallback
771
- else:
772
- # Config object was passed directly
773
- loaded_config = config
774
-
775
-
776
- # --- 2. Determine Component Paths ---
777
- llm_tokenizer_path_or_id = "./LLM" # Default relative path
778
- w2v_processor_path_or_id = "./wav2vec2-large-xlsr-53" # Default relative path
779
-
780
- if loaded_config:
781
- llm_tokenizer_path_or_id = getattr(loaded_config, 'llm_model_name_or_path', llm_tokenizer_path_or_id)
782
- w2v_processor_path_or_id = getattr(loaded_config, 'wav2vec2_model_name_or_path', w2v_processor_path_or_id)
783
-
784
- # The component `from_pretrained` methods handle resolving these paths/IDs
785
- # whether they are relative subfolders of `pretrained_model_name_or_path`
786
- # or separate Hub IDs.
787
-
788
- # --- 3. Load Components ---
789
- # Pass down relevant kwargs for loading components
790
- component_loading_kwargs = {
791
- "cache_dir": cache_dir,
792
- "force_download": force_download,
793
- "local_files_only": local_files_only,
794
- "token": token,
795
- "revision": revision,
796
- **kwargs # Pass other user kwargs
797
- }
798
- try:
799
- # Tokenizer might require trust_remote_code if its class is custom
800
- tokenizer = AutoTokenizer.from_pretrained(
801
- pretrained_model_name_or_path, # Main path
802
- subfolder=llm_tokenizer_path_or_id.lstrip('./'), # Specify subfolder relative to main path
803
- trust_remote_code=trust_remote_code,
804
- **component_loading_kwargs
805
- )
806
- except Exception as e:
807
- # Fallback: try loading directly using the path/id from config if different
808
- if llm_tokenizer_path_or_id != "./LLM":
809
- try:
810
- logger.info(f"Retrying tokenizer load directly from: {llm_tokenizer_path_or_id}")
811
- tokenizer = AutoTokenizer.from_pretrained(
812
- llm_tokenizer_path_or_id,
813
- trust_remote_code=trust_remote_code,
814
- **component_loading_kwargs
815
- )
816
- except Exception as e2:
817
- raise OSError(f"Could not load tokenizer using main path + subfolder or directly from '{llm_tokenizer_path_or_id}'. Error: {e2}") from e
818
- else:
819
- raise OSError(f"Could not load tokenizer from subfolder '{llm_tokenizer_path_or_id}' within '{pretrained_model_name_or_path}'. Error: {e}")
820
-
821
-
822
- try:
823
- # Feature extractor usually doesn't need trust_remote_code
824
- feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
825
- pretrained_model_name_or_path, # Main path
826
- subfolder=w2v_processor_path_or_id.lstrip('./'), # Specify subfolder relative to main path
827
- **component_loading_kwargs
828
- )
829
- except Exception as e:
830
- # Fallback: try loading directly using the path/id from config if different
831
- if w2v_processor_path_or_id != "./wav2vec2-large-xlsr-53":
832
- try:
833
- logger.info(f"Retrying feature extractor load directly from: {w2v_processor_path_or_id}")
834
- feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
835
- w2v_processor_path_or_id,
836
- **component_loading_kwargs
837
- )
838
- except Exception as e2:
839
- raise OSError(f"Could not load feature extractor using main path + subfolder or directly from '{w2v_processor_path_or_id}'. Error: {e2}") from e
840
- else:
841
- raise OSError(f"Could not load feature extractor from subfolder '{w2v_processor_path_or_id}' within '{pretrained_model_name_or_path}'. Error: {e}")
842
-
843
-
844
- # --- 4. Instantiate processor ---
845
- # Pass the potentially loaded config object (or None)
846
- return cls(tokenizer=tokenizer, feature_extractor=feature_extractor, config=loaded_config)
847
-
848
-
849
- def save_pretrained(
850
- self,
851
- save_directory: Union[str, os.PathLike],
852
- push_to_hub: bool = False,
853
- **kwargs,
854
- ):
855
- """
856
- Save the processor's state (tokenizer and feature extractor files) to a directory.
857
-
858
- Args:
859
- save_directory (`str` or `os.PathLike`):
860
- Directory where the processor files will be saved.
861
- push_to_hub (`bool`, *optional*, defaults to `False`):
862
- Whether or not to push your model to the Hugging Face Hub after saving it.
863
- **kwargs:
864
- Additional key word arguments passed along to the `push_to_hub` method.
865
- """
866
- save_directory = Path(save_directory)
867
- save_directory.mkdir(parents=True, exist_ok=True)
868
-
869
- # Save tokenizer
870
- self.tokenizer.save_pretrained(str(save_directory), **kwargs)
871
-
872
- # Save feature extractor
873
- self.feature_extractor.save_pretrained(str(save_directory), **kwargs)
874
-
875
- # Save the main processor config (if it exists and has relevant info)
876
- # Note: The SparkTTSConfig is usually saved with the *model*, not the processor.
877
- # However, if the processor holds specific config needed for reloading *itself*,
878
- # it could be saved here. Usually, relying on the model's config is sufficient.
879
- # if self.config:
880
- # self.config.save_pretrained(str(save_directory)) # Example if needed
881
-
882
- logger.info(f"Processor components saved in {save_directory}")
883
-
884
- if push_to_hub:
885
- # Commit message and other hub kwargs can be passed via **kwargs
886
- commit_message = kwargs.pop("commit_message", "Save processor")
887
- return self.push_to_hub(save_directory, commit_message=commit_message, **kwargs)
888
-
889
- return str(save_directory) # Return path consistent with Mixin