Add NVFP4 quantized checkpoint
Browse files- README.md +27 -0
- chat_template.jinja +80 -0
- config.json +376 -0
- configuration_step3p5.py +59 -0
- generation_config.json +11 -0
- model-00001-of-00026.safetensors +3 -0
- model-00002-of-00026.safetensors +3 -0
- model-00003-of-00026.safetensors +3 -0
- model-00004-of-00026.safetensors +3 -0
- model-00005-of-00026.safetensors +3 -0
- model-00006-of-00026.safetensors +3 -0
- model-00007-of-00026.safetensors +3 -0
- model-00008-of-00026.safetensors +3 -0
- model-00009-of-00026.safetensors +3 -0
- model-00010-of-00026.safetensors +3 -0
- model-00011-of-00026.safetensors +3 -0
- model-00012-of-00026.safetensors +3 -0
- model-00013-of-00026.safetensors +3 -0
- model-00014-of-00026.safetensors +3 -0
- model-00015-of-00026.safetensors +3 -0
- model-00016-of-00026.safetensors +3 -0
- model-00017-of-00026.safetensors +3 -0
- model-00018-of-00026.safetensors +3 -0
- model-00019-of-00026.safetensors +3 -0
- model-00020-of-00026.safetensors +3 -0
- model-00021-of-00026.safetensors +3 -0
- model-00022-of-00026.safetensors +3 -0
- model-00023-of-00026.safetensors +3 -0
- model-00024-of-00026.safetensors +3 -0
- model-00025-of-00026.safetensors +3 -0
- model-00026-of-00026.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_step3p5.py +900 -0
- recipe.yaml +7 -0
- special_tokens_map.json +23 -0
- tokenizer.json +0 -0
- tokenizer_config.json +0 -0
README.md
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
datasets:
|
| 3 |
+
- Rombo-Org/Optimized_Reasoning
|
| 4 |
+
base_model:
|
| 5 |
+
- stepfun-ai/Step-3.5-Flash
|
| 6 |
+
tags:
|
| 7 |
+
- nvfp4
|
| 8 |
+
- fp4
|
| 9 |
+
- quantized
|
| 10 |
+
---
|
| 11 |
+
# Step-3.5-Flash-nvfp4
|
| 12 |
+
|
| 13 |
+
**Format:** NVFP4 — weights & activations quantized to FP4 with dual scaling.
|
| 14 |
+
**Base model:** `stepfun-ai/Step-3.5-Flash`
|
| 15 |
+
**How it was made:** One-shot calibration with LLM Compressor (NVFP4 recipe), long-seq calibration (1 samples of 512 length) with Rombo-Org/Optimized_Reasoning.
|
| 16 |
+
|
| 17 |
+
> Notes: Keep `lm_head` in high precision; calibrate on long, domain-relevant sequences.
|
| 18 |
+
|
| 19 |
+
Check the original model card for information about this model.
|
| 20 |
+
|
| 21 |
+
# Running the model with VLLM in Docker
|
| 22 |
+
```sh
|
| 23 |
+
sudo docker run --runtime nvidia --gpus all -p 8000:8000 --ipc=host vllm/vllm-openai:nightly --model Firworks/Step-3.5-Flash-nvfp4 --dtype auto --max-model-len 32768
|
| 24 |
+
```
|
| 25 |
+
This was tested on an RTX Pro 6000 Blackwell cloud instance.
|
| 26 |
+
|
| 27 |
+
If there are other models you're interested in seeing quantized to NVFP4 for use on the DGX Spark, or other modern Blackwell (or newer) cards let me know. I'm trying to make more NVFP4 models available to allow more people to try them out.
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{% macro render_content(content) %}{% if content is none %}{{- '' }}{% elif content is string %}{{- content }}{% elif content is mapping %}{{- content['value'] if 'value' in content else content['text'] }}{% elif content is iterable %}{% for item in content %}{% if item.type == 'text' %}{{- item['value'] if 'value' in item else item['text'] }}{% elif item.type == 'image' %}<im_patch>{% endif %}{% endfor %}{% endif %}{% endmacro %}
|
| 2 |
+
{{bos_token}}{%- if tools %}
|
| 3 |
+
{{- '<|im_start|>system\n' }}
|
| 4 |
+
{%- if messages[0].role == 'system' %}
|
| 5 |
+
{{- render_content(messages[0].content) + '\n\n' }}
|
| 6 |
+
{%- endif %}
|
| 7 |
+
{{- "# Tools\n\nYou have access to the following functions in JSONSchema format:\n\n<tools>" }}
|
| 8 |
+
{%- for tool in tools %}
|
| 9 |
+
{{- "\n" }}
|
| 10 |
+
{{- tool | tojson(ensure_ascii=False) }}
|
| 11 |
+
{%- endfor %}
|
| 12 |
+
{{- "\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...>\n...\n</function> block must be nested within <tool_call>\n...\n</tool_call> XML tags\n- Required parameters MUST be specified\n</IMPORTANT><|im_end|>\n" }}
|
| 13 |
+
{%- else %}
|
| 14 |
+
{%- if messages[0].role == 'system' %}
|
| 15 |
+
{{- '<|im_start|>system\n' + render_content(messages[0].content) + '<|im_end|>\n' }}
|
| 16 |
+
{%- endif %}
|
| 17 |
+
{%- endif %}
|
| 18 |
+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
| 19 |
+
{%- for message in messages[::-1] %}
|
| 20 |
+
{%- set index = (messages|length - 1) - loop.index0 %}
|
| 21 |
+
{%- if ns.multi_step_tool and message.role == "user" and render_content(message.content) is string and not(render_content(message.content).startswith('<tool_response>') and render_content(message.content).endswith('</tool_response>')) %}
|
| 22 |
+
{%- set ns.multi_step_tool = false %}
|
| 23 |
+
{%- set ns.last_query_index = index %}
|
| 24 |
+
{%- endif %}
|
| 25 |
+
{%- endfor %}
|
| 26 |
+
{%- for message in messages %}
|
| 27 |
+
{%- set content = render_content(message.content) %}
|
| 28 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
| 29 |
+
{%- set role_name = 'observation' if (message.role == "system" and not loop.first and message.name == 'observation') else message.role %}
|
| 30 |
+
{{- '<|im_start|>' + role_name + '\n' + content + '<|im_end|>' + '\n' }}
|
| 31 |
+
{%- elif message.role == "assistant" %}
|
| 32 |
+
{%- if message.reasoning_content is string %}
|
| 33 |
+
{%- set reasoning_content = render_content(message.reasoning_content) %}
|
| 34 |
+
{%- else %}
|
| 35 |
+
{%- if '</think>' in content %}
|
| 36 |
+
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
| 37 |
+
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
| 38 |
+
{%- else %}
|
| 39 |
+
{%- set reasoning_content = '' %}
|
| 40 |
+
{%- endif %}
|
| 41 |
+
{%- endif %}
|
| 42 |
+
{%- if loop.index0 > ns.last_query_index %}
|
| 43 |
+
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n' + content }}
|
| 44 |
+
{%- else %}
|
| 45 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 46 |
+
{%- endif %}
|
| 47 |
+
{%- if message.tool_calls %}
|
| 48 |
+
{%- for tool_call in message.tool_calls %}
|
| 49 |
+
{%- if tool_call.function is defined %}
|
| 50 |
+
{%- set tool_call = tool_call.function %}
|
| 51 |
+
{%- endif %}
|
| 52 |
+
{{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
| 53 |
+
{%- if tool_call.arguments is defined %}
|
| 54 |
+
{%- set arguments = tool_call.arguments %}
|
| 55 |
+
{%- for args_name, args_value in arguments|items %}
|
| 56 |
+
{{- '<parameter=' + args_name + '>\n' }}
|
| 57 |
+
{%- set args_value = args_value | tojson(ensure_ascii=False) | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
|
| 58 |
+
{{- args_value }}
|
| 59 |
+
{{- '\n</parameter>\n' }}
|
| 60 |
+
{%- endfor %}
|
| 61 |
+
{%- endif %}
|
| 62 |
+
{{- '</function>\n</tool_call>' }}
|
| 63 |
+
{%- endfor %}
|
| 64 |
+
{%- endif %}
|
| 65 |
+
{{- '<|im_end|>\n' }}
|
| 66 |
+
{%- elif message.role == "tool" %}
|
| 67 |
+
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
| 68 |
+
{{- '<|im_start|>tool_response\n' }}
|
| 69 |
+
{%- endif %}
|
| 70 |
+
{{- '<tool_response>' }}
|
| 71 |
+
{{- content }}
|
| 72 |
+
{{- '</tool_response>' }}
|
| 73 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 74 |
+
{{- '<|im_end|>\n' }}
|
| 75 |
+
{%- endif %}
|
| 76 |
+
{%- endif %}
|
| 77 |
+
{%- endfor %}
|
| 78 |
+
{%- if add_generation_prompt %}
|
| 79 |
+
{{- '<|im_start|>assistant\n<think>\n' }}
|
| 80 |
+
{%- endif %}
|
config.json
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Step3p5ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"att_impl_type": "GQA",
|
| 6 |
+
"attention_other_setting": {
|
| 7 |
+
"attention_type": "sliding_attention",
|
| 8 |
+
"head_dim": 128,
|
| 9 |
+
"num_attention_groups": 8,
|
| 10 |
+
"num_attention_heads": 96,
|
| 11 |
+
"true_head_dim": 128
|
| 12 |
+
},
|
| 13 |
+
"auto_map": {
|
| 14 |
+
"AutoConfig": "configuration_step3p5.Step3p5Config",
|
| 15 |
+
"AutoModelForCausalLM": "modeling_step3p5.Step3p5ForCausalLM"
|
| 16 |
+
},
|
| 17 |
+
"bos_token_id": 0,
|
| 18 |
+
"dtype": "bfloat16",
|
| 19 |
+
"eos_token_id": [
|
| 20 |
+
1,
|
| 21 |
+
2,
|
| 22 |
+
128007
|
| 23 |
+
],
|
| 24 |
+
"head_dim": 128,
|
| 25 |
+
"hidden_size": 4096,
|
| 26 |
+
"intermediate_size": 11264,
|
| 27 |
+
"layer_types": [
|
| 28 |
+
"full_attention",
|
| 29 |
+
"sliding_attention",
|
| 30 |
+
"sliding_attention",
|
| 31 |
+
"sliding_attention",
|
| 32 |
+
"full_attention",
|
| 33 |
+
"sliding_attention",
|
| 34 |
+
"sliding_attention",
|
| 35 |
+
"sliding_attention",
|
| 36 |
+
"full_attention",
|
| 37 |
+
"sliding_attention",
|
| 38 |
+
"sliding_attention",
|
| 39 |
+
"sliding_attention",
|
| 40 |
+
"full_attention",
|
| 41 |
+
"sliding_attention",
|
| 42 |
+
"sliding_attention",
|
| 43 |
+
"sliding_attention",
|
| 44 |
+
"full_attention",
|
| 45 |
+
"sliding_attention",
|
| 46 |
+
"sliding_attention",
|
| 47 |
+
"sliding_attention",
|
| 48 |
+
"full_attention",
|
| 49 |
+
"sliding_attention",
|
| 50 |
+
"sliding_attention",
|
| 51 |
+
"sliding_attention",
|
| 52 |
+
"full_attention",
|
| 53 |
+
"sliding_attention",
|
| 54 |
+
"sliding_attention",
|
| 55 |
+
"sliding_attention",
|
| 56 |
+
"full_attention",
|
| 57 |
+
"sliding_attention",
|
| 58 |
+
"sliding_attention",
|
| 59 |
+
"sliding_attention",
|
| 60 |
+
"full_attention",
|
| 61 |
+
"sliding_attention",
|
| 62 |
+
"sliding_attention",
|
| 63 |
+
"sliding_attention",
|
| 64 |
+
"full_attention",
|
| 65 |
+
"sliding_attention",
|
| 66 |
+
"sliding_attention",
|
| 67 |
+
"sliding_attention",
|
| 68 |
+
"full_attention",
|
| 69 |
+
"sliding_attention",
|
| 70 |
+
"sliding_attention",
|
| 71 |
+
"sliding_attention",
|
| 72 |
+
"full_attention",
|
| 73 |
+
"sliding_attention",
|
| 74 |
+
"sliding_attention",
|
| 75 |
+
"sliding_attention"
|
| 76 |
+
],
|
| 77 |
+
"max_position_embeddings": 262144,
|
| 78 |
+
"max_seq_len": 262144,
|
| 79 |
+
"model_type": "step3p5",
|
| 80 |
+
"moe_every_n_layer": 1,
|
| 81 |
+
"moe_intermediate_size": 1280,
|
| 82 |
+
"moe_layer_offset": 0,
|
| 83 |
+
"moe_layers_enum": "3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44",
|
| 84 |
+
"moe_num_experts": 288,
|
| 85 |
+
"moe_router_activation": "sigmoid",
|
| 86 |
+
"moe_router_scaling_factor": 3.0,
|
| 87 |
+
"moe_top_k": 8,
|
| 88 |
+
"need_fp32_gate": true,
|
| 89 |
+
"norm_expert_weight": true,
|
| 90 |
+
"num_attention_groups": 8,
|
| 91 |
+
"num_attention_heads": 64,
|
| 92 |
+
"num_hidden_layers": 45,
|
| 93 |
+
"num_nextn_predict_layers": 3,
|
| 94 |
+
"partial_rotary_factor": 0.5,
|
| 95 |
+
"partial_rotary_factors": [
|
| 96 |
+
0.5,
|
| 97 |
+
1.0,
|
| 98 |
+
1.0,
|
| 99 |
+
1.0,
|
| 100 |
+
0.5,
|
| 101 |
+
1.0,
|
| 102 |
+
1.0,
|
| 103 |
+
1.0,
|
| 104 |
+
0.5,
|
| 105 |
+
1.0,
|
| 106 |
+
1.0,
|
| 107 |
+
1.0,
|
| 108 |
+
0.5,
|
| 109 |
+
1.0,
|
| 110 |
+
1.0,
|
| 111 |
+
1.0,
|
| 112 |
+
0.5,
|
| 113 |
+
1.0,
|
| 114 |
+
1.0,
|
| 115 |
+
1.0,
|
| 116 |
+
0.5,
|
| 117 |
+
1.0,
|
| 118 |
+
1.0,
|
| 119 |
+
1.0,
|
| 120 |
+
0.5,
|
| 121 |
+
1.0,
|
| 122 |
+
1.0,
|
| 123 |
+
1.0,
|
| 124 |
+
0.5,
|
| 125 |
+
1.0,
|
| 126 |
+
1.0,
|
| 127 |
+
1.0,
|
| 128 |
+
0.5,
|
| 129 |
+
1.0,
|
| 130 |
+
1.0,
|
| 131 |
+
1.0,
|
| 132 |
+
0.5,
|
| 133 |
+
1.0,
|
| 134 |
+
1.0,
|
| 135 |
+
1.0,
|
| 136 |
+
0.5,
|
| 137 |
+
1.0,
|
| 138 |
+
1.0,
|
| 139 |
+
1.0,
|
| 140 |
+
0.5,
|
| 141 |
+
1.0,
|
| 142 |
+
1.0,
|
| 143 |
+
1.0
|
| 144 |
+
],
|
| 145 |
+
"quantization_config": {
|
| 146 |
+
"config_groups": {
|
| 147 |
+
"group_0": {
|
| 148 |
+
"format": "nvfp4-pack-quantized",
|
| 149 |
+
"input_activations": {
|
| 150 |
+
"actorder": null,
|
| 151 |
+
"block_structure": null,
|
| 152 |
+
"dynamic": "local",
|
| 153 |
+
"group_size": 16,
|
| 154 |
+
"num_bits": 4,
|
| 155 |
+
"observer": "static_minmax",
|
| 156 |
+
"observer_kwargs": {},
|
| 157 |
+
"scale_dtype": "torch.float8_e4m3fn",
|
| 158 |
+
"strategy": "tensor_group",
|
| 159 |
+
"symmetric": true,
|
| 160 |
+
"type": "float",
|
| 161 |
+
"zp_dtype": null
|
| 162 |
+
},
|
| 163 |
+
"output_activations": null,
|
| 164 |
+
"targets": [
|
| 165 |
+
"Linear",
|
| 166 |
+
"MoELinear"
|
| 167 |
+
],
|
| 168 |
+
"weights": {
|
| 169 |
+
"actorder": null,
|
| 170 |
+
"block_structure": null,
|
| 171 |
+
"dynamic": false,
|
| 172 |
+
"group_size": 16,
|
| 173 |
+
"num_bits": 4,
|
| 174 |
+
"observer": "static_minmax",
|
| 175 |
+
"observer_kwargs": {},
|
| 176 |
+
"scale_dtype": "torch.float8_e4m3fn",
|
| 177 |
+
"strategy": "tensor_group",
|
| 178 |
+
"symmetric": true,
|
| 179 |
+
"type": "float",
|
| 180 |
+
"zp_dtype": null
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
},
|
| 184 |
+
"format": "nvfp4-pack-quantized",
|
| 185 |
+
"global_compression_ratio": null,
|
| 186 |
+
"ignore": [
|
| 187 |
+
"lm_head"
|
| 188 |
+
],
|
| 189 |
+
"kv_cache_scheme": null,
|
| 190 |
+
"quant_method": "compressed-tensors",
|
| 191 |
+
"quantization_status": "compressed",
|
| 192 |
+
"sparsity_config": {},
|
| 193 |
+
"transform_config": {},
|
| 194 |
+
"version": "0.13.0"
|
| 195 |
+
},
|
| 196 |
+
"rms_norm_eps": 1e-05,
|
| 197 |
+
"rope_parameters": {
|
| 198 |
+
"factor": 2.0,
|
| 199 |
+
"high_freq_factor": 32.0,
|
| 200 |
+
"low_freq_factor": 1.0,
|
| 201 |
+
"original_max_position_embeddings": 131072,
|
| 202 |
+
"rope_type": "llama3"
|
| 203 |
+
},
|
| 204 |
+
"rope_scaling": {
|
| 205 |
+
"factor": 2.0,
|
| 206 |
+
"high_freq_factor": 32.0,
|
| 207 |
+
"low_freq_factor": 1.0,
|
| 208 |
+
"original_max_position_embeddings": 131072,
|
| 209 |
+
"rope_type": "llama3"
|
| 210 |
+
},
|
| 211 |
+
"rope_theta": [
|
| 212 |
+
5000000.0,
|
| 213 |
+
10000.0,
|
| 214 |
+
10000.0,
|
| 215 |
+
10000.0,
|
| 216 |
+
5000000.0,
|
| 217 |
+
10000.0,
|
| 218 |
+
10000.0,
|
| 219 |
+
10000.0,
|
| 220 |
+
5000000.0,
|
| 221 |
+
10000.0,
|
| 222 |
+
10000.0,
|
| 223 |
+
10000.0,
|
| 224 |
+
5000000.0,
|
| 225 |
+
10000.0,
|
| 226 |
+
10000.0,
|
| 227 |
+
10000.0,
|
| 228 |
+
5000000.0,
|
| 229 |
+
10000.0,
|
| 230 |
+
10000.0,
|
| 231 |
+
10000.0,
|
| 232 |
+
5000000.0,
|
| 233 |
+
10000.0,
|
| 234 |
+
10000.0,
|
| 235 |
+
10000.0,
|
| 236 |
+
5000000.0,
|
| 237 |
+
10000.0,
|
| 238 |
+
10000.0,
|
| 239 |
+
10000.0,
|
| 240 |
+
5000000.0,
|
| 241 |
+
10000.0,
|
| 242 |
+
10000.0,
|
| 243 |
+
10000.0,
|
| 244 |
+
5000000.0,
|
| 245 |
+
10000.0,
|
| 246 |
+
10000.0,
|
| 247 |
+
10000.0,
|
| 248 |
+
5000000.0,
|
| 249 |
+
10000.0,
|
| 250 |
+
10000.0,
|
| 251 |
+
10000.0,
|
| 252 |
+
5000000.0,
|
| 253 |
+
10000.0,
|
| 254 |
+
10000.0,
|
| 255 |
+
10000.0,
|
| 256 |
+
5000000.0,
|
| 257 |
+
10000.0,
|
| 258 |
+
10000.0,
|
| 259 |
+
10000.0
|
| 260 |
+
],
|
| 261 |
+
"share_expert_dim": 1280,
|
| 262 |
+
"sink": false,
|
| 263 |
+
"sliding_window": 512,
|
| 264 |
+
"swiglu_limits": [
|
| 265 |
+
0.0,
|
| 266 |
+
0.0,
|
| 267 |
+
0.0,
|
| 268 |
+
0.0,
|
| 269 |
+
0.0,
|
| 270 |
+
0.0,
|
| 271 |
+
0.0,
|
| 272 |
+
0.0,
|
| 273 |
+
0.0,
|
| 274 |
+
0.0,
|
| 275 |
+
0.0,
|
| 276 |
+
0.0,
|
| 277 |
+
0.0,
|
| 278 |
+
0.0,
|
| 279 |
+
0.0,
|
| 280 |
+
0.0,
|
| 281 |
+
0.0,
|
| 282 |
+
0.0,
|
| 283 |
+
0.0,
|
| 284 |
+
0.0,
|
| 285 |
+
0.0,
|
| 286 |
+
0.0,
|
| 287 |
+
0.0,
|
| 288 |
+
0.0,
|
| 289 |
+
0.0,
|
| 290 |
+
0.0,
|
| 291 |
+
0.0,
|
| 292 |
+
0.0,
|
| 293 |
+
0.0,
|
| 294 |
+
0.0,
|
| 295 |
+
0.0,
|
| 296 |
+
0.0,
|
| 297 |
+
0.0,
|
| 298 |
+
0.0,
|
| 299 |
+
0.0,
|
| 300 |
+
0.0,
|
| 301 |
+
0.0,
|
| 302 |
+
0.0,
|
| 303 |
+
0.0,
|
| 304 |
+
0.0,
|
| 305 |
+
0.0,
|
| 306 |
+
0.0,
|
| 307 |
+
0.0,
|
| 308 |
+
7,
|
| 309 |
+
7,
|
| 310 |
+
0.0,
|
| 311 |
+
0.0,
|
| 312 |
+
0.0
|
| 313 |
+
],
|
| 314 |
+
"swiglu_limits_shared": [
|
| 315 |
+
0.0,
|
| 316 |
+
0.0,
|
| 317 |
+
0.0,
|
| 318 |
+
0.0,
|
| 319 |
+
0.0,
|
| 320 |
+
0.0,
|
| 321 |
+
0.0,
|
| 322 |
+
0.0,
|
| 323 |
+
0.0,
|
| 324 |
+
0.0,
|
| 325 |
+
0.0,
|
| 326 |
+
0.0,
|
| 327 |
+
0.0,
|
| 328 |
+
0.0,
|
| 329 |
+
0.0,
|
| 330 |
+
0.0,
|
| 331 |
+
0.0,
|
| 332 |
+
0.0,
|
| 333 |
+
0.0,
|
| 334 |
+
0.0,
|
| 335 |
+
0.0,
|
| 336 |
+
0.0,
|
| 337 |
+
0.0,
|
| 338 |
+
0.0,
|
| 339 |
+
0.0,
|
| 340 |
+
0.0,
|
| 341 |
+
0.0,
|
| 342 |
+
0.0,
|
| 343 |
+
0.0,
|
| 344 |
+
0.0,
|
| 345 |
+
0.0,
|
| 346 |
+
0.0,
|
| 347 |
+
0.0,
|
| 348 |
+
0.0,
|
| 349 |
+
0.0,
|
| 350 |
+
0.0,
|
| 351 |
+
0.0,
|
| 352 |
+
0.0,
|
| 353 |
+
0.0,
|
| 354 |
+
0.0,
|
| 355 |
+
0.0,
|
| 356 |
+
0.0,
|
| 357 |
+
0.0,
|
| 358 |
+
0.0,
|
| 359 |
+
16,
|
| 360 |
+
0.0,
|
| 361 |
+
0.0,
|
| 362 |
+
0.0
|
| 363 |
+
],
|
| 364 |
+
"transformers_version": "4.57.3",
|
| 365 |
+
"use_cache": false,
|
| 366 |
+
"use_head_wise_attn_gate": true,
|
| 367 |
+
"use_moe": true,
|
| 368 |
+
"use_moe_router_bias": true,
|
| 369 |
+
"use_qk_norm": true,
|
| 370 |
+
"use_rope_layers": [],
|
| 371 |
+
"vocab_size": 128896,
|
| 372 |
+
"yarn_only_types": [
|
| 373 |
+
"full_attention"
|
| 374 |
+
],
|
| 375 |
+
"zero_centered": true
|
| 376 |
+
}
|
configuration_step3p5.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Optional, Union
|
| 2 |
+
|
| 3 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Step3p5Config(PretrainedConfig):
|
| 8 |
+
model_type = "step3p5"
|
| 9 |
+
architectures = ["Step3p5ForCausalLM"]
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
hidden_size: int = 4096,
|
| 14 |
+
intermediate_size: int = 11264,
|
| 15 |
+
num_attention_heads: int = 64,
|
| 16 |
+
num_attention_groups: int = 8,
|
| 17 |
+
num_hidden_layers: int = 45,
|
| 18 |
+
max_seq_len: int = 128000,
|
| 19 |
+
vocab_size: int = 128815,
|
| 20 |
+
rms_norm_eps: float = 1e-5,
|
| 21 |
+
moe_intermediate_size: int = 1280,
|
| 22 |
+
moe_num_experts: int = 288,
|
| 23 |
+
moe_top_k: int = 8,
|
| 24 |
+
rope_theta: float = 10000,
|
| 25 |
+
rope_scaling: Optional[dict[str, Any]] = None,
|
| 26 |
+
max_position_embeddings: int = 128000,
|
| 27 |
+
share_expert_dims: int = 1280,
|
| 28 |
+
head_dim: int = 128,
|
| 29 |
+
norm_expert_weight: bool = True,
|
| 30 |
+
layer_types: list[str] = None,
|
| 31 |
+
sliding_window: Optional[int] = None,
|
| 32 |
+
moe_layers_enum: tuple[int] = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
|
| 33 |
+
15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
|
| 34 |
+
25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
|
| 35 |
+
35, 36, 37, 38, 39, 40, 41, 42, 43, 44),
|
| 36 |
+
**kwargs,
|
| 37 |
+
) -> None:
|
| 38 |
+
self.hidden_size = hidden_size
|
| 39 |
+
self.intermediate_size = intermediate_size
|
| 40 |
+
self.num_attention_heads = num_attention_heads
|
| 41 |
+
self.num_attention_groups = num_attention_groups
|
| 42 |
+
self.num_hidden_layers = num_hidden_layers
|
| 43 |
+
self.max_seq_len = max_seq_len
|
| 44 |
+
self.vocab_size = vocab_size
|
| 45 |
+
self.rms_norm_eps = rms_norm_eps
|
| 46 |
+
self.moe_intermediate_size = moe_intermediate_size
|
| 47 |
+
self.moe_num_experts = moe_num_experts
|
| 48 |
+
self.moe_top_k = moe_top_k
|
| 49 |
+
self.rope_theta = rope_theta
|
| 50 |
+
self.rope_scaling = rope_scaling
|
| 51 |
+
self.max_position_embeddings = max_position_embeddings
|
| 52 |
+
self.share_expert_dim = share_expert_dims
|
| 53 |
+
self.head_dim = head_dim
|
| 54 |
+
self.norm_expert_weight = norm_expert_weight
|
| 55 |
+
self.moe_layers_enum = moe_layers_enum
|
| 56 |
+
self.layer_types = layer_types
|
| 57 |
+
self.sliding_window = sliding_window
|
| 58 |
+
super().__init__(**kwargs)
|
| 59 |
+
|
generation_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 0,
|
| 4 |
+
"do_sample": true,
|
| 5 |
+
"eos_token_id": [
|
| 6 |
+
1,
|
| 7 |
+
2,
|
| 8 |
+
128007
|
| 9 |
+
],
|
| 10 |
+
"transformers_version": "4.57.3"
|
| 11 |
+
}
|
model-00001-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a8a4116f97b83f5272cbcf4a14d879bd4cf7cced6cfb1f94c9959e22deee32af
|
| 3 |
+
size 4967057968
|
model-00002-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:00d3c857834bae585f823926408406efe936f96bc96a187bd166f2ea9e984852
|
| 3 |
+
size 4388928280
|
model-00003-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:72845c5cd1603b4642f0ffc8e0f0202e3ffd0365f7dade7dff1c9b4953d43e95
|
| 3 |
+
size 4317831672
|
model-00004-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f137d8752377ca68fa1c7982528bef8caeeb0141767c40477a7d0f8a005978b
|
| 3 |
+
size 4369980176
|
model-00005-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:589470dfae75ba75232ad47f0f43399ac7a376c60e30c23fd667107793a57c8b
|
| 3 |
+
size 4388928400
|
model-00006-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:78f23487e3bab1f0048b335a3138e15533cac08fc1539247bcae8cbef3d059e6
|
| 3 |
+
size 4298883640
|
model-00007-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0f97bd8e16073c02428151b0171fbacb6579f3d340022b958f17f81c877b2c50
|
| 3 |
+
size 4388928376
|
model-00008-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:402fa96eea50c25016c1751d0893718fcf7097072e60f7b4bac39ce13086a8d6
|
| 3 |
+
size 4369980288
|
model-00009-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d179a8923dab1deb79a13de991191313ef261794fcb96fd0371fd24dcd99c44a
|
| 3 |
+
size 4317831736
|
model-00010-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:92eb378130807584c9f8f2a04b29115205f36c21c49344616a3272d81a3adaaf
|
| 3 |
+
size 4388928376
|
model-00011-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:138c4db286faca86fb22d534d68c90922eb6d9a9f3f6544cbd2cba8756b0e419
|
| 3 |
+
size 4369980288
|
model-00012-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1d7460bbc4db7812de1559686a83d22502c7bcff90fb383429bd2ec82a3fd9b5
|
| 3 |
+
size 4317831736
|
model-00013-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cee0b32bd83e9bae3830457deee6ec5f0b78f11bf5bf70f4d9726711cca97ad0
|
| 3 |
+
size 4369980280
|
model-00014-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:471cab510b6c6bc4c5ca8473542548e0f33556e42dc93fc36148c4da04e97d44
|
| 3 |
+
size 4388928384
|
model-00015-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c058151516351df8dd559daa6322842dcfa3358aa8d009dcd7eb5c84acad571e
|
| 3 |
+
size 4317831736
|
model-00016-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e0bb9ecac2ae3f3c1458c15b3ea09bd6c77102e44666f579d4d5f57de3f48bb0
|
| 3 |
+
size 4369980280
|
model-00017-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c66c30f7e9fb6af842634957b92791caacff7a191ead59a9b768da0fc9eb72df
|
| 3 |
+
size 4388928384
|
model-00018-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7902621aa51b9eaa855da55e89f17a265e3b0351691d58a712eb16fea66e2b88
|
| 3 |
+
size 4298883640
|
model-00019-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:92ae180ac81f7e402b7ddee3f6f51399e9dbe7f684e5399f53dd395a74811e16
|
| 3 |
+
size 4388928376
|
model-00020-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a5b1b25ee6d4f5a932a50cd17868b81b157de890d8ab47da33af620e45cb491
|
| 3 |
+
size 4369980288
|
model-00021-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a2cbee091b139efe89c3f4aaade8b0f1840640a792a47c484a3a7612b27f814f
|
| 3 |
+
size 4317831736
|
model-00022-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1489fde66e452704854197f8dc833ce9e5f51275addec07d6293728e3100f8c8
|
| 3 |
+
size 4388928376
|
model-00023-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:285f76218c43fce239d54372e2e57040c0ccf7be7992abe6c6709ef656feb3bc
|
| 3 |
+
size 4369980288
|
model-00024-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:17f9c6fd63a52aab3560d9c0e26dc76457527c049b2e488119a1c406983c40b3
|
| 3 |
+
size 4317831736
|
model-00025-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1e037a341d982cc982d4e21c8f3c2c9a1ec48712c6403f080e44b22875b22ad1
|
| 3 |
+
size 4369980280
|
model-00026-of-00026.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b2542bce70caefe00ef78727dd0adf1c4def65fc3000b342e0401c0abe8c576
|
| 3 |
+
size 2763483928
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_step3p5.py
ADDED
|
@@ -0,0 +1,900 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import Callable, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from transformers.activations import ACT2FN
|
| 22 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 23 |
+
from transformers.generation import GenerationMixin
|
| 24 |
+
from transformers.masking_utils import (create_causal_mask,
|
| 25 |
+
create_sliding_window_causal_mask)
|
| 26 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 27 |
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 28 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
| 29 |
+
from transformers.modeling_rope_utils import (ROPE_INIT_FUNCTIONS,
|
| 30 |
+
dynamic_rope_update)
|
| 31 |
+
from transformers.modeling_utils import (ALL_ATTENTION_FUNCTIONS,
|
| 32 |
+
PreTrainedModel)
|
| 33 |
+
from transformers.processing_utils import Unpack
|
| 34 |
+
from transformers.utils import TransformersKwargs, can_return_tuple, logging
|
| 35 |
+
|
| 36 |
+
from .configuration_step3p5 import Step3p5Config
|
| 37 |
+
|
| 38 |
+
logger = logging.get_logger(__name__)
|
| 39 |
+
|
| 40 |
+
__all__ = ["Step3p5Model", "Step3p5ForCausalLM"]
|
| 41 |
+
|
| 42 |
+
class Step3p5RotaryEmbedding(nn.Module):
|
| 43 |
+
|
| 44 |
+
def __init__(self, config: Step3p5Config, device=None, layer_idx=None):
|
| 45 |
+
super().__init__()
|
| 46 |
+
# BC: "rope_type" was originally "type"
|
| 47 |
+
self.layer_idx = layer_idx
|
| 48 |
+
if config.rope_parameters is not None:
|
| 49 |
+
self.rope_type = config.rope_parameters.get(
|
| 50 |
+
"rope_type", config.rope_parameters.get("type"))
|
| 51 |
+
else:
|
| 52 |
+
self.rope_type = "default"
|
| 53 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 54 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 55 |
+
|
| 56 |
+
partial_rotary_factors = getattr(config, "partial_rotary_factors",
|
| 57 |
+
None)
|
| 58 |
+
if partial_rotary_factors is not None:
|
| 59 |
+
config.partial_rotary_factor = partial_rotary_factors[
|
| 60 |
+
self.layer_idx]
|
| 61 |
+
else:
|
| 62 |
+
config.partial_rotary_factor = 1.0
|
| 63 |
+
|
| 64 |
+
self.rope_theta = config.rope_theta
|
| 65 |
+
if isinstance(config.rope_theta, list):
|
| 66 |
+
self.rope_theta = config.rope_theta.copy()
|
| 67 |
+
config.rope_theta = self.rope_theta[self.layer_idx]
|
| 68 |
+
|
| 69 |
+
self.config = config
|
| 70 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 71 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
| 72 |
+
self.config, device)
|
| 73 |
+
|
| 74 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 75 |
+
self.original_inv_freq = self.inv_freq
|
| 76 |
+
config.rope_theta = self.rope_theta
|
| 77 |
+
|
| 78 |
+
@torch.no_grad()
|
| 79 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 80 |
+
def forward(self, x, position_ids):
|
| 81 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
|
| 82 |
+
position_ids.shape[0], -1, 1).to(x.device)
|
| 83 |
+
position_ids_expanded = position_ids[:, None, :].float().to(x.device)
|
| 84 |
+
|
| 85 |
+
device_type = x.device.type if isinstance(
|
| 86 |
+
x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 87 |
+
with torch.autocast(device_type=device_type,
|
| 88 |
+
enabled=False): # Force float32
|
| 89 |
+
freqs = (inv_freq_expanded.float()
|
| 90 |
+
@ position_ids_expanded.float()).transpose(1, 2)
|
| 91 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 92 |
+
cos = emb.cos() * self.attention_scaling
|
| 93 |
+
sin = emb.sin() * self.attention_scaling
|
| 94 |
+
|
| 95 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def rotate_half(x):
|
| 99 |
+
"""Rotates half the hidden dims of the input."""
|
| 100 |
+
x1 = x[..., :x.shape[-1] // 2]
|
| 101 |
+
x2 = x[..., x.shape[-1] // 2:]
|
| 102 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 106 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
q (`torch.Tensor`): The query tensor.
|
| 110 |
+
k (`torch.Tensor`): The key tensor.
|
| 111 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 112 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 113 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 114 |
+
Deprecated and unused.
|
| 115 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 116 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 117 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 118 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 119 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 120 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 121 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 122 |
+
Returns:
|
| 123 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 124 |
+
"""
|
| 125 |
+
rotary_dim = cos.shape[-1]
|
| 126 |
+
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
|
| 127 |
+
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
|
| 128 |
+
|
| 129 |
+
# Apply rotary embeddings on the first half or full tensor
|
| 130 |
+
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
|
| 131 |
+
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
|
| 132 |
+
|
| 133 |
+
# Concatenate back to full shape
|
| 134 |
+
q_embed = torch.cat([q_embed, q_pass], dim=-1)
|
| 135 |
+
k_embed = torch.cat([k_embed, k_pass], dim=-1)
|
| 136 |
+
return q_embed, k_embed
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 140 |
+
"""
|
| 141 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 142 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 143 |
+
"""
|
| 144 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 145 |
+
if n_rep == 1:
|
| 146 |
+
return hidden_states
|
| 147 |
+
hidden_states = hidden_states[:, :,
|
| 148 |
+
None, :, :].expand(batch,
|
| 149 |
+
num_key_value_heads,
|
| 150 |
+
n_rep, slen, head_dim)
|
| 151 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
|
| 152 |
+
head_dim)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32
|
| 156 |
+
def eager_attention_forward(
|
| 157 |
+
module: nn.Module,
|
| 158 |
+
query: torch.Tensor,
|
| 159 |
+
key: torch.Tensor,
|
| 160 |
+
value: torch.Tensor,
|
| 161 |
+
attention_mask: Optional[torch.Tensor],
|
| 162 |
+
scaling: float,
|
| 163 |
+
dropout: float = 0.0,
|
| 164 |
+
**kwargs,
|
| 165 |
+
):
|
| 166 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 167 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 168 |
+
# breakpoint()
|
| 169 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 170 |
+
if attention_mask is not None:
|
| 171 |
+
causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
|
| 172 |
+
attn_weights = attn_weights + causal_mask
|
| 173 |
+
|
| 174 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 175 |
+
attn_weights = nn.functional.dropout(attn_weights,
|
| 176 |
+
p=dropout,
|
| 177 |
+
training=module.training)
|
| 178 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 179 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 180 |
+
|
| 181 |
+
return attn_output, attn_weights
|
| 182 |
+
|
| 183 |
+
@dataclass
|
| 184 |
+
class Step3p5CausalLMOutputWithPast(ModelOutput):
|
| 185 |
+
r"""
|
| 186 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 187 |
+
Language modeling loss (for next-token prediction).
|
| 188 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 189 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 190 |
+
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 191 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 192 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
| 193 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 194 |
+
`past_key_values` input) to speed up sequential decoding.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
loss: Optional[torch.FloatTensor] = None
|
| 198 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 199 |
+
logits: torch.FloatTensor = None
|
| 200 |
+
past_key_values: Optional[list[torch.FloatTensor]] = None
|
| 201 |
+
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 202 |
+
attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class Step3p5MLP(nn.Module):
|
| 206 |
+
|
| 207 |
+
def __init__(self, config, intermediate_size=None, swiglu_limit=None):
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.config = config
|
| 210 |
+
self.hidden_size = config.hidden_size
|
| 211 |
+
self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
|
| 212 |
+
self.gate_proj = nn.Linear(self.hidden_size,
|
| 213 |
+
self.intermediate_size,
|
| 214 |
+
bias=False)
|
| 215 |
+
self.up_proj = nn.Linear(self.hidden_size,
|
| 216 |
+
self.intermediate_size,
|
| 217 |
+
bias=False)
|
| 218 |
+
self.down_proj = nn.Linear(self.intermediate_size,
|
| 219 |
+
self.hidden_size,
|
| 220 |
+
bias=False)
|
| 221 |
+
self.act_fn = ACT2FN["silu"]
|
| 222 |
+
self.limit = swiglu_limit
|
| 223 |
+
|
| 224 |
+
def forward(self, x):
|
| 225 |
+
up = self.up_proj(x)
|
| 226 |
+
gate = self.act_fn(self.gate_proj(x))
|
| 227 |
+
if self.limit is not None:
|
| 228 |
+
gate = gate.clamp(min=None, max=self.limit)
|
| 229 |
+
up = up.clamp(min=-self.limit, max=self.limit)
|
| 230 |
+
|
| 231 |
+
return self.down_proj(gate * up)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def sigmoid_routing_function(gating_output: torch.Tensor, topk: int,
|
| 235 |
+
renormalize: bool):
|
| 236 |
+
gating_output = gating_output.float()
|
| 237 |
+
gate_prob = torch.sigmoid(gating_output)
|
| 238 |
+
gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
|
| 239 |
+
topk_prob, indices = torch.topk(gate_prob, k=topk, dim=1)
|
| 240 |
+
expert_topk_weight = topk_prob
|
| 241 |
+
if renormalize:
|
| 242 |
+
expert_topk_weight = expert_topk_weight / torch.sum(
|
| 243 |
+
expert_topk_weight, dim=-1, keepdim=True)
|
| 244 |
+
return expert_topk_weight, indices
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def softmax_routing_function(gating_output: torch.Tensor, top_k: int,
|
| 248 |
+
renormalize: bool):
|
| 249 |
+
gating_output = gating_output.float()
|
| 250 |
+
gate_prob = torch.softmax(gating_output, dim=-1)
|
| 251 |
+
gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
|
| 252 |
+
topk_prob, indices = torch.topk(gate_prob, k=top_k, dim=1)
|
| 253 |
+
expert_topk_weight = topk_prob
|
| 254 |
+
if renormalize:
|
| 255 |
+
expert_topk_weight = expert_topk_weight / torch.sum(
|
| 256 |
+
expert_topk_weight, dim=-1, keepdim=True)
|
| 257 |
+
return expert_topk_weight, indices.to(torch.int32)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class MoELinear(nn.Module):
|
| 261 |
+
|
| 262 |
+
def __init__(self, num_experts, in_features, out_features):
|
| 263 |
+
super().__init__()
|
| 264 |
+
self.num_experts = num_experts
|
| 265 |
+
self.in_features = in_features
|
| 266 |
+
self.out_features = out_features
|
| 267 |
+
self.weight = nn.Parameter(
|
| 268 |
+
torch.empty(num_experts, out_features, in_features))
|
| 269 |
+
|
| 270 |
+
def forward(self, x, expert_id):
|
| 271 |
+
x = F.linear(x.float(), self.weight[expert_id].float())
|
| 272 |
+
return x
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class Step3p5MoEMLP(nn.Module):
|
| 276 |
+
|
| 277 |
+
def __init__(self, config, swiglu_limit=None):
|
| 278 |
+
super().__init__()
|
| 279 |
+
self.num_experts = config.moe_num_experts
|
| 280 |
+
self.top_k = config.moe_top_k
|
| 281 |
+
self.hidden_size = config.hidden_size
|
| 282 |
+
self.moe_intermediate_size = config.moe_intermediate_size
|
| 283 |
+
|
| 284 |
+
self.use_moe_router_bias = config.use_moe_router_bias
|
| 285 |
+
if self.use_moe_router_bias:
|
| 286 |
+
self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts,
|
| 287 |
+
dtype=torch.float32),
|
| 288 |
+
requires_grad=False)
|
| 289 |
+
self.custom_routing_function = self.router_bias_func
|
| 290 |
+
elif config.moe_router_activation == "sigmoid":
|
| 291 |
+
self.custom_routing_function = sigmoid_routing_function
|
| 292 |
+
else:
|
| 293 |
+
self.custom_routing_function = None
|
| 294 |
+
self.need_fp32_gate = config.need_fp32_gate
|
| 295 |
+
self.routed_scaling_factor = getattr(config,
|
| 296 |
+
"moe_router_scaling_factor", 1.0)
|
| 297 |
+
|
| 298 |
+
# gating
|
| 299 |
+
self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
|
| 300 |
+
|
| 301 |
+
self.act_fn = ACT2FN["silu"]
|
| 302 |
+
self.limit = swiglu_limit
|
| 303 |
+
|
| 304 |
+
self.up_proj = MoELinear(self.num_experts, self.hidden_size,
|
| 305 |
+
self.moe_intermediate_size)
|
| 306 |
+
self.gate_proj = MoELinear(self.num_experts, self.hidden_size,
|
| 307 |
+
self.moe_intermediate_size)
|
| 308 |
+
self.down_proj = MoELinear(self.num_experts,
|
| 309 |
+
self.moe_intermediate_size,
|
| 310 |
+
self.hidden_size)
|
| 311 |
+
|
| 312 |
+
def router_bias_func(self, gating_output: torch.Tensor, topk: int,
|
| 313 |
+
renormalize: bool):
|
| 314 |
+
gate_prob = torch.sigmoid(gating_output.float())
|
| 315 |
+
gate_prob_with_bias = gate_prob + self.router_bias.unsqueeze(0)
|
| 316 |
+
_, indices = torch.topk(gate_prob_with_bias, k=topk, dim=1)
|
| 317 |
+
topk_prob = torch.gather(gate_prob, 1, indices)
|
| 318 |
+
expert_topk_weight = topk_prob
|
| 319 |
+
if renormalize:
|
| 320 |
+
expert_topk_weight = expert_topk_weight / (
|
| 321 |
+
torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20)
|
| 322 |
+
return expert_topk_weight, indices
|
| 323 |
+
|
| 324 |
+
def get_expert_output(self, inputs: torch.Tensor, expert_id):
|
| 325 |
+
#if self.limit is None:
|
| 326 |
+
up = self.up_proj(inputs, expert_id)
|
| 327 |
+
gate = self.act_fn(self.gate_proj(inputs, expert_id))
|
| 328 |
+
if self.limit is not None:
|
| 329 |
+
gate = gate.clamp(min=None, max=self.limit)
|
| 330 |
+
up = up.clamp(min=-self.limit, max=self.limit)
|
| 331 |
+
|
| 332 |
+
return self.down_proj(gate * up, expert_id)
|
| 333 |
+
|
| 334 |
+
def forward(self, hidden_states):
|
| 335 |
+
""" """
|
| 336 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 337 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 338 |
+
if self.need_fp32_gate:
|
| 339 |
+
router_logits = torch.matmul(hidden_states.to(torch.float32), self.gate.weight.t().to(torch.float32))
|
| 340 |
+
else:
|
| 341 |
+
# router_logits: (batch * sequence_length, n_experts)
|
| 342 |
+
router_logits = self.gate(hidden_states)
|
| 343 |
+
|
| 344 |
+
if self.custom_routing_function:
|
| 345 |
+
routing_weights, selected_experts = self.custom_routing_function(
|
| 346 |
+
router_logits, self.top_k, renormalize=True)
|
| 347 |
+
else:
|
| 348 |
+
routing_weights = F.softmax(router_logits,
|
| 349 |
+
dim=1,
|
| 350 |
+
dtype=torch.float)
|
| 351 |
+
routing_weights, selected_experts = torch.topk(routing_weights,
|
| 352 |
+
self.top_k,
|
| 353 |
+
dim=-1)
|
| 354 |
+
|
| 355 |
+
routing_weights = routing_weights * self.routed_scaling_factor
|
| 356 |
+
|
| 357 |
+
final_hidden_states = torch.zeros(
|
| 358 |
+
(batch_size * sequence_length, hidden_dim),
|
| 359 |
+
dtype=hidden_states.dtype,
|
| 360 |
+
device=hidden_states.device)
|
| 361 |
+
|
| 362 |
+
# One hot encode the selected experts to create an expert mask
|
| 363 |
+
# this will be used to easily index which expert is going to be sollicitated
|
| 364 |
+
expert_mask = torch.nn.functional.one_hot(
|
| 365 |
+
selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
| 366 |
+
|
| 367 |
+
# Loop over all available experts in the model and perform the computation on each expert
|
| 368 |
+
for expert_idx in range(self.num_experts):
|
| 369 |
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
| 370 |
+
|
| 371 |
+
# Index the correct hidden states and compute the expert hidden state for
|
| 372 |
+
# the current expert. We need to make sure to multiply the output hidden
|
| 373 |
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
| 374 |
+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
| 375 |
+
current_hidden_states = (
|
| 376 |
+
self.get_expert_output(current_state, expert_idx) *
|
| 377 |
+
routing_weights[top_x, idx, None])
|
| 378 |
+
|
| 379 |
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
| 380 |
+
# the `top_x` tensor here.
|
| 381 |
+
final_hidden_states.index_add_(
|
| 382 |
+
0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 383 |
+
final_hidden_states = final_hidden_states.reshape(
|
| 384 |
+
batch_size, sequence_length, hidden_dim)
|
| 385 |
+
return final_hidden_states
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class Step3p5RMSNorm(nn.Module):
|
| 389 |
+
|
| 390 |
+
def __init__(
|
| 391 |
+
self,
|
| 392 |
+
hidden_size: int,
|
| 393 |
+
eps: float = 1e-5,
|
| 394 |
+
) -> None:
|
| 395 |
+
super().__init__()
|
| 396 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 397 |
+
self.variance_epsilon = eps
|
| 398 |
+
|
| 399 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 400 |
+
dtype = x.dtype
|
| 401 |
+
x = x.float()
|
| 402 |
+
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
| 403 |
+
normed = x * torch.rsqrt(variance + self.variance_epsilon)
|
| 404 |
+
normed = normed * (self.weight.float() + 1)
|
| 405 |
+
return normed.to(dtype)
|
| 406 |
+
class Step3p5Attention(nn.Module):
|
| 407 |
+
|
| 408 |
+
def __init__(self, config: Step3p5Config, layer_idx):
|
| 409 |
+
super().__init__()
|
| 410 |
+
self.config = config
|
| 411 |
+
self.layer_idx = layer_idx
|
| 412 |
+
self.num_attention_heads = config.num_attention_heads
|
| 413 |
+
self.num_key_value_heads = config.num_attention_groups
|
| 414 |
+
|
| 415 |
+
layer_types = getattr(config, "layer_types", [])
|
| 416 |
+
if layer_types:
|
| 417 |
+
enable_sliding_window = layer_types[
|
| 418 |
+
self.layer_idx] == "sliding_attention"
|
| 419 |
+
else:
|
| 420 |
+
enable_sliding_window = self.layer_idx % 2 == 0
|
| 421 |
+
|
| 422 |
+
if hasattr(config, "yarn_only_types") and layer_types[
|
| 423 |
+
self.layer_idx] not in config.yarn_only_types:
|
| 424 |
+
config.rope_parameters = None
|
| 425 |
+
else:
|
| 426 |
+
config.rope_parameters = getattr(config, "rope_scaling", None)
|
| 427 |
+
|
| 428 |
+
self.sliding_window = config.sliding_window
|
| 429 |
+
if enable_sliding_window:
|
| 430 |
+
self.num_attention_heads = config.attention_other_setting[
|
| 431 |
+
"num_attention_heads"]
|
| 432 |
+
self.num_key_value_heads = config.attention_other_setting[
|
| 433 |
+
"num_attention_groups"]
|
| 434 |
+
|
| 435 |
+
if self.sliding_window is not None and enable_sliding_window:
|
| 436 |
+
self.sliding_window = (self.sliding_window)
|
| 437 |
+
else:
|
| 438 |
+
self.sliding_window = None
|
| 439 |
+
self.head_dim = getattr(config, "head_dim",
|
| 440 |
+
config.hidden_size // self.num_attention_heads)
|
| 441 |
+
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
|
| 442 |
+
|
| 443 |
+
self.rotary_emb = Step3p5RotaryEmbedding(config, layer_idx=layer_idx)
|
| 444 |
+
|
| 445 |
+
self.q_size = self.num_attention_heads * self.head_dim
|
| 446 |
+
self.kv_size = self.num_key_value_heads * self.head_dim
|
| 447 |
+
self.scaling = self.head_dim**-0.5
|
| 448 |
+
|
| 449 |
+
self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=False)
|
| 450 |
+
self.k_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
|
| 451 |
+
self.v_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
|
| 452 |
+
self.o_proj = nn.Linear(self.q_size, config.hidden_size, bias=False)
|
| 453 |
+
self.q_norm = Step3p5RMSNorm(self.head_dim,
|
| 454 |
+
eps=config.rms_norm_eps)
|
| 455 |
+
self.k_norm = Step3p5RMSNorm(self.head_dim,
|
| 456 |
+
eps=config.rms_norm_eps)
|
| 457 |
+
|
| 458 |
+
self.use_head_wise_attn_gate = config.use_head_wise_attn_gate
|
| 459 |
+
if self.use_head_wise_attn_gate:
|
| 460 |
+
self.g_proj = nn.Linear(config.hidden_size,
|
| 461 |
+
self.num_attention_heads,
|
| 462 |
+
bias=False)
|
| 463 |
+
|
| 464 |
+
self.use_rope = True
|
| 465 |
+
use_rope_layers = getattr(config, "use_rope_layers", None)
|
| 466 |
+
if use_rope_layers:
|
| 467 |
+
self.use_rope = use_rope_layers[self.layer_idx]
|
| 468 |
+
|
| 469 |
+
def forward(
|
| 470 |
+
self,
|
| 471 |
+
hidden_states: torch.Tensor,
|
| 472 |
+
attention_mask: Optional[torch.Tensor],
|
| 473 |
+
past_key_value: Optional[Cache] = None,
|
| 474 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 475 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 476 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 477 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
|
| 478 |
+
Optional[Tuple[torch.Tensor]]]:
|
| 479 |
+
input_shape = hidden_states.shape[:-1]
|
| 480 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 481 |
+
|
| 482 |
+
query_states = self.q_norm(
|
| 483 |
+
self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 484 |
+
key_states = self.k_norm(
|
| 485 |
+
self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 486 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(
|
| 487 |
+
1, 2)
|
| 488 |
+
if self.use_head_wise_attn_gate:
|
| 489 |
+
gate_states = self.g_proj(hidden_states)
|
| 490 |
+
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
| 491 |
+
|
| 492 |
+
# cos, sin = position_embeddings
|
| 493 |
+
query_states, key_states = apply_rotary_pos_emb(
|
| 494 |
+
query_states, key_states, cos, sin)
|
| 495 |
+
|
| 496 |
+
# query_states, key_states = apply_rotary_pos_emb(query_norm_states, key_norm_states, cos, sin)
|
| 497 |
+
if past_key_value is not None:
|
| 498 |
+
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
| 499 |
+
cache_kwargs = {
|
| 500 |
+
"sin": sin,
|
| 501 |
+
"cos": cos,
|
| 502 |
+
"cache_position": cache_position
|
| 503 |
+
}
|
| 504 |
+
key_states, value_states = past_key_value.update(
|
| 505 |
+
key_states, value_states, self.layer_idx, cache_kwargs)
|
| 506 |
+
|
| 507 |
+
attention_interface: Callable = eager_attention_forward
|
| 508 |
+
# TODO: considering FP8;
|
| 509 |
+
# RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype,
|
| 510 |
+
# but got attn_mask.dtype: long int and query.dtype: c10::BFloat16 instead.
|
| 511 |
+
if self.config._attn_implementation != "eager":
|
| 512 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[
|
| 513 |
+
self.config._attn_implementation]
|
| 514 |
+
|
| 515 |
+
attn_output, attn_weights = attention_interface(
|
| 516 |
+
self,
|
| 517 |
+
query_states,
|
| 518 |
+
key_states,
|
| 519 |
+
value_states,
|
| 520 |
+
attention_mask,
|
| 521 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 522 |
+
scaling=self.scaling,
|
| 523 |
+
sliding_window=self.sliding_window, # main diff with Llama
|
| 524 |
+
**kwargs,
|
| 525 |
+
)
|
| 526 |
+
attn_output = attn_output.reshape(*input_shape, -1)
|
| 527 |
+
if self.use_head_wise_attn_gate:
|
| 528 |
+
output = attn_output.view(
|
| 529 |
+
*attn_output.shape[:-1], self.num_attention_heads,
|
| 530 |
+
self.head_dim) * gate_states.unsqueeze(-1).sigmoid()
|
| 531 |
+
attn_output = output.view(*attn_output.shape)
|
| 532 |
+
attn_output = self.o_proj(attn_output)
|
| 533 |
+
|
| 534 |
+
return attn_output, attn_weights
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
class Step3p5DecoderLayer(GradientCheckpointingLayer):
|
| 538 |
+
|
| 539 |
+
def __init__(self, config, layer_idx):
|
| 540 |
+
super().__init__()
|
| 541 |
+
self.hidden_size = config.hidden_size
|
| 542 |
+
self.layer_idx = layer_idx
|
| 543 |
+
self.self_attn = Step3p5Attention(config, layer_idx)
|
| 544 |
+
self.attention_type = config.layer_types[layer_idx]
|
| 545 |
+
|
| 546 |
+
moe_layers_enum = getattr(config, "moe_layers_enum", None)
|
| 547 |
+
if moe_layers_enum is not None:
|
| 548 |
+
moe_layers_idx = [
|
| 549 |
+
int(i) for i in moe_layers_enum.strip().split(',')
|
| 550 |
+
]
|
| 551 |
+
else:
|
| 552 |
+
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
|
| 553 |
+
self.is_moe_layer = layer_idx in moe_layers_idx
|
| 554 |
+
self.use_moe = False
|
| 555 |
+
|
| 556 |
+
if config.swiglu_limits_shared and config.swiglu_limits_shared[
|
| 557 |
+
layer_idx] is not None and config.swiglu_limits_shared[
|
| 558 |
+
layer_idx] != 0:
|
| 559 |
+
swiglu_limit_shared = config.swiglu_limits_shared[layer_idx]
|
| 560 |
+
else:
|
| 561 |
+
swiglu_limit_shared = None
|
| 562 |
+
if config.swiglu_limits and config.swiglu_limits[
|
| 563 |
+
layer_idx] is not None and config.swiglu_limits[layer_idx] != 0:
|
| 564 |
+
swiglu_limit = config.swiglu_limits[layer_idx]
|
| 565 |
+
else:
|
| 566 |
+
swiglu_limit = None
|
| 567 |
+
if self.is_moe_layer:
|
| 568 |
+
self.moe = Step3p5MoEMLP(config, swiglu_limit=swiglu_limit) #
|
| 569 |
+
self.share_expert = Step3p5MLP(
|
| 570 |
+
config,
|
| 571 |
+
intermediate_size=config.share_expert_dim,
|
| 572 |
+
swiglu_limit=swiglu_limit_shared)
|
| 573 |
+
self.use_moe = True
|
| 574 |
+
else:
|
| 575 |
+
self.mlp = Step3p5MLP(config,
|
| 576 |
+
intermediate_size=config.intermediate_size,
|
| 577 |
+
swiglu_limit=swiglu_limit_shared)
|
| 578 |
+
|
| 579 |
+
self.input_layernorm = Step3p5RMSNorm(
|
| 580 |
+
config.hidden_size,
|
| 581 |
+
eps=config.rms_norm_eps)
|
| 582 |
+
self.post_attention_layernorm = Step3p5RMSNorm(
|
| 583 |
+
config.hidden_size,
|
| 584 |
+
eps=config.rms_norm_eps)
|
| 585 |
+
|
| 586 |
+
def forward(
|
| 587 |
+
self,
|
| 588 |
+
hidden_states: torch.Tensor,
|
| 589 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 590 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 591 |
+
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
| 592 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 593 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 594 |
+
) -> torch.FloatTensor:
|
| 595 |
+
residual = hidden_states
|
| 596 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 597 |
+
hidden_states, _ = self.self_attn(
|
| 598 |
+
hidden_states=hidden_states,
|
| 599 |
+
attention_mask=attention_mask,
|
| 600 |
+
position_ids=position_ids,
|
| 601 |
+
past_key_value=past_key_value,
|
| 602 |
+
cache_position=cache_position,
|
| 603 |
+
**kwargs,
|
| 604 |
+
)
|
| 605 |
+
hidden_states = residual + hidden_states
|
| 606 |
+
|
| 607 |
+
# Fully Connected
|
| 608 |
+
residual = hidden_states
|
| 609 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 610 |
+
if self.use_moe:
|
| 611 |
+
share_output = self.share_expert(hidden_states)
|
| 612 |
+
moe_output = self.moe(hidden_states)
|
| 613 |
+
ffn_output = moe_output + share_output
|
| 614 |
+
else:
|
| 615 |
+
ffn_output = self.mlp(hidden_states)
|
| 616 |
+
if isinstance(ffn_output, tuple):
|
| 617 |
+
hidden_states, _ = ffn_output
|
| 618 |
+
else:
|
| 619 |
+
hidden_states = ffn_output
|
| 620 |
+
|
| 621 |
+
hidden_states = residual + hidden_states
|
| 622 |
+
return hidden_states
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
class Step3p5PreTrainedModel(PreTrainedModel):
|
| 626 |
+
# Link this model family to its configuration class so PreTrainedModel.from_pretrained
|
| 627 |
+
# can load the config instead of failing with a NoneType error.
|
| 628 |
+
config_class = Step3p5Config
|
| 629 |
+
supports_gradient_checkpointing = True
|
| 630 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 631 |
+
_keys_to_ignore_on_load_unexpected = [
|
| 632 |
+
r"model\.layers\.45\.*",
|
| 633 |
+
r"model\.layers\.46\.*",
|
| 634 |
+
r"model\.layers\.47\.*"
|
| 635 |
+
]
|
| 636 |
+
_supports_flash_attn = False
|
| 637 |
+
_supports_sdpa = True
|
| 638 |
+
_supports_flex_attn = True
|
| 639 |
+
_supports_static_cache = True
|
| 640 |
+
_supports_attention_backend = True
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
class Step3p5Model(Step3p5PreTrainedModel, GenerationMixin):
|
| 644 |
+
_no_split_modules = ["Step3p5DecoderLayer"]
|
| 645 |
+
base_model_prefix = "model"
|
| 646 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 647 |
+
config: Step3p5Config
|
| 648 |
+
def __init__(self, config: Step3p5Config):
|
| 649 |
+
super().__init__(config)
|
| 650 |
+
self.padding_idx = config.pad_token_id
|
| 651 |
+
self.vocab_size = config.vocab_size
|
| 652 |
+
|
| 653 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
|
| 654 |
+
self.padding_idx)
|
| 655 |
+
self.layers = nn.ModuleList([
|
| 656 |
+
Step3p5DecoderLayer(config, layer_idx)
|
| 657 |
+
for layer_idx in range(config.num_hidden_layers)
|
| 658 |
+
])
|
| 659 |
+
self.norm = Step3p5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 660 |
+
self.gradient_checkpointing = False
|
| 661 |
+
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
|
| 662 |
+
|
| 663 |
+
# Initialize weights and apply final processing
|
| 664 |
+
self.post_init()
|
| 665 |
+
|
| 666 |
+
def get_input_embeddings(self):
|
| 667 |
+
return self.embed_tokens
|
| 668 |
+
|
| 669 |
+
@can_return_tuple
|
| 670 |
+
def forward(
|
| 671 |
+
self,
|
| 672 |
+
input_ids: torch.LongTensor = None,
|
| 673 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 674 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 675 |
+
past_key_values: Optional[Cache] = None,
|
| 676 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 677 |
+
use_cache: Optional[bool] = None,
|
| 678 |
+
output_attentions: Optional[bool] = None,
|
| 679 |
+
output_hidden_states: Optional[bool] = None,
|
| 680 |
+
return_dict: Optional[bool] = None,
|
| 681 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 682 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 683 |
+
) -> Union[tuple, BaseModelOutputWithPast]:
|
| 684 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 685 |
+
output_hidden_states = (output_hidden_states
|
| 686 |
+
if output_hidden_states is not None else
|
| 687 |
+
self.config.output_hidden_states)
|
| 688 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 689 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 690 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 691 |
+
raise ValueError(
|
| 692 |
+
"You must specify exactly one of input_ids or inputs_embeds")
|
| 693 |
+
|
| 694 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 695 |
+
logger.warning_once(
|
| 696 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
| 697 |
+
)
|
| 698 |
+
use_cache = False
|
| 699 |
+
|
| 700 |
+
if inputs_embeds is None:
|
| 701 |
+
inputs_embeds = self.embed_tokens(
|
| 702 |
+
input_ids.to(self.embed_tokens.weight.device))
|
| 703 |
+
|
| 704 |
+
if use_cache and past_key_values is None:
|
| 705 |
+
past_key_values = DynamicCache()
|
| 706 |
+
|
| 707 |
+
if cache_position is None:
|
| 708 |
+
past_seen_tokens = past_key_values.get_seq_length(
|
| 709 |
+
) if past_key_values is not None else 0
|
| 710 |
+
cache_position = torch.arange(past_seen_tokens,
|
| 711 |
+
past_seen_tokens +
|
| 712 |
+
inputs_embeds.shape[1],
|
| 713 |
+
device=inputs_embeds.device)
|
| 714 |
+
|
| 715 |
+
if position_ids is None:
|
| 716 |
+
position_ids = cache_position.unsqueeze(0)
|
| 717 |
+
|
| 718 |
+
hidden_states = inputs_embeds
|
| 719 |
+
|
| 720 |
+
# It may already have been prepared by e.g. `generate`
|
| 721 |
+
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 722 |
+
# Prepare mask arguments
|
| 723 |
+
mask_kwargs = {
|
| 724 |
+
"config": self.config,
|
| 725 |
+
"input_embeds": inputs_embeds,
|
| 726 |
+
"attention_mask": attention_mask,
|
| 727 |
+
"cache_position": cache_position,
|
| 728 |
+
"past_key_values": past_key_values,
|
| 729 |
+
"position_ids": position_ids,
|
| 730 |
+
}
|
| 731 |
+
# Create the masks
|
| 732 |
+
causal_mask_mapping = {
|
| 733 |
+
"full_attention": create_causal_mask(**mask_kwargs),
|
| 734 |
+
}
|
| 735 |
+
|
| 736 |
+
# The sliding window alternating layers are not always activated depending on the config
|
| 737 |
+
if self.has_sliding_layers:
|
| 738 |
+
causal_mask_mapping[
|
| 739 |
+
"sliding_attention"] = create_sliding_window_causal_mask(
|
| 740 |
+
**mask_kwargs)
|
| 741 |
+
|
| 742 |
+
# # create position embeddings to be shared across the decoder layers
|
| 743 |
+
# decoder layers
|
| 744 |
+
all_hidden_states = () if output_hidden_states else None
|
| 745 |
+
all_self_attns = () if output_attentions else None
|
| 746 |
+
for decoder_layer in self.layers[:self.config.num_hidden_layers]:
|
| 747 |
+
if output_hidden_states:
|
| 748 |
+
all_hidden_states += (hidden_states, )
|
| 749 |
+
|
| 750 |
+
layer_outputs = decoder_layer(
|
| 751 |
+
hidden_states,
|
| 752 |
+
attention_mask=causal_mask_mapping[
|
| 753 |
+
decoder_layer.attention_type],
|
| 754 |
+
position_ids=position_ids,
|
| 755 |
+
past_key_value=past_key_values,
|
| 756 |
+
output_attentions=output_attentions,
|
| 757 |
+
use_cache=use_cache,
|
| 758 |
+
cache_position=cache_position,
|
| 759 |
+
**kwargs,
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
hidden_states = layer_outputs
|
| 763 |
+
|
| 764 |
+
hidden_states = self.norm(hidden_states)
|
| 765 |
+
|
| 766 |
+
return BaseModelOutputWithPast(
|
| 767 |
+
last_hidden_state=hidden_states,
|
| 768 |
+
past_key_values=past_key_values if use_cache else None,
|
| 769 |
+
hidden_states=all_hidden_states,
|
| 770 |
+
attentions=all_self_attns,
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
class Step3p5ForCausalLM(Step3p5PreTrainedModel, GenerationMixin):
|
| 775 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 776 |
+
config: Step3p5Config
|
| 777 |
+
|
| 778 |
+
def __init__(self, config: Step3p5Config):
|
| 779 |
+
super().__init__(config)
|
| 780 |
+
self.model = Step3p5Model(config)
|
| 781 |
+
self.lm_head = nn.Linear(config.hidden_size,
|
| 782 |
+
config.vocab_size,
|
| 783 |
+
bias=False)
|
| 784 |
+
|
| 785 |
+
self.post_init()
|
| 786 |
+
|
| 787 |
+
def get_input_embeddings(self):
|
| 788 |
+
return self.model.get_input_embeddings()
|
| 789 |
+
|
| 790 |
+
def set_input_embeddings(self, value):
|
| 791 |
+
self.model.set_input_embeddings(value)
|
| 792 |
+
|
| 793 |
+
def get_output_embeddings(self):
|
| 794 |
+
return self.model.get_output_embeddings()
|
| 795 |
+
|
| 796 |
+
def set_output_embeddings(self, new_embeddings):
|
| 797 |
+
self.model.set_output_embeddings(new_embeddings)
|
| 798 |
+
|
| 799 |
+
def set_decoder(self, decoder):
|
| 800 |
+
self.model.set_decoder(decoder)
|
| 801 |
+
|
| 802 |
+
def get_decoder(self):
|
| 803 |
+
return self.model.get_decoder()
|
| 804 |
+
|
| 805 |
+
def forward(
|
| 806 |
+
self,
|
| 807 |
+
input_ids: torch.LongTensor = None,
|
| 808 |
+
num_patches=None,
|
| 809 |
+
patch_pixel_values=None,
|
| 810 |
+
patch_newline_mask=None,
|
| 811 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 812 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 813 |
+
past_key_values: Optional[Cache] = None,
|
| 814 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 815 |
+
labels: Optional[torch.LongTensor] = None,
|
| 816 |
+
use_cache: Optional[bool] = None,
|
| 817 |
+
output_attentions: Optional[bool] = None,
|
| 818 |
+
output_hidden_states: Optional[bool] = None,
|
| 819 |
+
return_dict: Optional[bool] = None,
|
| 820 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 821 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 822 |
+
) -> Union[tuple, Step3p5CausalLMOutputWithPast]:
|
| 823 |
+
r"""
|
| 824 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 825 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 826 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 827 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 828 |
+
Example:
|
| 829 |
+
```python
|
| 830 |
+
>>> from transformers import AutoTokenizer, Llama4ForCausalLM
|
| 831 |
+
>>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
| 832 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
| 833 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 834 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 835 |
+
>>> # Generate
|
| 836 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 837 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 838 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 839 |
+
```"""
|
| 840 |
+
|
| 841 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 842 |
+
output_hidden_states = (output_hidden_states
|
| 843 |
+
if output_hidden_states is not None else
|
| 844 |
+
self.config.output_hidden_states)
|
| 845 |
+
# breakpoint()
|
| 846 |
+
outputs = self.model(
|
| 847 |
+
input_ids=input_ids,
|
| 848 |
+
num_patches=num_patches,
|
| 849 |
+
patch_pixel_values=patch_pixel_values,
|
| 850 |
+
patch_newline_mask=patch_newline_mask,
|
| 851 |
+
position_ids=position_ids,
|
| 852 |
+
attention_mask=attention_mask,
|
| 853 |
+
past_key_values=past_key_values,
|
| 854 |
+
inputs_embeds=inputs_embeds,
|
| 855 |
+
use_cache=use_cache,
|
| 856 |
+
output_attentions=output_attentions,
|
| 857 |
+
output_hidden_states=output_hidden_states,
|
| 858 |
+
return_dict=return_dict,
|
| 859 |
+
cache_position=cache_position,
|
| 860 |
+
**kwargs,
|
| 861 |
+
)
|
| 862 |
+
hidden_states = outputs.last_hidden_state
|
| 863 |
+
logits = self.lm_head(hidden_states)
|
| 864 |
+
|
| 865 |
+
return Step3p5CausalLMOutputWithPast(logits=logits, )
|
| 866 |
+
|
| 867 |
+
def prepare_inputs_for_generation(
|
| 868 |
+
self,
|
| 869 |
+
input_ids,
|
| 870 |
+
past_key_values=None,
|
| 871 |
+
inputs_embeds=None,
|
| 872 |
+
pixel_values=None,
|
| 873 |
+
attention_mask=None,
|
| 874 |
+
cache_position=None,
|
| 875 |
+
logits_to_keep=None,
|
| 876 |
+
**kwargs,
|
| 877 |
+
):
|
| 878 |
+
|
| 879 |
+
model_inputs = super().prepare_inputs_for_generation(
|
| 880 |
+
input_ids,
|
| 881 |
+
past_key_values=past_key_values,
|
| 882 |
+
inputs_embeds=inputs_embeds,
|
| 883 |
+
attention_mask=attention_mask,
|
| 884 |
+
cache_position=cache_position,
|
| 885 |
+
logits_to_keep=logits_to_keep,
|
| 886 |
+
**kwargs,
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
if cache_position[0] == 0:
|
| 890 |
+
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
| 891 |
+
# Otherwise we need pixel values to be passed to model
|
| 892 |
+
model_inputs["pixel_values"] = pixel_values
|
| 893 |
+
|
| 894 |
+
return model_inputs
|
| 895 |
+
|
| 896 |
+
def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]:
|
| 897 |
+
if key.startswith("language_model."):
|
| 898 |
+
return key[len("language_model."):], True
|
| 899 |
+
|
| 900 |
+
return key, False
|
recipe.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
default_stage:
|
| 2 |
+
default_modifiers:
|
| 3 |
+
QuantizationModifier:
|
| 4 |
+
targets: [Linear, MoELinear]
|
| 5 |
+
ignore: [lm_head, 're:visual.*', 're:.*vision_tower.*', 're:.*video_tower.*', 're:.*audio_tower.*',
|
| 6 |
+
're:.*multi_modal_projector.*']
|
| 7 |
+
scheme: NVFP4
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<|begin▁of▁sentence|>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "<|im_end|>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "<|end▁of▁sentence|>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
}
|
| 23 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|