|
|
from transformers import MistralConfig |
|
|
from transformers.utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
ssm_config_default = { |
|
|
"d_state": 64, |
|
|
"n_qk_heads": 32, |
|
|
"expand": 1, |
|
|
"chunk_size": 128, |
|
|
"activation": "identity", |
|
|
"bias": False, |
|
|
"d_conv": 4, |
|
|
"d_inner": 32 * 128, |
|
|
"d_xb": None, |
|
|
"dt_rank": "auto", |
|
|
"dt_min": 0.001, |
|
|
"dt_max": 0.1, |
|
|
"dt_init": "random", |
|
|
"dt_scale": 1.0, |
|
|
"dt_init_floor": 1e-4, |
|
|
"conv_bias": True, |
|
|
} |
|
|
|
|
|
|
|
|
class AprielHConfig(MistralConfig): |
|
|
model_type = "apriel_h" |
|
|
|
|
|
def __init__(self, hybrid_block_layout=["m2"], ssm_cfg=None, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.hybrid_block_layout = hybrid_block_layout |
|
|
self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads |
|
|
self.ssm_cfg = ssm_cfg or ssm_config_default |
|
|
|
|
|
for k, v in ssm_config_default.items(): |
|
|
if k not in self.ssm_cfg: |
|
|
self.ssm_cfg[k] = v |
|
|
|