File size: 2,236 Bytes
4cffcdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# --------------------------------------------------------
# SenseTime
# Copyright (c) 2025 SenseTime
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import copy

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers import LlamaConfig, Qwen2Config

logger = logging.get_logger(__name__)

class VoiceLMConfig(PretrainedConfig):
    def __init__(
            self,
            llm_input_size = 896,
            llm_output_size = 896,
            speech_token_size = 6561,
            length_normalized_loss = True,
            lsm_weight = 0, 
            llm_config=None,
            sampling_config={
                'top_p': 0.8,
                'top_k': 25,
                'win_size': 10,
                'tau_r': 0.1,
                },
            **kwargs):
        super().__init__(**kwargs)

        self.llm_input_size = llm_input_size
        self.llm_output_size = llm_output_size
        self.speech_token_size = speech_token_size
        self.length_normalized_loss = length_normalized_loss
        self.lsm_weight = lsm_weight
        self.sampling_config = sampling_config
        
        if llm_config is None:
            llm_config = {}
            logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')

        self.llm_config = Qwen2Config(**llm_config)
        pass

    def to_dict(self):
        """
        Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].

        Returns:
            `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
        """
        output = copy.deepcopy(self.__dict__)
        output['llm_input_size'] = self.llm_input_size
        output['llm_output_size'] = self.llm_output_size
        output['speech_token_size'] = self.speech_token_size
        output['length_normalized_loss'] = self.length_normalized_loss
        output['lsm_weight'] = self.lsm_weight
        output['sampling_config'] = self.sampling_config
        output['llm_config'] = self.llm_config.to_dict()

        return output