Firworks commited on
Commit
2d48701
·
verified ·
1 Parent(s): fd0a929

Add NVFP4 quantized checkpoint

Browse files
README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - Rombo-Org/Optimized_Reasoning
4
+ base_model:
5
+ - stepfun-ai/Step-3.5-Flash
6
+ tags:
7
+ - nvfp4
8
+ - fp4
9
+ - quantized
10
+ ---
11
+ # Step-3.5-Flash-nvfp4
12
+
13
+ **Format:** NVFP4 — weights & activations quantized to FP4 with dual scaling.
14
+ **Base model:** `stepfun-ai/Step-3.5-Flash`
15
+ **How it was made:** One-shot calibration with LLM Compressor (NVFP4 recipe), long-seq calibration (1 samples of 512 length) with Rombo-Org/Optimized_Reasoning.
16
+
17
+ > Notes: Keep `lm_head` in high precision; calibrate on long, domain-relevant sequences.
18
+
19
+ Check the original model card for information about this model.
20
+
21
+ # Running the model with VLLM in Docker
22
+ ```sh
23
+ sudo docker run --runtime nvidia --gpus all -p 8000:8000 --ipc=host vllm/vllm-openai:nightly --model Firworks/Step-3.5-Flash-nvfp4 --dtype auto --max-model-len 32768
24
+ ```
25
+ This was tested on an RTX Pro 6000 Blackwell cloud instance.
26
+
27
+ If there are other models you're interested in seeing quantized to NVFP4 for use on the DGX Spark, or other modern Blackwell (or newer) cards let me know. I'm trying to make more NVFP4 models available to allow more people to try them out.
chat_template.jinja ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% macro render_content(content) %}{% if content is none %}{{- '' }}{% elif content is string %}{{- content }}{% elif content is mapping %}{{- content['value'] if 'value' in content else content['text'] }}{% elif content is iterable %}{% for item in content %}{% if item.type == 'text' %}{{- item['value'] if 'value' in item else item['text'] }}{% elif item.type == 'image' %}<im_patch>{% endif %}{% endfor %}{% endif %}{% endmacro %}
2
+ {{bos_token}}{%- if tools %}
3
+ {{- '<|im_start|>system\n' }}
4
+ {%- if messages[0].role == 'system' %}
5
+ {{- render_content(messages[0].content) + '\n\n' }}
6
+ {%- endif %}
7
+ {{- "# Tools\n\nYou have access to the following functions in JSONSchema format:\n\n<tools>" }}
8
+ {%- for tool in tools %}
9
+ {{- "\n" }}
10
+ {{- tool | tojson(ensure_ascii=False) }}
11
+ {%- endfor %}
12
+ {{- "\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...>\n...\n</function> block must be nested within <tool_call>\n...\n</tool_call> XML tags\n- Required parameters MUST be specified\n</IMPORTANT><|im_end|>\n" }}
13
+ {%- else %}
14
+ {%- if messages[0].role == 'system' %}
15
+ {{- '<|im_start|>system\n' + render_content(messages[0].content) + '<|im_end|>\n' }}
16
+ {%- endif %}
17
+ {%- endif %}
18
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
19
+ {%- for message in messages[::-1] %}
20
+ {%- set index = (messages|length - 1) - loop.index0 %}
21
+ {%- if ns.multi_step_tool and message.role == "user" and render_content(message.content) is string and not(render_content(message.content).startswith('<tool_response>') and render_content(message.content).endswith('</tool_response>')) %}
22
+ {%- set ns.multi_step_tool = false %}
23
+ {%- set ns.last_query_index = index %}
24
+ {%- endif %}
25
+ {%- endfor %}
26
+ {%- for message in messages %}
27
+ {%- set content = render_content(message.content) %}
28
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
29
+ {%- set role_name = 'observation' if (message.role == "system" and not loop.first and message.name == 'observation') else message.role %}
30
+ {{- '<|im_start|>' + role_name + '\n' + content + '<|im_end|>' + '\n' }}
31
+ {%- elif message.role == "assistant" %}
32
+ {%- if message.reasoning_content is string %}
33
+ {%- set reasoning_content = render_content(message.reasoning_content) %}
34
+ {%- else %}
35
+ {%- if '</think>' in content %}
36
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
37
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
38
+ {%- else %}
39
+ {%- set reasoning_content = '' %}
40
+ {%- endif %}
41
+ {%- endif %}
42
+ {%- if loop.index0 > ns.last_query_index %}
43
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n' + content }}
44
+ {%- else %}
45
+ {{- '<|im_start|>' + message.role + '\n' + content }}
46
+ {%- endif %}
47
+ {%- if message.tool_calls %}
48
+ {%- for tool_call in message.tool_calls %}
49
+ {%- if tool_call.function is defined %}
50
+ {%- set tool_call = tool_call.function %}
51
+ {%- endif %}
52
+ {{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
53
+ {%- if tool_call.arguments is defined %}
54
+ {%- set arguments = tool_call.arguments %}
55
+ {%- for args_name, args_value in arguments|items %}
56
+ {{- '<parameter=' + args_name + '>\n' }}
57
+ {%- set args_value = args_value | tojson(ensure_ascii=False) | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
58
+ {{- args_value }}
59
+ {{- '\n</parameter>\n' }}
60
+ {%- endfor %}
61
+ {%- endif %}
62
+ {{- '</function>\n</tool_call>' }}
63
+ {%- endfor %}
64
+ {%- endif %}
65
+ {{- '<|im_end|>\n' }}
66
+ {%- elif message.role == "tool" %}
67
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
68
+ {{- '<|im_start|>tool_response\n' }}
69
+ {%- endif %}
70
+ {{- '<tool_response>' }}
71
+ {{- content }}
72
+ {{- '</tool_response>' }}
73
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
74
+ {{- '<|im_end|>\n' }}
75
+ {%- endif %}
76
+ {%- endif %}
77
+ {%- endfor %}
78
+ {%- if add_generation_prompt %}
79
+ {{- '<|im_start|>assistant\n<think>\n' }}
80
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Step3p5ForCausalLM"
4
+ ],
5
+ "att_impl_type": "GQA",
6
+ "attention_other_setting": {
7
+ "attention_type": "sliding_attention",
8
+ "head_dim": 128,
9
+ "num_attention_groups": 8,
10
+ "num_attention_heads": 96,
11
+ "true_head_dim": 128
12
+ },
13
+ "auto_map": {
14
+ "AutoConfig": "configuration_step3p5.Step3p5Config",
15
+ "AutoModelForCausalLM": "modeling_step3p5.Step3p5ForCausalLM"
16
+ },
17
+ "bos_token_id": 0,
18
+ "dtype": "bfloat16",
19
+ "eos_token_id": [
20
+ 1,
21
+ 2,
22
+ 128007
23
+ ],
24
+ "head_dim": 128,
25
+ "hidden_size": 4096,
26
+ "intermediate_size": 11264,
27
+ "layer_types": [
28
+ "full_attention",
29
+ "sliding_attention",
30
+ "sliding_attention",
31
+ "sliding_attention",
32
+ "full_attention",
33
+ "sliding_attention",
34
+ "sliding_attention",
35
+ "sliding_attention",
36
+ "full_attention",
37
+ "sliding_attention",
38
+ "sliding_attention",
39
+ "sliding_attention",
40
+ "full_attention",
41
+ "sliding_attention",
42
+ "sliding_attention",
43
+ "sliding_attention",
44
+ "full_attention",
45
+ "sliding_attention",
46
+ "sliding_attention",
47
+ "sliding_attention",
48
+ "full_attention",
49
+ "sliding_attention",
50
+ "sliding_attention",
51
+ "sliding_attention",
52
+ "full_attention",
53
+ "sliding_attention",
54
+ "sliding_attention",
55
+ "sliding_attention",
56
+ "full_attention",
57
+ "sliding_attention",
58
+ "sliding_attention",
59
+ "sliding_attention",
60
+ "full_attention",
61
+ "sliding_attention",
62
+ "sliding_attention",
63
+ "sliding_attention",
64
+ "full_attention",
65
+ "sliding_attention",
66
+ "sliding_attention",
67
+ "sliding_attention",
68
+ "full_attention",
69
+ "sliding_attention",
70
+ "sliding_attention",
71
+ "sliding_attention",
72
+ "full_attention",
73
+ "sliding_attention",
74
+ "sliding_attention",
75
+ "sliding_attention"
76
+ ],
77
+ "max_position_embeddings": 262144,
78
+ "max_seq_len": 262144,
79
+ "model_type": "step3p5",
80
+ "moe_every_n_layer": 1,
81
+ "moe_intermediate_size": 1280,
82
+ "moe_layer_offset": 0,
83
+ "moe_layers_enum": "3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44",
84
+ "moe_num_experts": 288,
85
+ "moe_router_activation": "sigmoid",
86
+ "moe_router_scaling_factor": 3.0,
87
+ "moe_top_k": 8,
88
+ "need_fp32_gate": true,
89
+ "norm_expert_weight": true,
90
+ "num_attention_groups": 8,
91
+ "num_attention_heads": 64,
92
+ "num_hidden_layers": 45,
93
+ "num_nextn_predict_layers": 3,
94
+ "partial_rotary_factor": 0.5,
95
+ "partial_rotary_factors": [
96
+ 0.5,
97
+ 1.0,
98
+ 1.0,
99
+ 1.0,
100
+ 0.5,
101
+ 1.0,
102
+ 1.0,
103
+ 1.0,
104
+ 0.5,
105
+ 1.0,
106
+ 1.0,
107
+ 1.0,
108
+ 0.5,
109
+ 1.0,
110
+ 1.0,
111
+ 1.0,
112
+ 0.5,
113
+ 1.0,
114
+ 1.0,
115
+ 1.0,
116
+ 0.5,
117
+ 1.0,
118
+ 1.0,
119
+ 1.0,
120
+ 0.5,
121
+ 1.0,
122
+ 1.0,
123
+ 1.0,
124
+ 0.5,
125
+ 1.0,
126
+ 1.0,
127
+ 1.0,
128
+ 0.5,
129
+ 1.0,
130
+ 1.0,
131
+ 1.0,
132
+ 0.5,
133
+ 1.0,
134
+ 1.0,
135
+ 1.0,
136
+ 0.5,
137
+ 1.0,
138
+ 1.0,
139
+ 1.0,
140
+ 0.5,
141
+ 1.0,
142
+ 1.0,
143
+ 1.0
144
+ ],
145
+ "quantization_config": {
146
+ "config_groups": {
147
+ "group_0": {
148
+ "format": "nvfp4-pack-quantized",
149
+ "input_activations": {
150
+ "actorder": null,
151
+ "block_structure": null,
152
+ "dynamic": "local",
153
+ "group_size": 16,
154
+ "num_bits": 4,
155
+ "observer": "static_minmax",
156
+ "observer_kwargs": {},
157
+ "scale_dtype": "torch.float8_e4m3fn",
158
+ "strategy": "tensor_group",
159
+ "symmetric": true,
160
+ "type": "float",
161
+ "zp_dtype": null
162
+ },
163
+ "output_activations": null,
164
+ "targets": [
165
+ "Linear",
166
+ "MoELinear"
167
+ ],
168
+ "weights": {
169
+ "actorder": null,
170
+ "block_structure": null,
171
+ "dynamic": false,
172
+ "group_size": 16,
173
+ "num_bits": 4,
174
+ "observer": "static_minmax",
175
+ "observer_kwargs": {},
176
+ "scale_dtype": "torch.float8_e4m3fn",
177
+ "strategy": "tensor_group",
178
+ "symmetric": true,
179
+ "type": "float",
180
+ "zp_dtype": null
181
+ }
182
+ }
183
+ },
184
+ "format": "nvfp4-pack-quantized",
185
+ "global_compression_ratio": null,
186
+ "ignore": [
187
+ "lm_head"
188
+ ],
189
+ "kv_cache_scheme": null,
190
+ "quant_method": "compressed-tensors",
191
+ "quantization_status": "compressed",
192
+ "sparsity_config": {},
193
+ "transform_config": {},
194
+ "version": "0.13.0"
195
+ },
196
+ "rms_norm_eps": 1e-05,
197
+ "rope_parameters": {
198
+ "factor": 2.0,
199
+ "high_freq_factor": 32.0,
200
+ "low_freq_factor": 1.0,
201
+ "original_max_position_embeddings": 131072,
202
+ "rope_type": "llama3"
203
+ },
204
+ "rope_scaling": {
205
+ "factor": 2.0,
206
+ "high_freq_factor": 32.0,
207
+ "low_freq_factor": 1.0,
208
+ "original_max_position_embeddings": 131072,
209
+ "rope_type": "llama3"
210
+ },
211
+ "rope_theta": [
212
+ 5000000.0,
213
+ 10000.0,
214
+ 10000.0,
215
+ 10000.0,
216
+ 5000000.0,
217
+ 10000.0,
218
+ 10000.0,
219
+ 10000.0,
220
+ 5000000.0,
221
+ 10000.0,
222
+ 10000.0,
223
+ 10000.0,
224
+ 5000000.0,
225
+ 10000.0,
226
+ 10000.0,
227
+ 10000.0,
228
+ 5000000.0,
229
+ 10000.0,
230
+ 10000.0,
231
+ 10000.0,
232
+ 5000000.0,
233
+ 10000.0,
234
+ 10000.0,
235
+ 10000.0,
236
+ 5000000.0,
237
+ 10000.0,
238
+ 10000.0,
239
+ 10000.0,
240
+ 5000000.0,
241
+ 10000.0,
242
+ 10000.0,
243
+ 10000.0,
244
+ 5000000.0,
245
+ 10000.0,
246
+ 10000.0,
247
+ 10000.0,
248
+ 5000000.0,
249
+ 10000.0,
250
+ 10000.0,
251
+ 10000.0,
252
+ 5000000.0,
253
+ 10000.0,
254
+ 10000.0,
255
+ 10000.0,
256
+ 5000000.0,
257
+ 10000.0,
258
+ 10000.0,
259
+ 10000.0
260
+ ],
261
+ "share_expert_dim": 1280,
262
+ "sink": false,
263
+ "sliding_window": 512,
264
+ "swiglu_limits": [
265
+ 0.0,
266
+ 0.0,
267
+ 0.0,
268
+ 0.0,
269
+ 0.0,
270
+ 0.0,
271
+ 0.0,
272
+ 0.0,
273
+ 0.0,
274
+ 0.0,
275
+ 0.0,
276
+ 0.0,
277
+ 0.0,
278
+ 0.0,
279
+ 0.0,
280
+ 0.0,
281
+ 0.0,
282
+ 0.0,
283
+ 0.0,
284
+ 0.0,
285
+ 0.0,
286
+ 0.0,
287
+ 0.0,
288
+ 0.0,
289
+ 0.0,
290
+ 0.0,
291
+ 0.0,
292
+ 0.0,
293
+ 0.0,
294
+ 0.0,
295
+ 0.0,
296
+ 0.0,
297
+ 0.0,
298
+ 0.0,
299
+ 0.0,
300
+ 0.0,
301
+ 0.0,
302
+ 0.0,
303
+ 0.0,
304
+ 0.0,
305
+ 0.0,
306
+ 0.0,
307
+ 0.0,
308
+ 7,
309
+ 7,
310
+ 0.0,
311
+ 0.0,
312
+ 0.0
313
+ ],
314
+ "swiglu_limits_shared": [
315
+ 0.0,
316
+ 0.0,
317
+ 0.0,
318
+ 0.0,
319
+ 0.0,
320
+ 0.0,
321
+ 0.0,
322
+ 0.0,
323
+ 0.0,
324
+ 0.0,
325
+ 0.0,
326
+ 0.0,
327
+ 0.0,
328
+ 0.0,
329
+ 0.0,
330
+ 0.0,
331
+ 0.0,
332
+ 0.0,
333
+ 0.0,
334
+ 0.0,
335
+ 0.0,
336
+ 0.0,
337
+ 0.0,
338
+ 0.0,
339
+ 0.0,
340
+ 0.0,
341
+ 0.0,
342
+ 0.0,
343
+ 0.0,
344
+ 0.0,
345
+ 0.0,
346
+ 0.0,
347
+ 0.0,
348
+ 0.0,
349
+ 0.0,
350
+ 0.0,
351
+ 0.0,
352
+ 0.0,
353
+ 0.0,
354
+ 0.0,
355
+ 0.0,
356
+ 0.0,
357
+ 0.0,
358
+ 0.0,
359
+ 16,
360
+ 0.0,
361
+ 0.0,
362
+ 0.0
363
+ ],
364
+ "transformers_version": "4.57.3",
365
+ "use_cache": false,
366
+ "use_head_wise_attn_gate": true,
367
+ "use_moe": true,
368
+ "use_moe_router_bias": true,
369
+ "use_qk_norm": true,
370
+ "use_rope_layers": [],
371
+ "vocab_size": 128896,
372
+ "yarn_only_types": [
373
+ "full_attention"
374
+ ],
375
+ "zero_centered": true
376
+ }
configuration_step3p5.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Union
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+
7
+ class Step3p5Config(PretrainedConfig):
8
+ model_type = "step3p5"
9
+ architectures = ["Step3p5ForCausalLM"]
10
+
11
+ def __init__(
12
+ self,
13
+ hidden_size: int = 4096,
14
+ intermediate_size: int = 11264,
15
+ num_attention_heads: int = 64,
16
+ num_attention_groups: int = 8,
17
+ num_hidden_layers: int = 45,
18
+ max_seq_len: int = 128000,
19
+ vocab_size: int = 128815,
20
+ rms_norm_eps: float = 1e-5,
21
+ moe_intermediate_size: int = 1280,
22
+ moe_num_experts: int = 288,
23
+ moe_top_k: int = 8,
24
+ rope_theta: float = 10000,
25
+ rope_scaling: Optional[dict[str, Any]] = None,
26
+ max_position_embeddings: int = 128000,
27
+ share_expert_dims: int = 1280,
28
+ head_dim: int = 128,
29
+ norm_expert_weight: bool = True,
30
+ layer_types: list[str] = None,
31
+ sliding_window: Optional[int] = None,
32
+ moe_layers_enum: tuple[int] = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
33
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
34
+ 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
35
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44),
36
+ **kwargs,
37
+ ) -> None:
38
+ self.hidden_size = hidden_size
39
+ self.intermediate_size = intermediate_size
40
+ self.num_attention_heads = num_attention_heads
41
+ self.num_attention_groups = num_attention_groups
42
+ self.num_hidden_layers = num_hidden_layers
43
+ self.max_seq_len = max_seq_len
44
+ self.vocab_size = vocab_size
45
+ self.rms_norm_eps = rms_norm_eps
46
+ self.moe_intermediate_size = moe_intermediate_size
47
+ self.moe_num_experts = moe_num_experts
48
+ self.moe_top_k = moe_top_k
49
+ self.rope_theta = rope_theta
50
+ self.rope_scaling = rope_scaling
51
+ self.max_position_embeddings = max_position_embeddings
52
+ self.share_expert_dim = share_expert_dims
53
+ self.head_dim = head_dim
54
+ self.norm_expert_weight = norm_expert_weight
55
+ self.moe_layers_enum = moe_layers_enum
56
+ self.layer_types = layer_types
57
+ self.sliding_window = sliding_window
58
+ super().__init__(**kwargs)
59
+
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "do_sample": true,
5
+ "eos_token_id": [
6
+ 1,
7
+ 2,
8
+ 128007
9
+ ],
10
+ "transformers_version": "4.57.3"
11
+ }
model-00001-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8a4116f97b83f5272cbcf4a14d879bd4cf7cced6cfb1f94c9959e22deee32af
3
+ size 4967057968
model-00002-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00d3c857834bae585f823926408406efe936f96bc96a187bd166f2ea9e984852
3
+ size 4388928280
model-00003-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72845c5cd1603b4642f0ffc8e0f0202e3ffd0365f7dade7dff1c9b4953d43e95
3
+ size 4317831672
model-00004-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f137d8752377ca68fa1c7982528bef8caeeb0141767c40477a7d0f8a005978b
3
+ size 4369980176
model-00005-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:589470dfae75ba75232ad47f0f43399ac7a376c60e30c23fd667107793a57c8b
3
+ size 4388928400
model-00006-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78f23487e3bab1f0048b335a3138e15533cac08fc1539247bcae8cbef3d059e6
3
+ size 4298883640
model-00007-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f97bd8e16073c02428151b0171fbacb6579f3d340022b958f17f81c877b2c50
3
+ size 4388928376
model-00008-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:402fa96eea50c25016c1751d0893718fcf7097072e60f7b4bac39ce13086a8d6
3
+ size 4369980288
model-00009-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d179a8923dab1deb79a13de991191313ef261794fcb96fd0371fd24dcd99c44a
3
+ size 4317831736
model-00010-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92eb378130807584c9f8f2a04b29115205f36c21c49344616a3272d81a3adaaf
3
+ size 4388928376
model-00011-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:138c4db286faca86fb22d534d68c90922eb6d9a9f3f6544cbd2cba8756b0e419
3
+ size 4369980288
model-00012-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d7460bbc4db7812de1559686a83d22502c7bcff90fb383429bd2ec82a3fd9b5
3
+ size 4317831736
model-00013-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cee0b32bd83e9bae3830457deee6ec5f0b78f11bf5bf70f4d9726711cca97ad0
3
+ size 4369980280
model-00014-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:471cab510b6c6bc4c5ca8473542548e0f33556e42dc93fc36148c4da04e97d44
3
+ size 4388928384
model-00015-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c058151516351df8dd559daa6322842dcfa3358aa8d009dcd7eb5c84acad571e
3
+ size 4317831736
model-00016-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0bb9ecac2ae3f3c1458c15b3ea09bd6c77102e44666f579d4d5f57de3f48bb0
3
+ size 4369980280
model-00017-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c66c30f7e9fb6af842634957b92791caacff7a191ead59a9b768da0fc9eb72df
3
+ size 4388928384
model-00018-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7902621aa51b9eaa855da55e89f17a265e3b0351691d58a712eb16fea66e2b88
3
+ size 4298883640
model-00019-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92ae180ac81f7e402b7ddee3f6f51399e9dbe7f684e5399f53dd395a74811e16
3
+ size 4388928376
model-00020-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a5b1b25ee6d4f5a932a50cd17868b81b157de890d8ab47da33af620e45cb491
3
+ size 4369980288
model-00021-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2cbee091b139efe89c3f4aaade8b0f1840640a792a47c484a3a7612b27f814f
3
+ size 4317831736
model-00022-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1489fde66e452704854197f8dc833ce9e5f51275addec07d6293728e3100f8c8
3
+ size 4388928376
model-00023-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:285f76218c43fce239d54372e2e57040c0ccf7be7992abe6c6709ef656feb3bc
3
+ size 4369980288
model-00024-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17f9c6fd63a52aab3560d9c0e26dc76457527c049b2e488119a1c406983c40b3
3
+ size 4317831736
model-00025-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e037a341d982cc982d4e21c8f3c2c9a1ec48712c6403f080e44b22875b22ad1
3
+ size 4369980280
model-00026-of-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b2542bce70caefe00ef78727dd0adf1c4def65fc3000b342e0401c0abe8c576
3
+ size 2763483928
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_step3p5.py ADDED
@@ -0,0 +1,900 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from typing import Callable, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from transformers.activations import ACT2FN
22
+ from transformers.cache_utils import Cache, DynamicCache
23
+ from transformers.generation import GenerationMixin
24
+ from transformers.masking_utils import (create_causal_mask,
25
+ create_sliding_window_causal_mask)
26
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
27
+ from transformers.modeling_layers import GradientCheckpointingLayer
28
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
29
+ from transformers.modeling_rope_utils import (ROPE_INIT_FUNCTIONS,
30
+ dynamic_rope_update)
31
+ from transformers.modeling_utils import (ALL_ATTENTION_FUNCTIONS,
32
+ PreTrainedModel)
33
+ from transformers.processing_utils import Unpack
34
+ from transformers.utils import TransformersKwargs, can_return_tuple, logging
35
+
36
+ from .configuration_step3p5 import Step3p5Config
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ __all__ = ["Step3p5Model", "Step3p5ForCausalLM"]
41
+
42
+ class Step3p5RotaryEmbedding(nn.Module):
43
+
44
+ def __init__(self, config: Step3p5Config, device=None, layer_idx=None):
45
+ super().__init__()
46
+ # BC: "rope_type" was originally "type"
47
+ self.layer_idx = layer_idx
48
+ if config.rope_parameters is not None:
49
+ self.rope_type = config.rope_parameters.get(
50
+ "rope_type", config.rope_parameters.get("type"))
51
+ else:
52
+ self.rope_type = "default"
53
+ self.max_seq_len_cached = config.max_position_embeddings
54
+ self.original_max_seq_len = config.max_position_embeddings
55
+
56
+ partial_rotary_factors = getattr(config, "partial_rotary_factors",
57
+ None)
58
+ if partial_rotary_factors is not None:
59
+ config.partial_rotary_factor = partial_rotary_factors[
60
+ self.layer_idx]
61
+ else:
62
+ config.partial_rotary_factor = 1.0
63
+
64
+ self.rope_theta = config.rope_theta
65
+ if isinstance(config.rope_theta, list):
66
+ self.rope_theta = config.rope_theta.copy()
67
+ config.rope_theta = self.rope_theta[self.layer_idx]
68
+
69
+ self.config = config
70
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
71
+ inv_freq, self.attention_scaling = self.rope_init_fn(
72
+ self.config, device)
73
+
74
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
75
+ self.original_inv_freq = self.inv_freq
76
+ config.rope_theta = self.rope_theta
77
+
78
+ @torch.no_grad()
79
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
80
+ def forward(self, x, position_ids):
81
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
82
+ position_ids.shape[0], -1, 1).to(x.device)
83
+ position_ids_expanded = position_ids[:, None, :].float().to(x.device)
84
+
85
+ device_type = x.device.type if isinstance(
86
+ x.device.type, str) and x.device.type != "mps" else "cpu"
87
+ with torch.autocast(device_type=device_type,
88
+ enabled=False): # Force float32
89
+ freqs = (inv_freq_expanded.float()
90
+ @ position_ids_expanded.float()).transpose(1, 2)
91
+ emb = torch.cat((freqs, freqs), dim=-1)
92
+ cos = emb.cos() * self.attention_scaling
93
+ sin = emb.sin() * self.attention_scaling
94
+
95
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
96
+
97
+
98
+ def rotate_half(x):
99
+ """Rotates half the hidden dims of the input."""
100
+ x1 = x[..., :x.shape[-1] // 2]
101
+ x2 = x[..., x.shape[-1] // 2:]
102
+ return torch.cat((-x2, x1), dim=-1)
103
+
104
+
105
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
106
+ """Applies Rotary Position Embedding to the query and key tensors.
107
+
108
+ Args:
109
+ q (`torch.Tensor`): The query tensor.
110
+ k (`torch.Tensor`): The key tensor.
111
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
112
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
113
+ position_ids (`torch.Tensor`, *optional*):
114
+ Deprecated and unused.
115
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
116
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
117
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
118
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
119
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
120
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
121
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
122
+ Returns:
123
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
124
+ """
125
+ rotary_dim = cos.shape[-1]
126
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
127
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
128
+
129
+ # Apply rotary embeddings on the first half or full tensor
130
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
131
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
132
+
133
+ # Concatenate back to full shape
134
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
135
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
136
+ return q_embed, k_embed
137
+
138
+
139
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
140
+ """
141
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
142
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
143
+ """
144
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
145
+ if n_rep == 1:
146
+ return hidden_states
147
+ hidden_states = hidden_states[:, :,
148
+ None, :, :].expand(batch,
149
+ num_key_value_heads,
150
+ n_rep, slen, head_dim)
151
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
152
+ head_dim)
153
+
154
+
155
+ # Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32
156
+ def eager_attention_forward(
157
+ module: nn.Module,
158
+ query: torch.Tensor,
159
+ key: torch.Tensor,
160
+ value: torch.Tensor,
161
+ attention_mask: Optional[torch.Tensor],
162
+ scaling: float,
163
+ dropout: float = 0.0,
164
+ **kwargs,
165
+ ):
166
+ key_states = repeat_kv(key, module.num_key_value_groups)
167
+ value_states = repeat_kv(value, module.num_key_value_groups)
168
+ # breakpoint()
169
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
170
+ if attention_mask is not None:
171
+ causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
172
+ attn_weights = attn_weights + causal_mask
173
+
174
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
175
+ attn_weights = nn.functional.dropout(attn_weights,
176
+ p=dropout,
177
+ training=module.training)
178
+ attn_output = torch.matmul(attn_weights, value_states)
179
+ attn_output = attn_output.transpose(1, 2).contiguous()
180
+
181
+ return attn_output, attn_weights
182
+
183
+ @dataclass
184
+ class Step3p5CausalLMOutputWithPast(ModelOutput):
185
+ r"""
186
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
187
+ Language modeling loss (for next-token prediction).
188
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
189
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
190
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
191
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
192
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
193
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
194
+ `past_key_values` input) to speed up sequential decoding.
195
+ """
196
+
197
+ loss: Optional[torch.FloatTensor] = None
198
+ last_hidden_state: Optional[torch.FloatTensor] = None
199
+ logits: torch.FloatTensor = None
200
+ past_key_values: Optional[list[torch.FloatTensor]] = None
201
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
202
+ attentions: Optional[tuple[torch.FloatTensor]] = None
203
+
204
+
205
+ class Step3p5MLP(nn.Module):
206
+
207
+ def __init__(self, config, intermediate_size=None, swiglu_limit=None):
208
+ super().__init__()
209
+ self.config = config
210
+ self.hidden_size = config.hidden_size
211
+ self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
212
+ self.gate_proj = nn.Linear(self.hidden_size,
213
+ self.intermediate_size,
214
+ bias=False)
215
+ self.up_proj = nn.Linear(self.hidden_size,
216
+ self.intermediate_size,
217
+ bias=False)
218
+ self.down_proj = nn.Linear(self.intermediate_size,
219
+ self.hidden_size,
220
+ bias=False)
221
+ self.act_fn = ACT2FN["silu"]
222
+ self.limit = swiglu_limit
223
+
224
+ def forward(self, x):
225
+ up = self.up_proj(x)
226
+ gate = self.act_fn(self.gate_proj(x))
227
+ if self.limit is not None:
228
+ gate = gate.clamp(min=None, max=self.limit)
229
+ up = up.clamp(min=-self.limit, max=self.limit)
230
+
231
+ return self.down_proj(gate * up)
232
+
233
+
234
+ def sigmoid_routing_function(gating_output: torch.Tensor, topk: int,
235
+ renormalize: bool):
236
+ gating_output = gating_output.float()
237
+ gate_prob = torch.sigmoid(gating_output)
238
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
239
+ topk_prob, indices = torch.topk(gate_prob, k=topk, dim=1)
240
+ expert_topk_weight = topk_prob
241
+ if renormalize:
242
+ expert_topk_weight = expert_topk_weight / torch.sum(
243
+ expert_topk_weight, dim=-1, keepdim=True)
244
+ return expert_topk_weight, indices
245
+
246
+
247
+ def softmax_routing_function(gating_output: torch.Tensor, top_k: int,
248
+ renormalize: bool):
249
+ gating_output = gating_output.float()
250
+ gate_prob = torch.softmax(gating_output, dim=-1)
251
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
252
+ topk_prob, indices = torch.topk(gate_prob, k=top_k, dim=1)
253
+ expert_topk_weight = topk_prob
254
+ if renormalize:
255
+ expert_topk_weight = expert_topk_weight / torch.sum(
256
+ expert_topk_weight, dim=-1, keepdim=True)
257
+ return expert_topk_weight, indices.to(torch.int32)
258
+
259
+
260
+ class MoELinear(nn.Module):
261
+
262
+ def __init__(self, num_experts, in_features, out_features):
263
+ super().__init__()
264
+ self.num_experts = num_experts
265
+ self.in_features = in_features
266
+ self.out_features = out_features
267
+ self.weight = nn.Parameter(
268
+ torch.empty(num_experts, out_features, in_features))
269
+
270
+ def forward(self, x, expert_id):
271
+ x = F.linear(x.float(), self.weight[expert_id].float())
272
+ return x
273
+
274
+
275
+ class Step3p5MoEMLP(nn.Module):
276
+
277
+ def __init__(self, config, swiglu_limit=None):
278
+ super().__init__()
279
+ self.num_experts = config.moe_num_experts
280
+ self.top_k = config.moe_top_k
281
+ self.hidden_size = config.hidden_size
282
+ self.moe_intermediate_size = config.moe_intermediate_size
283
+
284
+ self.use_moe_router_bias = config.use_moe_router_bias
285
+ if self.use_moe_router_bias:
286
+ self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts,
287
+ dtype=torch.float32),
288
+ requires_grad=False)
289
+ self.custom_routing_function = self.router_bias_func
290
+ elif config.moe_router_activation == "sigmoid":
291
+ self.custom_routing_function = sigmoid_routing_function
292
+ else:
293
+ self.custom_routing_function = None
294
+ self.need_fp32_gate = config.need_fp32_gate
295
+ self.routed_scaling_factor = getattr(config,
296
+ "moe_router_scaling_factor", 1.0)
297
+
298
+ # gating
299
+ self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
300
+
301
+ self.act_fn = ACT2FN["silu"]
302
+ self.limit = swiglu_limit
303
+
304
+ self.up_proj = MoELinear(self.num_experts, self.hidden_size,
305
+ self.moe_intermediate_size)
306
+ self.gate_proj = MoELinear(self.num_experts, self.hidden_size,
307
+ self.moe_intermediate_size)
308
+ self.down_proj = MoELinear(self.num_experts,
309
+ self.moe_intermediate_size,
310
+ self.hidden_size)
311
+
312
+ def router_bias_func(self, gating_output: torch.Tensor, topk: int,
313
+ renormalize: bool):
314
+ gate_prob = torch.sigmoid(gating_output.float())
315
+ gate_prob_with_bias = gate_prob + self.router_bias.unsqueeze(0)
316
+ _, indices = torch.topk(gate_prob_with_bias, k=topk, dim=1)
317
+ topk_prob = torch.gather(gate_prob, 1, indices)
318
+ expert_topk_weight = topk_prob
319
+ if renormalize:
320
+ expert_topk_weight = expert_topk_weight / (
321
+ torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20)
322
+ return expert_topk_weight, indices
323
+
324
+ def get_expert_output(self, inputs: torch.Tensor, expert_id):
325
+ #if self.limit is None:
326
+ up = self.up_proj(inputs, expert_id)
327
+ gate = self.act_fn(self.gate_proj(inputs, expert_id))
328
+ if self.limit is not None:
329
+ gate = gate.clamp(min=None, max=self.limit)
330
+ up = up.clamp(min=-self.limit, max=self.limit)
331
+
332
+ return self.down_proj(gate * up, expert_id)
333
+
334
+ def forward(self, hidden_states):
335
+ """ """
336
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
337
+ hidden_states = hidden_states.view(-1, hidden_dim)
338
+ if self.need_fp32_gate:
339
+ router_logits = torch.matmul(hidden_states.to(torch.float32), self.gate.weight.t().to(torch.float32))
340
+ else:
341
+ # router_logits: (batch * sequence_length, n_experts)
342
+ router_logits = self.gate(hidden_states)
343
+
344
+ if self.custom_routing_function:
345
+ routing_weights, selected_experts = self.custom_routing_function(
346
+ router_logits, self.top_k, renormalize=True)
347
+ else:
348
+ routing_weights = F.softmax(router_logits,
349
+ dim=1,
350
+ dtype=torch.float)
351
+ routing_weights, selected_experts = torch.topk(routing_weights,
352
+ self.top_k,
353
+ dim=-1)
354
+
355
+ routing_weights = routing_weights * self.routed_scaling_factor
356
+
357
+ final_hidden_states = torch.zeros(
358
+ (batch_size * sequence_length, hidden_dim),
359
+ dtype=hidden_states.dtype,
360
+ device=hidden_states.device)
361
+
362
+ # One hot encode the selected experts to create an expert mask
363
+ # this will be used to easily index which expert is going to be sollicitated
364
+ expert_mask = torch.nn.functional.one_hot(
365
+ selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
366
+
367
+ # Loop over all available experts in the model and perform the computation on each expert
368
+ for expert_idx in range(self.num_experts):
369
+ idx, top_x = torch.where(expert_mask[expert_idx])
370
+
371
+ # Index the correct hidden states and compute the expert hidden state for
372
+ # the current expert. We need to make sure to multiply the output hidden
373
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
374
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
375
+ current_hidden_states = (
376
+ self.get_expert_output(current_state, expert_idx) *
377
+ routing_weights[top_x, idx, None])
378
+
379
+ # However `index_add_` only support torch tensors for indexing so we'll use
380
+ # the `top_x` tensor here.
381
+ final_hidden_states.index_add_(
382
+ 0, top_x, current_hidden_states.to(hidden_states.dtype))
383
+ final_hidden_states = final_hidden_states.reshape(
384
+ batch_size, sequence_length, hidden_dim)
385
+ return final_hidden_states
386
+
387
+
388
+ class Step3p5RMSNorm(nn.Module):
389
+
390
+ def __init__(
391
+ self,
392
+ hidden_size: int,
393
+ eps: float = 1e-5,
394
+ ) -> None:
395
+ super().__init__()
396
+ self.weight = nn.Parameter(torch.ones(hidden_size))
397
+ self.variance_epsilon = eps
398
+
399
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
400
+ dtype = x.dtype
401
+ x = x.float()
402
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
403
+ normed = x * torch.rsqrt(variance + self.variance_epsilon)
404
+ normed = normed * (self.weight.float() + 1)
405
+ return normed.to(dtype)
406
+ class Step3p5Attention(nn.Module):
407
+
408
+ def __init__(self, config: Step3p5Config, layer_idx):
409
+ super().__init__()
410
+ self.config = config
411
+ self.layer_idx = layer_idx
412
+ self.num_attention_heads = config.num_attention_heads
413
+ self.num_key_value_heads = config.num_attention_groups
414
+
415
+ layer_types = getattr(config, "layer_types", [])
416
+ if layer_types:
417
+ enable_sliding_window = layer_types[
418
+ self.layer_idx] == "sliding_attention"
419
+ else:
420
+ enable_sliding_window = self.layer_idx % 2 == 0
421
+
422
+ if hasattr(config, "yarn_only_types") and layer_types[
423
+ self.layer_idx] not in config.yarn_only_types:
424
+ config.rope_parameters = None
425
+ else:
426
+ config.rope_parameters = getattr(config, "rope_scaling", None)
427
+
428
+ self.sliding_window = config.sliding_window
429
+ if enable_sliding_window:
430
+ self.num_attention_heads = config.attention_other_setting[
431
+ "num_attention_heads"]
432
+ self.num_key_value_heads = config.attention_other_setting[
433
+ "num_attention_groups"]
434
+
435
+ if self.sliding_window is not None and enable_sliding_window:
436
+ self.sliding_window = (self.sliding_window)
437
+ else:
438
+ self.sliding_window = None
439
+ self.head_dim = getattr(config, "head_dim",
440
+ config.hidden_size // self.num_attention_heads)
441
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
442
+
443
+ self.rotary_emb = Step3p5RotaryEmbedding(config, layer_idx=layer_idx)
444
+
445
+ self.q_size = self.num_attention_heads * self.head_dim
446
+ self.kv_size = self.num_key_value_heads * self.head_dim
447
+ self.scaling = self.head_dim**-0.5
448
+
449
+ self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=False)
450
+ self.k_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
451
+ self.v_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
452
+ self.o_proj = nn.Linear(self.q_size, config.hidden_size, bias=False)
453
+ self.q_norm = Step3p5RMSNorm(self.head_dim,
454
+ eps=config.rms_norm_eps)
455
+ self.k_norm = Step3p5RMSNorm(self.head_dim,
456
+ eps=config.rms_norm_eps)
457
+
458
+ self.use_head_wise_attn_gate = config.use_head_wise_attn_gate
459
+ if self.use_head_wise_attn_gate:
460
+ self.g_proj = nn.Linear(config.hidden_size,
461
+ self.num_attention_heads,
462
+ bias=False)
463
+
464
+ self.use_rope = True
465
+ use_rope_layers = getattr(config, "use_rope_layers", None)
466
+ if use_rope_layers:
467
+ self.use_rope = use_rope_layers[self.layer_idx]
468
+
469
+ def forward(
470
+ self,
471
+ hidden_states: torch.Tensor,
472
+ attention_mask: Optional[torch.Tensor],
473
+ past_key_value: Optional[Cache] = None,
474
+ cache_position: Optional[torch.LongTensor] = None,
475
+ position_ids: Optional[torch.LongTensor] = None,
476
+ **kwargs: Unpack[FlashAttentionKwargs],
477
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
478
+ Optional[Tuple[torch.Tensor]]]:
479
+ input_shape = hidden_states.shape[:-1]
480
+ hidden_shape = (*input_shape, -1, self.head_dim)
481
+
482
+ query_states = self.q_norm(
483
+ self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
484
+ key_states = self.k_norm(
485
+ self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
486
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(
487
+ 1, 2)
488
+ if self.use_head_wise_attn_gate:
489
+ gate_states = self.g_proj(hidden_states)
490
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
491
+
492
+ # cos, sin = position_embeddings
493
+ query_states, key_states = apply_rotary_pos_emb(
494
+ query_states, key_states, cos, sin)
495
+
496
+ # query_states, key_states = apply_rotary_pos_emb(query_norm_states, key_norm_states, cos, sin)
497
+ if past_key_value is not None:
498
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
499
+ cache_kwargs = {
500
+ "sin": sin,
501
+ "cos": cos,
502
+ "cache_position": cache_position
503
+ }
504
+ key_states, value_states = past_key_value.update(
505
+ key_states, value_states, self.layer_idx, cache_kwargs)
506
+
507
+ attention_interface: Callable = eager_attention_forward
508
+ # TODO: considering FP8;
509
+ # RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype,
510
+ # but got attn_mask.dtype: long int and query.dtype: c10::BFloat16 instead.
511
+ if self.config._attn_implementation != "eager":
512
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
513
+ self.config._attn_implementation]
514
+
515
+ attn_output, attn_weights = attention_interface(
516
+ self,
517
+ query_states,
518
+ key_states,
519
+ value_states,
520
+ attention_mask,
521
+ dropout=0.0 if not self.training else self.attention_dropout,
522
+ scaling=self.scaling,
523
+ sliding_window=self.sliding_window, # main diff with Llama
524
+ **kwargs,
525
+ )
526
+ attn_output = attn_output.reshape(*input_shape, -1)
527
+ if self.use_head_wise_attn_gate:
528
+ output = attn_output.view(
529
+ *attn_output.shape[:-1], self.num_attention_heads,
530
+ self.head_dim) * gate_states.unsqueeze(-1).sigmoid()
531
+ attn_output = output.view(*attn_output.shape)
532
+ attn_output = self.o_proj(attn_output)
533
+
534
+ return attn_output, attn_weights
535
+
536
+
537
+ class Step3p5DecoderLayer(GradientCheckpointingLayer):
538
+
539
+ def __init__(self, config, layer_idx):
540
+ super().__init__()
541
+ self.hidden_size = config.hidden_size
542
+ self.layer_idx = layer_idx
543
+ self.self_attn = Step3p5Attention(config, layer_idx)
544
+ self.attention_type = config.layer_types[layer_idx]
545
+
546
+ moe_layers_enum = getattr(config, "moe_layers_enum", None)
547
+ if moe_layers_enum is not None:
548
+ moe_layers_idx = [
549
+ int(i) for i in moe_layers_enum.strip().split(',')
550
+ ]
551
+ else:
552
+ moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
553
+ self.is_moe_layer = layer_idx in moe_layers_idx
554
+ self.use_moe = False
555
+
556
+ if config.swiglu_limits_shared and config.swiglu_limits_shared[
557
+ layer_idx] is not None and config.swiglu_limits_shared[
558
+ layer_idx] != 0:
559
+ swiglu_limit_shared = config.swiglu_limits_shared[layer_idx]
560
+ else:
561
+ swiglu_limit_shared = None
562
+ if config.swiglu_limits and config.swiglu_limits[
563
+ layer_idx] is not None and config.swiglu_limits[layer_idx] != 0:
564
+ swiglu_limit = config.swiglu_limits[layer_idx]
565
+ else:
566
+ swiglu_limit = None
567
+ if self.is_moe_layer:
568
+ self.moe = Step3p5MoEMLP(config, swiglu_limit=swiglu_limit) #
569
+ self.share_expert = Step3p5MLP(
570
+ config,
571
+ intermediate_size=config.share_expert_dim,
572
+ swiglu_limit=swiglu_limit_shared)
573
+ self.use_moe = True
574
+ else:
575
+ self.mlp = Step3p5MLP(config,
576
+ intermediate_size=config.intermediate_size,
577
+ swiglu_limit=swiglu_limit_shared)
578
+
579
+ self.input_layernorm = Step3p5RMSNorm(
580
+ config.hidden_size,
581
+ eps=config.rms_norm_eps)
582
+ self.post_attention_layernorm = Step3p5RMSNorm(
583
+ config.hidden_size,
584
+ eps=config.rms_norm_eps)
585
+
586
+ def forward(
587
+ self,
588
+ hidden_states: torch.Tensor,
589
+ attention_mask: Optional[torch.Tensor] = None,
590
+ position_ids: Optional[torch.LongTensor] = None,
591
+ past_key_value: Optional[tuple[torch.Tensor]] = None,
592
+ cache_position: Optional[torch.LongTensor] = None,
593
+ **kwargs: Unpack[FlashAttentionKwargs],
594
+ ) -> torch.FloatTensor:
595
+ residual = hidden_states
596
+ hidden_states = self.input_layernorm(hidden_states)
597
+ hidden_states, _ = self.self_attn(
598
+ hidden_states=hidden_states,
599
+ attention_mask=attention_mask,
600
+ position_ids=position_ids,
601
+ past_key_value=past_key_value,
602
+ cache_position=cache_position,
603
+ **kwargs,
604
+ )
605
+ hidden_states = residual + hidden_states
606
+
607
+ # Fully Connected
608
+ residual = hidden_states
609
+ hidden_states = self.post_attention_layernorm(hidden_states)
610
+ if self.use_moe:
611
+ share_output = self.share_expert(hidden_states)
612
+ moe_output = self.moe(hidden_states)
613
+ ffn_output = moe_output + share_output
614
+ else:
615
+ ffn_output = self.mlp(hidden_states)
616
+ if isinstance(ffn_output, tuple):
617
+ hidden_states, _ = ffn_output
618
+ else:
619
+ hidden_states = ffn_output
620
+
621
+ hidden_states = residual + hidden_states
622
+ return hidden_states
623
+
624
+
625
+ class Step3p5PreTrainedModel(PreTrainedModel):
626
+ # Link this model family to its configuration class so PreTrainedModel.from_pretrained
627
+ # can load the config instead of failing with a NoneType error.
628
+ config_class = Step3p5Config
629
+ supports_gradient_checkpointing = True
630
+ _skip_keys_device_placement = ["past_key_values"]
631
+ _keys_to_ignore_on_load_unexpected = [
632
+ r"model\.layers\.45\.*",
633
+ r"model\.layers\.46\.*",
634
+ r"model\.layers\.47\.*"
635
+ ]
636
+ _supports_flash_attn = False
637
+ _supports_sdpa = True
638
+ _supports_flex_attn = True
639
+ _supports_static_cache = True
640
+ _supports_attention_backend = True
641
+
642
+
643
+ class Step3p5Model(Step3p5PreTrainedModel, GenerationMixin):
644
+ _no_split_modules = ["Step3p5DecoderLayer"]
645
+ base_model_prefix = "model"
646
+ _tied_weights_keys = ["lm_head.weight"]
647
+ config: Step3p5Config
648
+ def __init__(self, config: Step3p5Config):
649
+ super().__init__(config)
650
+ self.padding_idx = config.pad_token_id
651
+ self.vocab_size = config.vocab_size
652
+
653
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
654
+ self.padding_idx)
655
+ self.layers = nn.ModuleList([
656
+ Step3p5DecoderLayer(config, layer_idx)
657
+ for layer_idx in range(config.num_hidden_layers)
658
+ ])
659
+ self.norm = Step3p5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
660
+ self.gradient_checkpointing = False
661
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
662
+
663
+ # Initialize weights and apply final processing
664
+ self.post_init()
665
+
666
+ def get_input_embeddings(self):
667
+ return self.embed_tokens
668
+
669
+ @can_return_tuple
670
+ def forward(
671
+ self,
672
+ input_ids: torch.LongTensor = None,
673
+ attention_mask: Optional[torch.Tensor] = None,
674
+ position_ids: Optional[torch.LongTensor] = None,
675
+ past_key_values: Optional[Cache] = None,
676
+ inputs_embeds: Optional[torch.FloatTensor] = None,
677
+ use_cache: Optional[bool] = None,
678
+ output_attentions: Optional[bool] = None,
679
+ output_hidden_states: Optional[bool] = None,
680
+ return_dict: Optional[bool] = None,
681
+ cache_position: Optional[torch.LongTensor] = None,
682
+ **kwargs: Unpack[TransformersKwargs],
683
+ ) -> Union[tuple, BaseModelOutputWithPast]:
684
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
685
+ output_hidden_states = (output_hidden_states
686
+ if output_hidden_states is not None else
687
+ self.config.output_hidden_states)
688
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
689
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
690
+ if (input_ids is None) ^ (inputs_embeds is not None):
691
+ raise ValueError(
692
+ "You must specify exactly one of input_ids or inputs_embeds")
693
+
694
+ if self.gradient_checkpointing and self.training and use_cache:
695
+ logger.warning_once(
696
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
697
+ )
698
+ use_cache = False
699
+
700
+ if inputs_embeds is None:
701
+ inputs_embeds = self.embed_tokens(
702
+ input_ids.to(self.embed_tokens.weight.device))
703
+
704
+ if use_cache and past_key_values is None:
705
+ past_key_values = DynamicCache()
706
+
707
+ if cache_position is None:
708
+ past_seen_tokens = past_key_values.get_seq_length(
709
+ ) if past_key_values is not None else 0
710
+ cache_position = torch.arange(past_seen_tokens,
711
+ past_seen_tokens +
712
+ inputs_embeds.shape[1],
713
+ device=inputs_embeds.device)
714
+
715
+ if position_ids is None:
716
+ position_ids = cache_position.unsqueeze(0)
717
+
718
+ hidden_states = inputs_embeds
719
+
720
+ # It may already have been prepared by e.g. `generate`
721
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
722
+ # Prepare mask arguments
723
+ mask_kwargs = {
724
+ "config": self.config,
725
+ "input_embeds": inputs_embeds,
726
+ "attention_mask": attention_mask,
727
+ "cache_position": cache_position,
728
+ "past_key_values": past_key_values,
729
+ "position_ids": position_ids,
730
+ }
731
+ # Create the masks
732
+ causal_mask_mapping = {
733
+ "full_attention": create_causal_mask(**mask_kwargs),
734
+ }
735
+
736
+ # The sliding window alternating layers are not always activated depending on the config
737
+ if self.has_sliding_layers:
738
+ causal_mask_mapping[
739
+ "sliding_attention"] = create_sliding_window_causal_mask(
740
+ **mask_kwargs)
741
+
742
+ # # create position embeddings to be shared across the decoder layers
743
+ # decoder layers
744
+ all_hidden_states = () if output_hidden_states else None
745
+ all_self_attns = () if output_attentions else None
746
+ for decoder_layer in self.layers[:self.config.num_hidden_layers]:
747
+ if output_hidden_states:
748
+ all_hidden_states += (hidden_states, )
749
+
750
+ layer_outputs = decoder_layer(
751
+ hidden_states,
752
+ attention_mask=causal_mask_mapping[
753
+ decoder_layer.attention_type],
754
+ position_ids=position_ids,
755
+ past_key_value=past_key_values,
756
+ output_attentions=output_attentions,
757
+ use_cache=use_cache,
758
+ cache_position=cache_position,
759
+ **kwargs,
760
+ )
761
+
762
+ hidden_states = layer_outputs
763
+
764
+ hidden_states = self.norm(hidden_states)
765
+
766
+ return BaseModelOutputWithPast(
767
+ last_hidden_state=hidden_states,
768
+ past_key_values=past_key_values if use_cache else None,
769
+ hidden_states=all_hidden_states,
770
+ attentions=all_self_attns,
771
+ )
772
+
773
+
774
+ class Step3p5ForCausalLM(Step3p5PreTrainedModel, GenerationMixin):
775
+ _tied_weights_keys = ["lm_head.weight"]
776
+ config: Step3p5Config
777
+
778
+ def __init__(self, config: Step3p5Config):
779
+ super().__init__(config)
780
+ self.model = Step3p5Model(config)
781
+ self.lm_head = nn.Linear(config.hidden_size,
782
+ config.vocab_size,
783
+ bias=False)
784
+
785
+ self.post_init()
786
+
787
+ def get_input_embeddings(self):
788
+ return self.model.get_input_embeddings()
789
+
790
+ def set_input_embeddings(self, value):
791
+ self.model.set_input_embeddings(value)
792
+
793
+ def get_output_embeddings(self):
794
+ return self.model.get_output_embeddings()
795
+
796
+ def set_output_embeddings(self, new_embeddings):
797
+ self.model.set_output_embeddings(new_embeddings)
798
+
799
+ def set_decoder(self, decoder):
800
+ self.model.set_decoder(decoder)
801
+
802
+ def get_decoder(self):
803
+ return self.model.get_decoder()
804
+
805
+ def forward(
806
+ self,
807
+ input_ids: torch.LongTensor = None,
808
+ num_patches=None,
809
+ patch_pixel_values=None,
810
+ patch_newline_mask=None,
811
+ attention_mask: Optional[torch.Tensor] = None,
812
+ position_ids: Optional[torch.LongTensor] = None,
813
+ past_key_values: Optional[Cache] = None,
814
+ inputs_embeds: Optional[torch.FloatTensor] = None,
815
+ labels: Optional[torch.LongTensor] = None,
816
+ use_cache: Optional[bool] = None,
817
+ output_attentions: Optional[bool] = None,
818
+ output_hidden_states: Optional[bool] = None,
819
+ return_dict: Optional[bool] = None,
820
+ cache_position: Optional[torch.LongTensor] = None,
821
+ **kwargs: Unpack[TransformersKwargs],
822
+ ) -> Union[tuple, Step3p5CausalLMOutputWithPast]:
823
+ r"""
824
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
825
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
826
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
827
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
828
+ Example:
829
+ ```python
830
+ >>> from transformers import AutoTokenizer, Llama4ForCausalLM
831
+ >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
832
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
833
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
834
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
835
+ >>> # Generate
836
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
837
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
838
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
839
+ ```"""
840
+
841
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
842
+ output_hidden_states = (output_hidden_states
843
+ if output_hidden_states is not None else
844
+ self.config.output_hidden_states)
845
+ # breakpoint()
846
+ outputs = self.model(
847
+ input_ids=input_ids,
848
+ num_patches=num_patches,
849
+ patch_pixel_values=patch_pixel_values,
850
+ patch_newline_mask=patch_newline_mask,
851
+ position_ids=position_ids,
852
+ attention_mask=attention_mask,
853
+ past_key_values=past_key_values,
854
+ inputs_embeds=inputs_embeds,
855
+ use_cache=use_cache,
856
+ output_attentions=output_attentions,
857
+ output_hidden_states=output_hidden_states,
858
+ return_dict=return_dict,
859
+ cache_position=cache_position,
860
+ **kwargs,
861
+ )
862
+ hidden_states = outputs.last_hidden_state
863
+ logits = self.lm_head(hidden_states)
864
+
865
+ return Step3p5CausalLMOutputWithPast(logits=logits, )
866
+
867
+ def prepare_inputs_for_generation(
868
+ self,
869
+ input_ids,
870
+ past_key_values=None,
871
+ inputs_embeds=None,
872
+ pixel_values=None,
873
+ attention_mask=None,
874
+ cache_position=None,
875
+ logits_to_keep=None,
876
+ **kwargs,
877
+ ):
878
+
879
+ model_inputs = super().prepare_inputs_for_generation(
880
+ input_ids,
881
+ past_key_values=past_key_values,
882
+ inputs_embeds=inputs_embeds,
883
+ attention_mask=attention_mask,
884
+ cache_position=cache_position,
885
+ logits_to_keep=logits_to_keep,
886
+ **kwargs,
887
+ )
888
+
889
+ if cache_position[0] == 0:
890
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
891
+ # Otherwise we need pixel values to be passed to model
892
+ model_inputs["pixel_values"] = pixel_values
893
+
894
+ return model_inputs
895
+
896
+ def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]:
897
+ if key.startswith("language_model."):
898
+ return key[len("language_model."):], True
899
+
900
+ return key, False
recipe.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ default_stage:
2
+ default_modifiers:
3
+ QuantizationModifier:
4
+ targets: [Linear, MoELinear]
5
+ ignore: [lm_head, 're:visual.*', 're:.*vision_tower.*', 're:.*video_tower.*', 're:.*audio_tower.*',
6
+ 're:.*multi_modal_projector.*']
7
+ scheme: NVFP4
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|begin▁of▁sentence|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|im_end|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|end▁of▁sentence|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff