Step-Audio-R1.1 / configuration_step_audio_2.py
moevis's picture
Upload folder using huggingface_hub
056d3e1 verified
from typing import Optional, Union
from transformers import Qwen2Config
from transformers.configuration_utils import PretrainedConfig
class StepAudio2EncoderConfig(PretrainedConfig):
model_type = "step_audio_2_encoder"
def __init__(
self,
n_mels=128,
n_audio_ctx=1500,
n_audio_state=512,
n_audio_head=8,
n_audio_layer=6,
llm_dim=4096,
kernel_size=3,
adapter_stride=2,
**kwargs,
):
self.n_mels = n_mels
self.n_audio_ctx = n_audio_ctx
self.n_audio_state = n_audio_state
self.n_audio_head = n_audio_head
self.n_audio_layer = n_audio_layer
self.llm_dim = llm_dim
self.kernel_size = kernel_size
self.adapter_stride = adapter_stride
super().__init__(**kwargs)
class StepAudio2TextConfig(PretrainedConfig):
model_type = "step_audio_2_text"
def __init__(
self,
vocab_size=64012,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=48,
num_attention_heads=32,
num_attention_groups=4,
num_key_value_heads=4,
hidden_act="silu",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
rope_theta=1000000.0,
rope_scaling=None,
eos_token_id=None,
**kwargs
):
if eos_token_id is not None:
if isinstance(eos_token_id, list):
eos_token_id = list(set([151643, 151645, 151665] + eos_token_id))
else:
eos_token_id = [151643, 151645, 151665, eos_token_id]
else:
eos_token_id = [151643, 151645, 151665]
super().__init__(
eos_token_id=eos_token_id,
**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_key_value_heads = num_key_value_heads
assert self.num_attention_groups == self.num_key_value_heads, "num_attention_groups must be equal to num_key_value_heads"
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
# Get torch_dtype from kwargs if provided
torch_dtype = kwargs.get("torch_dtype", getattr(self, "torch_dtype", "bfloat16"))
self.text_config = Qwen2Config(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
hidden_act=hidden_act,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
rms_norm_eps=rms_norm_eps,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
architectures=["Qwen2ForCausalLM"],
torch_dtype=torch_dtype,
)
class StepAudio2Config(PretrainedConfig):
model_type = "step_audio_2"
architectures = ["StepAudio2ForCausalLM"]
# Support alternative model types and architectures for 32B model
# This allows the config to work with both "step_audio_2" and "step_audio_qwen2" model types
def __init__(
self,
audio_encoder_config :Optional[Union[dict, StepAudio2EncoderConfig]] = None,
text_config: Optional[Union[dict, StepAudio2TextConfig]] = None,
use_sliding_window: bool = False,
sliding_window: Optional[int] = 2048,
max_window_layers: Optional[int] = None,
**kwargs
):
kwargs.setdefault("use_sliding_window", use_sliding_window)
kwargs.setdefault("sliding_window", sliding_window)
if max_window_layers is None:
max_window_layers = kwargs.get("num_hidden_layers", None)
kwargs.setdefault("max_window_layers", max_window_layers)
# Save torch_dtype if provided (for 32B model flat config)
if 'torch_dtype' in kwargs:
self.torch_dtype = kwargs['torch_dtype']
super().__init__(**kwargs)
# Support for flat config structure (32B model format)
# If text_config is None and we have flat config parameters, extract them
if text_config is None:
# Check if we have flat config parameters (32B model format)
flat_text_params = {}
text_param_names = [
'vocab_size', 'hidden_size', 'intermediate_size', 'num_hidden_layers',
'num_attention_heads', 'num_attention_groups', 'num_key_value_heads',
'hidden_act', 'max_position_embeddings', 'initializer_range',
'rms_norm_eps', 'rope_theta', 'rope_scaling', 'eos_token_id', 'pad_token_id'
]
for param_name in text_param_names:
if param_name in kwargs:
flat_text_params[param_name] = kwargs[param_name]
# Set default hidden_act if not present (32B model config doesn't have it)
if 'hidden_act' not in flat_text_params:
flat_text_params['hidden_act'] = 'silu'
# Set default initializer_range if not present
if 'initializer_range' not in flat_text_params:
flat_text_params['initializer_range'] = 0.02
# Also check for torch_dtype
if 'torch_dtype' in kwargs:
flat_text_params['torch_dtype'] = kwargs['torch_dtype']
if flat_text_params:
# We have flat config, use it to build text_config
text_config = StepAudio2TextConfig(**flat_text_params).text_config
else:
# No flat config, use defaults
text_config = StepAudio2TextConfig().text_config
elif isinstance(text_config, dict):
text_config = StepAudio2TextConfig(**text_config).text_config
self.text_config = text_config
if audio_encoder_config is None:
# Check if we have flat audio_encoder_config in kwargs
if 'audio_encoder_config' in kwargs and isinstance(kwargs['audio_encoder_config'], dict):
self.audio_encoder_config = StepAudio2EncoderConfig(**kwargs['audio_encoder_config'])
else:
self.audio_encoder_config = StepAudio2EncoderConfig()
elif isinstance(audio_encoder_config, dict):
self.audio_encoder_config = StepAudio2EncoderConfig(**audio_encoder_config)