kaizen9 commited on
Commit
6250c70
·
verified ·
1 Parent(s): 301552e

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
__pycache__/modeling_bailing_moe_v2.cpython-310.pyc ADDED
Binary file (30.7 kB). View file
 
added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
chat_template.jinja ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- for message in messages %}
18
+ {%- if message.content is string %}
19
+ {%- set content = message.content %}
20
+ {%- else %}
21
+ {%- set content = '' %}
22
+ {%- endif %}
23
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
24
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
25
+ {%- elif message.role == "assistant" %}
26
+ {{- '<|im_start|>' + message.role + '\n' + content }}
27
+ {%- if message.tool_calls %}
28
+ {%- for tool_call in message.tool_calls %}
29
+ {%- if (loop.first and content) or (not loop.first) %}
30
+ {{- '\n' }}
31
+ {%- endif %}
32
+ {%- if tool_call.function %}
33
+ {%- set tool_call = tool_call.function %}
34
+ {%- endif %}
35
+ {{- '<tool_call>\n{"name": "' }}
36
+ {{- tool_call.name }}
37
+ {{- '", "arguments": ' }}
38
+ {%- if tool_call.arguments is string %}
39
+ {{- tool_call.arguments }}
40
+ {%- else %}
41
+ {{- tool_call.arguments | tojson }}
42
+ {%- endif %}
43
+ {{- '}\n</tool_call>' }}
44
+ {%- endfor %}
45
+ {%- endif %}
46
+ {{- '<|im_end|>\n' }}
47
+ {%- elif message.role == "tool" %}
48
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
49
+ {{- '<|im_start|>user' }}
50
+ {%- endif %}
51
+ {{- '\n<tool_response>\n' }}
52
+ {{- content }}
53
+ {{- '\n</tool_response>' }}
54
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
55
+ {{- '<|im_end|>\n' }}
56
+ {%- endif %}
57
+ {%- endif %}
58
+ {%- endfor %}
59
+ {%- if add_generation_prompt %}
60
+ {{- '<|im_start|>assistant\n' }}
61
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_moe_implementation": "fused",
3
+ "architectures": [
4
+ "BailingMoeV2ForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_bailing_moe_v2.BailingMoeV2Config",
9
+ "AutoModel": "modeling_bailing_moe_v2.BailingMoeV2Model",
10
+ "AutoModelForCausalLM": "modeling_bailing_moe_v2.BailingMoeV2ForCausalLM"
11
+ },
12
+ "bos_token_id": 151643,
13
+ "dtype": "bfloat16",
14
+ "embedding_dropout": 0.0,
15
+ "eos_token_id": 151645,
16
+ "first_k_dense_replace": 0,
17
+ "hc": false,
18
+ "head_dim": 128,
19
+ "hidden_act": "silu",
20
+ "hidden_size": 2048,
21
+ "initializer_range": 0.02,
22
+ "intermediate_size": 5120,
23
+ "layer_types": [
24
+ "sliding_attention",
25
+ "sliding_attention",
26
+ "sliding_attention",
27
+ "sliding_attention",
28
+ "sliding_attention",
29
+ "sliding_attention",
30
+ "sliding_attention",
31
+ "full_attention",
32
+ "sliding_attention",
33
+ "sliding_attention",
34
+ "sliding_attention",
35
+ "full_attention",
36
+ "sliding_attention",
37
+ "sliding_attention",
38
+ "full_attention",
39
+ "sliding_attention",
40
+ "sliding_attention",
41
+ "full_attention",
42
+ "sliding_attention",
43
+ "sliding_attention",
44
+ "sliding_attention",
45
+ "full_attention",
46
+ "sliding_attention",
47
+ "sliding_attention",
48
+ "sliding_attention",
49
+ "sliding_attention",
50
+ "sliding_attention",
51
+ "sliding_attention",
52
+ "sliding_attention",
53
+ "sliding_attention"
54
+ ],
55
+ "max_position_embeddings": 32768,
56
+ "max_window_layers": 30,
57
+ "moe_intermediate_size": 512,
58
+ "moe_router_enable_expert_bias": true,
59
+ "moe_shared_expert_intermediate_size": 512,
60
+ "mtp_loss_scaling_factor": 0,
61
+ "n_group": 8,
62
+ "norm_topk_prob": true,
63
+ "num_attention_heads": 16,
64
+ "num_experts": 224,
65
+ "num_experts_per_tok": 8,
66
+ "num_hidden_layers": 30,
67
+ "num_key_value_heads": 4,
68
+ "num_nextn_predict_layers": 0,
69
+ "num_shared_experts": 0,
70
+ "output_dropout": 0.0,
71
+ "output_router_logits": true,
72
+ "pad_token_id": null,
73
+ "partial_rotary_factor": 0.5,
74
+ "pruning_info": {
75
+ "original_experts": 256,
76
+ "original_model_path": "5kling-fuse_heal",
77
+ "pruned_experts": 224,
78
+ "pruning_date": "2026-01-16T05:26:11.661656",
79
+ "pruning_method": "MoP"
80
+ },
81
+ "quantize": false,
82
+ "rms_norm_eps": 1e-06,
83
+ "rope_scaling": null,
84
+ "rope_theta": 10000,
85
+ "routed_scaling_factor": 2.5,
86
+ "router_dtype": "fp32",
87
+ "score_function": "sigmoid",
88
+ "sliding_window": 512,
89
+ "tie_word_embeddings": false,
90
+ "topk_group": 4,
91
+ "transformers_version": "4.57.1",
92
+ "use_bias": false,
93
+ "use_cache": true,
94
+ "use_qk_norm": true,
95
+ "use_qkv_bias": false,
96
+ "use_rmsnorm": true,
97
+ "vocab_size": 151936
98
+ }
configuration_bailing_moe_v2.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Bailing MoE V2 model configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class BailingMoeV2Config(PretrainedConfig):
7
+ def __init__(
8
+ self,
9
+ vocab_size=157184,
10
+ hidden_size=2048,
11
+ intermediate_size=5120,
12
+ num_hidden_layers=20,
13
+ num_attention_heads=16,
14
+ num_key_value_heads=4,
15
+ hidden_act="silu",
16
+ use_qkv_bias=False, # bailing only
17
+ use_bias=False, # bailing only
18
+ rms_norm_eps=1e-06,
19
+ tie_word_embeddings=False, # PretrainedConfig key, here change default value.
20
+ embedding_dropout=0.0,
21
+ attention_dropout=0.0,
22
+ output_dropout=0.0,
23
+ initializer_range=0.02,
24
+ max_position_embeddings=32768,
25
+ rope_theta=600000.0,
26
+ use_cache=True,
27
+ max_window_layers=20,
28
+ rope_scaling=None,
29
+ pad_token_id=156892,
30
+ eos_token_id=156892,
31
+ num_experts=256,
32
+ num_shared_experts=1,
33
+ num_experts_per_tok=8,
34
+ n_group=8,
35
+ topk_group=4,
36
+ moe_intermediate_size=512,
37
+ first_k_dense_replace=1,
38
+ head_dim=128,
39
+ output_router_logits=False,
40
+ use_qk_norm=True,
41
+ num_nextn_predict_layers=0,
42
+ mtp_loss_scaling_factor=0,
43
+ moe_router_enable_expert_bias=True,
44
+ routed_scaling_factor=1.0,
45
+ layer_types=None,
46
+ sliding_window=256,
47
+ hc_expand=1,
48
+ quantize=False,
49
+ **kwargs,
50
+ ):
51
+ self.num_hidden_layers = num_hidden_layers
52
+ self.vocab_size = vocab_size
53
+ self.hidden_size = hidden_size
54
+ self.intermediate_size = intermediate_size
55
+ self.num_attention_heads = num_attention_heads
56
+ self.num_key_value_heads = num_key_value_heads
57
+ self.hidden_act = hidden_act
58
+ self.use_qkv_bias = use_qkv_bias
59
+ self.use_bias = use_bias
60
+ self.rms_norm_eps = rms_norm_eps
61
+ self.embedding_dropout = embedding_dropout
62
+ self.attention_dropout = attention_dropout
63
+ self.output_dropout = output_dropout
64
+ self.num_nextn_predict_layers = num_nextn_predict_layers
65
+ self.mtp_loss_scaling_factor = mtp_loss_scaling_factor
66
+ self.initializer_range = initializer_range
67
+ self.max_position_embeddings = max_position_embeddings
68
+ self.rope_theta = rope_theta
69
+ self.use_cache = use_cache
70
+ self.max_window_layers = max_window_layers
71
+ self.head_dim = head_dim or self.hidden_size // self.num_attention_heads
72
+ self.rope_scaling = rope_scaling
73
+ self.use_qk_norm = use_qk_norm
74
+ self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
75
+ self.routed_scaling_factor = routed_scaling_factor
76
+ self.quantize = quantize
77
+
78
+ # SWA configs
79
+ self.layer_types = layer_types
80
+ if self.layer_types is None:
81
+ self.layer_types = ["full_attention" for i in range(self.num_hidden_layers)]
82
+ self.sliding_window = sliding_window
83
+
84
+ # HC configs
85
+ if hc_expand > 1:
86
+ self.hc_expand = hc_expand
87
+ self.hc = True
88
+ else:
89
+ self.hc = False
90
+
91
+ # MoE configs
92
+ self.num_experts = num_experts
93
+ self.num_shared_experts = num_shared_experts
94
+ self.num_experts_per_tok = num_experts_per_tok
95
+ self.n_group = n_group
96
+ self.topk_group = topk_group
97
+ self.moe_intermediate_size = moe_intermediate_size
98
+ self.first_k_dense_replace = first_k_dense_replace
99
+ self.output_router_logits = output_router_logits
100
+
101
+ super().__init__(
102
+ pad_token_id=pad_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
103
+ )
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 151643,
4
+ "eos_token_id": 151645,
5
+ "transformers_version": "4.57.1"
6
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:661fe530d280577508bd59106f8469a541de31028b21312ed13e52976b5c88f8
3
+ size 9999232832
model-00002-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a89c5d4e06b59adfb23443fac1b73d7c106af0640d51586b2604a3f6183946d
3
+ size 9999814432
model-00003-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de6c42187e59abc48612b35ebcba74580e606f9aa162df75be0d4c2689f6c0ee
3
+ size 9999814888
model-00004-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9b1daaf3b7962dfb963c2133e10dc9feb99c05d0a1339a5a07447c656c74a3a
3
+ size 9999812248
model-00005-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e5e71eea2c14ee69fabde54abe4f39870d83e14da75057c914562970789d239
3
+ size 4184064320
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_bailing_moe_v2.py ADDED
@@ -0,0 +1,1172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import triton
9
+ import triton.language as tl
10
+ from torch import nn
11
+ from torch.library import triton_op, wrap_triton
12
+ from transformers.activations import ACT2FN
13
+ from transformers.cache_utils import Cache, DynamicCache
14
+ from transformers.generation.utils import GenerationMixin
15
+ from transformers.modeling_outputs import MoeModelOutputWithPast
16
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
17
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
18
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
19
+ from transformers.utils import (
20
+ ModelOutput,
21
+ add_start_docstrings,
22
+ add_start_docstrings_to_model_forward,
23
+ )
24
+ from transformers.utils import logging as hf_logging
25
+ from transformers.utils.import_utils import is_torch_fx_available
26
+
27
+ from .configuration_bailing_moe_v2 import BailingMoeV2Config
28
+
29
+
30
+ logger = hf_logging.get_logger(__name__)
31
+ _CONFIG_FOR_DOC = "BailingMoeV2Config"
32
+
33
+
34
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
35
+ if is_torch_fx_available():
36
+ if not is_torch_greater_or_equal_than_1_13:
37
+ import torch.fx # noqa: F401
38
+
39
+
40
+ # quantizers
41
+ def twn_torch_ref(W):
42
+ W_fp = W.float()
43
+ dim = -1 # Always last dim
44
+ absW = W_fp.abs()
45
+ th = absW.mean(dim, keepdim=True) * 0.7
46
+ mask = absW > th
47
+ mask_f = mask.float()
48
+ alpha = (absW * mask_f).sum(dim, keepdim=True) / mask_f.sum(dim, keepdim=True).clamp(min=1.0)
49
+ out = W_fp.sign() * mask_f * alpha
50
+ return out.to(W.dtype)
51
+
52
+
53
+ twn_torch_compiled = torch.compile(twn_torch_ref, mode="max-autotune")
54
+
55
+
56
+ @triton.autotune(
57
+ configs=[
58
+ triton.Config({"BLOCK_SIZE": 128}, num_warps=4, num_stages=3),
59
+ triton.Config({"BLOCK_SIZE": 256}, num_warps=4, num_stages=3),
60
+ triton.Config({"BLOCK_SIZE": 512}, num_warps=8, num_stages=3),
61
+ triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=3),
62
+ triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=3),
63
+ ],
64
+ key=["N"],
65
+ )
66
+ @triton.jit
67
+ def twn_quant_row_merged_bf16_kernel(
68
+ w_ptr,
69
+ out_ptr,
70
+ M,
71
+ N,
72
+ stride_wm,
73
+ stride_wn,
74
+ stride_om,
75
+ stride_on,
76
+ BLOCK_SIZE: tl.constexpr,
77
+ ):
78
+ pid = tl.program_id(0)
79
+ if pid >= M:
80
+ return
81
+
82
+ row_w_ptr = w_ptr + pid * stride_wm
83
+ row_out_ptr = out_ptr + pid * stride_om
84
+
85
+ # --- Pass 1: Threshold ---
86
+ sum_abs = 0.0
87
+ count = 0.0
88
+ for off in range(0, N, BLOCK_SIZE):
89
+ cols = off + tl.arange(0, BLOCK_SIZE)
90
+ mask = cols < N
91
+ val = tl.load(row_w_ptr + cols * stride_wn, mask=mask, other=0.0).to(tl.float32)
92
+ val_abs = tl.abs(val)
93
+ sum_abs += tl.sum(val_abs, axis=0)
94
+ count += tl.sum(mask.to(tl.float32), axis=0)
95
+
96
+ th = (sum_abs / tl.maximum(count, 1.0)) * 0.7
97
+
98
+ # --- Pass 2: Alpha ---
99
+ masked_sum = 0.0
100
+ masked_count = 0.0
101
+ for off in range(0, N, BLOCK_SIZE):
102
+ cols = off + tl.arange(0, BLOCK_SIZE)
103
+ mask = cols < N
104
+ val = tl.load(row_w_ptr + cols * stride_wn, mask=mask, other=0.0).to(tl.float32)
105
+ val_abs = tl.abs(val)
106
+ is_selected = (val_abs > th).to(tl.float32)
107
+ masked_sum += tl.sum(val_abs * is_selected, axis=0)
108
+ masked_count += tl.sum(is_selected, axis=0)
109
+
110
+ alpha = masked_sum / tl.maximum(masked_count, 1.0)
111
+
112
+ # --- Pass 3: Output ---
113
+ for off in range(0, N, BLOCK_SIZE):
114
+ cols = off + tl.arange(0, BLOCK_SIZE)
115
+ mask = cols < N
116
+ val = tl.load(row_w_ptr + cols * stride_wn, mask=mask, other=0.0).to(tl.float32)
117
+ is_selected = tl.abs(val) > th
118
+
119
+ # Output is -alpha, 0, or +alpha
120
+ sign = tl.where(val >= 0, alpha, -alpha)
121
+ out_val = tl.where(is_selected, sign, 0.0)
122
+
123
+ tl.store(row_out_ptr + cols * stride_on, out_val.to(tl.bfloat16), mask=mask)
124
+
125
+
126
+ @triton_op("grove_kernels::twn_triton", mutates_args={})
127
+ def twn_triton(W: torch.Tensor) -> torch.Tensor:
128
+ M, N = W.shape
129
+ out = torch.empty_like(W, dtype=torch.bfloat16)
130
+ grid = (M,)
131
+ wrap_triton(twn_quant_row_merged_bf16_kernel)[grid](
132
+ W,
133
+ out,
134
+ M,
135
+ N,
136
+ W.stride(0),
137
+ W.stride(1),
138
+ out.stride(0),
139
+ out.stride(1),
140
+ )
141
+ return out
142
+
143
+
144
+ class QuantizeTernary(torch.autograd.Function):
145
+ @staticmethod
146
+ def forward(ctx, input):
147
+ # with torch.no_grad():
148
+ if len(input.shape) == 3:
149
+ return twn_torch_ref(input) # fatser when
150
+ else:
151
+ return twn_triton(input)
152
+
153
+ @staticmethod
154
+ def backward(ctx, grad_output):
155
+ # Straight-Through Estimator: gradient is just passed through.
156
+ return grad_output, None
157
+
158
+
159
+ def quantize(input: torch.Tensor) -> torch.Tensor:
160
+ return QuantizeTernary.apply(input)
161
+
162
+
163
+ def conditionally_quantize(input: torch.Tensor, do_quantize: bool) -> torch.Tensor:
164
+ if do_quantize:
165
+ return quantize(input)
166
+ else:
167
+ return input
168
+
169
+
170
+ def quantize_weight_inplace(owner: nn.Module, weight_name: str, enabled: bool = True) -> bool:
171
+ """
172
+ Quantize `owner.<weight_name>` once and write back to the same Parameter storage.
173
+ Returns True if this call performed quantization, False if skipped/already-done.
174
+ """
175
+ if not enabled:
176
+ return False
177
+ done_attr = f"__inplace_quantized_{weight_name}"
178
+ if bool(getattr(owner, done_attr, False)):
179
+ return False
180
+
181
+ weight = getattr(owner, weight_name)
182
+ with torch.no_grad():
183
+ quantized = quantize(weight).to(device=weight.device, dtype=weight.dtype)
184
+ weight.data.copy_(quantized)
185
+ setattr(owner, done_attr, True)
186
+ return True
187
+
188
+
189
+ def conditionally_quantize_inplace_on_prefill(
190
+ owner: nn.Module,
191
+ weight_name: str,
192
+ do_quantize: bool,
193
+ *,
194
+ quantize_inplace_now: bool = False,
195
+ ) -> torch.Tensor:
196
+ """
197
+ In eval mode, quantize the target weight once (during prefill) and write it back in-place.
198
+ This avoids storing duplicate cached tensors while removing per-token quantization overhead.
199
+ """
200
+ weight = getattr(owner, weight_name)
201
+ if not do_quantize:
202
+ return weight
203
+ if owner.training:
204
+ return quantize(weight)
205
+
206
+ if not quantize_inplace_now:
207
+ return weight
208
+ quantize_weight_inplace(owner, weight_name, enabled=True)
209
+ return weight
210
+
211
+
212
+ @dataclass
213
+ class MoEV2CausalLMOutputWithPast(ModelOutput):
214
+ loss: Optional[torch.FloatTensor] = None
215
+ logits: Optional[torch.FloatTensor] = None
216
+ past_key_values: Optional[Cache] = None
217
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
218
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
219
+ z_loss: Optional[torch.FloatTensor] = None
220
+ aux_loss: Optional[torch.FloatTensor] = None
221
+ router_logits: Optional[tuple[torch.FloatTensor]] = None
222
+ mtp_loss: Optional[torch.FloatTensor] = None
223
+ mtp_logits: Optional[tuple[torch.FloatTensor, ...]] = None
224
+
225
+
226
+ class MoeV2ModelOutputWithPast(MoeModelOutputWithPast):
227
+ def __init__(self, mtp_hidden_states=None, aux_loss=0.0, **kwargs):
228
+ super().__init__(**kwargs)
229
+ self.mtp_hidden_states = mtp_hidden_states
230
+ self.aux_loss = aux_loss
231
+
232
+
233
+ class BailingMoeV2RotaryEmbedding(nn.Module):
234
+ def __init__(self, config: BailingMoeV2Config, device=None):
235
+ super().__init__()
236
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
237
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
238
+ else:
239
+ self.rope_type = "default"
240
+ self.max_seq_len_cached = config.max_position_embeddings
241
+ self.original_max_seq_len = config.max_position_embeddings
242
+
243
+ self.config = config
244
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
245
+
246
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
247
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
248
+ self.original_inv_freq = self.inv_freq
249
+
250
+ @torch.no_grad()
251
+ @dynamic_rope_update
252
+ def forward(self, x, position_ids):
253
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
254
+ position_ids_expanded = position_ids[:, None, :].float()
255
+
256
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
257
+ with torch.autocast(device_type=device_type, enabled=False):
258
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
259
+ emb = torch.cat((freqs, freqs), dim=-1)
260
+ cos = emb.cos() * self.attention_scaling
261
+ sin = emb.sin() * self.attention_scaling
262
+ freqs = torch.cat([freqs, freqs], dim=-1)
263
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype), freqs.float()
264
+
265
+
266
+ def rotate_half(x):
267
+ x1 = x[..., : x.shape[-1] // 2]
268
+ x2 = x[..., x.shape[-1] // 2 :]
269
+ return torch.cat((-x2, x1), dim=-1)
270
+
271
+
272
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
273
+ cos = cos.unsqueeze(unsqueeze_dim)
274
+ sin = sin.unsqueeze(unsqueeze_dim)
275
+
276
+ rotary_dim = cos.shape[-1]
277
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
278
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
279
+
280
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
281
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
282
+
283
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
284
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
285
+ return q_embed, k_embed
286
+
287
+
288
+ class BailingMoeV2MLP(nn.Module):
289
+ def __init__(self, config: BailingMoeV2Config, intermediate_size: int):
290
+ super().__init__()
291
+ self.config = config
292
+ self.hidden_size = config.hidden_size
293
+ self.intermediate_size = intermediate_size
294
+
295
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
296
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
297
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
298
+ self.act_fn = ACT2FN[config.hidden_act]
299
+
300
+ def forward(self, x, quantize_inplace_now: bool = False):
301
+ down_weight, gate_weight, up_weight = (
302
+ conditionally_quantize_inplace_on_prefill(
303
+ self.down_proj,
304
+ "weight",
305
+ self.config.quantize,
306
+ quantize_inplace_now=quantize_inplace_now,
307
+ ),
308
+ conditionally_quantize_inplace_on_prefill(
309
+ self.gate_proj,
310
+ "weight",
311
+ self.config.quantize,
312
+ quantize_inplace_now=quantize_inplace_now,
313
+ ),
314
+ conditionally_quantize_inplace_on_prefill(
315
+ self.up_proj,
316
+ "weight",
317
+ self.config.quantize,
318
+ quantize_inplace_now=quantize_inplace_now,
319
+ ),
320
+ )
321
+ return torch.nn.functional.linear(
322
+ self.act_fn(torch.nn.functional.linear(x, gate_weight)) * torch.nn.functional.linear(x, up_weight),
323
+ down_weight,
324
+ )
325
+
326
+
327
+ class BailingMoeV2RMSNorm(nn.Module):
328
+ def __init__(self, hidden_size, eps=1e-6):
329
+ super().__init__()
330
+ self.weight = nn.Parameter(torch.ones(hidden_size))
331
+ self.variance_epsilon = eps
332
+
333
+ def forward(self, hidden_states):
334
+ input_dtype = hidden_states.dtype
335
+ hidden_states = hidden_states.to(torch.float32)
336
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
337
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
338
+ return self.weight * hidden_states.to(input_dtype)
339
+
340
+
341
+ try:
342
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
343
+
344
+ BailingMoeV2RMSNorm = LigerRMSNorm
345
+ except:
346
+ print("no liger kernel")
347
+
348
+
349
+ class BailingMoeV2Gate(nn.Module):
350
+ expert_bias: torch.Tensor
351
+
352
+ def __init__(self, config: BailingMoeV2Config):
353
+ super().__init__()
354
+ self.config = config
355
+ self.top_k = config.num_experts_per_tok
356
+ self.num_experts = config.num_experts
357
+
358
+ self.n_group = config.n_group
359
+ self.topk_group = config.topk_group
360
+
361
+ self.gating_dim = config.hidden_size
362
+ self.weight = nn.Parameter(torch.empty((self.num_experts, self.gating_dim)))
363
+ # self.bias = nn.Parameter(torch.zeros((self.num_experts)))
364
+ self.routed_scaling_factor = config.routed_scaling_factor
365
+
366
+ self.register_buffer("expert_bias", torch.zeros((self.num_experts)))
367
+ self.reset_parameters()
368
+
369
+ def reset_parameters(self) -> None:
370
+ import torch.nn.init as init
371
+
372
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
373
+
374
+ def group_limited_topk(self, scores: torch.Tensor):
375
+ num_tokens, _ = scores.size()
376
+ group_scores = scores.view(num_tokens, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
377
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
378
+ group_mask = torch.zeros_like(group_scores)
379
+ group_mask.scatter_(1, group_idx, 1)
380
+
381
+ score_mask = (
382
+ group_mask.unsqueeze(-1)
383
+ .expand(num_tokens, self.n_group, self.num_experts // self.n_group)
384
+ .reshape(num_tokens, -1)
385
+ )
386
+
387
+ masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf"))
388
+ probs, top_indices = torch.topk(masked_scores, k=self.top_k, dim=-1)
389
+ return probs, top_indices
390
+
391
+ def forward(self, hidden_states: torch.Tensor):
392
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
393
+ logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
394
+
395
+ scores = torch.sigmoid(logits.float()).type_as(logits)
396
+ scores_for_routing = scores + self.expert_bias
397
+ _, topk_idx = self.group_limited_topk(scores_for_routing)
398
+
399
+ scores = torch.gather(scores, dim=1, index=topk_idx).type_as(logits)
400
+ topk_weight = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.top_k > 1 else scores
401
+ topk_weight = topk_weight * self.routed_scaling_factor
402
+
403
+ return topk_idx, topk_weight, logits
404
+
405
+
406
+ class BailingMoeV2SparseMoeBlock(nn.Module):
407
+ """
408
+ Unfused MoE block matching Ling-mini HF layout (ModuleList experts).
409
+ """
410
+
411
+ def __init__(self, config) -> None:
412
+ super().__init__()
413
+ self.config = config
414
+ self.num_experts_per_tok = config.num_experts_per_tok
415
+ self._setup_experts()
416
+ self.gate = BailingMoeV2Gate(config)
417
+ if config.num_shared_experts is not None:
418
+ self.shared_experts = BailingMoeV2MLP(
419
+ config=config,
420
+ intermediate_size=config.moe_intermediate_size * config.num_shared_experts,
421
+ )
422
+
423
+ def _setup_experts(self):
424
+ self.experts = nn.ModuleList(
425
+ [
426
+ BailingMoeV2MLP(
427
+ config=self.config,
428
+ intermediate_size=self.config.moe_intermediate_size,
429
+ )
430
+ for _ in range(self.config.num_experts)
431
+ ]
432
+ )
433
+
434
+ def forward(
435
+ self, hidden_states: torch.Tensor, quantize_inplace_now: bool = False
436
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
437
+ original_shape = hidden_states.shape
438
+ identity = hidden_states
439
+
440
+ bsz, seq_len, h = hidden_states.shape
441
+ topk_idx, topk_weight, router_logits = self.gate(hidden_states)
442
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
443
+ flat_topk_idx = topk_idx.view(-1)
444
+
445
+ if self.training:
446
+ hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0)
447
+ y = torch.empty_like(hidden_states)
448
+ for i, expert in enumerate(self.experts):
449
+ y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
450
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
451
+ y = y.to(hidden_states.dtype).view(bsz, seq_len, h)
452
+ else:
453
+ y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(bsz, seq_len, h)
454
+
455
+ if self.config.num_shared_experts is not None:
456
+ y = y + self.shared_experts(identity)
457
+
458
+ return y
459
+
460
+ @torch.no_grad()
461
+ def moe_infer(self, x, topk_ids, topk_weight):
462
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
463
+ cnts.scatter_(1, topk_ids, 1)
464
+ tokens_per_expert = cnts.sum(dim=0)
465
+ idxs = topk_ids.view(-1).argsort()
466
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
467
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
468
+ outputs = []
469
+ start_idx = 0
470
+ for i, num_tokens in enumerate(tokens_per_expert):
471
+ end_idx = start_idx + num_tokens
472
+ if num_tokens == 0:
473
+ continue
474
+ expert = self.experts[i]
475
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
476
+ expert_out = expert(tokens_for_this_expert)
477
+ outputs.append(expert_out.to(x.device))
478
+ start_idx = end_idx
479
+
480
+ outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
481
+ new_x = torch.empty_like(outs)
482
+ new_x[idxs] = outs
483
+ final_out = (
484
+ new_x.view(*topk_ids.shape, -1)
485
+ .type(topk_weight.dtype)
486
+ .mul_(topk_weight.unsqueeze(dim=-1))
487
+ .sum(dim=1)
488
+ .type(new_x.dtype)
489
+ )
490
+ return final_out
491
+
492
+
493
+ class BailingMoeV2Attention(nn.Module):
494
+ """Fixed wiring for modern HF attention APIs: uses prepared causal_mask + cache_position + Cache.update()."""
495
+
496
+ def __init__(self, config: BailingMoeV2Config, layer_idx: Optional[int] = None):
497
+ super().__init__()
498
+ self.config = config
499
+ self.layer_idx = layer_idx
500
+ if layer_idx is None:
501
+ logger.warning_once(
502
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
503
+ "lead to errors during the forward call if caching is used. Please pass `layer_idx`."
504
+ )
505
+
506
+ self.attention_dropout = config.attention_dropout
507
+ self.hidden_size = config.hidden_size
508
+ self.num_heads = config.num_attention_heads
509
+ self.head_dim = config.head_dim or self.hidden_size // self.num_heads
510
+ self.scaling = self.head_dim**-0.5
511
+
512
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
513
+ self.rope_dim = int(self.head_dim * partial_rotary_factor)
514
+
515
+ self.num_key_value_heads = config.num_key_value_heads
516
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
517
+ self.is_causal = True
518
+
519
+ self.sliding_window = None
520
+ self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
521
+ self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
522
+
523
+ self.query_key_value = nn.Linear(
524
+ self.hidden_size,
525
+ (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
526
+ bias=config.use_qkv_bias,
527
+ )
528
+
529
+ if self.config.use_qk_norm:
530
+ self.query_layernorm = BailingMoeV2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
531
+ self.key_layernorm = BailingMoeV2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
532
+
533
+ self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias)
534
+
535
+ def forward(
536
+ self,
537
+ hidden_states: torch.Tensor,
538
+ attention_mask: Optional[torch.Tensor] = None, # IMPORTANT: pass prepared causal_mask here
539
+ position_ids: Optional[torch.LongTensor] = None,
540
+ past_key_value: Optional[Cache] = None,
541
+ output_attentions: bool = False,
542
+ use_cache: bool = False,
543
+ cache_position: Optional[torch.LongTensor] = None, # IMPORTANT: needed for modern cache update
544
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None,
545
+ **kwargs,
546
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
547
+ quantize_inplace_now = bool(kwargs.pop("quantize_inplace_now", False))
548
+ bsz, q_len, _ = hidden_states.size()
549
+ qkv_weight = conditionally_quantize_inplace_on_prefill(
550
+ self.query_key_value,
551
+ "weight",
552
+ self.config.quantize,
553
+ quantize_inplace_now=quantize_inplace_now,
554
+ )
555
+ out_qkv = torch.nn.functional.linear(hidden_states, qkv_weight)
556
+ cos, sin, _freqs = position_embeddings
557
+ # fused path
558
+ # if self.sliding_window is not None:
559
+ # query_states, key_states, value_states = functional_fused_split_transpose_rope_qknorm(
560
+ # out_qkv,
561
+ # self.query_layernorm.weight,
562
+ # self.key_layernorm.weight,
563
+ # _freqs.contiguous(),
564
+ # self.config.num_attention_heads,
565
+ # self.config.num_key_value_heads,
566
+ # )
567
+ # else:
568
+ # query_states, key_states, value_states = functional_fused_split_transpose_qknorm(
569
+ # out_qkv,
570
+ # self.query_layernorm.weight,
571
+ # self.key_layernorm.weight,
572
+ # _freqs.contiguous(),
573
+ # self.config.num_attention_heads,
574
+ # self.config.num_key_value_heads,
575
+ # )
576
+ qkv = out_qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
577
+
578
+ query_states, key_states, value_states = qkv.split(
579
+ [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2
580
+ )
581
+ query_states = query_states.transpose(1, 2)
582
+ key_states = key_states.transpose(1, 2)
583
+ value_states = value_states.transpose(1, 2)
584
+
585
+ if self.config.use_qk_norm:
586
+ query_states = self.query_layernorm(query_states)
587
+ key_states = self.key_layernorm(key_states)
588
+ if self.sliding_window is not None:
589
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
590
+
591
+ # ---- Modern Cache update wiring (DynamicCache / StaticCache compatible) ----
592
+ if use_cache and past_key_value is not None:
593
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
594
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
595
+
596
+ # fa should transpose internally
597
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
598
+ attn_output, attn_weights = attention_interface(
599
+ self,
600
+ query_states,
601
+ key_states,
602
+ value_states,
603
+ attention_mask, # prepared causal mask (or None for varlen flash path)
604
+ dropout=0.0,
605
+ position_ids=position_ids,
606
+ scaling=self.scaling,
607
+ sliding_window=self.sliding_window, # keep your prototype behavior
608
+ **kwargs,
609
+ )
610
+
611
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
612
+ dense_weight = conditionally_quantize_inplace_on_prefill(
613
+ self.dense,
614
+ "weight",
615
+ self.config.quantize,
616
+ quantize_inplace_now=quantize_inplace_now,
617
+ )
618
+ attn_output = torch.nn.functional.linear(attn_output, dense_weight)
619
+
620
+ if not output_attentions:
621
+ attn_weights = None
622
+
623
+ return attn_output, attn_weights, past_key_value
624
+
625
+
626
+ class BailingMoeV2DecoderLayer(nn.Module):
627
+ def __init__(self, config: BailingMoeV2Config, layer_idx: int):
628
+ super().__init__()
629
+ self.hidden_size = config.hidden_size
630
+ self.layer_idx = layer_idx
631
+
632
+ self.attention = BailingMoeV2Attention(config=config, layer_idx=layer_idx)
633
+
634
+ self.mlp = (
635
+ BailingMoeV2SparseMoeBlock(config)
636
+ if (config.num_experts is not None and layer_idx >= config.first_k_dense_replace)
637
+ else BailingMoeV2MLP(config=config, intermediate_size=config.intermediate_size)
638
+ )
639
+
640
+ self.input_layernorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
641
+ self.post_attention_layernorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
642
+
643
+ def forward(
644
+ self,
645
+ hidden_states: torch.Tensor,
646
+ attention_mask: Optional[torch.Tensor] = None, # prepared causal mask
647
+ position_ids: Optional[torch.LongTensor] = None,
648
+ past_key_value: Optional[Cache] = None,
649
+ output_attentions: Optional[bool] = False,
650
+ output_router_logits: Optional[bool] = False, # your MOE doesn't return router logits; kept for API
651
+ use_cache: Optional[bool] = False,
652
+ cache_position: Optional[torch.LongTensor] = None,
653
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None,
654
+ **kwargs,
655
+ ) -> Tuple[
656
+ torch.Tensor,
657
+ Optional[torch.Tensor],
658
+ Optional[Cache],
659
+ torch.Tensor,
660
+ Optional[torch.Tensor],
661
+ ]:
662
+ quantize_inplace_now = bool(kwargs.get("quantize_inplace_now", False))
663
+ residual = hidden_states
664
+ hidden_states = self.input_layernorm(hidden_states)
665
+
666
+ attn_out, self_attn_weights, present_key_value = self.attention(
667
+ hidden_states=hidden_states,
668
+ attention_mask=attention_mask,
669
+ position_ids=position_ids,
670
+ past_key_value=past_key_value,
671
+ output_attentions=bool(output_attentions),
672
+ use_cache=bool(use_cache),
673
+ cache_position=cache_position,
674
+ position_embeddings=position_embeddings,
675
+ **kwargs,
676
+ )
677
+ hidden_states = residual + attn_out
678
+
679
+ residual = hidden_states
680
+ hidden_states = self.post_attention_layernorm(hidden_states)
681
+
682
+ mlp_out = self.mlp(hidden_states, quantize_inplace_now=quantize_inplace_now)
683
+ if isinstance(mlp_out, tuple):
684
+ hidden_states, aux_loss = mlp_out
685
+ else:
686
+ hidden_states, aux_loss = mlp_out, 0.0
687
+
688
+ hidden_states = residual + hidden_states.to(residual.device)
689
+
690
+ # Your MOE path does not provide router logits; keep placeholder.
691
+ router_logits = None
692
+
693
+ return (
694
+ hidden_states,
695
+ self_attn_weights,
696
+ present_key_value,
697
+ aux_loss,
698
+ router_logits,
699
+ )
700
+
701
+
702
+ BAILINGMOEV2_START_DOCSTRING = r"""
703
+ This model inherits from [`PreTrainedModel`].
704
+ """
705
+
706
+
707
+ @add_start_docstrings(
708
+ "The bare BailingMoeV2 Model outputting raw hidden-states without any specific head on top.",
709
+ BAILINGMOEV2_START_DOCSTRING,
710
+ )
711
+ class BailingMoeV2PreTrainedModel(PreTrainedModel):
712
+ config_class = BailingMoeV2Config
713
+ base_model_prefix = "model"
714
+ supports_gradient_checkpointing = True
715
+ _no_split_modules = ["BailingMoeV2DecoderLayer"]
716
+ _skip_keys_device_placement = "past_key_values"
717
+ _supports_attention_backend = True
718
+ _supports_flash_attn_2 = True
719
+ _supports_sdpa = True
720
+ _supports_cache_class = True
721
+
722
+ def _init_weights(self, module):
723
+ std = self.config.initializer_range
724
+ if isinstance(module, nn.Linear):
725
+ module.weight.data.normal_(mean=0.0, std=std)
726
+ if module.bias is not None:
727
+ module.bias.data.zero_()
728
+ elif isinstance(module, nn.Embedding):
729
+ module.weight.data.normal_(mean=0.0, std=std)
730
+ if module.padding_idx is not None:
731
+ module.weight.data[module.padding_idx].zero_()
732
+
733
+
734
+ BAILINGMOEV2_INPUTS_DOCSTRING = r"""NA"""
735
+
736
+
737
+ @add_start_docstrings(
738
+ "The bare BailingMoeV2 Model outputting raw hidden-states without any specific head on top.",
739
+ BAILINGMOEV2_START_DOCSTRING,
740
+ )
741
+ class BailingMoeV2Model(BailingMoeV2PreTrainedModel):
742
+ def __init__(self, config: BailingMoeV2Config):
743
+ super().__init__(config)
744
+ self.padding_idx = config.pad_token_id
745
+ self.vocab_size = config.vocab_size
746
+ self.num_nextn_predict_layers = getattr(config, "num_nextn_predict_layers", 0)
747
+
748
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
749
+
750
+ layers = []
751
+ for layer_idx in range(config.num_hidden_layers + self.num_nextn_predict_layers):
752
+ # NOTE: your prototype referenced BailingMoeV2MTPLayer but didn't include it.
753
+ # Keep behavior: only decoder layers here unless you add MTP layers yourself.
754
+ if layer_idx < config.num_hidden_layers:
755
+ layers.append(BailingMoeV2DecoderLayer(config, layer_idx))
756
+ else:
757
+ raise NotImplementedError("BailingMoeV2MTPLayer not included in this prototype file.")
758
+ self.layers = nn.ModuleList(layers)
759
+ self.config = config
760
+
761
+ self.norm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
762
+ self.rotary_emb = BailingMoeV2RotaryEmbedding(config=config)
763
+ self.gradient_checkpointing = False
764
+ self._cache_debug_calls = 0
765
+ self.post_init()
766
+
767
+ def get_input_embeddings(self):
768
+ return self.word_embeddings
769
+
770
+ def set_input_embeddings(self, value):
771
+ self.word_embeddings = value
772
+
773
+ def quantize_inplace(self, verbose: bool = False) -> int:
774
+ """
775
+ Quantize this base model in-place once (for inference).
776
+ Returns the number of tensors newly quantized in this call.
777
+ """
778
+ if not self.config.quantize:
779
+ if verbose:
780
+ print("[quantize-inplace] config.quantize is False; nothing to do")
781
+ return 0
782
+
783
+ quantized_count = 0
784
+ for layer in self.layers:
785
+ # Attention projections.
786
+ quantized_count += int(quantize_weight_inplace(layer.attention.query_key_value, "weight", enabled=True))
787
+ quantized_count += int(quantize_weight_inplace(layer.attention.dense, "weight", enabled=True))
788
+
789
+ # MLP path: dense MLP or MoE experts (+ optional shared experts).
790
+ if isinstance(layer.mlp, BailingMoeV2MLP):
791
+ quantized_count += int(quantize_weight_inplace(layer.mlp.down_proj, "weight", enabled=True))
792
+ quantized_count += int(quantize_weight_inplace(layer.mlp.gate_proj, "weight", enabled=True))
793
+ quantized_count += int(quantize_weight_inplace(layer.mlp.up_proj, "weight", enabled=True))
794
+ elif isinstance(layer.mlp, LingSonicMoe):
795
+ quantized_count += int(quantize_weight_inplace(layer.mlp.experts, "gate_up_proj", enabled=True))
796
+ quantized_count += int(quantize_weight_inplace(layer.mlp.experts, "down_proj", enabled=True))
797
+ if hasattr(layer.mlp, "shared_experts"):
798
+ quantized_count += int(
799
+ quantize_weight_inplace(layer.mlp.shared_experts.down_proj, "weight", enabled=True)
800
+ )
801
+ quantized_count += int(
802
+ quantize_weight_inplace(layer.mlp.shared_experts.gate_proj, "weight", enabled=True)
803
+ )
804
+ quantized_count += int(
805
+ quantize_weight_inplace(layer.mlp.shared_experts.up_proj, "weight", enabled=True)
806
+ )
807
+
808
+ if verbose:
809
+ print(f"[quantize-inplace] newly quantized tensors: {quantized_count}")
810
+ return quantized_count
811
+
812
+ def prepare_fa2_from_position_ids(self, position_ids: torch.Tensor):
813
+ position_ids = position_ids.flatten()
814
+ T = position_ids.numel()
815
+ indices_q = torch.arange(T, device=position_ids.device, dtype=torch.int32)
816
+
817
+ starts = indices_q[position_ids == 0]
818
+
819
+ # If no segment-start markers exist (common in decoding where pos ids are offset),
820
+ # treat as a single sequence.
821
+ if starts.numel() == 0:
822
+ cu_seq_lens = torch.tensor([0, T], device=position_ids.device, dtype=torch.int32)
823
+ else:
824
+ # ensure boundaries valid
825
+ if starts[0].item() != 0:
826
+ starts = torch.cat([starts.new_zeros(1), starts], dim=0)
827
+ if starts[-1].item() != T:
828
+ starts = torch.cat([starts, starts.new_tensor([T])], dim=0)
829
+ cu_seq_lens = starts
830
+
831
+ max_length = (cu_seq_lens[1:] - cu_seq_lens[:-1]).max().item()
832
+ return (indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
833
+
834
+ # def prepare_fa2_from_position_ids(self, position_ids: torch.Tensor):
835
+ # position_ids = position_ids.flatten()
836
+ # indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
837
+
838
+ # cu_seq_lens = torch.cat(
839
+ # (
840
+ # indices_q[position_ids == 0],
841
+ # torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
842
+ # )
843
+ # )
844
+
845
+ # # max_length在不同的model里面type不同
846
+ # # modeling_qwen3_moe_foundation/modeling_qwen2_5_omni里为tensor
847
+ # # modeling_qwen2_vl的为int
848
+ # # 此处采用有.item()的写法,在decoder layers之前拿到int type的max_length
849
+ # # 否则在decoder里面仍然每一层都会触发.item()
850
+ # max_length = cu_seq_lens.diff().max().item()
851
+
852
+ # return (indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
853
+
854
+ @add_start_docstrings_to_model_forward(BAILINGMOEV2_INPUTS_DOCSTRING)
855
+ def forward(
856
+ self,
857
+ input_ids: torch.LongTensor = None,
858
+ attention_mask: Optional[torch.Tensor] = None, # 2D padding mask (B, S) coming in
859
+ position_ids: Optional[torch.LongTensor] = None,
860
+ past_key_values: Optional[Cache] = None,
861
+ inputs_embeds: Optional[torch.FloatTensor] = None,
862
+ use_cache: Optional[bool] = None,
863
+ output_attentions: Optional[bool] = None,
864
+ output_hidden_states: Optional[bool] = None,
865
+ output_router_logits: Optional[bool] = None,
866
+ return_dict: Optional[bool] = None,
867
+ cache_position: Optional[torch.LongTensor] = None,
868
+ **kwargs,
869
+ ) -> Union[Tuple, MoeV2ModelOutputWithPast]:
870
+ debug_cache = bool(kwargs.pop("debug_cache", False))
871
+ if debug_cache:
872
+ print(f"Debug cache enabled for call {self._cache_debug_calls}")
873
+ debug_call_id = self._cache_debug_calls
874
+ self._cache_debug_calls += 1
875
+
876
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
877
+ output_router_logits = (
878
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
879
+ )
880
+ output_hidden_states = (
881
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
882
+ )
883
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
884
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
885
+
886
+ # exactly one of input_ids / inputs_embeds
887
+ if (input_ids is None) == (inputs_embeds is None):
888
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
889
+
890
+ if self.gradient_checkpointing and self.training and use_cache:
891
+ logger.warning_once(
892
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
893
+ )
894
+ use_cache = False
895
+
896
+ if use_cache and past_key_values is None:
897
+ past_key_values = DynamicCache()
898
+
899
+ if inputs_embeds is None:
900
+ inputs_embeds = self.word_embeddings(input_ids)
901
+
902
+ # SGLang transformers backend passes `forward_batch`; use it to identify
903
+ # decode mode (can have S>1 tokens due to token packing) and avoid
904
+ # decode-only dynamic metadata that harms CUDA graph capture.
905
+ forward_batch = kwargs.get("forward_batch", None)
906
+ is_decode_step = False
907
+ forward_mode = getattr(forward_batch, "forward_mode", None) if forward_batch is not None else None
908
+ if forward_mode is not None:
909
+ for mode_name in (
910
+ "is_decode",
911
+ "is_decode_or_idle",
912
+ "is_target_verify",
913
+ "is_draft_decode",
914
+ ):
915
+ mode_fn = getattr(forward_mode, mode_name, None)
916
+ if callable(mode_fn) and bool(mode_fn()):
917
+ is_decode_step = True
918
+ break
919
+
920
+ # Perform one-time in-place weight quantization during prefill (S > 1),
921
+ # then reuse the mutated weights for decode without extra memory cache.
922
+ kwargs["quantize_inplace_now"] = bool(
923
+ self.config.quantize and (not self.training) and (not is_decode_step) and inputs_embeds.shape[1] > 1
924
+ )
925
+
926
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
927
+
928
+ if cache_position is None:
929
+ cache_position = torch.arange(
930
+ past_seen_tokens,
931
+ past_seen_tokens + inputs_embeds.shape[1],
932
+ device=inputs_embeds.device,
933
+ )
934
+
935
+ if position_ids is not None:
936
+ # For bsh cases, expand [1, S] position_ids to [B, S] before FA2 metadata prep.
937
+ batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
938
+ if position_ids.shape[0] != batch_size:
939
+ position_ids = position_ids.expand(batch_size, -1)
940
+
941
+ # Decode does not need cu_seq_lens/max_length metadata and creating
942
+ # them every step hurts CUDA graph capture stability.
943
+ if (not is_decode_step) and inputs_embeds.shape[1] > 1:
944
+ _, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = self.prepare_fa2_from_position_ids(
945
+ position_ids
946
+ )
947
+ kwargs["cu_seq_lens_q"] = cu_seq_lens_q
948
+ kwargs["cu_seq_lens_k"] = cu_seq_lens_k
949
+ kwargs["max_length_q"] = max_length_q
950
+ kwargs["max_length_k"] = max_length_k
951
+
952
+ if position_ids is None:
953
+ position_ids = cache_position.unsqueeze(0)
954
+
955
+ # IMPORTANT: build prepared causal_mask and pass it into layers (NOT raw attention_mask)
956
+ # mask_function = create_causal_mask # swap to create_sliding_window_causal_mask if you enable sliding window
957
+ # causal_mask = mask_function(
958
+ # config=self.config,
959
+ # input_embeds=inputs_embeds,
960
+ # attention_mask=attention_mask,
961
+ # cache_position=cache_position,
962
+ # past_key_values=past_key_values,
963
+ # position_ids=position_ids,
964
+ # )
965
+ # TODO: Im just disabling causal mask right now idk fix this later when we need SWA
966
+ causal_mask = None
967
+
968
+ hidden_states = inputs_embeds
969
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
970
+ # if self.config.hc:
971
+ # hidden_states = self.expand_streams(hidden_states)
972
+
973
+ all_hidden_states = () if output_hidden_states else None
974
+ all_self_attns = () if output_attentions else None
975
+ all_router_logits = () if output_router_logits else None
976
+
977
+ aux_loss_sum = 0.0
978
+
979
+ for decoder_layer in self.layers:
980
+ if output_hidden_states:
981
+ all_hidden_states += (hidden_states,)
982
+
983
+ if self.gradient_checkpointing and self.training:
984
+ layer_outputs = self._gradient_checkpointing_func(
985
+ decoder_layer.__call__,
986
+ hidden_states,
987
+ causal_mask,
988
+ position_ids,
989
+ past_key_values,
990
+ output_attentions,
991
+ output_router_logits,
992
+ use_cache,
993
+ cache_position,
994
+ position_embeddings,
995
+ **kwargs,
996
+ )
997
+ else:
998
+ layer_outputs = decoder_layer(
999
+ hidden_states,
1000
+ attention_mask=causal_mask, # <-- FIXED
1001
+ position_ids=position_ids,
1002
+ past_key_value=past_key_values,
1003
+ output_attentions=output_attentions,
1004
+ output_router_logits=output_router_logits,
1005
+ use_cache=use_cache,
1006
+ cache_position=cache_position, # <-- FIXED
1007
+ position_embeddings=position_embeddings,
1008
+ **kwargs,
1009
+ )
1010
+
1011
+ hidden_states = layer_outputs[0]
1012
+
1013
+ if output_attentions:
1014
+ all_self_attns += (layer_outputs[1],)
1015
+
1016
+ # aux loss is at index 3 in our layer return
1017
+ aux_loss_sum = aux_loss_sum + layer_outputs[3]
1018
+
1019
+ if output_router_logits:
1020
+ all_router_logits += (layer_outputs[4],)
1021
+
1022
+ hidden_states = self.norm(hidden_states)
1023
+
1024
+ if debug_cache:
1025
+ past_after_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1026
+ cache_start = int(cache_position[0].item()) if cache_position.numel() > 0 else -1
1027
+ cache_end = int(cache_position[-1].item()) if cache_position.numel() > 0 else -1
1028
+ cache_hit = bool(use_cache and inputs_embeds.shape[1] == 1 and past_seen_tokens > 0)
1029
+ print(
1030
+ "[cache-debug] "
1031
+ f"call={debug_call_id} use_cache={use_cache} "
1032
+ f"input_len={inputs_embeds.shape[1]} "
1033
+ f"past_before={past_seen_tokens} past_after={past_after_tokens} "
1034
+ f"cache_pos=[{cache_start},{cache_end}] "
1035
+ f"cache_hit_expected={cache_hit}"
1036
+ )
1037
+
1038
+ if output_hidden_states:
1039
+ all_hidden_states += (hidden_states,)
1040
+ moe_layer_count = len(self.layers) - 1
1041
+ out = MoeV2ModelOutputWithPast(
1042
+ last_hidden_state=hidden_states,
1043
+ past_key_values=past_key_values if use_cache else None,
1044
+ hidden_states=all_hidden_states,
1045
+ attentions=all_self_attns,
1046
+ router_logits=all_router_logits,
1047
+ aux_loss=aux_loss_sum / moe_layer_count, # keeping your prototype behavior
1048
+ )
1049
+ return (
1050
+ out
1051
+ if return_dict
1052
+ else (
1053
+ out.last_hidden_state,
1054
+ out.past_key_values,
1055
+ out.hidden_states,
1056
+ out.attentions,
1057
+ )
1058
+ )
1059
+
1060
+
1061
+ class BailingMoeV2ForCausalLM(BailingMoeV2PreTrainedModel, GenerationMixin):
1062
+ _tied_weights_keys = ["lm_head.weight"]
1063
+
1064
+ def __init__(self, config: BailingMoeV2Config):
1065
+ super().__init__(config)
1066
+ self.model = BailingMoeV2Model(config)
1067
+ self.vocab_size = config.vocab_size
1068
+
1069
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1070
+ self.router_aux_loss_coef = 0.001
1071
+ self.post_init()
1072
+
1073
+ def get_input_embeddings(self):
1074
+ return self.model.word_embeddings
1075
+
1076
+ def set_input_embeddings(self, value):
1077
+ self.model.word_embeddings = value
1078
+
1079
+ def get_output_embeddings(self):
1080
+ return self.lm_head
1081
+
1082
+ def set_output_embeddings(self, new_embeddings):
1083
+ self.lm_head = new_embeddings
1084
+
1085
+ def set_decoder(self, decoder):
1086
+ self.model = decoder
1087
+
1088
+ def get_decoder(self):
1089
+ return self.model
1090
+
1091
+ def quantize_inplace(self, verbose: bool = False) -> int:
1092
+ """
1093
+ Quantize model (and lm_head) in-place once for inference.
1094
+ Returns the number of tensors newly quantized in this call.
1095
+ """
1096
+ quantized_count = self.model.quantize_inplace(verbose=verbose)
1097
+ if self.config.quantize:
1098
+ quantized_count += int(quantize_weight_inplace(self.lm_head, "weight", enabled=True))
1099
+ if verbose:
1100
+ print(f"[quantize-inplace] total newly quantized tensors (with lm_head): {quantized_count}")
1101
+ return quantized_count
1102
+
1103
+ # @add_start_docstrings_to_model_forward(BAILINGMOEV2_INPUTS_DOCSTRING)
1104
+ # @replace_return_docstrings(output_type=MoEV2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1105
+ def forward(
1106
+ self,
1107
+ input_ids: torch.LongTensor = None,
1108
+ attention_mask: Optional[torch.Tensor] = None,
1109
+ position_ids: Optional[torch.LongTensor] = None,
1110
+ past_key_values: Optional[Cache] = None,
1111
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1112
+ labels: Optional[torch.Tensor] = None,
1113
+ use_cache: Optional[bool] = None,
1114
+ output_attentions: Optional[bool] = None,
1115
+ output_hidden_states: Optional[bool] = None,
1116
+ output_router_logits: Optional[bool] = None,
1117
+ return_dict: Optional[bool] = None,
1118
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1119
+ **kwargs,
1120
+ ) -> Union[Tuple, MoEV2CausalLMOutputWithPast]:
1121
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1122
+ output_hidden_states = (
1123
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1124
+ )
1125
+ output_router_logits = (
1126
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1127
+ )
1128
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1129
+
1130
+ outputs = self.model(
1131
+ input_ids=input_ids,
1132
+ attention_mask=attention_mask,
1133
+ position_ids=position_ids,
1134
+ past_key_values=past_key_values,
1135
+ inputs_embeds=inputs_embeds,
1136
+ use_cache=use_cache,
1137
+ output_attentions=output_attentions,
1138
+ output_hidden_states=output_hidden_states,
1139
+ output_router_logits=output_router_logits,
1140
+ return_dict=True, # ensure attribute access
1141
+ **kwargs,
1142
+ )
1143
+
1144
+ hidden_states = outputs.last_hidden_state
1145
+ assert isinstance(hidden_states, torch.Tensor)
1146
+
1147
+ # slice logits if requested
1148
+ loss = None
1149
+ logits = None
1150
+ if labels is not None:
1151
+ loss, logits = self.loss_function(hidden_states, self.lm_head.weight, labels)
1152
+ else:
1153
+ logits = self.lm_head(hidden_states)
1154
+ out = MoEV2CausalLMOutputWithPast(
1155
+ loss=loss,
1156
+ aux_loss=getattr(outputs, "aux_loss", 0.0),
1157
+ logits=logits,
1158
+ past_key_values=outputs.past_key_values if hasattr(outputs, "past_key_values") else None,
1159
+ hidden_states=outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
1160
+ attentions=outputs.attentions if hasattr(outputs, "attentions") else None,
1161
+ router_logits=outputs.router_logits if hasattr(outputs, "router_logits") else None,
1162
+ )
1163
+ return out
1164
+
1165
+
1166
+ ModelClass = BailingMoeV2ForCausalLM
1167
+
1168
+ __all__ = [
1169
+ "BailingMoeV2ForCausalLM",
1170
+ "BailingMoeV2Model",
1171
+ "BailingMoeV2PreTrainedModel",
1172
+ ]
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
state_dict_shape.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model.word_embeddings.weight": [[151936, 2048], "torch.bfloat16"], "model.layers.0.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.0.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.0.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.0.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.0.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.0.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.0.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.0.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.0.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.0.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.1.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.1.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.1.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.1.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.1.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.1.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.1.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.1.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.1.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.1.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.2.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.2.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.2.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.2.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.2.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.2.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.2.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.2.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.2.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.2.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.3.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.3.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.3.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.3.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.3.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.3.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.3.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.3.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.3.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.3.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.4.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.4.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.4.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.4.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.4.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.4.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.4.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.4.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.4.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.4.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.5.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.5.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.5.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.5.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.5.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.5.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.5.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.5.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.5.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.5.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.6.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.6.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.6.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.6.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.6.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.6.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.6.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.6.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.6.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.6.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.7.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.7.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.7.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.7.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.7.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.7.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.7.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.7.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.7.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.7.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.8.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.8.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.8.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.8.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.8.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.8.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.8.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.8.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.8.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.8.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.9.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.9.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.9.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.9.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.9.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.9.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.9.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.9.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.9.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.9.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.10.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.10.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.10.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.10.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.10.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.10.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.10.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.10.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.10.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.10.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.11.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.11.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.11.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.11.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.11.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.11.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.11.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.11.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.11.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.11.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.12.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.12.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.12.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.12.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.12.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.12.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.12.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.12.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.12.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.12.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.13.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.13.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.13.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.13.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.13.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.13.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.13.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.13.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.13.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.13.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.14.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.14.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.14.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.14.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.14.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.14.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.14.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.14.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.14.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.14.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.15.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.15.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.15.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.15.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.15.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.15.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.15.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.15.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.15.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.15.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.16.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.16.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.16.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.16.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.16.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.16.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.16.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.16.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.16.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.16.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.17.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.17.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.17.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.17.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.17.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.17.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.17.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.17.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.17.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.17.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.18.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.18.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.18.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.18.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.18.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.18.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.18.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.18.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.18.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.18.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.19.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.19.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.19.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.19.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.19.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.19.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.19.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.19.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.19.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.19.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.20.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.20.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.20.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.20.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.20.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.20.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.20.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.20.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.20.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.20.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.21.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.21.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.21.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.21.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.21.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.21.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.21.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.21.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.21.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.21.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.22.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.22.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.22.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.22.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.22.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.22.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.22.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.22.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.22.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.22.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.23.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.23.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.23.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.23.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.23.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.23.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.23.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.23.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.23.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.23.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.24.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.24.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.24.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.24.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.24.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.24.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.24.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.24.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.24.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.24.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.25.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.25.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.25.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.25.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.25.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.25.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.25.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.25.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.25.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.25.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.26.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.26.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.26.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.26.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.26.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.26.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.26.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.26.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.26.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.26.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.27.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.27.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.27.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.27.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.27.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.27.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.27.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.27.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.27.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.27.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.28.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.28.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.28.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.28.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.28.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.28.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.28.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.28.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.28.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.28.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.29.attention.query_key_value.weight": [[3072, 2048], "torch.bfloat16"], "model.layers.29.attention.query_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.29.attention.key_layernorm.weight": [[128], "torch.bfloat16"], "model.layers.29.attention.dense.weight": [[2048, 2048], "torch.bfloat16"], "model.layers.29.mlp.gate.weight": [[224, 2048], "torch.bfloat16"], "model.layers.29.mlp.gate.expert_bias": [[224], "torch.bfloat16"], "model.layers.29.mlp.experts.gate_up_proj": [[224, 1024, 2048], "torch.bfloat16"], "model.layers.29.mlp.experts.down_proj": [[224, 2048, 512], "torch.bfloat16"], "model.layers.29.input_layernorm.weight": [[2048], "torch.bfloat16"], "model.layers.29.post_attention_layernorm.weight": [[2048], "torch.bfloat16"], "model.norm.weight": [[2048], "torch.bfloat16"], "lm_head.weight": [[151936, 2048], "torch.bfloat16"]}
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
3
+ size 11422654
tokenizer_config.json ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 1010000,
235
+ "pad_token": "<|endoftext|>",
236
+ "padding_side": "right",
237
+ "split_special_tokens": false,
238
+ "tokenizer_class": "Qwen2Tokenizer",
239
+ "unk_token": null
240
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff