Training in progress - step 500
Browse files- asr_config.py +24 -18
- asr_modeling.py +105 -77
- asr_pipeline.py +439 -6
- asr_processing.py +25 -3
- chat_template.jinja +89 -94
- config.json +165 -109
- generation_config.json +12 -8
- model.safetensors +2 -2
- preprocessor_config.json +1 -1
- projectors.py +276 -295
- tokenizer.json +2 -2
- tokenizer_config.json +0 -0
asr_config.py
CHANGED
|
@@ -14,29 +14,34 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 14 |
attn_implementation: str = "flash_attention_2",
|
| 15 |
model_dtype: str = "bfloat16",
|
| 16 |
num_beams: Optional[int] = None,
|
| 17 |
-
system_prompt: str = "
|
| 18 |
-
user_prompt: str = "
|
| 19 |
encoder_dim: Optional[int] = None,
|
| 20 |
llm_dim: Optional[int] = None,
|
|
|
|
|
|
|
|
|
|
| 21 |
audio_sample_rate: int = 16000,
|
| 22 |
projector_init_std: float = 0.02,
|
| 23 |
-
projector_pool_stride: int =
|
| 24 |
-
downsample_rate: int =
|
| 25 |
projector_hidden_dim: Optional[int] = None,
|
| 26 |
-
projector_type: str = "moe", # "moe", "swiglu", "residual", "shared_moe", "mlp"
|
| 27 |
projector_num_layers: int = 2, # Number of layers (for residual projector)
|
| 28 |
-
projector_dropout: float = 0.
|
| 29 |
-
projector_input_noise: float = 0.02, # Input noise for projector
|
| 30 |
# MoE-specific configuration
|
| 31 |
num_experts: int = 4, # Number of experts in MoE projectors
|
| 32 |
num_experts_per_tok: int = 2, # Top-k experts per token
|
| 33 |
router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
|
| 36 |
-
inference_diversity_penalty: float = 0.0,
|
| 37 |
inference_warmup_tokens: int = 10,
|
| 38 |
max_new_tokens: Optional[int] = None,
|
| 39 |
-
min_new_tokens: Optional[int] = None,
|
| 40 |
repetition_penalty: Optional[float] = None,
|
| 41 |
length_penalty: Optional[float] = None,
|
| 42 |
no_repeat_ngram_size: Optional[int] = None,
|
|
@@ -46,8 +51,7 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 46 |
# Set default generation parameters (greedy decoding only)
|
| 47 |
generation_defaults = {
|
| 48 |
"num_beams": 1,
|
| 49 |
-
"max_new_tokens":
|
| 50 |
-
"min_new_tokens": 0,
|
| 51 |
"repetition_penalty": 1.0,
|
| 52 |
"length_penalty": 1.0,
|
| 53 |
"no_repeat_ngram_size": 0,
|
|
@@ -65,6 +69,8 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 65 |
self.user_prompt = user_prompt
|
| 66 |
self.encoder_dim = encoder_dim
|
| 67 |
self.llm_dim = llm_dim
|
|
|
|
|
|
|
| 68 |
self.audio_sample_rate = audio_sample_rate
|
| 69 |
self.projector_init_std = projector_init_std
|
| 70 |
self.projector_pool_stride = projector_pool_stride
|
|
@@ -73,14 +79,17 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 73 |
self.projector_type = projector_type
|
| 74 |
self.projector_num_layers = projector_num_layers
|
| 75 |
self.projector_dropout = projector_dropout
|
| 76 |
-
self.projector_input_noise = projector_input_noise
|
| 77 |
# MoE-specific configuration
|
| 78 |
self.num_experts = num_experts
|
| 79 |
self.num_experts_per_tok = num_experts_per_tok
|
| 80 |
self.router_aux_loss_coef = router_aux_loss_coef
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
self.label_smoothing = label_smoothing
|
| 83 |
-
self.inference_diversity_penalty = inference_diversity_penalty
|
| 84 |
self.inference_warmup_tokens = inference_warmup_tokens
|
| 85 |
|
| 86 |
# Generation parameters (use explicit value if provided, else use default)
|
|
@@ -88,9 +97,6 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 88 |
self.max_new_tokens = (
|
| 89 |
max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
|
| 90 |
)
|
| 91 |
-
self.min_new_tokens = (
|
| 92 |
-
min_new_tokens if min_new_tokens is not None else generation_defaults["min_new_tokens"]
|
| 93 |
-
)
|
| 94 |
self.repetition_penalty = (
|
| 95 |
repetition_penalty
|
| 96 |
if repetition_penalty is not None
|
|
|
|
| 14 |
attn_implementation: str = "flash_attention_2",
|
| 15 |
model_dtype: str = "bfloat16",
|
| 16 |
num_beams: Optional[int] = None,
|
| 17 |
+
system_prompt: str = "You are a helpful assistant.",
|
| 18 |
+
user_prompt: str = "Please transcribe this English audio into text: <audio>",
|
| 19 |
encoder_dim: Optional[int] = None,
|
| 20 |
llm_dim: Optional[int] = None,
|
| 21 |
+
# Encoder conv layers: list of (padding, kernel_size, stride) tuples
|
| 22 |
+
# Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
|
| 23 |
+
encoder_conv_layers: Optional[list] = None,
|
| 24 |
audio_sample_rate: int = 16000,
|
| 25 |
projector_init_std: float = 0.02,
|
| 26 |
+
projector_pool_stride: int = 4,
|
| 27 |
+
downsample_rate: int = 5, # Granite default
|
| 28 |
projector_hidden_dim: Optional[int] = None,
|
| 29 |
+
projector_type: str = "moe", # "moe", "swiglu", "residual", "shared_moe", "mlp", "qformer"
|
| 30 |
projector_num_layers: int = 2, # Number of layers (for residual projector)
|
| 31 |
+
projector_dropout: float = 0.0, # Dropout rate for projector layers
|
|
|
|
| 32 |
# MoE-specific configuration
|
| 33 |
num_experts: int = 4, # Number of experts in MoE projectors
|
| 34 |
num_experts_per_tok: int = 2, # Top-k experts per token
|
| 35 |
router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
|
| 36 |
+
# QFormer-specific configuration (Granite defaults)
|
| 37 |
+
qformer_window_size: int = 15, # Window size for QFormer processing
|
| 38 |
+
qformer_hidden_size: Optional[int] = None, # QFormer hidden size (defaults to encoder_dim)
|
| 39 |
+
qformer_num_layers: int = 2, # Number of QFormer transformer layers
|
| 40 |
+
qformer_num_heads: int = 16, # Number of attention heads in QFormer
|
| 41 |
+
qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
|
| 42 |
label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
|
|
|
|
| 43 |
inference_warmup_tokens: int = 10,
|
| 44 |
max_new_tokens: Optional[int] = None,
|
|
|
|
| 45 |
repetition_penalty: Optional[float] = None,
|
| 46 |
length_penalty: Optional[float] = None,
|
| 47 |
no_repeat_ngram_size: Optional[int] = None,
|
|
|
|
| 51 |
# Set default generation parameters (greedy decoding only)
|
| 52 |
generation_defaults = {
|
| 53 |
"num_beams": 1,
|
| 54 |
+
"max_new_tokens": 256,
|
|
|
|
| 55 |
"repetition_penalty": 1.0,
|
| 56 |
"length_penalty": 1.0,
|
| 57 |
"no_repeat_ngram_size": 0,
|
|
|
|
| 69 |
self.user_prompt = user_prompt
|
| 70 |
self.encoder_dim = encoder_dim
|
| 71 |
self.llm_dim = llm_dim
|
| 72 |
+
# Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
|
| 73 |
+
self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
|
| 74 |
self.audio_sample_rate = audio_sample_rate
|
| 75 |
self.projector_init_std = projector_init_std
|
| 76 |
self.projector_pool_stride = projector_pool_stride
|
|
|
|
| 79 |
self.projector_type = projector_type
|
| 80 |
self.projector_num_layers = projector_num_layers
|
| 81 |
self.projector_dropout = projector_dropout
|
|
|
|
| 82 |
# MoE-specific configuration
|
| 83 |
self.num_experts = num_experts
|
| 84 |
self.num_experts_per_tok = num_experts_per_tok
|
| 85 |
self.router_aux_loss_coef = router_aux_loss_coef
|
| 86 |
+
# QFormer-specific configuration
|
| 87 |
+
self.qformer_window_size = qformer_window_size
|
| 88 |
+
self.qformer_hidden_size = qformer_hidden_size
|
| 89 |
+
self.qformer_num_layers = qformer_num_layers
|
| 90 |
+
self.qformer_num_heads = qformer_num_heads
|
| 91 |
+
self.qformer_intermediate_size = qformer_intermediate_size
|
| 92 |
self.label_smoothing = label_smoothing
|
|
|
|
| 93 |
self.inference_warmup_tokens = inference_warmup_tokens
|
| 94 |
|
| 95 |
# Generation parameters (use explicit value if provided, else use default)
|
|
|
|
| 97 |
self.max_new_tokens = (
|
| 98 |
max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
|
| 99 |
)
|
|
|
|
|
|
|
|
|
|
| 100 |
self.repetition_penalty = (
|
| 101 |
repetition_penalty
|
| 102 |
if repetition_penalty is not None
|
asr_modeling.py
CHANGED
|
@@ -13,9 +13,6 @@ from transformers import (
|
|
| 13 |
)
|
| 14 |
from transformers.generation import GenerationMixin
|
| 15 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 16 |
-
from transformers.models.whisper.modeling_whisper import (
|
| 17 |
-
_compute_mask_indices,
|
| 18 |
-
)
|
| 19 |
|
| 20 |
try:
|
| 21 |
from .asr_config import ASRConfig
|
|
@@ -75,6 +72,21 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 75 |
state_dict = load_file(model_file)
|
| 76 |
model.load_state_dict(state_dict, strict=False)
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
return model
|
| 79 |
finally:
|
| 80 |
cls._is_loading_from_pretrained = False
|
|
@@ -108,7 +120,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 108 |
self.generation_config.length_penalty = config.length_penalty
|
| 109 |
self.generation_config.repetition_penalty = config.repetition_penalty
|
| 110 |
self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
|
| 111 |
-
self.generation_config.eos_token_id =
|
|
|
|
|
|
|
|
|
|
| 112 |
self.generation_config.pad_token_id = self.tokenizer.pad_token_id
|
| 113 |
|
| 114 |
# Feature extractor for audio preprocessing
|
|
@@ -141,6 +156,22 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 141 |
full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 142 |
encoder = full_model.encoder
|
| 143 |
del full_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
else:
|
| 145 |
encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 146 |
|
|
@@ -210,12 +241,15 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 210 |
self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
|
| 211 |
|
| 212 |
# Add audio token
|
| 213 |
-
existing_special = self.tokenizer
|
| 214 |
if "<audio>" not in existing_special:
|
| 215 |
self.tokenizer.add_special_tokens(
|
| 216 |
{"additional_special_tokens": existing_special + ["<audio>"]}
|
| 217 |
)
|
| 218 |
self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=False)
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
|
| 221 |
self.tokenizer.padding_side = "right"
|
|
@@ -263,92 +297,80 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 263 |
except ImportError:
|
| 264 |
from asr_processing import ASRProcessor # type: ignore[no-redef]
|
| 265 |
|
| 266 |
-
return ASRProcessor(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
def state_dict(self, *args, **kwargs):
|
| 269 |
"""Only save trainable projector weights."""
|
| 270 |
return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
|
| 271 |
|
| 272 |
-
def
|
| 273 |
self,
|
| 274 |
-
|
| 275 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 276 |
) -> torch.Tensor:
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
return input_features
|
| 282 |
-
|
| 283 |
-
# Input shape: (batch_size, num_mel_bins, sequence_length) for Whisper
|
| 284 |
-
batch_size, hidden_size, sequence_length = input_features.size()
|
| 285 |
-
|
| 286 |
-
mask_time_prob = getattr(self.config, "mask_time_prob", 0.05)
|
| 287 |
-
mask_time_length = getattr(self.config, "mask_time_length", 10)
|
| 288 |
-
mask_feature_prob = getattr(self.config, "mask_feature_prob", 0.0)
|
| 289 |
-
mask_feature_length = getattr(self.config, "mask_feature_length", 10)
|
| 290 |
-
|
| 291 |
-
# Time masking
|
| 292 |
-
if mask_time_prob > 0:
|
| 293 |
-
mask_time_np = _compute_mask_indices(
|
| 294 |
-
(batch_size, sequence_length),
|
| 295 |
-
mask_prob=mask_time_prob,
|
| 296 |
-
mask_length=mask_time_length,
|
| 297 |
-
attention_mask=attention_mask,
|
| 298 |
-
min_masks=2,
|
| 299 |
-
)
|
| 300 |
-
mask_time_indices = torch.tensor(
|
| 301 |
-
mask_time_np, device=input_features.device, dtype=torch.bool
|
| 302 |
-
)
|
| 303 |
-
# Expand to cover all features: (batch, seq) -> (batch, features, seq)
|
| 304 |
-
mask_time_expanded = mask_time_indices[:, None].expand(-1, hidden_size, -1)
|
| 305 |
-
input_features = input_features.masked_fill(mask_time_expanded, 0.0)
|
| 306 |
-
|
| 307 |
-
# Feature masking
|
| 308 |
-
if mask_feature_prob > 0:
|
| 309 |
-
mask_feature_np = _compute_mask_indices(
|
| 310 |
-
(batch_size, hidden_size),
|
| 311 |
-
mask_prob=mask_feature_prob,
|
| 312 |
-
mask_length=mask_feature_length,
|
| 313 |
-
min_masks=2,
|
| 314 |
-
)
|
| 315 |
-
mask_feature_indices = torch.tensor(
|
| 316 |
-
mask_feature_np, device=input_features.device, dtype=torch.bool
|
| 317 |
-
)
|
| 318 |
-
# Expand: (batch, features) -> (batch, features, seq)
|
| 319 |
-
mask_feature_expanded = mask_feature_indices[:, :, None].expand(-1, -1, sequence_length)
|
| 320 |
-
input_features = input_features.masked_fill(mask_feature_expanded, 0.0)
|
| 321 |
|
| 322 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
def _encode_audio(
|
| 325 |
self,
|
| 326 |
audio_features: torch.Tensor,
|
| 327 |
-
audio_attention_mask:
|
| 328 |
) -> torch.Tensor:
|
| 329 |
"""Encode audio and project to LLM embedding space.
|
| 330 |
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
audio_features = self._apply_specaugment(audio_features, audio_attention_mask)
|
| 335 |
|
|
|
|
|
|
|
|
|
|
| 336 |
with torch.no_grad():
|
| 337 |
-
encoder_out = self.audio_tower(
|
| 338 |
-
input_features=audio_features, attention_mask=audio_attention_mask
|
| 339 |
-
)
|
| 340 |
hidden_states = encoder_out.last_hidden_state
|
| 341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
audio_embeds = self.projector(hidden_states)
|
| 343 |
|
| 344 |
-
#
|
| 345 |
-
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
def forward(
|
| 349 |
self,
|
| 350 |
input_ids: Optional[torch.Tensor] = None,
|
| 351 |
input_features: Optional[torch.Tensor] = None,
|
|
|
|
| 352 |
attention_mask: Optional[torch.Tensor] = None,
|
| 353 |
position_ids: Optional[torch.Tensor] = None,
|
| 354 |
past_key_values: Optional[torch.Tensor] = None,
|
|
@@ -356,7 +378,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 356 |
labels: Optional[torch.Tensor] = None,
|
| 357 |
use_cache: Optional[bool] = None,
|
| 358 |
cache_position: Optional[torch.Tensor] = None,
|
| 359 |
-
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 360 |
**kwargs,
|
| 361 |
) -> CausalLMOutputWithPast:
|
| 362 |
"""Forward pass for training and inference."""
|
|
@@ -408,23 +429,27 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 408 |
|
| 409 |
return model_inputs
|
| 410 |
|
| 411 |
-
def _get_num_audio_tokens(
|
| 412 |
-
|
|
|
|
|
|
|
|
|
|
| 413 |
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
MLP projector adds another stride-2 for 4x total downsampling
|
| 417 |
"""
|
| 418 |
-
|
| 419 |
-
|
|
|
|
|
|
|
| 420 |
|
| 421 |
@torch.no_grad()
|
| 422 |
def generate(
|
| 423 |
self,
|
| 424 |
input_ids: Optional[torch.Tensor] = None,
|
| 425 |
input_features: Optional[torch.Tensor] = None,
|
| 426 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 427 |
audio_attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
| 428 |
system_prompt: Optional[str] = None,
|
| 429 |
**generate_kwargs,
|
| 430 |
) -> torch.Tensor:
|
|
@@ -436,6 +461,8 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 436 |
"""
|
| 437 |
if input_features is None:
|
| 438 |
raise ValueError("input_features required for generation")
|
|
|
|
|
|
|
| 439 |
|
| 440 |
device = input_features.device
|
| 441 |
batch_size = input_features.shape[0]
|
|
@@ -445,7 +472,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 445 |
|
| 446 |
# If input_ids not provided, build prompt with correct number of audio tokens
|
| 447 |
if input_ids is None:
|
| 448 |
-
num_audio_tokens = self._get_num_audio_tokens(
|
| 449 |
audio_placeholder = "<audio>" * num_audio_tokens
|
| 450 |
|
| 451 |
system_prompt = system_prompt or self.system_prompt
|
|
@@ -455,12 +482,13 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 455 |
messages.append({"role": "system", "content": system_prompt})
|
| 456 |
messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
|
| 457 |
|
| 458 |
-
|
| 459 |
messages,
|
| 460 |
tokenize=True,
|
| 461 |
add_generation_prompt=True,
|
| 462 |
return_tensors="pt",
|
| 463 |
-
)
|
|
|
|
| 464 |
|
| 465 |
if input_ids.dim() == 1:
|
| 466 |
input_ids = input_ids.unsqueeze(0)
|
|
|
|
| 13 |
)
|
| 14 |
from transformers.generation import GenerationMixin
|
| 15 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
try:
|
| 18 |
from .asr_config import ASRConfig
|
|
|
|
| 72 |
state_dict = load_file(model_file)
|
| 73 |
model.load_state_dict(state_dict, strict=False)
|
| 74 |
|
| 75 |
+
# Load LoRA adapter if present
|
| 76 |
+
adapter_config = cached_file(
|
| 77 |
+
pretrained_model_name_or_path,
|
| 78 |
+
"adapter_config.json",
|
| 79 |
+
_raise_exceptions_for_missing_entries=False,
|
| 80 |
+
**cache_kwargs,
|
| 81 |
+
)
|
| 82 |
+
if adapter_config is not None:
|
| 83 |
+
from peft import PeftModel
|
| 84 |
+
|
| 85 |
+
# Pass original repo ID to PEFT, let it handle caching
|
| 86 |
+
model.language_model = PeftModel.from_pretrained(
|
| 87 |
+
model.language_model, pretrained_model_name_or_path, is_trainable=False
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
return model
|
| 91 |
finally:
|
| 92 |
cls._is_loading_from_pretrained = False
|
|
|
|
| 120 |
self.generation_config.length_penalty = config.length_penalty
|
| 121 |
self.generation_config.repetition_penalty = config.repetition_penalty
|
| 122 |
self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
|
| 123 |
+
self.generation_config.eos_token_id = [
|
| 124 |
+
self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
|
| 125 |
+
self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
|
| 126 |
+
]
|
| 127 |
self.generation_config.pad_token_id = self.tokenizer.pad_token_id
|
| 128 |
|
| 129 |
# Feature extractor for audio preprocessing
|
|
|
|
| 156 |
full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 157 |
encoder = full_model.encoder
|
| 158 |
del full_model
|
| 159 |
+
elif "glm" in config.audio_model_id.lower():
|
| 160 |
+
# GLM-ASR models use audio_tower as the encoder
|
| 161 |
+
# Requires transformers >= 5.x or installed from source
|
| 162 |
+
from transformers import AutoModelForSeq2SeqLM
|
| 163 |
+
|
| 164 |
+
full_model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 165 |
+
config.audio_model_id, trust_remote_code=True, **encoder_kwargs
|
| 166 |
+
)
|
| 167 |
+
# GLM stores encoder at audio_tower (GlmAsrEncoder)
|
| 168 |
+
encoder = full_model.audio_tower
|
| 169 |
+
# Clear references to free VRAM from the LLM decoder
|
| 170 |
+
full_model.language_model = None
|
| 171 |
+
full_model.multi_modal_projector = None
|
| 172 |
+
del full_model
|
| 173 |
+
if torch.cuda.is_available():
|
| 174 |
+
torch.cuda.empty_cache()
|
| 175 |
else:
|
| 176 |
encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 177 |
|
|
|
|
| 241 |
self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
|
| 242 |
|
| 243 |
# Add audio token
|
| 244 |
+
existing_special = getattr(self.tokenizer, "additional_special_tokens", None) or []
|
| 245 |
if "<audio>" not in existing_special:
|
| 246 |
self.tokenizer.add_special_tokens(
|
| 247 |
{"additional_special_tokens": existing_special + ["<audio>"]}
|
| 248 |
)
|
| 249 |
self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=False)
|
| 250 |
+
# Ensure lm_head stays tied to embeddings (e.g., SmolLM3)
|
| 251 |
+
if hasattr(self.language_model, "tie_weights"):
|
| 252 |
+
self.language_model.tie_weights()
|
| 253 |
|
| 254 |
self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
|
| 255 |
self.tokenizer.padding_side = "right"
|
|
|
|
| 297 |
except ImportError:
|
| 298 |
from asr_processing import ASRProcessor # type: ignore[no-redef]
|
| 299 |
|
| 300 |
+
return ASRProcessor(
|
| 301 |
+
feature_extractor=self.feature_extractor,
|
| 302 |
+
tokenizer=self.tokenizer,
|
| 303 |
+
projector=self.projector,
|
| 304 |
+
encoder_conv_layers=self.config.encoder_conv_layers,
|
| 305 |
+
)
|
| 306 |
|
| 307 |
def state_dict(self, *args, **kwargs):
|
| 308 |
"""Only save trainable projector weights."""
|
| 309 |
return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
|
| 310 |
|
| 311 |
+
def _compute_encoder_output_lengths(
|
| 312 |
self,
|
| 313 |
+
audio_attention_mask: torch.Tensor,
|
|
|
|
| 314 |
) -> torch.Tensor:
|
| 315 |
+
"""Compute per-sample encoder output lengths using conv layer formulas.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
+
Returns:
|
| 321 |
+
Tensor of encoder output lengths per sample (batch,)
|
| 322 |
+
"""
|
| 323 |
+
# Get mel frame lengths from attention mask
|
| 324 |
+
lengths = audio_attention_mask.sum(dim=-1)
|
| 325 |
+
|
| 326 |
+
# Apply conv layer formulas: output = (input + 2*pad - (kernel-1) - 1) // stride + 1
|
| 327 |
+
for padding, kernel_size, stride in self.config.encoder_conv_layers:
|
| 328 |
+
lengths = (lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1
|
| 329 |
+
|
| 330 |
+
return lengths
|
| 331 |
|
| 332 |
def _encode_audio(
|
| 333 |
self,
|
| 334 |
audio_features: torch.Tensor,
|
| 335 |
+
audio_attention_mask: torch.Tensor,
|
| 336 |
) -> torch.Tensor:
|
| 337 |
"""Encode audio and project to LLM embedding space.
|
| 338 |
|
| 339 |
+
Args:
|
| 340 |
+
audio_features: Mel spectrogram features (batch, n_mels, mel_len)
|
| 341 |
+
audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
|
|
|
|
| 342 |
|
| 343 |
+
Returns:
|
| 344 |
+
Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
|
| 345 |
+
"""
|
| 346 |
with torch.no_grad():
|
| 347 |
+
encoder_out = self.audio_tower(input_features=audio_features)
|
|
|
|
|
|
|
| 348 |
hidden_states = encoder_out.last_hidden_state
|
| 349 |
|
| 350 |
+
# Compute per-sample encoder output lengths using conv formulas
|
| 351 |
+
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 352 |
+
|
| 353 |
+
# Project to LLM space
|
| 354 |
audio_embeds = self.projector(hidden_states)
|
| 355 |
|
| 356 |
+
# Compute per-sample projector output lengths
|
| 357 |
+
projector_lengths = torch.tensor(
|
| 358 |
+
[self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
|
| 359 |
+
device=audio_embeds.device,
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
# Create valid mask for variable-length samples and extract only real embeddings
|
| 363 |
+
max_len = audio_embeds.shape[1]
|
| 364 |
+
valid_mask = (
|
| 365 |
+
torch.arange(max_len, device=audio_embeds.device)[None, :] < projector_lengths[:, None]
|
| 366 |
+
)
|
| 367 |
+
return audio_embeds[valid_mask]
|
| 368 |
|
| 369 |
def forward(
|
| 370 |
self,
|
| 371 |
input_ids: Optional[torch.Tensor] = None,
|
| 372 |
input_features: Optional[torch.Tensor] = None,
|
| 373 |
+
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 374 |
attention_mask: Optional[torch.Tensor] = None,
|
| 375 |
position_ids: Optional[torch.Tensor] = None,
|
| 376 |
past_key_values: Optional[torch.Tensor] = None,
|
|
|
|
| 378 |
labels: Optional[torch.Tensor] = None,
|
| 379 |
use_cache: Optional[bool] = None,
|
| 380 |
cache_position: Optional[torch.Tensor] = None,
|
|
|
|
| 381 |
**kwargs,
|
| 382 |
) -> CausalLMOutputWithPast:
|
| 383 |
"""Forward pass for training and inference."""
|
|
|
|
| 429 |
|
| 430 |
return model_inputs
|
| 431 |
|
| 432 |
+
def _get_num_audio_tokens(
|
| 433 |
+
self,
|
| 434 |
+
audio_attention_mask: torch.Tensor,
|
| 435 |
+
) -> int:
|
| 436 |
+
"""Calculate number of audio tokens based on actual audio length.
|
| 437 |
|
| 438 |
+
Uses attention mask to get real audio length, then computes:
|
| 439 |
+
mel_frames -> encoder_frames (via conv formulas) -> projector output tokens
|
|
|
|
| 440 |
"""
|
| 441 |
+
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 442 |
+
# Use max length for batch (all samples should have same token count for generation)
|
| 443 |
+
encoder_output_len = int(encoder_lengths.max().item())
|
| 444 |
+
return int(self.projector.get_output_length(encoder_output_len))
|
| 445 |
|
| 446 |
@torch.no_grad()
|
| 447 |
def generate(
|
| 448 |
self,
|
| 449 |
input_ids: Optional[torch.Tensor] = None,
|
| 450 |
input_features: Optional[torch.Tensor] = None,
|
|
|
|
| 451 |
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 452 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 453 |
system_prompt: Optional[str] = None,
|
| 454 |
**generate_kwargs,
|
| 455 |
) -> torch.Tensor:
|
|
|
|
| 461 |
"""
|
| 462 |
if input_features is None:
|
| 463 |
raise ValueError("input_features required for generation")
|
| 464 |
+
if audio_attention_mask is None:
|
| 465 |
+
raise ValueError("audio_attention_mask required for generation")
|
| 466 |
|
| 467 |
device = input_features.device
|
| 468 |
batch_size = input_features.shape[0]
|
|
|
|
| 472 |
|
| 473 |
# If input_ids not provided, build prompt with correct number of audio tokens
|
| 474 |
if input_ids is None:
|
| 475 |
+
num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
|
| 476 |
audio_placeholder = "<audio>" * num_audio_tokens
|
| 477 |
|
| 478 |
system_prompt = system_prompt or self.system_prompt
|
|
|
|
| 482 |
messages.append({"role": "system", "content": system_prompt})
|
| 483 |
messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
|
| 484 |
|
| 485 |
+
chat_result = self.tokenizer.apply_chat_template(
|
| 486 |
messages,
|
| 487 |
tokenize=True,
|
| 488 |
add_generation_prompt=True,
|
| 489 |
return_tensors="pt",
|
| 490 |
+
)
|
| 491 |
+
input_ids = chat_result.input_ids.to(device)
|
| 492 |
|
| 493 |
if input_ids.dim() == 1:
|
| 494 |
input_ids = input_ids.unsqueeze(0)
|
asr_pipeline.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
|
|
|
|
|
| 1 |
from typing import Any
|
| 2 |
|
|
|
|
| 3 |
import torch
|
| 4 |
import transformers
|
| 5 |
|
|
@@ -9,6 +12,284 @@ except ImportError:
|
|
| 9 |
from asr_modeling import ASRModel # type: ignore[no-redef]
|
| 10 |
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
| 13 |
"""ASR Pipeline for audio-to-text transcription."""
|
| 14 |
|
|
@@ -24,6 +305,131 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 24 |
super().__init__(
|
| 25 |
model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
|
| 26 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
def preprocess(self, inputs, **preprocess_params):
|
| 29 |
# Handle dict with "array" key (from datasets)
|
|
@@ -42,15 +448,12 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 42 |
# Extract audio features and is_last flag
|
| 43 |
is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
if input_features is not None:
|
| 48 |
-
input_features = input_features.to(self.model.device)
|
| 49 |
-
else:
|
| 50 |
-
input_features = model_inputs.to(self.model.device)
|
| 51 |
|
| 52 |
generated_ids = self.model.generate(
|
| 53 |
input_features=input_features,
|
|
|
|
| 54 |
**generate_kwargs,
|
| 55 |
)
|
| 56 |
|
|
@@ -71,4 +474,34 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 71 |
tokens = tokens[0]
|
| 72 |
|
| 73 |
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
return {"text": text}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from pathlib import Path
|
| 3 |
from typing import Any
|
| 4 |
|
| 5 |
+
import numpy as np
|
| 6 |
import torch
|
| 7 |
import transformers
|
| 8 |
|
|
|
|
| 12 |
from asr_modeling import ASRModel # type: ignore[no-redef]
|
| 13 |
|
| 14 |
|
| 15 |
+
class ForcedAligner:
|
| 16 |
+
"""Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2."""
|
| 17 |
+
|
| 18 |
+
_bundle = None
|
| 19 |
+
_model = None
|
| 20 |
+
_labels = None
|
| 21 |
+
_dictionary = None
|
| 22 |
+
|
| 23 |
+
@classmethod
|
| 24 |
+
def get_instance(cls, device: str = "cuda"):
|
| 25 |
+
if cls._model is None:
|
| 26 |
+
import torchaudio
|
| 27 |
+
|
| 28 |
+
cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
|
| 29 |
+
cls._model = cls._bundle.get_model().to(device)
|
| 30 |
+
cls._model.eval()
|
| 31 |
+
cls._labels = cls._bundle.get_labels()
|
| 32 |
+
cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
|
| 33 |
+
return cls._model, cls._labels, cls._dictionary
|
| 34 |
+
|
| 35 |
+
@classmethod
|
| 36 |
+
def align(
|
| 37 |
+
cls,
|
| 38 |
+
audio: np.ndarray,
|
| 39 |
+
text: str,
|
| 40 |
+
sample_rate: int = 16000,
|
| 41 |
+
language: str = "eng",
|
| 42 |
+
batch_size: int = 16,
|
| 43 |
+
) -> list[dict]:
|
| 44 |
+
"""Align transcript to audio and return word-level timestamps.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
audio: Audio waveform as numpy array
|
| 48 |
+
text: Transcript text to align
|
| 49 |
+
sample_rate: Audio sample rate (default 16000)
|
| 50 |
+
language: ISO-639-3 language code (default "eng" for English, unused)
|
| 51 |
+
batch_size: Batch size for alignment model (unused)
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
List of dicts with 'word', 'start', 'end' keys
|
| 55 |
+
"""
|
| 56 |
+
import torchaudio
|
| 57 |
+
from torchaudio.functional import forced_align, merge_tokens
|
| 58 |
+
|
| 59 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 60 |
+
model, labels, dictionary = cls.get_instance(device)
|
| 61 |
+
|
| 62 |
+
# Convert audio to tensor (copy to ensure array is writable)
|
| 63 |
+
if isinstance(audio, np.ndarray):
|
| 64 |
+
waveform = torch.from_numpy(audio.copy()).float()
|
| 65 |
+
else:
|
| 66 |
+
waveform = audio.clone().float()
|
| 67 |
+
|
| 68 |
+
# Ensure 2D (channels, time)
|
| 69 |
+
if waveform.dim() == 1:
|
| 70 |
+
waveform = waveform.unsqueeze(0)
|
| 71 |
+
|
| 72 |
+
# Resample if needed (wav2vec2 expects 16kHz)
|
| 73 |
+
if sample_rate != cls._bundle.sample_rate:
|
| 74 |
+
waveform = torchaudio.functional.resample(
|
| 75 |
+
waveform, sample_rate, cls._bundle.sample_rate
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
waveform = waveform.to(device)
|
| 79 |
+
|
| 80 |
+
# Get emissions from model
|
| 81 |
+
with torch.inference_mode():
|
| 82 |
+
emissions, _ = model(waveform)
|
| 83 |
+
emissions = torch.log_softmax(emissions, dim=-1)
|
| 84 |
+
|
| 85 |
+
emission = emissions[0].cpu()
|
| 86 |
+
|
| 87 |
+
# Normalize text: uppercase, keep only valid characters
|
| 88 |
+
transcript = text.upper()
|
| 89 |
+
# Build tokens from transcript
|
| 90 |
+
tokens = []
|
| 91 |
+
for char in transcript:
|
| 92 |
+
if char in dictionary:
|
| 93 |
+
tokens.append(dictionary[char])
|
| 94 |
+
elif char == " ":
|
| 95 |
+
tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
|
| 96 |
+
|
| 97 |
+
if not tokens:
|
| 98 |
+
return []
|
| 99 |
+
|
| 100 |
+
targets = torch.tensor([tokens], dtype=torch.int32)
|
| 101 |
+
|
| 102 |
+
# Run forced alignment
|
| 103 |
+
# Note: forced_align is deprecated in torchaudio 2.6+ and will be removed in 2.9 (late 2025)
|
| 104 |
+
# No official replacement announced yet. See https://github.com/pytorch/audio/issues/3902
|
| 105 |
+
aligned_tokens, scores = forced_align(emission.unsqueeze(0), targets, blank=0)
|
| 106 |
+
|
| 107 |
+
# Use torchaudio's merge_tokens to get token spans (removes blanks and merges repeats)
|
| 108 |
+
token_spans = merge_tokens(aligned_tokens[0], scores[0])
|
| 109 |
+
|
| 110 |
+
# Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
|
| 111 |
+
frame_duration = 320 / cls._bundle.sample_rate
|
| 112 |
+
|
| 113 |
+
# Group token spans into words based on pipe separator
|
| 114 |
+
words = text.split()
|
| 115 |
+
word_timestamps = []
|
| 116 |
+
current_word_start = None
|
| 117 |
+
current_word_end = None
|
| 118 |
+
word_idx = 0
|
| 119 |
+
|
| 120 |
+
for span in token_spans:
|
| 121 |
+
token_char = labels[span.token]
|
| 122 |
+
if token_char == "|": # Word separator
|
| 123 |
+
if current_word_start is not None and word_idx < len(words):
|
| 124 |
+
word_timestamps.append(
|
| 125 |
+
{
|
| 126 |
+
"word": words[word_idx],
|
| 127 |
+
"start": current_word_start * frame_duration,
|
| 128 |
+
"end": current_word_end * frame_duration,
|
| 129 |
+
}
|
| 130 |
+
)
|
| 131 |
+
word_idx += 1
|
| 132 |
+
current_word_start = None
|
| 133 |
+
current_word_end = None
|
| 134 |
+
else:
|
| 135 |
+
if current_word_start is None:
|
| 136 |
+
current_word_start = span.start
|
| 137 |
+
current_word_end = span.end
|
| 138 |
+
|
| 139 |
+
# Don't forget the last word
|
| 140 |
+
if current_word_start is not None and word_idx < len(words):
|
| 141 |
+
word_timestamps.append(
|
| 142 |
+
{
|
| 143 |
+
"word": words[word_idx],
|
| 144 |
+
"start": current_word_start * frame_duration,
|
| 145 |
+
"end": current_word_end * frame_duration,
|
| 146 |
+
}
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
return word_timestamps
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class SpeakerDiarizer:
|
| 153 |
+
"""Lazy-loaded speaker diarization using pyannote-audio."""
|
| 154 |
+
|
| 155 |
+
_pipeline = None
|
| 156 |
+
|
| 157 |
+
@classmethod
|
| 158 |
+
def get_instance(cls, hf_token: str | None = None):
|
| 159 |
+
"""Get or create the diarization pipeline.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
hf_token: HuggingFace token with access to pyannote models.
|
| 163 |
+
Can also be set via HF_TOKEN environment variable.
|
| 164 |
+
"""
|
| 165 |
+
if cls._pipeline is None:
|
| 166 |
+
from pyannote.audio import Pipeline
|
| 167 |
+
|
| 168 |
+
cls._pipeline = Pipeline.from_pretrained(
|
| 169 |
+
"pyannote/speaker-diarization-3.1",
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Move to GPU if available
|
| 173 |
+
if torch.cuda.is_available():
|
| 174 |
+
cls._pipeline.to(torch.device("cuda"))
|
| 175 |
+
elif torch.backends.mps.is_available():
|
| 176 |
+
cls._pipeline.to(torch.device("mps"))
|
| 177 |
+
|
| 178 |
+
return cls._pipeline
|
| 179 |
+
|
| 180 |
+
@classmethod
|
| 181 |
+
def diarize(
|
| 182 |
+
cls,
|
| 183 |
+
audio: np.ndarray | str,
|
| 184 |
+
sample_rate: int = 16000,
|
| 185 |
+
num_speakers: int | None = None,
|
| 186 |
+
min_speakers: int | None = None,
|
| 187 |
+
max_speakers: int | None = None,
|
| 188 |
+
hf_token: str | None = None,
|
| 189 |
+
) -> list[dict]:
|
| 190 |
+
"""Run speaker diarization on audio.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
audio: Audio waveform as numpy array or path to audio file
|
| 194 |
+
sample_rate: Audio sample rate (default 16000)
|
| 195 |
+
num_speakers: Exact number of speakers (if known)
|
| 196 |
+
min_speakers: Minimum number of speakers
|
| 197 |
+
max_speakers: Maximum number of speakers
|
| 198 |
+
hf_token: HuggingFace token for pyannote models
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
List of dicts with 'speaker', 'start', 'end' keys
|
| 202 |
+
"""
|
| 203 |
+
pipeline = cls.get_instance(hf_token)
|
| 204 |
+
|
| 205 |
+
# Prepare audio input
|
| 206 |
+
if isinstance(audio, np.ndarray):
|
| 207 |
+
# pyannote expects {"waveform": tensor, "sample_rate": int}
|
| 208 |
+
waveform = torch.from_numpy(audio).unsqueeze(0) # Add channel dim
|
| 209 |
+
if waveform.dim() == 1:
|
| 210 |
+
waveform = waveform.unsqueeze(0)
|
| 211 |
+
audio_input = {"waveform": waveform, "sample_rate": sample_rate}
|
| 212 |
+
else:
|
| 213 |
+
# File path
|
| 214 |
+
audio_input = audio
|
| 215 |
+
|
| 216 |
+
# Run diarization
|
| 217 |
+
diarization_args = {}
|
| 218 |
+
if num_speakers is not None:
|
| 219 |
+
diarization_args["num_speakers"] = num_speakers
|
| 220 |
+
if min_speakers is not None:
|
| 221 |
+
diarization_args["min_speakers"] = min_speakers
|
| 222 |
+
if max_speakers is not None:
|
| 223 |
+
diarization_args["max_speakers"] = max_speakers
|
| 224 |
+
|
| 225 |
+
diarization = pipeline(audio_input, **diarization_args)
|
| 226 |
+
|
| 227 |
+
# Handle different pyannote return types
|
| 228 |
+
# pyannote 3.x returns DiarizeOutput dataclass, older versions return Annotation
|
| 229 |
+
if hasattr(diarization, "itertracks"):
|
| 230 |
+
annotation = diarization
|
| 231 |
+
elif hasattr(diarization, "speaker_diarization"):
|
| 232 |
+
# pyannote 3.x DiarizeOutput dataclass
|
| 233 |
+
annotation = diarization.speaker_diarization
|
| 234 |
+
elif isinstance(diarization, tuple):
|
| 235 |
+
# Some versions return (annotation, embeddings) tuple
|
| 236 |
+
annotation = diarization[0]
|
| 237 |
+
else:
|
| 238 |
+
raise TypeError(f"Unexpected diarization output type: {type(diarization)}")
|
| 239 |
+
|
| 240 |
+
# Convert to simple format
|
| 241 |
+
segments = []
|
| 242 |
+
for turn, _, speaker in annotation.itertracks(yield_label=True):
|
| 243 |
+
segments.append(
|
| 244 |
+
{
|
| 245 |
+
"speaker": speaker,
|
| 246 |
+
"start": turn.start,
|
| 247 |
+
"end": turn.end,
|
| 248 |
+
}
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
return segments
|
| 252 |
+
|
| 253 |
+
@classmethod
|
| 254 |
+
def assign_speakers_to_words(
|
| 255 |
+
cls,
|
| 256 |
+
words: list[dict],
|
| 257 |
+
speaker_segments: list[dict],
|
| 258 |
+
) -> list[dict]:
|
| 259 |
+
"""Assign speaker labels to words based on timestamp overlap.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
words: List of word dicts with 'word', 'start', 'end' keys
|
| 263 |
+
speaker_segments: List of speaker dicts with 'speaker', 'start', 'end' keys
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
Words list with 'speaker' key added to each word
|
| 267 |
+
"""
|
| 268 |
+
for word in words:
|
| 269 |
+
word_mid = (word["start"] + word["end"]) / 2
|
| 270 |
+
|
| 271 |
+
# Find the speaker segment that contains this word's midpoint
|
| 272 |
+
best_speaker = None
|
| 273 |
+
for seg in speaker_segments:
|
| 274 |
+
if seg["start"] <= word_mid <= seg["end"]:
|
| 275 |
+
best_speaker = seg["speaker"]
|
| 276 |
+
break
|
| 277 |
+
|
| 278 |
+
# If no exact match, find closest segment
|
| 279 |
+
if best_speaker is None and speaker_segments:
|
| 280 |
+
min_dist = float("inf")
|
| 281 |
+
for seg in speaker_segments:
|
| 282 |
+
seg_mid = (seg["start"] + seg["end"]) / 2
|
| 283 |
+
dist = abs(word_mid - seg_mid)
|
| 284 |
+
if dist < min_dist:
|
| 285 |
+
min_dist = dist
|
| 286 |
+
best_speaker = seg["speaker"]
|
| 287 |
+
|
| 288 |
+
word["speaker"] = best_speaker
|
| 289 |
+
|
| 290 |
+
return words
|
| 291 |
+
|
| 292 |
+
|
| 293 |
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
| 294 |
"""ASR Pipeline for audio-to-text transcription."""
|
| 295 |
|
|
|
|
| 305 |
super().__init__(
|
| 306 |
model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
|
| 307 |
)
|
| 308 |
+
self._current_audio = None
|
| 309 |
+
|
| 310 |
+
def _sanitize_parameters(self, **kwargs):
|
| 311 |
+
"""Intercept our custom parameters before parent class validates them."""
|
| 312 |
+
# Remove our custom parameters so parent doesn't see them
|
| 313 |
+
kwargs.pop("return_timestamps", None)
|
| 314 |
+
kwargs.pop("return_speakers", None)
|
| 315 |
+
kwargs.pop("num_speakers", None)
|
| 316 |
+
kwargs.pop("min_speakers", None)
|
| 317 |
+
kwargs.pop("max_speakers", None)
|
| 318 |
+
kwargs.pop("hf_token", None)
|
| 319 |
+
|
| 320 |
+
return super()._sanitize_parameters(**kwargs)
|
| 321 |
+
|
| 322 |
+
def __call__(
|
| 323 |
+
self,
|
| 324 |
+
inputs,
|
| 325 |
+
**kwargs,
|
| 326 |
+
):
|
| 327 |
+
"""Transcribe audio with optional word-level timestamps and speaker diarization.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
inputs: Audio input (file path, dict with array/sampling_rate, etc.)
|
| 331 |
+
return_timestamps: If True, return word-level timestamps using forced alignment
|
| 332 |
+
return_speakers: If True, return speaker labels for each word
|
| 333 |
+
num_speakers: Exact number of speakers (if known, for diarization)
|
| 334 |
+
min_speakers: Minimum number of speakers (for diarization)
|
| 335 |
+
max_speakers: Maximum number of speakers (for diarization)
|
| 336 |
+
hf_token: HuggingFace token for pyannote models (or set HF_TOKEN env var)
|
| 337 |
+
**kwargs: Additional arguments passed to the pipeline
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
Dict with 'text' key, 'words' key if return_timestamps=True,
|
| 341 |
+
and speaker labels on words if return_speakers=True
|
| 342 |
+
"""
|
| 343 |
+
# Extract our params before super().__call__ (which will also call _sanitize_parameters)
|
| 344 |
+
return_timestamps = kwargs.pop("return_timestamps", False)
|
| 345 |
+
return_speakers = kwargs.pop("return_speakers", False)
|
| 346 |
+
diarization_params = {
|
| 347 |
+
"num_speakers": kwargs.pop("num_speakers", None),
|
| 348 |
+
"min_speakers": kwargs.pop("min_speakers", None),
|
| 349 |
+
"max_speakers": kwargs.pop("max_speakers", None),
|
| 350 |
+
"hf_token": kwargs.pop("hf_token", None),
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
if return_speakers:
|
| 354 |
+
return_timestamps = True
|
| 355 |
+
|
| 356 |
+
# Store audio for timestamp alignment and diarization
|
| 357 |
+
if return_timestamps or return_speakers:
|
| 358 |
+
self._current_audio = self._extract_audio(inputs)
|
| 359 |
+
|
| 360 |
+
# Run standard transcription
|
| 361 |
+
result = super().__call__(inputs, **kwargs)
|
| 362 |
+
|
| 363 |
+
# Add timestamps if requested
|
| 364 |
+
if return_timestamps and self._current_audio is not None:
|
| 365 |
+
text = result.get("text", "")
|
| 366 |
+
if text:
|
| 367 |
+
try:
|
| 368 |
+
words = ForcedAligner.align(
|
| 369 |
+
self._current_audio["array"],
|
| 370 |
+
text,
|
| 371 |
+
sample_rate=self._current_audio.get("sampling_rate", 16000),
|
| 372 |
+
)
|
| 373 |
+
result["words"] = words
|
| 374 |
+
except Exception as e:
|
| 375 |
+
result["words"] = []
|
| 376 |
+
result["timestamp_error"] = str(e)
|
| 377 |
+
else:
|
| 378 |
+
result["words"] = []
|
| 379 |
+
|
| 380 |
+
# Add speaker diarization if requested
|
| 381 |
+
if return_speakers and self._current_audio is not None:
|
| 382 |
+
try:
|
| 383 |
+
# Run diarization
|
| 384 |
+
speaker_segments = SpeakerDiarizer.diarize(
|
| 385 |
+
self._current_audio["array"],
|
| 386 |
+
sample_rate=self._current_audio.get("sampling_rate", 16000),
|
| 387 |
+
**{k: v for k, v in diarization_params.items() if v is not None},
|
| 388 |
+
)
|
| 389 |
+
result["speaker_segments"] = speaker_segments
|
| 390 |
+
|
| 391 |
+
# Assign speakers to words
|
| 392 |
+
if result.get("words"):
|
| 393 |
+
result["words"] = SpeakerDiarizer.assign_speakers_to_words(
|
| 394 |
+
result["words"],
|
| 395 |
+
speaker_segments,
|
| 396 |
+
)
|
| 397 |
+
except Exception as e:
|
| 398 |
+
result["speaker_segments"] = []
|
| 399 |
+
result["diarization_error"] = str(e)
|
| 400 |
+
|
| 401 |
+
# Clean up
|
| 402 |
+
self._current_audio = None
|
| 403 |
+
|
| 404 |
+
return result
|
| 405 |
+
|
| 406 |
+
def _extract_audio(self, inputs) -> dict | None:
|
| 407 |
+
"""Extract audio array from various input formats using HF utilities."""
|
| 408 |
+
from transformers.pipelines.audio_utils import ffmpeg_read
|
| 409 |
+
|
| 410 |
+
if isinstance(inputs, dict):
|
| 411 |
+
if "array" in inputs:
|
| 412 |
+
return {
|
| 413 |
+
"array": inputs["array"],
|
| 414 |
+
"sampling_rate": inputs.get("sampling_rate", 16000),
|
| 415 |
+
}
|
| 416 |
+
if "raw" in inputs:
|
| 417 |
+
return {
|
| 418 |
+
"array": inputs["raw"],
|
| 419 |
+
"sampling_rate": inputs.get("sampling_rate", 16000),
|
| 420 |
+
}
|
| 421 |
+
elif isinstance(inputs, str):
|
| 422 |
+
# File path - load audio using ffmpeg (same as HF pipeline)
|
| 423 |
+
with Path(inputs).open("rb") as f:
|
| 424 |
+
audio = ffmpeg_read(f.read(), sampling_rate=16000)
|
| 425 |
+
return {"array": audio, "sampling_rate": 16000}
|
| 426 |
+
elif isinstance(inputs, bytes):
|
| 427 |
+
audio = ffmpeg_read(inputs, sampling_rate=16000)
|
| 428 |
+
return {"array": audio, "sampling_rate": 16000}
|
| 429 |
+
elif isinstance(inputs, np.ndarray):
|
| 430 |
+
return {"array": inputs, "sampling_rate": 16000}
|
| 431 |
+
|
| 432 |
+
return None
|
| 433 |
|
| 434 |
def preprocess(self, inputs, **preprocess_params):
|
| 435 |
# Handle dict with "array" key (from datasets)
|
|
|
|
| 448 |
# Extract audio features and is_last flag
|
| 449 |
is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
|
| 450 |
|
| 451 |
+
input_features = model_inputs["input_features"].to(self.model.device)
|
| 452 |
+
audio_attention_mask = model_inputs["attention_mask"].to(self.model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
|
| 454 |
generated_ids = self.model.generate(
|
| 455 |
input_features=input_features,
|
| 456 |
+
audio_attention_mask=audio_attention_mask,
|
| 457 |
**generate_kwargs,
|
| 458 |
)
|
| 459 |
|
|
|
|
| 474 |
tokens = tokens[0]
|
| 475 |
|
| 476 |
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
| 477 |
+
# Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
|
| 478 |
+
text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
|
| 479 |
+
# Truncate if a word repeats more than 3 times consecutively
|
| 480 |
+
text = self._truncate_repetitions(text, max_repeats=3)
|
| 481 |
return {"text": text}
|
| 482 |
+
|
| 483 |
+
def _truncate_repetitions(self, text: str, max_repeats: int = 3) -> str:
|
| 484 |
+
"""Truncate text when a word repeats more than max_repeats times consecutively.
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
text: Input text to check for repetitions
|
| 488 |
+
max_repeats: Maximum allowed consecutive repetitions (default 3)
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
Truncated text if repetition detected, otherwise original text
|
| 492 |
+
"""
|
| 493 |
+
words = text.split()
|
| 494 |
+
if len(words) <= max_repeats:
|
| 495 |
+
return text
|
| 496 |
+
|
| 497 |
+
repeat_count = 1
|
| 498 |
+
for i in range(1, len(words)):
|
| 499 |
+
if words[i].lower() == words[i - 1].lower():
|
| 500 |
+
repeat_count += 1
|
| 501 |
+
if repeat_count > max_repeats:
|
| 502 |
+
# Keep up to max_repeats of the repeated word
|
| 503 |
+
return " ".join(words[:i])
|
| 504 |
+
else:
|
| 505 |
+
repeat_count = 1
|
| 506 |
+
|
| 507 |
+
return text
|
asr_processing.py
CHANGED
|
@@ -18,11 +18,28 @@ class ASRProcessor(ProcessorMixin):
|
|
| 18 |
tokenizer_class = "AutoTokenizer"
|
| 19 |
AUDIO_TOKEN = "<audio>"
|
| 20 |
TRANSCRIBE_PROMPT = "Transcribe: "
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
self.feature_extractor = feature_extractor
|
| 24 |
self.tokenizer = tokenizer
|
| 25 |
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
def __call__(
|
| 28 |
self,
|
|
@@ -50,12 +67,17 @@ class ASRProcessor(ProcessorMixin):
|
|
| 50 |
audio_inputs = self.feature_extractor(
|
| 51 |
audio,
|
| 52 |
sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
|
|
|
|
| 53 |
return_tensors=return_tensors,
|
| 54 |
**kwargs,
|
| 55 |
)
|
| 56 |
result["input_features"] = audio_inputs["input_features"]
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
else:
|
| 60 |
num_audio_tokens = 0
|
| 61 |
|
|
|
|
| 18 |
tokenizer_class = "AutoTokenizer"
|
| 19 |
AUDIO_TOKEN = "<audio>"
|
| 20 |
TRANSCRIBE_PROMPT = "Transcribe: "
|
| 21 |
+
# Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
|
| 22 |
+
DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
|
| 23 |
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
feature_extractor,
|
| 27 |
+
tokenizer,
|
| 28 |
+
projector=None,
|
| 29 |
+
encoder_conv_layers: Optional[list] = None,
|
| 30 |
+
):
|
| 31 |
self.feature_extractor = feature_extractor
|
| 32 |
self.tokenizer = tokenizer
|
| 33 |
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
|
| 34 |
+
self.projector = projector
|
| 35 |
+
self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS
|
| 36 |
+
|
| 37 |
+
def _compute_encoder_output_length(self, mel_length: int) -> int:
|
| 38 |
+
"""Compute encoder output length using conv layer formulas."""
|
| 39 |
+
length = mel_length
|
| 40 |
+
for padding, kernel_size, stride in self.encoder_conv_layers:
|
| 41 |
+
length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
|
| 42 |
+
return length
|
| 43 |
|
| 44 |
def __call__(
|
| 45 |
self,
|
|
|
|
| 67 |
audio_inputs = self.feature_extractor(
|
| 68 |
audio,
|
| 69 |
sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
|
| 70 |
+
return_attention_mask=True,
|
| 71 |
return_tensors=return_tensors,
|
| 72 |
**kwargs,
|
| 73 |
)
|
| 74 |
result["input_features"] = audio_inputs["input_features"]
|
| 75 |
+
result["audio_attention_mask"] = audio_inputs["attention_mask"]
|
| 76 |
+
|
| 77 |
+
# Use actual audio length (from attention mask) for token count
|
| 78 |
+
real_mel_len = int(audio_inputs["attention_mask"].sum(dim=-1).max().item())
|
| 79 |
+
encoder_output_len = self._compute_encoder_output_length(real_mel_len)
|
| 80 |
+
num_audio_tokens = self.projector.get_output_length(encoder_output_len)
|
| 81 |
else:
|
| 82 |
num_audio_tokens = 0
|
| 83 |
|
chat_template.jinja
CHANGED
|
@@ -1,94 +1,89 @@
|
|
| 1 |
-
{
|
| 2 |
-
{
|
| 3 |
-
{%-
|
| 4 |
-
{
|
| 5 |
-
|
| 6 |
-
{#
|
| 7 |
-
{%-
|
| 8 |
-
|
| 9 |
-
{
|
| 10 |
-
|
| 11 |
-
{
|
| 12 |
-
|
| 13 |
-
{
|
| 14 |
-
{{-
|
| 15 |
-
|
| 16 |
-
{%-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
{%- set
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
{%-
|
| 25 |
-
|
| 26 |
-
{%- if
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
{%-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
{%-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
{%-
|
| 70 |
-
{
|
| 71 |
-
{
|
| 72 |
-
{%-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
{
|
| 77 |
-
{
|
| 78 |
-
|
| 79 |
-
{%-
|
| 80 |
-
{{
|
| 81 |
-
{%- endif
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
{
|
| 86 |
-
{%-
|
| 87 |
-
{
|
| 88 |
-
{%-
|
| 89 |
-
|
| 90 |
-
{{ "<|im_start|>assistant\n" }}
|
| 91 |
-
{%- else -%}
|
| 92 |
-
{{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" }}
|
| 93 |
-
{%- endif -%}
|
| 94 |
-
{%- endif -%}
|
|
|
|
| 1 |
+
{%- if tools %}
|
| 2 |
+
{{- '<|im_start|>system\n' }}
|
| 3 |
+
{%- if messages[0].role == 'system' %}
|
| 4 |
+
{{- messages[0].content + '\n\n' }}
|
| 5 |
+
{%- endif %}
|
| 6 |
+
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
| 7 |
+
{%- for tool in tools %}
|
| 8 |
+
{{- "\n" }}
|
| 9 |
+
{{- tool | tojson }}
|
| 10 |
+
{%- endfor %}
|
| 11 |
+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
| 12 |
+
{%- else %}
|
| 13 |
+
{%- if messages[0].role == 'system' %}
|
| 14 |
+
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
| 15 |
+
{%- endif %}
|
| 16 |
+
{%- endif %}
|
| 17 |
+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
| 18 |
+
{%- for message in messages[::-1] %}
|
| 19 |
+
{%- set index = (messages|length - 1) - loop.index0 %}
|
| 20 |
+
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
| 21 |
+
{%- set ns.multi_step_tool = false %}
|
| 22 |
+
{%- set ns.last_query_index = index %}
|
| 23 |
+
{%- endif %}
|
| 24 |
+
{%- endfor %}
|
| 25 |
+
{%- for message in messages %}
|
| 26 |
+
{%- if message.content is string %}
|
| 27 |
+
{%- set content = message.content %}
|
| 28 |
+
{%- else %}
|
| 29 |
+
{%- set content = '' %}
|
| 30 |
+
{%- endif %}
|
| 31 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
| 32 |
+
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
| 33 |
+
{%- elif message.role == "assistant" %}
|
| 34 |
+
{%- set reasoning_content = '' %}
|
| 35 |
+
{%- if message.reasoning_content is string %}
|
| 36 |
+
{%- set reasoning_content = message.reasoning_content %}
|
| 37 |
+
{%- else %}
|
| 38 |
+
{%- if '</think>' in content %}
|
| 39 |
+
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
| 40 |
+
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
| 41 |
+
{%- endif %}
|
| 42 |
+
{%- endif %}
|
| 43 |
+
{%- if loop.index0 > ns.last_query_index %}
|
| 44 |
+
{%- if loop.last or (not loop.last and reasoning_content) %}
|
| 45 |
+
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
| 46 |
+
{%- else %}
|
| 47 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 48 |
+
{%- endif %}
|
| 49 |
+
{%- else %}
|
| 50 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 51 |
+
{%- endif %}
|
| 52 |
+
{%- if message.tool_calls %}
|
| 53 |
+
{%- for tool_call in message.tool_calls %}
|
| 54 |
+
{%- if (loop.first and content) or (not loop.first) %}
|
| 55 |
+
{{- '\n' }}
|
| 56 |
+
{%- endif %}
|
| 57 |
+
{%- if tool_call.function %}
|
| 58 |
+
{%- set tool_call = tool_call.function %}
|
| 59 |
+
{%- endif %}
|
| 60 |
+
{{- '<tool_call>\n{"name": "' }}
|
| 61 |
+
{{- tool_call.name }}
|
| 62 |
+
{{- '", "arguments": ' }}
|
| 63 |
+
{%- if tool_call.arguments is string %}
|
| 64 |
+
{{- tool_call.arguments }}
|
| 65 |
+
{%- else %}
|
| 66 |
+
{{- tool_call.arguments | tojson }}
|
| 67 |
+
{%- endif %}
|
| 68 |
+
{{- '}\n</tool_call>' }}
|
| 69 |
+
{%- endfor %}
|
| 70 |
+
{%- endif %}
|
| 71 |
+
{{- '<|im_end|>\n' }}
|
| 72 |
+
{%- elif message.role == "tool" %}
|
| 73 |
+
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
| 74 |
+
{{- '<|im_start|>user' }}
|
| 75 |
+
{%- endif %}
|
| 76 |
+
{{- '\n<tool_response>\n' }}
|
| 77 |
+
{{- content }}
|
| 78 |
+
{{- '\n</tool_response>' }}
|
| 79 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 80 |
+
{{- '<|im_end|>\n' }}
|
| 81 |
+
{%- endif %}
|
| 82 |
+
{%- endif %}
|
| 83 |
+
{%- endfor %}
|
| 84 |
+
{%- if add_generation_prompt %}
|
| 85 |
+
{{- '<|im_start|>assistant\n' }}
|
| 86 |
+
{%- if enable_thinking is defined and enable_thinking is false %}
|
| 87 |
+
{{- '<think>\n\n</think>\n\n' }}
|
| 88 |
+
{%- endif %}
|
| 89 |
+
{%- endif %}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.json
CHANGED
|
@@ -4,49 +4,126 @@
|
|
| 4 |
],
|
| 5 |
"attn_implementation": "flash_attention_2",
|
| 6 |
"audio_config": {
|
| 7 |
-
"_name_or_path": "
|
| 8 |
-
"activation_dropout": 0.0,
|
| 9 |
-
"activation_function": "gelu",
|
| 10 |
-
"apply_spec_augment": false,
|
| 11 |
"architectures": [
|
| 12 |
-
"
|
| 13 |
],
|
| 14 |
-
"
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
"dtype": "bfloat16",
|
| 25 |
-
"
|
| 26 |
-
"
|
| 27 |
-
"encoder_layerdrop": 0.0,
|
| 28 |
-
"encoder_layers": 32,
|
| 29 |
-
"eos_token_id": 50257,
|
| 30 |
-
"init_std": 0.02,
|
| 31 |
-
"mask_feature_length": 10,
|
| 32 |
-
"mask_feature_min_masks": 0,
|
| 33 |
-
"mask_feature_prob": 0.0,
|
| 34 |
-
"mask_time_length": 10,
|
| 35 |
-
"mask_time_min_masks": 2,
|
| 36 |
-
"mask_time_prob": 0.05,
|
| 37 |
-
"max_source_positions": 1500,
|
| 38 |
-
"max_target_positions": 448,
|
| 39 |
-
"median_filter_width": 7,
|
| 40 |
-
"model_type": "whisper",
|
| 41 |
-
"num_hidden_layers": 32,
|
| 42 |
"num_mel_bins": 128,
|
| 43 |
-
"
|
| 44 |
-
"
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
},
|
| 49 |
-
"audio_model_id": "
|
| 50 |
"audio_sample_rate": 16000,
|
| 51 |
"auto_map": {
|
| 52 |
"AutoConfig": "asr_config.ASRConfig",
|
|
@@ -64,17 +141,34 @@
|
|
| 64 |
"type": "audio"
|
| 65 |
}
|
| 66 |
},
|
| 67 |
-
"downsample_rate":
|
| 68 |
"dtype": "bfloat16",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
"encoder_dim": 1280,
|
| 70 |
-
"inference_diversity_penalty": 0.0,
|
| 71 |
"inference_warmup_tokens": 10,
|
| 72 |
"label_smoothing": 0.0,
|
|
|
|
| 73 |
"llm_dim": 2048,
|
| 74 |
-
"
|
| 75 |
-
"
|
|
|
|
|
|
|
|
|
|
| 76 |
"model_dtype": "bfloat16",
|
| 77 |
"model_type": "asr_model",
|
|
|
|
|
|
|
| 78 |
"num_experts": 4,
|
| 79 |
"num_experts_per_tok": 2,
|
| 80 |
"pipeline_tag": "automatic-speech-recognition",
|
|
@@ -83,24 +177,30 @@
|
|
| 83 |
"projector_init_std": 0.02,
|
| 84 |
"projector_input_noise": 0.0,
|
| 85 |
"projector_num_layers": 2,
|
| 86 |
-
"projector_pool_stride":
|
| 87 |
"projector_type": "mlp",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
"router_aux_loss_coef": 0.01,
|
| 89 |
"system_prompt": "/no_think /system_override",
|
| 90 |
"text_config": {
|
| 91 |
-
"_name_or_path": "
|
| 92 |
"architectures": [
|
| 93 |
-
"
|
| 94 |
],
|
| 95 |
"attention_bias": false,
|
| 96 |
"attention_dropout": 0.0,
|
| 97 |
-
"bos_token_id": null,
|
| 98 |
"dtype": "bfloat16",
|
| 99 |
-
"eos_token_id":
|
|
|
|
| 100 |
"hidden_act": "silu",
|
| 101 |
"hidden_size": 2048,
|
| 102 |
"initializer_range": 0.02,
|
| 103 |
-
"intermediate_size":
|
| 104 |
"layer_types": [
|
| 105 |
"full_attention",
|
| 106 |
"full_attention",
|
|
@@ -129,75 +229,31 @@
|
|
| 129 |
"full_attention",
|
| 130 |
"full_attention",
|
| 131 |
"full_attention",
|
| 132 |
-
"full_attention",
|
| 133 |
-
"full_attention",
|
| 134 |
-
"full_attention",
|
| 135 |
-
"full_attention",
|
| 136 |
-
"full_attention",
|
| 137 |
-
"full_attention",
|
| 138 |
-
"full_attention",
|
| 139 |
-
"full_attention",
|
| 140 |
"full_attention"
|
| 141 |
],
|
| 142 |
-
"max_position_embeddings":
|
| 143 |
"max_window_layers": 28,
|
| 144 |
-
"
|
| 145 |
-
"model_type": "smollm3",
|
| 146 |
-
"no_rope_layer_interval": 4,
|
| 147 |
-
"no_rope_layers": [
|
| 148 |
-
1,
|
| 149 |
-
1,
|
| 150 |
-
1,
|
| 151 |
-
0,
|
| 152 |
-
1,
|
| 153 |
-
1,
|
| 154 |
-
1,
|
| 155 |
-
0,
|
| 156 |
-
1,
|
| 157 |
-
1,
|
| 158 |
-
1,
|
| 159 |
-
0,
|
| 160 |
-
1,
|
| 161 |
-
1,
|
| 162 |
-
1,
|
| 163 |
-
0,
|
| 164 |
-
1,
|
| 165 |
-
1,
|
| 166 |
-
1,
|
| 167 |
-
0,
|
| 168 |
-
1,
|
| 169 |
-
1,
|
| 170 |
-
1,
|
| 171 |
-
0,
|
| 172 |
-
1,
|
| 173 |
-
1,
|
| 174 |
-
1,
|
| 175 |
-
0,
|
| 176 |
-
1,
|
| 177 |
-
1,
|
| 178 |
-
1,
|
| 179 |
-
0,
|
| 180 |
-
1,
|
| 181 |
-
1,
|
| 182 |
-
1,
|
| 183 |
-
0
|
| 184 |
-
],
|
| 185 |
"num_attention_heads": 16,
|
| 186 |
-
"num_hidden_layers":
|
| 187 |
-
"num_key_value_heads":
|
| 188 |
-
"
|
| 189 |
"rms_norm_eps": 1e-06,
|
| 190 |
-
"
|
| 191 |
-
|
|
|
|
|
|
|
| 192 |
"sliding_window": null,
|
| 193 |
-
"
|
|
|
|
| 194 |
"use_sliding_window": false,
|
| 195 |
-
"vocab_size":
|
| 196 |
},
|
| 197 |
-
"text_model_id": "
|
| 198 |
-
"transformers_version": "
|
| 199 |
"use_cache": false,
|
|
|
|
| 200 |
"use_specaugment": true,
|
| 201 |
-
"user_prompt": "
|
| 202 |
-
"vocab_size":
|
| 203 |
}
|
|
|
|
| 4 |
],
|
| 5 |
"attn_implementation": "flash_attention_2",
|
| 6 |
"audio_config": {
|
| 7 |
+
"_name_or_path": "zai-org/GLM-ASR-Nano-2512",
|
|
|
|
|
|
|
|
|
|
| 8 |
"architectures": [
|
| 9 |
+
"GlmAsrForConditionalGeneration"
|
| 10 |
],
|
| 11 |
+
"audio_config": {
|
| 12 |
+
"_name_or_path": "",
|
| 13 |
+
"add_cross_attention": false,
|
| 14 |
+
"architectures": null,
|
| 15 |
+
"attention_dropout": 0.0,
|
| 16 |
+
"bos_token_id": null,
|
| 17 |
+
"chunk_size_feed_forward": 0,
|
| 18 |
+
"cross_attention_hidden_size": null,
|
| 19 |
+
"decoder_start_token_id": null,
|
| 20 |
+
"dtype": null,
|
| 21 |
+
"eos_token_id": null,
|
| 22 |
+
"finetuning_task": null,
|
| 23 |
+
"head_dim": 64,
|
| 24 |
+
"hidden_act": "gelu",
|
| 25 |
+
"hidden_size": 1280,
|
| 26 |
+
"id2label": {
|
| 27 |
+
"0": "LABEL_0",
|
| 28 |
+
"1": "LABEL_1"
|
| 29 |
+
},
|
| 30 |
+
"initializer_range": 0.02,
|
| 31 |
+
"intermediate_size": 5120,
|
| 32 |
+
"is_decoder": false,
|
| 33 |
+
"is_encoder_decoder": false,
|
| 34 |
+
"label2id": {
|
| 35 |
+
"LABEL_0": 0,
|
| 36 |
+
"LABEL_1": 1
|
| 37 |
+
},
|
| 38 |
+
"max_position_embeddings": 1500,
|
| 39 |
+
"model_type": "glmasr_encoder",
|
| 40 |
+
"num_attention_heads": 20,
|
| 41 |
+
"num_hidden_layers": 32,
|
| 42 |
+
"num_key_value_heads": 20,
|
| 43 |
+
"num_mel_bins": 128,
|
| 44 |
+
"output_attentions": false,
|
| 45 |
+
"output_hidden_states": false,
|
| 46 |
+
"pad_token_id": null,
|
| 47 |
+
"partial_rotary_factor": 0.5,
|
| 48 |
+
"prefix": null,
|
| 49 |
+
"problem_type": null,
|
| 50 |
+
"return_dict": true,
|
| 51 |
+
"rope_parameters": {
|
| 52 |
+
"partial_rotary_factor": 0.5,
|
| 53 |
+
"rope_theta": 10000.0,
|
| 54 |
+
"rope_type": "default"
|
| 55 |
+
},
|
| 56 |
+
"sep_token_id": null,
|
| 57 |
+
"task_specific_params": null,
|
| 58 |
+
"tie_word_embeddings": true,
|
| 59 |
+
"tokenizer_class": null
|
| 60 |
+
},
|
| 61 |
+
"audio_token_id": 59260,
|
| 62 |
"dtype": "bfloat16",
|
| 63 |
+
"hidden_size": 2048,
|
| 64 |
+
"model_type": "glmasr",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
"num_mel_bins": 128,
|
| 66 |
+
"projector_hidden_act": "gelu",
|
| 67 |
+
"text_config": {
|
| 68 |
+
"_name_or_path": "",
|
| 69 |
+
"add_cross_attention": false,
|
| 70 |
+
"architectures": null,
|
| 71 |
+
"attention_bias": false,
|
| 72 |
+
"attention_dropout": 0.0,
|
| 73 |
+
"bos_token_id": 1,
|
| 74 |
+
"chunk_size_feed_forward": 0,
|
| 75 |
+
"cross_attention_hidden_size": null,
|
| 76 |
+
"decoder_start_token_id": null,
|
| 77 |
+
"dtype": null,
|
| 78 |
+
"eos_token_id": [
|
| 79 |
+
59246,
|
| 80 |
+
59253,
|
| 81 |
+
59255
|
| 82 |
+
],
|
| 83 |
+
"finetuning_task": null,
|
| 84 |
+
"head_dim": 128,
|
| 85 |
+
"hidden_act": "silu",
|
| 86 |
+
"hidden_size": 2048,
|
| 87 |
+
"id2label": {
|
| 88 |
+
"0": "LABEL_0",
|
| 89 |
+
"1": "LABEL_1"
|
| 90 |
+
},
|
| 91 |
+
"initializer_range": 0.02,
|
| 92 |
+
"intermediate_size": 6144,
|
| 93 |
+
"is_decoder": false,
|
| 94 |
+
"is_encoder_decoder": false,
|
| 95 |
+
"label2id": {
|
| 96 |
+
"LABEL_0": 0,
|
| 97 |
+
"LABEL_1": 1
|
| 98 |
+
},
|
| 99 |
+
"max_position_embeddings": 8192,
|
| 100 |
+
"mlp_bias": false,
|
| 101 |
+
"model_type": "llama",
|
| 102 |
+
"num_attention_heads": 16,
|
| 103 |
+
"num_hidden_layers": 28,
|
| 104 |
+
"num_key_value_heads": 4,
|
| 105 |
+
"output_attentions": false,
|
| 106 |
+
"output_hidden_states": false,
|
| 107 |
+
"pad_token_id": null,
|
| 108 |
+
"prefix": null,
|
| 109 |
+
"pretraining_tp": 1,
|
| 110 |
+
"problem_type": null,
|
| 111 |
+
"return_dict": true,
|
| 112 |
+
"rms_norm_eps": 1e-05,
|
| 113 |
+
"rope_parameters": {
|
| 114 |
+
"rope_theta": 10000.0,
|
| 115 |
+
"rope_type": "default"
|
| 116 |
+
},
|
| 117 |
+
"sep_token_id": null,
|
| 118 |
+
"task_specific_params": null,
|
| 119 |
+
"tie_word_embeddings": false,
|
| 120 |
+
"tokenizer_class": null,
|
| 121 |
+
"use_cache": true,
|
| 122 |
+
"vocab_size": 59264
|
| 123 |
+
},
|
| 124 |
+
"vocab_size": 59264
|
| 125 |
},
|
| 126 |
+
"audio_model_id": "zai-org/GLM-ASR-Nano-2512",
|
| 127 |
"audio_sample_rate": 16000,
|
| 128 |
"auto_map": {
|
| 129 |
"AutoConfig": "asr_config.ASRConfig",
|
|
|
|
| 141 |
"type": "audio"
|
| 142 |
}
|
| 143 |
},
|
| 144 |
+
"downsample_rate": 5,
|
| 145 |
"dtype": "bfloat16",
|
| 146 |
+
"encoder_conv_layers": [
|
| 147 |
+
[
|
| 148 |
+
1,
|
| 149 |
+
3,
|
| 150 |
+
1
|
| 151 |
+
],
|
| 152 |
+
[
|
| 153 |
+
1,
|
| 154 |
+
3,
|
| 155 |
+
2
|
| 156 |
+
]
|
| 157 |
+
],
|
| 158 |
"encoder_dim": 1280,
|
|
|
|
| 159 |
"inference_warmup_tokens": 10,
|
| 160 |
"label_smoothing": 0.0,
|
| 161 |
+
"length_penalty": 1.0,
|
| 162 |
"llm_dim": 2048,
|
| 163 |
+
"lora_alpha": 128,
|
| 164 |
+
"lora_dropout": 0.05,
|
| 165 |
+
"lora_r": 64,
|
| 166 |
+
"lora_target_modules": "all-linear",
|
| 167 |
+
"max_new_tokens": 256,
|
| 168 |
"model_dtype": "bfloat16",
|
| 169 |
"model_type": "asr_model",
|
| 170 |
+
"no_repeat_ngram_size": 0,
|
| 171 |
+
"num_beams": 1,
|
| 172 |
"num_experts": 4,
|
| 173 |
"num_experts_per_tok": 2,
|
| 174 |
"pipeline_tag": "automatic-speech-recognition",
|
|
|
|
| 177 |
"projector_init_std": 0.02,
|
| 178 |
"projector_input_noise": 0.0,
|
| 179 |
"projector_num_layers": 2,
|
| 180 |
+
"projector_pool_stride": 4,
|
| 181 |
"projector_type": "mlp",
|
| 182 |
+
"qformer_hidden_size": null,
|
| 183 |
+
"qformer_intermediate_size": null,
|
| 184 |
+
"qformer_num_heads": 16,
|
| 185 |
+
"qformer_num_layers": 2,
|
| 186 |
+
"qformer_window_size": 15,
|
| 187 |
+
"repetition_penalty": 1.0,
|
| 188 |
"router_aux_loss_coef": 0.01,
|
| 189 |
"system_prompt": "/no_think /system_override",
|
| 190 |
"text_config": {
|
| 191 |
+
"_name_or_path": "Qwen/Qwen3-1.7B",
|
| 192 |
"architectures": [
|
| 193 |
+
"Qwen3ForCausalLM"
|
| 194 |
],
|
| 195 |
"attention_bias": false,
|
| 196 |
"attention_dropout": 0.0,
|
|
|
|
| 197 |
"dtype": "bfloat16",
|
| 198 |
+
"eos_token_id": 151645,
|
| 199 |
+
"head_dim": 128,
|
| 200 |
"hidden_act": "silu",
|
| 201 |
"hidden_size": 2048,
|
| 202 |
"initializer_range": 0.02,
|
| 203 |
+
"intermediate_size": 6144,
|
| 204 |
"layer_types": [
|
| 205 |
"full_attention",
|
| 206 |
"full_attention",
|
|
|
|
| 229 |
"full_attention",
|
| 230 |
"full_attention",
|
| 231 |
"full_attention",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
"full_attention"
|
| 233 |
],
|
| 234 |
+
"max_position_embeddings": 40960,
|
| 235 |
"max_window_layers": 28,
|
| 236 |
+
"model_type": "qwen3",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
"num_attention_heads": 16,
|
| 238 |
+
"num_hidden_layers": 28,
|
| 239 |
+
"num_key_value_heads": 8,
|
| 240 |
+
"pad_token_id": 151643,
|
| 241 |
"rms_norm_eps": 1e-06,
|
| 242 |
+
"rope_parameters": {
|
| 243 |
+
"rope_theta": 1000000,
|
| 244 |
+
"rope_type": "default"
|
| 245 |
+
},
|
| 246 |
"sliding_window": null,
|
| 247 |
+
"tie_word_embeddings": true,
|
| 248 |
+
"use_cache": true,
|
| 249 |
"use_sliding_window": false,
|
| 250 |
+
"vocab_size": 151670
|
| 251 |
},
|
| 252 |
+
"text_model_id": "Qwen/Qwen3-1.7B",
|
| 253 |
+
"transformers_version": "5.0.0.dev0",
|
| 254 |
"use_cache": false,
|
| 255 |
+
"use_lora": true,
|
| 256 |
"use_specaugment": true,
|
| 257 |
+
"user_prompt": "Please transcribe this English audio into text: <audio>",
|
| 258 |
+
"vocab_size": 151670
|
| 259 |
}
|
generation_config.json
CHANGED
|
@@ -1,10 +1,14 @@
|
|
| 1 |
{
|
| 2 |
-
"bos_token_id":
|
| 3 |
-
"eos_token_id":
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
"
|
| 8 |
-
"
|
| 9 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"bos_token_id": 151643,
|
| 3 |
+
"eos_token_id": [
|
| 4 |
+
151645,
|
| 5 |
+
151643
|
| 6 |
+
],
|
| 7 |
+
"length_penalty": 1.0,
|
| 8 |
+
"max_new_tokens": 256,
|
| 9 |
+
"no_repeat_ngram_size": 0,
|
| 10 |
+
"num_beams": 1,
|
| 11 |
+
"pad_token_id": 151643,
|
| 12 |
+
"repetition_penalty": 1.0,
|
| 13 |
+
"transformers_version": "5.0.0.dev0"
|
| 14 |
}
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5f325deceeb565a0764abd09b46d51706bff2d643c0dd96b38070f246b0410de
|
| 3 |
+
size 58732960
|
preprocessor_config.json
CHANGED
|
@@ -9,9 +9,9 @@
|
|
| 9 |
"nb_max_frames": 3000,
|
| 10 |
"padding_side": "right",
|
| 11 |
"padding_value": 0.0,
|
| 12 |
-
"processor_class": "ASRProcessor",
|
| 13 |
"return_attention_mask": false,
|
| 14 |
"sampling_rate": 16000,
|
|
|
|
| 15 |
"auto_map": {
|
| 16 |
"AutoProcessor": "asr_processing.ASRProcessor"
|
| 17 |
}
|
|
|
|
| 9 |
"nb_max_frames": 3000,
|
| 10 |
"padding_side": "right",
|
| 11 |
"padding_value": 0.0,
|
|
|
|
| 12 |
"return_attention_mask": false,
|
| 13 |
"sampling_rate": 16000,
|
| 14 |
+
"processor_class": "ASRProcessor",
|
| 15 |
"auto_map": {
|
| 16 |
"AutoProcessor": "asr_processing.ASRProcessor"
|
| 17 |
}
|
projectors.py
CHANGED
|
@@ -1,16 +1,18 @@
|
|
| 1 |
"""Audio projector modules for bridging encoder and decoder embeddings.
|
| 2 |
|
| 3 |
This module contains all projector architectures:
|
| 4 |
-
- MLPAudioProjector: Simple 2-layer MLP with
|
| 5 |
-
-
|
| 6 |
-
- SwiGLUAudioProjector: SwiGLU-based projector with temporal pooling
|
| 7 |
-
- ResidualAudioProjector: Residual MLP blocks with linear projection
|
| 8 |
- SharedMoEAudioProjector: Shared expert + sparse routed experts
|
|
|
|
| 9 |
"""
|
| 10 |
|
|
|
|
|
|
|
| 11 |
import torch
|
| 12 |
import torch.nn as nn
|
| 13 |
import torch.nn.functional as F # noqa: N812
|
|
|
|
| 14 |
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
| 15 |
|
| 16 |
# =============================================================================
|
|
@@ -19,40 +21,36 @@ from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
|
| 19 |
|
| 20 |
|
| 21 |
class MLPAudioProjector(nn.Module):
|
| 22 |
-
"""2-layer MLP projector with
|
| 23 |
|
| 24 |
def __init__(self, config):
|
| 25 |
super().__init__()
|
| 26 |
|
| 27 |
encoder_dim = getattr(config, "encoder_dim", 768)
|
| 28 |
llm_dim = getattr(config, "llm_dim", 2048)
|
|
|
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
| 34 |
self.act = nn.GELU()
|
| 35 |
-
self.linear_2 = nn.Linear(
|
| 36 |
-
|
| 37 |
-
self.apply(self._init_weights)
|
| 38 |
|
| 39 |
-
def
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
elif isinstance(module, nn.Conv1d):
|
| 43 |
-
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 44 |
-
if module.bias is not None:
|
| 45 |
-
nn.init.zeros_(module.bias)
|
| 46 |
|
| 47 |
def forward(self, x):
|
| 48 |
"""
|
| 49 |
x: [Batch, Seq_Len, Dim]
|
| 50 |
-
Returns: [Batch, Seq_Len //
|
| 51 |
"""
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
x = x.
|
| 56 |
|
| 57 |
x = self.linear_1(x)
|
| 58 |
x = self.act(x)
|
|
@@ -65,291 +63,146 @@ class MLPAudioProjector(nn.Module):
|
|
| 65 |
|
| 66 |
|
| 67 |
class SimpleAdapter(nn.Module):
|
| 68 |
-
"""Simple
|
| 69 |
|
| 70 |
-
def __init__(self,
|
| 71 |
super().__init__()
|
| 72 |
-
self.fc1 = nn.Linear(
|
| 73 |
-
self.
|
| 74 |
-
self.
|
| 75 |
-
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 76 |
|
| 77 |
-
def forward(self, x):
|
| 78 |
-
|
| 79 |
-
x = self.relu(x)
|
| 80 |
-
x = self.dropout(x)
|
| 81 |
-
return self.fc2(x)
|
| 82 |
|
| 83 |
|
| 84 |
-
class
|
| 85 |
-
"""
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
-
From paper (arXiv:2508.18998):
|
| 89 |
-
- Dense mixture (softmax over ALL experts) instead of sparse Top-K
|
| 90 |
-
- Simple Linear->ReLU->Linear adapters
|
| 91 |
-
- No auxiliary losses - just cross-entropy on transcripts
|
| 92 |
-
- Conv downsampling: stride 4 total (two conv layers, stride 2 each)
|
| 93 |
-
"""
|
| 94 |
|
|
|
|
| 95 |
def __init__(self, config):
|
| 96 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
|
| 99 |
-
self.
|
| 100 |
-
self.
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
| 103 |
|
| 104 |
-
#
|
|
|
|
|
|
|
|
|
|
| 105 |
self.conv = nn.Sequential(
|
| 106 |
nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
|
| 107 |
-
nn.
|
| 108 |
nn.Conv1d(self.llm_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
|
| 109 |
-
nn.
|
| 110 |
)
|
| 111 |
|
| 112 |
-
# Router
|
| 113 |
-
router_hidden = 512
|
| 114 |
self.router = nn.Sequential(
|
| 115 |
-
nn.Linear(self.encoder_dim,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
nn.ReLU(),
|
| 117 |
-
nn.Linear(
|
| 118 |
)
|
| 119 |
|
| 120 |
-
# Experts
|
| 121 |
self.experts = nn.ModuleList(
|
| 122 |
[
|
| 123 |
-
SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim
|
| 124 |
for _ in range(self.num_experts)
|
| 125 |
]
|
| 126 |
)
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
|
|
|
| 130 |
|
| 131 |
-
|
| 132 |
-
std = 0.02
|
| 133 |
-
with torch.no_grad():
|
| 134 |
-
for module in self.conv:
|
| 135 |
-
if isinstance(module, nn.Conv1d):
|
| 136 |
-
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 137 |
-
if module.bias is not None:
|
| 138 |
-
nn.init.zeros_(module.bias)
|
| 139 |
-
|
| 140 |
-
for module in self.router:
|
| 141 |
-
if isinstance(module, nn.Linear):
|
| 142 |
-
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 143 |
-
if module.bias is not None:
|
| 144 |
-
nn.init.zeros_(module.bias)
|
| 145 |
-
|
| 146 |
-
for expert in self.experts:
|
| 147 |
-
nn.init.normal_(expert.fc1.weight, mean=0.0, std=std)
|
| 148 |
-
nn.init.normal_(expert.fc2.weight, mean=0.0, std=std)
|
| 149 |
-
if expert.fc1.bias is not None:
|
| 150 |
-
nn.init.zeros_(expert.fc1.bias)
|
| 151 |
-
if expert.fc2.bias is not None:
|
| 152 |
-
nn.init.zeros_(expert.fc2.bias)
|
| 153 |
-
|
| 154 |
-
self.ln_post.weight.data.fill_(1.0)
|
| 155 |
|
| 156 |
def forward(self, x):
|
|
|
|
| 157 |
batch_size, seq_len, _ = x.shape
|
| 158 |
|
| 159 |
-
#
|
| 160 |
-
|
| 161 |
-
if pad_amt > 0:
|
| 162 |
-
x = F.pad(x, (0, 0, 0, pad_amt))
|
| 163 |
-
seq_len = x.shape[1]
|
| 164 |
-
|
| 165 |
-
# Convolutional Downsampling
|
| 166 |
-
h_conv = self.conv(x.permute(0, 2, 1)).permute(0, 2, 1)
|
| 167 |
-
|
| 168 |
-
# Router on high-res input, then downsample weights
|
| 169 |
-
router_logits = self.router(x)
|
| 170 |
-
router_logits = router_logits.view(batch_size, seq_len // 4, 4, self.num_experts).mean(
|
| 171 |
-
dim=2
|
| 172 |
-
)
|
| 173 |
-
routing_weights = F.softmax(router_logits, dim=-1)
|
| 174 |
-
|
| 175 |
-
# Weighted sum of expert outputs
|
| 176 |
-
final_out = torch.zeros_like(h_conv)
|
| 177 |
-
for i, expert in enumerate(self.experts):
|
| 178 |
-
expert_out = expert(h_conv)
|
| 179 |
-
expert_weight = routing_weights[:, :, i : i + 1]
|
| 180 |
-
final_out.add_(expert_out * expert_weight)
|
| 181 |
-
|
| 182 |
-
return self.ln_post(final_out)
|
| 183 |
-
|
| 184 |
-
def get_aux_loss(self) -> torch.Tensor:
|
| 185 |
-
"""Return auxiliary loss (none for dense MoE)."""
|
| 186 |
-
return torch.tensor(0.0)
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
# =============================================================================
|
| 190 |
-
# SwiGLU Projector
|
| 191 |
-
# =============================================================================
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
class SwiGLU(nn.Module):
|
| 195 |
-
def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
|
| 196 |
-
super().__init__()
|
| 197 |
-
self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 198 |
-
self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 199 |
-
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 200 |
-
self.act = nn.SiLU()
|
| 201 |
-
self.dropout = nn.Dropout(dropout)
|
| 202 |
-
|
| 203 |
-
def forward(self, x):
|
| 204 |
-
x_gate = self.act(self.w1(x))
|
| 205 |
-
x_val = self.w2(x)
|
| 206 |
-
x = x_gate * x_val
|
| 207 |
-
x = self.dropout(x)
|
| 208 |
-
return self.w3(x)
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
class SwiGLUAudioProjector(nn.Module):
|
| 212 |
-
"""SwiGLU-based projector with temporal pooling."""
|
| 213 |
-
|
| 214 |
-
def __init__(self, config):
|
| 215 |
-
super().__init__()
|
| 216 |
-
self.k = getattr(config, "projector_pool_stride", 4)
|
| 217 |
-
in_dim = config.encoder_dim * self.k
|
| 218 |
-
out_dim = config.llm_dim
|
| 219 |
-
hidden_dim = config.projector_hidden_dim
|
| 220 |
-
if hidden_dim is None:
|
| 221 |
-
hidden_dim = config.encoder_dim * 2
|
| 222 |
-
|
| 223 |
-
dropout_rate = getattr(config, "projector_dropout", 0.0)
|
| 224 |
-
|
| 225 |
-
self.proj1 = SwiGLU(in_dim, hidden_dim, hidden_dim, dropout=dropout_rate)
|
| 226 |
-
self.proj2 = SwiGLU(hidden_dim, hidden_dim, out_dim, dropout=dropout_rate)
|
| 227 |
-
self.output_dropout = nn.Dropout(dropout_rate)
|
| 228 |
-
|
| 229 |
-
with torch.no_grad():
|
| 230 |
-
std = getattr(config, "projector_init_std", 0.02)
|
| 231 |
-
nn.init.normal_(self.proj1.w1.weight, mean=0.0, std=std)
|
| 232 |
-
nn.init.normal_(self.proj1.w2.weight, mean=0.0, std=std)
|
| 233 |
-
nn.init.normal_(self.proj1.w3.weight, mean=0.0, std=std)
|
| 234 |
-
nn.init.normal_(self.proj2.w1.weight, mean=0.0, std=std)
|
| 235 |
-
nn.init.normal_(self.proj2.w2.weight, mean=0.0, std=std)
|
| 236 |
-
nn.init.normal_(self.proj2.w3.weight, mean=0.0, std=std)
|
| 237 |
-
|
| 238 |
-
def forward(self, x):
|
| 239 |
-
batch_size, seq_len, dim = x.size()
|
| 240 |
-
|
| 241 |
-
target_dtype = self.proj1.w1.weight.dtype
|
| 242 |
-
if x.dtype != target_dtype:
|
| 243 |
-
x = x.to(target_dtype)
|
| 244 |
-
|
| 245 |
-
remainder = seq_len % self.k
|
| 246 |
-
if remainder:
|
| 247 |
-
pad_len = self.k - remainder
|
| 248 |
-
x = F.pad(x, (0, 0, 0, pad_len))
|
| 249 |
-
|
| 250 |
-
x = x.contiguous().view(batch_size, -1, dim * self.k)
|
| 251 |
-
x = self.proj1(x)
|
| 252 |
-
x = self.proj2(x)
|
| 253 |
-
|
| 254 |
-
return self.output_dropout(x)
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
# Alias for backwards compatibility
|
| 258 |
-
AudioProjector = SwiGLUAudioProjector
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
# =============================================================================
|
| 262 |
-
# Residual Projector
|
| 263 |
-
# =============================================================================
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
class ResidualMLP(nn.Module):
|
| 267 |
-
"""MLP block with residual connection: Output = x + MLP(x)."""
|
| 268 |
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
self.fc2 = nn.Linear(hidden_dim, dim)
|
| 273 |
-
self.act = nn.GELU()
|
| 274 |
-
self.dropout = nn.Dropout(dropout)
|
| 275 |
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
x = self.act(x)
|
| 280 |
-
x = self.dropout(x)
|
| 281 |
-
x = self.fc2(x)
|
| 282 |
-
x = self.dropout(x)
|
| 283 |
-
return residual + x
|
| 284 |
|
|
|
|
|
|
|
| 285 |
|
| 286 |
-
|
| 287 |
-
|
| 288 |
|
| 289 |
-
|
| 290 |
-
|
| 291 |
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim * 4
|
| 296 |
-
self.num_layers = getattr(config, "projector_num_layers", 2)
|
| 297 |
-
dropout_rate = getattr(config, "projector_dropout", 0.0)
|
| 298 |
|
| 299 |
-
|
| 300 |
-
|
|
|
|
| 301 |
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
)
|
| 305 |
-
self.layer_norms = nn.ModuleList(
|
| 306 |
-
[LlamaRMSNorm(out_dim, eps=1e-6) for _ in range(self.num_layers)]
|
| 307 |
-
)
|
| 308 |
|
| 309 |
-
|
| 310 |
-
|
|
|
|
| 311 |
|
| 312 |
-
|
| 313 |
-
std = getattr(config, "projector_init_std", 0.02)
|
| 314 |
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
self.ln_input.weight.data.fill_(1.0)
|
| 321 |
-
for ln in self.layer_norms:
|
| 322 |
-
ln.weight.data.fill_(1.0)
|
| 323 |
-
|
| 324 |
-
for layer in self.layers:
|
| 325 |
-
nn.init.normal_(layer.fc1.weight, mean=0.0, std=std)
|
| 326 |
-
nn.init.normal_(layer.fc2.weight, mean=0.0, std=std * 0.1)
|
| 327 |
-
if layer.fc1.bias is not None:
|
| 328 |
-
nn.init.zeros_(layer.fc1.bias)
|
| 329 |
-
if layer.fc2.bias is not None:
|
| 330 |
-
nn.init.zeros_(layer.fc2.bias)
|
| 331 |
|
| 332 |
-
def
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
if x.dtype != target_dtype:
|
| 337 |
-
x = x.to(target_dtype)
|
| 338 |
-
|
| 339 |
-
remainder = seq_len % self.k
|
| 340 |
-
if remainder:
|
| 341 |
-
pad_len = self.k - remainder
|
| 342 |
-
x = F.pad(x, (0, 0, 0, pad_len))
|
| 343 |
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
x = ln(x)
|
| 351 |
|
| 352 |
-
return self.
|
| 353 |
|
| 354 |
|
| 355 |
# =============================================================================
|
|
@@ -357,22 +210,8 @@ class ResidualAudioProjector(nn.Module):
|
|
| 357 |
# =============================================================================
|
| 358 |
|
| 359 |
|
| 360 |
-
class SwiGLUExpert(nn.Module):
|
| 361 |
-
"""SwiGLU expert MLP."""
|
| 362 |
-
|
| 363 |
-
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
| 364 |
-
super().__init__()
|
| 365 |
-
self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
|
| 366 |
-
self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
|
| 367 |
-
self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
|
| 368 |
-
self.act = nn.SiLU()
|
| 369 |
-
|
| 370 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 371 |
-
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
|
| 372 |
-
|
| 373 |
-
|
| 374 |
class SharedMoEBlock(nn.Module):
|
| 375 |
-
"""MoE block with
|
| 376 |
|
| 377 |
def __init__(
|
| 378 |
self,
|
|
@@ -387,8 +226,11 @@ class SharedMoEBlock(nn.Module):
|
|
| 387 |
self.top_k = top_k
|
| 388 |
self.output_dim = output_dim
|
| 389 |
|
|
|
|
|
|
|
|
|
|
| 390 |
self.router = nn.Linear(input_dim, num_experts, bias=False)
|
| 391 |
-
nn.init.
|
| 392 |
|
| 393 |
self.shared_expert = SwiGLUExpert(input_dim, hidden_dim, output_dim)
|
| 394 |
self.experts = nn.ModuleList(
|
|
@@ -401,19 +243,28 @@ class SharedMoEBlock(nn.Module):
|
|
| 401 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 402 |
batch_size, seq_len, dim = hidden_states.shape
|
| 403 |
|
| 404 |
-
|
|
|
|
|
|
|
| 405 |
|
| 406 |
-
|
|
|
|
| 407 |
router_logits = self.router(flat_hidden)
|
| 408 |
-
|
|
|
|
|
|
|
| 409 |
|
| 410 |
self.last_router_logits = router_logits
|
| 411 |
self.last_router_probs = router_probs
|
| 412 |
|
| 413 |
-
|
| 414 |
-
|
|
|
|
|
|
|
|
|
|
| 415 |
top_k_weights = top_k_weights.to(hidden_states.dtype)
|
| 416 |
|
|
|
|
| 417 |
routed_out = self._dispatch_experts(flat_hidden, top_k_indices, top_k_weights)
|
| 418 |
routed_out = routed_out.view(batch_size, seq_len, -1)
|
| 419 |
|
|
@@ -437,7 +288,7 @@ class SharedMoEBlock(nn.Module):
|
|
| 437 |
|
| 438 |
token_indices, slot_indices = torch.where(expert_mask)
|
| 439 |
expert_input = hidden_states[token_indices]
|
| 440 |
-
expert_output = expert(expert_input)
|
| 441 |
weights = top_k_weights[token_indices, slot_indices].unsqueeze(-1)
|
| 442 |
output.index_add_(0, token_indices, expert_output * weights)
|
| 443 |
|
|
@@ -446,11 +297,9 @@ class SharedMoEBlock(nn.Module):
|
|
| 446 |
|
| 447 |
def load_balancing_loss(router_probs: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor:
|
| 448 |
"""Auxiliary loss to encourage balanced expert usage."""
|
| 449 |
-
_, selected = torch.topk(router_probs, top_k, dim=-1)
|
| 450 |
-
expert_mask = F.one_hot(selected, num_experts).float()
|
| 451 |
-
tokens_per_expert = expert_mask.mean(dim=(0, 1))
|
| 452 |
prob_per_expert = router_probs.mean(dim=0)
|
| 453 |
-
|
|
|
|
| 454 |
|
| 455 |
|
| 456 |
def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
|
|
@@ -465,8 +314,13 @@ class SharedMoEAudioProjector(nn.Module):
|
|
| 465 |
super().__init__()
|
| 466 |
|
| 467 |
self.k = getattr(config, "projector_pool_stride", 4)
|
| 468 |
-
|
| 469 |
encoder_dim = config.encoder_dim
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
in_dim = encoder_dim * self.k
|
| 471 |
out_dim = config.llm_dim
|
| 472 |
hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim
|
|
@@ -477,9 +331,9 @@ class SharedMoEAudioProjector(nn.Module):
|
|
| 477 |
self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
|
| 478 |
|
| 479 |
self.moe = SharedMoEBlock(in_dim, hidden_dim, out_dim, self.num_experts, self.top_k)
|
| 480 |
-
self._init_weights(
|
| 481 |
|
| 482 |
-
def _init_weights(self
|
| 483 |
with torch.no_grad():
|
| 484 |
nn.init.orthogonal_(self.moe.shared_expert.gate_proj.weight)
|
| 485 |
nn.init.orthogonal_(self.moe.shared_expert.up_proj.weight)
|
|
@@ -490,6 +344,13 @@ class SharedMoEAudioProjector(nn.Module):
|
|
| 490 |
nn.init.orthogonal_(expert.up_proj.weight)
|
| 491 |
nn.init.orthogonal_(expert.down_proj.weight, gain=0.01)
|
| 492 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 494 |
batch_size, seq_len, dim = x.size()
|
| 495 |
|
|
@@ -497,6 +358,11 @@ class SharedMoEAudioProjector(nn.Module):
|
|
| 497 |
if x.dtype != target_dtype:
|
| 498 |
x = x.to(target_dtype)
|
| 499 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
if seq_len % self.k:
|
| 501 |
x = F.pad(x, (0, 0, 0, self.k - seq_len % self.k))
|
| 502 |
|
|
@@ -514,14 +380,129 @@ class SharedMoEAudioProjector(nn.Module):
|
|
| 514 |
return self.aux_loss_coef * balance + self.z_loss_coef * z
|
| 515 |
|
| 516 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
# =============================================================================
|
| 518 |
# Projector Registry
|
| 519 |
# =============================================================================
|
| 520 |
|
| 521 |
PROJECTOR_CLASSES = {
|
| 522 |
"mlp": MLPAudioProjector,
|
| 523 |
-
"
|
| 524 |
-
"swiglu": SwiGLUAudioProjector,
|
| 525 |
-
"residual": ResidualAudioProjector,
|
| 526 |
"shared_moe": SharedMoEAudioProjector,
|
|
|
|
| 527 |
}
|
|
|
|
| 1 |
"""Audio projector modules for bridging encoder and decoder embeddings.
|
| 2 |
|
| 3 |
This module contains all projector architectures:
|
| 4 |
+
- MLPAudioProjector: Simple 2-layer MLP with frame stacking downsampling
|
| 5 |
+
- MOSAProjector: MOSA-style dense mixture of experts
|
|
|
|
|
|
|
| 6 |
- SharedMoEAudioProjector: Shared expert + sparse routed experts
|
| 7 |
+
- QFormerAudioProjector: BLIP-2 QFormer with learnable queries (Granite-style)
|
| 8 |
"""
|
| 9 |
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
import torch
|
| 13 |
import torch.nn as nn
|
| 14 |
import torch.nn.functional as F # noqa: N812
|
| 15 |
+
from transformers import AutoModel, Blip2QFormerConfig
|
| 16 |
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
| 17 |
|
| 18 |
# =============================================================================
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
class MLPAudioProjector(nn.Module):
|
| 24 |
+
"""2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR)."""
|
| 25 |
|
| 26 |
def __init__(self, config):
|
| 27 |
super().__init__()
|
| 28 |
|
| 29 |
encoder_dim = getattr(config, "encoder_dim", 768)
|
| 30 |
llm_dim = getattr(config, "llm_dim", 2048)
|
| 31 |
+
self.k = getattr(config, "projector_pool_stride", 4)
|
| 32 |
|
| 33 |
+
# Frame stacking: concat k adjacent frames then project
|
| 34 |
+
# Matches GLM-ASR: in_dim -> 2*llm_dim -> llm_dim
|
| 35 |
+
in_dim = encoder_dim * self.k
|
| 36 |
+
hidden_dim = llm_dim * 2
|
| 37 |
+
self.linear_1 = nn.Linear(in_dim, hidden_dim)
|
| 38 |
self.act = nn.GELU()
|
| 39 |
+
self.linear_2 = nn.Linear(hidden_dim, llm_dim)
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
def get_output_length(self, input_length: int) -> int:
|
| 42 |
+
"""Calculate output sequence length given input length."""
|
| 43 |
+
return input_length // self.k
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
def forward(self, x):
|
| 46 |
"""
|
| 47 |
x: [Batch, Seq_Len, Dim]
|
| 48 |
+
Returns: [Batch, Seq_Len // k, llm_dim]
|
| 49 |
"""
|
| 50 |
+
batch, seq, dim = x.shape
|
| 51 |
+
# Reshape to combine k frames: [B, S, D] -> [B, -1, D*k]
|
| 52 |
+
# -1 infers sequence length, implicitly downsampling by factor k
|
| 53 |
+
x = x.reshape(batch, -1, dim * self.k)
|
| 54 |
|
| 55 |
x = self.linear_1(x)
|
| 56 |
x = self.act(x)
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
class SimpleAdapter(nn.Module):
|
| 66 |
+
"""Simple 2-layer ReLU adapter (from MOSA paper)."""
|
| 67 |
|
| 68 |
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
| 69 |
super().__init__()
|
| 70 |
+
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
| 71 |
+
self.act = nn.ReLU()
|
| 72 |
+
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
|
|
|
| 73 |
|
| 74 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 75 |
+
return self.fc2(self.act(self.fc1(x)))
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
|
| 78 |
+
class SwiGLUExpert(nn.Module):
|
| 79 |
+
"""SwiGLU expert (gated MLP with SiLU activation)."""
|
| 80 |
+
|
| 81 |
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
|
| 84 |
+
self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
|
| 85 |
+
self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
|
| 86 |
+
|
| 87 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
+
class MOSAProjector(nn.Module):
|
| 92 |
def __init__(self, config):
|
| 93 |
super().__init__()
|
| 94 |
+
self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
|
| 95 |
+
self.llm_dim = getattr(config, "llm_dim", None) or 2048
|
| 96 |
+
self.num_experts = getattr(config, "num_experts", None) or 8
|
| 97 |
+
adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
|
| 98 |
|
| 99 |
+
# Auxiliary loss coefficients (MOSA paper uses only cross-entropy, no aux losses)
|
| 100 |
+
self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.0)
|
| 101 |
+
self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.0)
|
| 102 |
+
|
| 103 |
+
# Store router state for aux loss computation
|
| 104 |
+
self.last_router_logits = None
|
| 105 |
+
self.last_routing_weights = None
|
| 106 |
|
| 107 |
+
# --- 1. Pre-Norms (CRITICAL for stability) ---
|
| 108 |
+
self.in_norm = LlamaRMSNorm(self.encoder_dim, eps=1e-8)
|
| 109 |
+
|
| 110 |
+
# --- 2. Convolutional Subsampling (Stride 4) ---
|
| 111 |
self.conv = nn.Sequential(
|
| 112 |
nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
|
| 113 |
+
nn.SiLU(),
|
| 114 |
nn.Conv1d(self.llm_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
|
| 115 |
+
nn.SiLU(),
|
| 116 |
)
|
| 117 |
|
| 118 |
+
# --- 3. Deep Router (ReLU per MOSA paper) ---
|
|
|
|
| 119 |
self.router = nn.Sequential(
|
| 120 |
+
nn.Linear(self.encoder_dim, 2560),
|
| 121 |
+
nn.ReLU(),
|
| 122 |
+
nn.Linear(2560, 5120),
|
| 123 |
+
nn.ReLU(),
|
| 124 |
+
nn.Linear(5120, 2560),
|
| 125 |
+
nn.ReLU(),
|
| 126 |
+
nn.Linear(2560, 1280),
|
| 127 |
nn.ReLU(),
|
| 128 |
+
nn.Linear(1280, self.num_experts),
|
| 129 |
)
|
| 130 |
|
| 131 |
+
# --- 4. Experts (Simple 2-layer ReLU adapters per MOSA paper) ---
|
| 132 |
self.experts = nn.ModuleList(
|
| 133 |
[
|
| 134 |
+
SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
|
| 135 |
for _ in range(self.num_experts)
|
| 136 |
]
|
| 137 |
)
|
| 138 |
|
| 139 |
+
# --- 5. Output Norm ---
|
| 140 |
+
# Projects often drift in magnitude; this clamps them before the LLM.
|
| 141 |
+
self.out_norm = LlamaRMSNorm(self.llm_dim, eps=1e-8)
|
| 142 |
|
| 143 |
+
# Using PyTorch default initialization (like MOSA paper)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
def forward(self, x):
|
| 146 |
+
# x: (B, S, 1280)
|
| 147 |
batch_size, seq_len, _ = x.shape
|
| 148 |
|
| 149 |
+
# Apply Input Norm
|
| 150 |
+
x = self.in_norm(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
+
# --- 1. Conv Branch ---
|
| 153 |
+
x_trans = x.permute(0, 2, 1) # (B, D, S)
|
| 154 |
+
h_conv = self.conv(x_trans).permute(0, 2, 1) # (B, S//4, llm_dim)
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
+
# --- 2. Router Branch ---
|
| 157 |
+
pad_amt = (4 - (seq_len % 4)) % 4
|
| 158 |
+
x_padded = F.pad(x, (0, 0, 0, pad_amt)) if pad_amt > 0 else x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
+
# Mean pool to align receptive fields
|
| 161 |
+
x_pooled = x_padded.view(batch_size, -1, 4, self.encoder_dim).mean(dim=2) # (B, S//4, D)
|
| 162 |
|
| 163 |
+
# Router Logits
|
| 164 |
+
router_logits = self.router(x_pooled) # (B, S//4, num_experts)
|
| 165 |
|
| 166 |
+
# Softmax for Dense MoE (Soft Mixing)
|
| 167 |
+
routing_weights = F.softmax(router_logits, dim=-1)
|
| 168 |
|
| 169 |
+
# Store for aux loss computation
|
| 170 |
+
self.last_router_logits = router_logits
|
| 171 |
+
self.last_routing_weights = routing_weights
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
+
# --- 3. Expert Mixture (Dense Execution) ---
|
| 174 |
+
# Warning: High VRAM usage. Runs all experts.
|
| 175 |
+
# h_conv: (B, S//4, llm_dim)
|
| 176 |
|
| 177 |
+
# Stack approach is clean but memory hungry.
|
| 178 |
+
# Checkpointing could be added here if OOM occurs.
|
| 179 |
+
expert_outputs = torch.stack([expert(h_conv) for expert in self.experts]) # (E, B, S//4, D)
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
+
# Weighted Sum
|
| 182 |
+
# (Experts, Batch, Seq, Dim) * (Batch, Seq, Experts) -> (Batch, Seq, Dim)
|
| 183 |
+
final_out = torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
|
| 184 |
|
| 185 |
+
return self.out_norm(final_out)
|
|
|
|
| 186 |
|
| 187 |
+
def get_output_length(self, input_length: int) -> int:
|
| 188 |
+
"""Calculate output sequence length given input length."""
|
| 189 |
+
# Two conv layers with stride=2 each = stride 4 total
|
| 190 |
+
padded = input_length + (4 - input_length % 4) % 4
|
| 191 |
+
return padded // 4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
+
def get_aux_loss(self) -> torch.Tensor:
|
| 194 |
+
"""Compute auxiliary losses: load balancing + z-loss."""
|
| 195 |
+
if self.last_router_logits is None:
|
| 196 |
+
return torch.tensor(0.0, device=self.conv[0].weight.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
+
# Flatten for loss computation: (B, S, E) -> (B*S, E)
|
| 199 |
+
logits_flat = self.last_router_logits.view(-1, self.num_experts)
|
| 200 |
+
probs_flat = self.last_routing_weights.view(-1, self.num_experts)
|
| 201 |
|
| 202 |
+
balance = load_balancing_loss(probs_flat, self.num_experts, top_k=self.num_experts)
|
| 203 |
+
z = z_loss(logits_flat)
|
|
|
|
| 204 |
|
| 205 |
+
return self.aux_loss_coef * balance + self.z_loss_coef * z
|
| 206 |
|
| 207 |
|
| 208 |
# =============================================================================
|
|
|
|
| 210 |
# =============================================================================
|
| 211 |
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
class SharedMoEBlock(nn.Module):
|
| 214 |
+
"""MoE block with Shared + Sigmoid-Routed Experts."""
|
| 215 |
|
| 216 |
def __init__(
|
| 217 |
self,
|
|
|
|
| 226 |
self.top_k = top_k
|
| 227 |
self.output_dim = output_dim
|
| 228 |
|
| 229 |
+
# RMSNorm before routing
|
| 230 |
+
self.norm = LlamaRMSNorm(input_dim, eps=1e-8)
|
| 231 |
+
|
| 232 |
self.router = nn.Linear(input_dim, num_experts, bias=False)
|
| 233 |
+
nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
|
| 234 |
|
| 235 |
self.shared_expert = SwiGLUExpert(input_dim, hidden_dim, output_dim)
|
| 236 |
self.experts = nn.ModuleList(
|
|
|
|
| 243 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 244 |
batch_size, seq_len, dim = hidden_states.shape
|
| 245 |
|
| 246 |
+
# 1. Apply Shared Expert
|
| 247 |
+
normed_states = self.norm(hidden_states)
|
| 248 |
+
shared_out = self.shared_expert(normed_states)
|
| 249 |
|
| 250 |
+
# 2. Router Logic (Sigmoid Style)
|
| 251 |
+
flat_hidden = normed_states.view(-1, dim)
|
| 252 |
router_logits = self.router(flat_hidden)
|
| 253 |
+
|
| 254 |
+
# Sigmoid routing
|
| 255 |
+
router_probs = torch.sigmoid(router_logits)
|
| 256 |
|
| 257 |
self.last_router_logits = router_logits
|
| 258 |
self.last_router_probs = router_probs
|
| 259 |
|
| 260 |
+
# 3. Top-K Selection
|
| 261 |
+
top_k_scores, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
|
| 262 |
+
|
| 263 |
+
# Normalize weights
|
| 264 |
+
top_k_weights = top_k_scores / (top_k_scores.sum(dim=-1, keepdim=True) + 1e-6)
|
| 265 |
top_k_weights = top_k_weights.to(hidden_states.dtype)
|
| 266 |
|
| 267 |
+
# 4. Dispatch
|
| 268 |
routed_out = self._dispatch_experts(flat_hidden, top_k_indices, top_k_weights)
|
| 269 |
routed_out = routed_out.view(batch_size, seq_len, -1)
|
| 270 |
|
|
|
|
| 288 |
|
| 289 |
token_indices, slot_indices = torch.where(expert_mask)
|
| 290 |
expert_input = hidden_states[token_indices]
|
| 291 |
+
expert_output = expert(expert_input).to(output.dtype)
|
| 292 |
weights = top_k_weights[token_indices, slot_indices].unsqueeze(-1)
|
| 293 |
output.index_add_(0, token_indices, expert_output * weights)
|
| 294 |
|
|
|
|
| 297 |
|
| 298 |
def load_balancing_loss(router_probs: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor:
|
| 299 |
"""Auxiliary loss to encourage balanced expert usage."""
|
|
|
|
|
|
|
|
|
|
| 300 |
prob_per_expert = router_probs.mean(dim=0)
|
| 301 |
+
target_mean = prob_per_expert.mean()
|
| 302 |
+
return (prob_per_expert - target_mean).square().sum() * num_experts
|
| 303 |
|
| 304 |
|
| 305 |
def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 314 |
super().__init__()
|
| 315 |
|
| 316 |
self.k = getattr(config, "projector_pool_stride", 4)
|
|
|
|
| 317 |
encoder_dim = config.encoder_dim
|
| 318 |
+
|
| 319 |
+
# Depthwise Conv for temporal mixing
|
| 320 |
+
self.temporal_conv = nn.Conv1d(
|
| 321 |
+
encoder_dim, encoder_dim, kernel_size=3, padding=1, groups=encoder_dim
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
in_dim = encoder_dim * self.k
|
| 325 |
out_dim = config.llm_dim
|
| 326 |
hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim
|
|
|
|
| 331 |
self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
|
| 332 |
|
| 333 |
self.moe = SharedMoEBlock(in_dim, hidden_dim, out_dim, self.num_experts, self.top_k)
|
| 334 |
+
self._init_weights()
|
| 335 |
|
| 336 |
+
def _init_weights(self):
|
| 337 |
with torch.no_grad():
|
| 338 |
nn.init.orthogonal_(self.moe.shared_expert.gate_proj.weight)
|
| 339 |
nn.init.orthogonal_(self.moe.shared_expert.up_proj.weight)
|
|
|
|
| 344 |
nn.init.orthogonal_(expert.up_proj.weight)
|
| 345 |
nn.init.orthogonal_(expert.down_proj.weight, gain=0.01)
|
| 346 |
|
| 347 |
+
def get_output_length(self, input_length: int) -> int:
|
| 348 |
+
"""Calculate output sequence length given input length."""
|
| 349 |
+
# Temporal pooling with stride k
|
| 350 |
+
if input_length % self.k:
|
| 351 |
+
input_length += self.k - input_length % self.k
|
| 352 |
+
return input_length // self.k
|
| 353 |
+
|
| 354 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 355 |
batch_size, seq_len, dim = x.size()
|
| 356 |
|
|
|
|
| 358 |
if x.dtype != target_dtype:
|
| 359 |
x = x.to(target_dtype)
|
| 360 |
|
| 361 |
+
# Temporal Context Injection
|
| 362 |
+
x_ctx = x.transpose(1, 2)
|
| 363 |
+
x_ctx = self.temporal_conv(x_ctx)
|
| 364 |
+
x = x + x_ctx.transpose(1, 2)
|
| 365 |
+
|
| 366 |
if seq_len % self.k:
|
| 367 |
x = F.pad(x, (0, 0, 0, self.k - seq_len % self.k))
|
| 368 |
|
|
|
|
| 380 |
return self.aux_loss_coef * balance + self.z_loss_coef * z
|
| 381 |
|
| 382 |
|
| 383 |
+
# =============================================================================
|
| 384 |
+
# QFormer Projector (Granite-style)
|
| 385 |
+
# =============================================================================
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class QFormerAudioProjector(nn.Module):
|
| 389 |
+
"""
|
| 390 |
+
BLIP-2 QFormer projector with learnable queries.
|
| 391 |
+
|
| 392 |
+
Based on GraniteSpeechEncoderProjector - uses a QFormer model with learnable
|
| 393 |
+
query embeddings to compress and project audio encoder outputs. The audio
|
| 394 |
+
sequence is processed in windows and downsampled via cross-attention.
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
def __init__(self, config):
|
| 398 |
+
super().__init__()
|
| 399 |
+
|
| 400 |
+
encoder_dim = config.encoder_dim
|
| 401 |
+
llm_dim = config.llm_dim
|
| 402 |
+
|
| 403 |
+
# Window and downsampling parameters (Granite defaults: window=15, downsample=5)
|
| 404 |
+
self.window_size = getattr(config, "qformer_window_size", 15)
|
| 405 |
+
self.downsample_rate = getattr(config, "downsample_rate", 5)
|
| 406 |
+
self.num_queries = self.window_size // self.downsample_rate
|
| 407 |
+
|
| 408 |
+
# QFormer hidden size (matches encoder for cross-attention)
|
| 409 |
+
qformer_hidden = getattr(config, "qformer_hidden_size", None) or encoder_dim
|
| 410 |
+
qformer_num_layers = getattr(config, "qformer_num_layers", 2)
|
| 411 |
+
qformer_num_heads = getattr(config, "qformer_num_heads", 16)
|
| 412 |
+
qformer_intermediate = getattr(config, "qformer_intermediate_size", None) or (
|
| 413 |
+
qformer_hidden * 4
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# Learnable query embeddings (Granite uses std=1.0)
|
| 417 |
+
self.query = nn.Parameter(torch.zeros(1, self.num_queries, qformer_hidden))
|
| 418 |
+
self.query.data.normal_(mean=0.0, std=1.0)
|
| 419 |
+
|
| 420 |
+
# Optional projection if encoder dim != qformer hidden
|
| 421 |
+
if encoder_dim != qformer_hidden:
|
| 422 |
+
self.encoder_proj = nn.Linear(encoder_dim, qformer_hidden, bias=False)
|
| 423 |
+
else:
|
| 424 |
+
self.encoder_proj = None
|
| 425 |
+
|
| 426 |
+
# Configure QFormer to match Granite's exact config
|
| 427 |
+
qformer_config = Blip2QFormerConfig(
|
| 428 |
+
hidden_size=qformer_hidden,
|
| 429 |
+
num_hidden_layers=qformer_num_layers,
|
| 430 |
+
num_attention_heads=qformer_num_heads,
|
| 431 |
+
intermediate_size=qformer_intermediate,
|
| 432 |
+
encoder_hidden_size=qformer_hidden,
|
| 433 |
+
cross_attention_frequency=1,
|
| 434 |
+
# Granite-specific settings
|
| 435 |
+
hidden_act="gelu",
|
| 436 |
+
attention_probs_dropout_prob=0.1,
|
| 437 |
+
hidden_dropout_prob=0.1,
|
| 438 |
+
layer_norm_eps=1e-12,
|
| 439 |
+
initializer_range=0.02,
|
| 440 |
+
)
|
| 441 |
+
self.qformer = AutoModel.from_config(qformer_config)
|
| 442 |
+
|
| 443 |
+
# Final projection to LLM dimension (Granite uses bias=True)
|
| 444 |
+
self.linear = nn.Linear(qformer_hidden, llm_dim)
|
| 445 |
+
|
| 446 |
+
def get_output_length(self, input_length: int) -> int:
|
| 447 |
+
"""Calculate output sequence length given input length."""
|
| 448 |
+
# QFormer uses window-based processing with num_queries per window
|
| 449 |
+
nblocks = math.ceil(input_length / self.window_size)
|
| 450 |
+
return nblocks * self.num_queries
|
| 451 |
+
|
| 452 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 453 |
+
"""
|
| 454 |
+
Args:
|
| 455 |
+
hidden_states: [batch_size, seq_len, encoder_dim]
|
| 456 |
+
|
| 457 |
+
Returns:
|
| 458 |
+
projected: [batch_size, num_output_tokens, llm_dim]
|
| 459 |
+
"""
|
| 460 |
+
batch_size, seq_len, dim = hidden_states.size()
|
| 461 |
+
|
| 462 |
+
# Ensure float dtype for QFormer
|
| 463 |
+
target_dtype = self.query.dtype
|
| 464 |
+
if hidden_states.dtype != target_dtype:
|
| 465 |
+
hidden_states = hidden_states.to(target_dtype)
|
| 466 |
+
|
| 467 |
+
# Optional encoder projection
|
| 468 |
+
if self.encoder_proj is not None:
|
| 469 |
+
hidden_states = self.encoder_proj(hidden_states)
|
| 470 |
+
|
| 471 |
+
# Compute number of windows and pad to fit
|
| 472 |
+
nblocks = math.ceil(seq_len / self.window_size)
|
| 473 |
+
pad = nblocks * self.window_size - seq_len
|
| 474 |
+
if pad > 0:
|
| 475 |
+
hidden_states = F.pad(hidden_states, (0, 0, 0, pad), "constant", 0)
|
| 476 |
+
|
| 477 |
+
# Reshape to process each window: [batch*nblocks, window_size, dim]
|
| 478 |
+
effective_batch = batch_size * nblocks
|
| 479 |
+
hidden_states = hidden_states.view(effective_batch, self.window_size, -1)
|
| 480 |
+
|
| 481 |
+
# Expand queries to match batch size
|
| 482 |
+
query_embeds = self.query.expand(effective_batch, -1, -1)
|
| 483 |
+
|
| 484 |
+
# QFormer cross-attention
|
| 485 |
+
query_output = self.qformer(
|
| 486 |
+
query_embeds=query_embeds,
|
| 487 |
+
encoder_hidden_states=hidden_states,
|
| 488 |
+
return_dict=True,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
# Reshape back: [batch, nblocks * num_queries, hidden]
|
| 492 |
+
output_tokens = nblocks * self.num_queries
|
| 493 |
+
query_proj = query_output.last_hidden_state.view(batch_size, output_tokens, -1)
|
| 494 |
+
|
| 495 |
+
# Project to LLM dimension
|
| 496 |
+
return self.linear(query_proj)
|
| 497 |
+
|
| 498 |
+
|
| 499 |
# =============================================================================
|
| 500 |
# Projector Registry
|
| 501 |
# =============================================================================
|
| 502 |
|
| 503 |
PROJECTOR_CLASSES = {
|
| 504 |
"mlp": MLPAudioProjector,
|
| 505 |
+
"mosa": MOSAProjector,
|
|
|
|
|
|
|
| 506 |
"shared_moe": SharedMoEAudioProjector,
|
| 507 |
+
"qformer": QFormerAudioProjector,
|
| 508 |
}
|
tokenizer.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:33b674fb8444e2553eae8f1b261093371920a28ef75b5c18f4deb3f9217ed0ba
|
| 3 |
+
size 11422834
|
tokenizer_config.json
CHANGED
|
Binary files a/tokenizer_config.json and b/tokenizer_config.json differ
|
|
|