File size: 2,236 Bytes
4cffcdc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
# --------------------------------------------------------
# SenseTime
# Copyright (c) 2025 SenseTime
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import copy
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers import LlamaConfig, Qwen2Config
logger = logging.get_logger(__name__)
class VoiceLMConfig(PretrainedConfig):
def __init__(
self,
llm_input_size = 896,
llm_output_size = 896,
speech_token_size = 6561,
length_normalized_loss = True,
lsm_weight = 0,
llm_config=None,
sampling_config={
'top_p': 0.8,
'top_k': 25,
'win_size': 10,
'tau_r': 0.1,
},
**kwargs):
super().__init__(**kwargs)
self.llm_input_size = llm_input_size
self.llm_output_size = llm_output_size
self.speech_token_size = speech_token_size
self.length_normalized_loss = length_normalized_loss
self.lsm_weight = lsm_weight
self.sampling_config = sampling_config
if llm_config is None:
llm_config = {}
logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
self.llm_config = Qwen2Config(**llm_config)
pass
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output['llm_input_size'] = self.llm_input_size
output['llm_output_size'] = self.llm_output_size
output['speech_token_size'] = self.speech_token_size
output['length_normalized_loss'] = self.length_normalized_loss
output['lsm_weight'] = self.lsm_weight
output['sampling_config'] = self.sampling_config
output['llm_config'] = self.llm_config.to_dict()
return output
|