|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
|
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
from transformers.utils import logging |
|
|
from transformers import LlamaConfig, Qwen2Config |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class VoiceLMConfig(PretrainedConfig): |
|
|
def __init__( |
|
|
self, |
|
|
llm_input_size = 896, |
|
|
llm_output_size = 896, |
|
|
speech_token_size = 6561, |
|
|
length_normalized_loss = True, |
|
|
lsm_weight = 0, |
|
|
llm_config=None, |
|
|
sampling_config={ |
|
|
'top_p': 0.8, |
|
|
'top_k': 25, |
|
|
'win_size': 10, |
|
|
'tau_r': 0.1, |
|
|
}, |
|
|
**kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.llm_input_size = llm_input_size |
|
|
self.llm_output_size = llm_output_size |
|
|
self.speech_token_size = speech_token_size |
|
|
self.length_normalized_loss = length_normalized_loss |
|
|
self.lsm_weight = lsm_weight |
|
|
self.sampling_config = sampling_config |
|
|
|
|
|
if llm_config is None: |
|
|
llm_config = {} |
|
|
logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).') |
|
|
|
|
|
self.llm_config = Qwen2Config(**llm_config) |
|
|
pass |
|
|
|
|
|
def to_dict(self): |
|
|
""" |
|
|
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. |
|
|
|
|
|
Returns: |
|
|
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, |
|
|
""" |
|
|
output = copy.deepcopy(self.__dict__) |
|
|
output['llm_input_size'] = self.llm_input_size |
|
|
output['llm_output_size'] = self.llm_output_size |
|
|
output['speech_token_size'] = self.speech_token_size |
|
|
output['length_normalized_loss'] = self.length_normalized_loss |
|
|
output['lsm_weight'] = self.lsm_weight |
|
|
output['sampling_config'] = self.sampling_config |
|
|
output['llm_config'] = self.llm_config.to_dict() |
|
|
|
|
|
return output |
|
|
|