InteractiveOmni-8B / configuration_voicelm.py
tongww's picture
upload initial model
4cffcdc verified
# --------------------------------------------------------
# SenseTime
# Copyright (c) 2025 SenseTime
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import copy
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers import LlamaConfig, Qwen2Config
logger = logging.get_logger(__name__)
class VoiceLMConfig(PretrainedConfig):
def __init__(
self,
llm_input_size = 896,
llm_output_size = 896,
speech_token_size = 6561,
length_normalized_loss = True,
lsm_weight = 0,
llm_config=None,
sampling_config={
'top_p': 0.8,
'top_k': 25,
'win_size': 10,
'tau_r': 0.1,
},
**kwargs):
super().__init__(**kwargs)
self.llm_input_size = llm_input_size
self.llm_output_size = llm_output_size
self.speech_token_size = speech_token_size
self.length_normalized_loss = length_normalized_loss
self.lsm_weight = lsm_weight
self.sampling_config = sampling_config
if llm_config is None:
llm_config = {}
logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
self.llm_config = Qwen2Config(**llm_config)
pass
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output['llm_input_size'] = self.llm_input_size
output['llm_output_size'] = self.llm_output_size
output['speech_token_size'] = self.speech_token_size
output['length_normalized_loss'] = self.length_normalized_loss
output['lsm_weight'] = self.lsm_weight
output['sampling_config'] = self.sampling_config
output['llm_config'] = self.llm_config.to_dict()
return output