Apriel-H1-15b-Thinker-SFT / configuration_apriel_h.py
nitsanluke's picture
initial commit
501730b
from transformers import MistralConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
ssm_config_default = {
"d_state": 64,
"n_qk_heads": 32,
"expand": 1,
"chunk_size": 128,
"activation": "identity",
"bias": False,
"d_conv": 4,
"d_inner": 32 * 128,
"d_xb": None, # will be set to model dim
"dt_rank": "auto",
"dt_min": 0.001,
"dt_max": 0.1,
"dt_init": "random",
"dt_scale": 1.0,
"dt_init_floor": 1e-4,
"conv_bias": True,
}
class AprielHConfig(MistralConfig):
model_type = "apriel_h"
def __init__(self, hybrid_block_layout=["m2"], ssm_cfg=None, **kwargs):
super().__init__(**kwargs)
self.hybrid_block_layout = hybrid_block_layout
self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3
self.ssm_cfg = ssm_cfg or ssm_config_default
for k, v in ssm_config_default.items():
if k not in self.ssm_cfg:
self.ssm_cfg[k] = v # to make sure all elements are present in the config