lazarevich commited on
Commit
81c98a5
·
0 Parent(s):

upload checkpoint

Browse files
Files changed (44) hide show
  1. README.md +125 -0
  2. chat_template.jinja +159 -0
  3. config.json +113 -0
  4. configuration_minimax_m2.py +200 -0
  5. merges.txt +0 -0
  6. model-00001-of-00033.safetensors +3 -0
  7. model-00002-of-00033.safetensors +3 -0
  8. model-00003-of-00033.safetensors +3 -0
  9. model-00004-of-00033.safetensors +3 -0
  10. model-00005-of-00033.safetensors +3 -0
  11. model-00006-of-00033.safetensors +3 -0
  12. model-00007-of-00033.safetensors +3 -0
  13. model-00008-of-00033.safetensors +3 -0
  14. model-00009-of-00033.safetensors +3 -0
  15. model-00010-of-00033.safetensors +3 -0
  16. model-00011-of-00033.safetensors +3 -0
  17. model-00012-of-00033.safetensors +3 -0
  18. model-00013-of-00033.safetensors +3 -0
  19. model-00014-of-00033.safetensors +3 -0
  20. model-00015-of-00033.safetensors +3 -0
  21. model-00016-of-00033.safetensors +3 -0
  22. model-00017-of-00033.safetensors +3 -0
  23. model-00018-of-00033.safetensors +3 -0
  24. model-00019-of-00033.safetensors +3 -0
  25. model-00020-of-00033.safetensors +3 -0
  26. model-00021-of-00033.safetensors +3 -0
  27. model-00022-of-00033.safetensors +3 -0
  28. model-00023-of-00033.safetensors +3 -0
  29. model-00024-of-00033.safetensors +3 -0
  30. model-00025-of-00033.safetensors +3 -0
  31. model-00026-of-00033.safetensors +3 -0
  32. model-00027-of-00033.safetensors +3 -0
  33. model-00028-of-00033.safetensors +3 -0
  34. model-00029-of-00033.safetensors +3 -0
  35. model-00030-of-00033.safetensors +3 -0
  36. model-00031-of-00033.safetensors +3 -0
  37. model-00032-of-00033.safetensors +3 -0
  38. model-00033-of-00033.safetensors +3 -0
  39. model.safetensors.index.json +0 -0
  40. modeling_minimax_m2.py +697 -0
  41. special_tokens_map.json +75 -0
  42. tokenizer.json +3 -0
  43. tokenizer_config.json +496 -0
  44. vocab.json +0 -0
README.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ library_name: transformers
5
+ tags:
6
+ - minimax
7
+ - MOE
8
+ - pruning
9
+ - compression
10
+ license: other
11
+ name: cerebras/MiniMax-M2.5-REAP-172B-A10B
12
+ description: >
13
+ This model was obtained by uniformly pruning 25% of experts in MiniMax-M2.5 using the REAP method.
14
+ readme: >
15
+ https://huggingface.co/cerebras/MiniMax-M2.5-REAP-172B-A10B/main/README.md
16
+ pipeline_tag: text-generation
17
+ base_model:
18
+ - MiniMaxAI/MiniMax-M2.5
19
+ ---
20
+
21
+ <p align="center">
22
+ <em>𓌳 <strong>REAP</strong>𓌳 the Experts: Why Pruning Prevails for One-Shot MoE Compression</em><br>
23
+ <img src="https://i.imgur.com/rmzG3gg.png" alt="REAP" width="75%">
24
+ </p>
25
+
26
+ # MiniMax-M2.5-REAP-172B-A10B
27
+
28
+ ## ✨ Highlights
29
+
30
+ Introducing **MiniMax-M2.5-REAP-172B-A10B**, a **memory-efficient compressed variant** of MiniMax-M2.5 that maintains near-identical performance while being **25% lighter**.
31
+
32
+ This model was created using **REAP (Router-weighted Expert Activation Pruning)**, a novel expert pruning method that selectively removes redundant experts while preserving the router's independent control over remaining experts. Key features include:
33
+
34
+ - **Near-Lossless Performance**: Maintains almost identical accuracy on code generation, agentic coding, and function calling tasks compared to the full 230B model
35
+ - **25% Memory Reduction**: Compressed from 230B to 172B parameters, significantly lowering deployment costs and memory requirements
36
+ - **Preserved Capabilities**: Retains all core functionalities including code generation, math & reasoning and tool calling.
37
+ - **Drop-in Compatibility**: Works with vanilla vLLM - no source modifications or custom patches required
38
+ - **Optimized for Real-World Use**: Particularly effective for resource-constrained environments, local deployments, and academic research
39
+ ---
40
+ ## 📋 Model Overview
41
+
42
+ **MiniMax-M2.5-REAP-172B-A10B** has the following specifications:
43
+
44
+ - **Base Model**: MiniMax-M2.5
45
+ - **Compression Method**: REAP (Router-weighted Expert Activation Pruning)
46
+ - **Compression Ratio**: 25% expert pruning
47
+ - **Type**: Sparse Mixture-of-Experts (SMoE) Causal Language Model
48
+ - **Number of Parameters**: 172B total, 10B activated per token
49
+ - **Number of Layers**: 62
50
+ - **Number of Attention Heads**: 48
51
+ - **Number of Experts**: 192 (uniformly pruned from 256)
52
+ - **Number of Activated Experts**: 8 per token
53
+ - **Context Length**: 196,608 tokens
54
+ - **License**: Modified MIT
55
+
56
+ ---
57
+
58
+ ## 📊 Evaluations
59
+
60
+ TBD
61
+
62
+ ---
63
+
64
+ ## 🚀 Deployment
65
+
66
+ You can deploy the model directly using the **latest vLLM** (that supports MiniMax-M2.5), no source modifications or custom patches required.
67
+
68
+ ```bash
69
+ vllm serve cerebras/MiniMax-M2.5-REAP-172B-A10B \
70
+ --tensor-parallel-size 8 \
71
+ --tool-call-parser minimax_m2 \
72
+ --reasoning-parser minimax_m2_append_think \
73
+ --trust-remote-code \
74
+ --enable_expert_parallel \
75
+ --enable-auto-tool-choice
76
+ ```
77
+
78
+ If you encounter insufficient memory when running this model, you might need to set a lower value for `--max-num-seqs` flag (e.g. set to 64). For more information, refer to the [official vLLM deployment guide](https://huggingface.co/MiniMaxAI/MiniMax-M2.5/blob/main/docs/vllm_deploy_guide.md).
79
+
80
+ ## 🧩 Model Creation
81
+
82
+ This checkpoint was created by applying the **REAP (Router-weighted Expert Activation Pruning)** method uniformly across all Mixture-of-Experts (MoE) blocks of **MiniMax-M2.5**, with a **25% pruning rate**.
83
+
84
+ ### How REAP Works
85
+
86
+ REAP selects experts to prune based on a novel **saliency criterion** that considers both:
87
+ - **Router gate values**: How frequently and strongly the router activates each expert
88
+ - **Expert activation norms**: The magnitude of each expert's output contributions
89
+
90
+ This dual consideration ensures that experts contributing minimally to the layer's output are pruned, while preserving those that play critical roles in the model's computations.
91
+
92
+ ### Key Advantages
93
+
94
+ - **One-Shot Compression**: No fine-tuning required after pruning - the model is immediately ready for deployment
95
+ - **Preserved Router Control**: Unlike expert merging methods, REAP maintains the router's independent, input-dependent control over remaining experts, avoiding "functional subspace collapse"
96
+ - **Generative Task Superiority**: REAP significantly outperforms expert merging approaches on generative benchmarks (code generation, creative writing, mathematical reasoning) while maintaining competitive performance on discriminative tasks
97
+
98
+ 📚 For more details, refer to the following resources:
99
+
100
+ - [🧾 arXiv Preprint](https://arxiv.org/abs/2510.13999)
101
+ - [🧾 REAP Blog](https://www.cerebras.ai/blog/reap)
102
+ - [💻 REAP Codebase (GitHub)](https://github.com/CerebrasResearch/reap)
103
+
104
+ ---
105
+
106
+ ## ⚖️ License
107
+
108
+ This model is derived from
109
+ **[`MiniMaxAI/MiniMax-M2.5`](https://huggingface.co/MiniMaxAI/MiniMax-M2.5)**
110
+ and distributed under the **modified MIT license**.
111
+
112
+ ---
113
+
114
+ ## 🧾 Citation
115
+
116
+ If you use this checkpoint, please cite the REAP paper:
117
+
118
+ ```bibtex
119
+ @article{lasby-reap,
120
+ title={REAP the Experts: Why Pruning Prevails for One-Shot MoE compression},
121
+ author={Lasby, Mike and Lazarevich, Ivan and Sinnadurai, Nish and Lie, Sean and Ioannou, Yani and Thangarasa, Vithursan},
122
+ journal={arXiv preprint arXiv:2510.13999},
123
+ year={2025}
124
+ }
125
+ ```
chat_template.jinja ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {# ----------‑‑‑ special token variables ‑‑‑---------- #}
2
+ {%- set toolcall_begin_token = '<minimax:tool_call>' -%}
3
+ {%- set toolcall_end_token = '</minimax:tool_call>' -%}
4
+ {#- Tool Rendering Functions ============================================== -#}
5
+ {%- macro render_tool_namespace(namespace_name, tool_list) -%}
6
+ {%- for tool in tool_list -%}
7
+ <tool>{{ tool.function | tojson(ensure_ascii=False) }}</tool>
8
+ {% endfor -%}
9
+ {%- endmacro -%}
10
+ {%- macro visible_text(content) -%}
11
+ {%- if content is string -%}
12
+ {{ content }}
13
+ {%- elif content is iterable and content is not mapping -%}
14
+ {%- for item in content -%}
15
+ {%- if item is mapping and item.type == 'text' -%}
16
+ {{- item.text }}
17
+ {%- elif item is string -%}
18
+ {{- item }}
19
+ {%- endif -%}
20
+ {%- endfor -%}
21
+ {%- else -%}
22
+ {{- content }}
23
+ {%- endif -%}
24
+ {%- endmacro -%}
25
+ {#- System Message Construction ============================================ -#}
26
+ {%- macro build_system_message(system_message) -%}
27
+ {%- if system_message and system_message.content -%}
28
+ {{- visible_text(system_message.content) }}
29
+ {%- else -%}
30
+ {%- if model_identity is not defined -%}
31
+ {%- set model_identity = "You are a helpful assistant. Your name is MiniMax-M2.5 and is built by MiniMax." -%}
32
+ {%- endif -%}
33
+ {{- model_identity }}
34
+ {%- endif -%}
35
+
36
+ {#- Handle current_date -#}
37
+ {%- if system_message and system_message.current_date -%}
38
+ {{- '\n' ~ 'Current date: ' + system_message.current_date }}
39
+ {%- endif -%}
40
+ {#- Handle current_location -#}
41
+ {%- if system_message and system_message.current_location -%}
42
+ {{- '\n' ~ 'Current location: ' + system_message.current_location }}
43
+ {%- endif -%}
44
+ {%- endmacro -%}
45
+ {#- Main Template Logic ================================================= -#}
46
+ {#- Extract system message (only first message if it's system) -#}
47
+ {%- set system_message = none -%}
48
+ {%- set conversation_messages = messages -%}
49
+ {%- if messages and messages[0].role == "system" -%}
50
+ {%- set system_message = messages[0] -%}
51
+ {%- set conversation_messages = messages[1:] -%}
52
+ {%- endif -%}
53
+ {#- Get the last user message turn, for interleved thinking -#}
54
+ {%- set ns = namespace(last_user_index=-1) %}
55
+ {% for m in conversation_messages %}
56
+ {%- if m.role == 'user' %}
57
+ {% set ns.last_user_index = loop.index0 -%}
58
+ {%- endif %}
59
+ {%- endfor %}
60
+ {#- Render system message -#}
61
+ {{- ']~!b[' ~ ']~b]system' ~ '\n' }}
62
+ {{- build_system_message(system_message) }}
63
+ {#- Render tools if available -#}
64
+ {%- if tools -%}
65
+ {{- '\n\n' ~ '# Tools' ~ '\n' ~ 'You may call one or more tools to assist with the user query.\nHere are the tools available in JSONSchema format:' ~ '\n' }}
66
+ {{- '\n' ~ '<tools>' ~ '\n' }}
67
+ {{- render_tool_namespace("functions", tools) }}
68
+ {{- '</tools>' ~ '\n\n' }}
69
+ {{- 'When making tool calls, use XML format to invoke tools and pass parameters:' ~ '\n' }}
70
+ {{- '\n' ~ toolcall_begin_token }}
71
+ <invoke name="tool-name-1">
72
+ <parameter name="param-key-1">param-value-1</parameter>
73
+ <parameter name="param-key-2">param-value-2</parameter>
74
+ ...
75
+ </invoke>
76
+ {{- '\n' ~ toolcall_end_token }}
77
+ {%- endif -%}
78
+ {{- '[e~[\n' }}
79
+
80
+ {#- Render messages -#}
81
+ {%- set last_tool_call = namespace(name=none) -%}
82
+ {%- for message in conversation_messages -%}
83
+ {%- if message.role == 'assistant' -%}
84
+ {#- Only render reasoning_content if no user message follows -#}
85
+ {{- ']~b]ai' ~ '\n' }}
86
+
87
+ {%- set reasoning_content = '' %}
88
+ {%- set content = visible_text(message.content) %}
89
+ {%- if message.reasoning_content is string %}
90
+ {%- set reasoning_content = message.reasoning_content %}
91
+ {%- else %}
92
+ {%- if '</think>' in content %}
93
+ {%- set reasoning_content = content.split('</think>')[0].strip('\n').split('<think>')[-1].strip('\n') %}
94
+ {%- set content = content.split('</think>')[-1].strip('\n') %}
95
+ {%- endif %}
96
+ {%- endif %}
97
+ {%- if reasoning_content and loop.index0 > ns.last_user_index -%}
98
+ {{- '<think>' ~ '\n' ~ reasoning_content ~ '\n' ~ '</think>' ~ '\n\n' }}
99
+ {%- endif -%}
100
+ {%- if content -%}
101
+ {{- content }}
102
+ {%- endif -%}
103
+ {%- if message.tool_calls -%}
104
+ {{- '\n' ~ toolcall_begin_token ~ '\n' }}
105
+
106
+ {%- for tool_call in message.tool_calls -%}
107
+ {%- if tool_call.function %}
108
+ {%- set tool_call = tool_call.function %}
109
+ {%- endif %}
110
+ {{- '<invoke name="' + tool_call.name + '">' }}
111
+ {% set _args = tool_call.arguments %}
112
+ {%- for k, v in _args.items() %}
113
+ {{- '<parameter name="' + k + '">' }}
114
+ {{- v | tojson(ensure_ascii=False) if v is not string else v }}
115
+ {{- '</parameter>' }}
116
+ {% endfor %}
117
+ {{- '</invoke>' ~ '\n' }}
118
+ {%- endfor -%}
119
+
120
+ {{- toolcall_end_token}}
121
+ {%- set last_tool_call.name = message.tool_calls[-1].name -%}
122
+ {%- else -%}
123
+ {%- set last_tool_call.name = none -%}
124
+ {%- endif -%}
125
+ {{- '[e~[' ~ '\n' }}
126
+
127
+ {%- elif message.role == 'tool' -%}
128
+ {%- if last_tool_call.name is none -%}
129
+ {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }}
130
+ {%- endif -%}
131
+ {%- if loop.first or (conversation_messages[loop.index0 - 1].role != 'tool') -%}
132
+ {{- ']~b]tool' }}
133
+ {%- endif -%}
134
+ {%- if message.content is string -%}
135
+ {{- '\n<response>' }}
136
+ {{- message.content }}
137
+ {{- '</response>' }}
138
+ {%- else -%}
139
+ {%- for tr in message.content -%}
140
+ {{- '\n<response>' }}
141
+ {{- tr.output if tr.output is defined else (tr.text if tr.type == 'text' and tr.text is defined else tr) }}
142
+ {{- '\n</response>' }}
143
+ {%- endfor -%}
144
+ {%- endif -%}
145
+ {%- if loop.last or (conversation_messages[loop.index0 + 1].role != 'tool') -%}
146
+ {{- '[e~[\n' -}}
147
+ {%- endif -%}
148
+
149
+ {%- elif message.role == 'user' -%}
150
+ {{- ']~b]user' ~ '\n' }}
151
+ {{- visible_text(message.content) }}
152
+ {{- '[e~[' ~ '\n' }}
153
+ {%- endif -%}
154
+ {%- endfor -%}
155
+
156
+ {#- Generation prompt -#}
157
+ {%- if add_generation_prompt -%}
158
+ {{- ']~b]ai' ~ '\n' ~ '<think>' ~ '\n' }}
159
+ {%- endif -%}
config.json ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MiniMaxM2ForCausalLM"
4
+ ],
5
+ "attn_type_list": [
6
+ 1,
7
+ 1,
8
+ 1,
9
+ 1,
10
+ 1,
11
+ 1,
12
+ 1,
13
+ 1,
14
+ 1,
15
+ 1,
16
+ 1,
17
+ 1,
18
+ 1,
19
+ 1,
20
+ 1,
21
+ 1,
22
+ 1,
23
+ 1,
24
+ 1,
25
+ 1,
26
+ 1,
27
+ 1,
28
+ 1,
29
+ 1,
30
+ 1,
31
+ 1,
32
+ 1,
33
+ 1,
34
+ 1,
35
+ 1,
36
+ 1,
37
+ 1,
38
+ 1,
39
+ 1,
40
+ 1,
41
+ 1,
42
+ 1,
43
+ 1,
44
+ 1,
45
+ 1,
46
+ 1,
47
+ 1,
48
+ 1,
49
+ 1,
50
+ 1,
51
+ 1,
52
+ 1,
53
+ 1,
54
+ 1,
55
+ 1,
56
+ 1,
57
+ 1,
58
+ 1,
59
+ 1,
60
+ 1,
61
+ 1,
62
+ 1,
63
+ 1,
64
+ 1,
65
+ 1,
66
+ 1,
67
+ 1
68
+ ],
69
+ "auto_map": {
70
+ "AutoConfig": "configuration_minimax_m2.MiniMaxM2Config",
71
+ "AutoModelForCausalLM": "modeling_minimax_m2.MiniMaxM2ForCausalLM"
72
+ },
73
+ "head_dim": 128,
74
+ "hidden_act": "silu",
75
+ "hidden_size": 3072,
76
+ "intermediate_size": 1536,
77
+ "max_position_embeddings": 196608,
78
+ "model_type": "minimax_m2",
79
+ "mtp_transformer_layers": 1,
80
+ "num_attention_heads": 48,
81
+ "num_experts_per_tok": 8,
82
+ "num_hidden_layers": 62,
83
+ "num_key_value_heads": 8,
84
+ "num_local_experts": 192,
85
+ "num_mtp_modules": 3,
86
+ "qk_norm_type": "per_layer",
87
+ "quantization_config": {
88
+ "activation_scheme": "dynamic",
89
+ "fmt": "float8_e4m3fn",
90
+ "quant_method": "fp8",
91
+ "weight_block_size": [
92
+ 128,
93
+ 128
94
+ ],
95
+ "modules_to_not_convert": [
96
+ "gate",
97
+ "e_score_correction_bias",
98
+ "lm_head"
99
+ ]
100
+ },
101
+ "rms_norm_eps": 1e-06,
102
+ "rope_theta": 5000000,
103
+ "rotary_dim": 64,
104
+ "scoring_func": "sigmoid",
105
+ "shared_intermediate_size": 0,
106
+ "tie_word_embeddings": false,
107
+ "transformers_version": "4.46.1",
108
+ "use_cache": true,
109
+ "use_mtp": true,
110
+ "use_qk_norm": true,
111
+ "use_routing_bias": true,
112
+ "vocab_size": 200064
113
+ }
configuration_minimax_m2.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/minimax_m2/modular_minimax_m2.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_minimax_m2.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 the HuggingFace Team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ from transformers.configuration_utils import PretrainedConfig
24
+
25
+
26
+ class MiniMaxM2Config(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`MiniMaxM2Model`]. It is used to instantiate an
29
+ MiniMaxM2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
30
+ with the defaults will yield a similar configuration to that of the MiniMaxM2-7B-v0.1 or MiniMaxM2-7B-Instruct-v0.1.
31
+
32
+ [minimax_m2ai/MiniMaxM2-8x7B](https://huggingface.co/minimax_m2ai/MiniMaxM2-8x7B)
33
+ [minimax_m2ai/MiniMaxM2-7B-Instruct-v0.1](https://huggingface.co/minimax_m2ai/MiniMaxM2-7B-Instruct-v0.1)
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 32000):
41
+ Vocabulary size of the MiniMaxM2 model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`MiniMaxM2Model`]
43
+ hidden_size (`int`, *optional*, defaults to 4096):
44
+ Dimension of the hidden representations.
45
+ intermediate_size (`int`, *optional*, defaults to 14336):
46
+ Dimension of the MLP representations.
47
+ num_hidden_layers (`int`, *optional*, defaults to 32):
48
+ Number of hidden layers in the Transformer encoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 32):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ num_key_value_heads (`int`, *optional*, defaults to 8):
52
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
53
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
54
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
55
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
56
+ by meanpooling all the original heads within that group. For more details, check out [this
57
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
58
+ head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
59
+ The attention head dimension.
60
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
61
+ The non-linear activation function (function or string) in the decoder.
62
+ max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
63
+ The maximum sequence length that this model might ever be used with. MiniMaxM2's sliding window attention
64
+ allows sequence of up to 4096*32 tokens.
65
+ initializer_range (`float`, *optional*, defaults to 0.02):
66
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
68
+ The epsilon used by the rms normalization layers.
69
+ use_cache (`bool`, *optional*, defaults to `True`):
70
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
71
+ relevant if `config.is_decoder=True`.
72
+ pad_token_id (`int`, *optional*):
73
+ The id of the padding token.
74
+ bos_token_id (`int`, *optional*, defaults to 1):
75
+ The id of the "beginning-of-sequence" token.
76
+ eos_token_id (`int`, *optional*, defaults to 2):
77
+ The id of the "end-of-sequence" token.
78
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
79
+ Whether the model's input and output word embeddings should be tied.
80
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
81
+ The base period of the RoPE embeddings.
82
+ sliding_window (`int`, *optional*):
83
+ Sliding window attention window size. If not specified, will default to `4096`.
84
+ attention_dropout (`float`, *optional*, defaults to 0.0):
85
+ The dropout ratio for the attention probabilities.
86
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
87
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
88
+ parameter
89
+ num_local_experts (`int`, *optional*, defaults to 8):
90
+ Number of experts per Sparse MLP layer.
91
+ output_router_logits (`bool`, *optional*, defaults to `False`):
92
+ Whether or not the router logits should be returned by the model. Enabling this will also
93
+ allow the model to output the auxiliary loss. See [here]() for more details
94
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
95
+ The aux loss factor for the total loss.
96
+ router_jitter_noise (`float`, *optional*, defaults to 0.0):
97
+ Amount of noise to add to the router.
98
+
99
+ ```python
100
+ >>> from transformers import MiniMaxM2Model, MiniMaxM2Config
101
+
102
+ >>> # Initializing a MiniMaxM2 7B style configuration
103
+ >>> configuration = MiniMaxM2Config()
104
+
105
+ >>> # Initializing a model from the MiniMaxM2 7B style configuration
106
+ >>> model = MiniMaxM2Model(configuration)
107
+
108
+ >>> # Accessing the model configuration
109
+ >>> configuration = model.config
110
+ ```"""
111
+
112
+ model_type = "minimax_m2"
113
+ keys_to_ignore_at_inference = ["past_key_values"]
114
+ base_model_tp_plan = {
115
+ "layers.*.self_attn.q_proj": "colwise",
116
+ "layers.*.self_attn.k_proj": "colwise",
117
+ "layers.*.self_attn.v_proj": "colwise",
118
+ "layers.*.self_attn.o_proj": "rowwise",
119
+ "layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts
120
+ "layers.*.block_sparse_moe.experts.*.w1": "colwise",
121
+ "layers.*.block_sparse_moe.experts.*.w2": "rowwise",
122
+ "layers.*.block_sparse_moe.experts.*.w3": "colwise",
123
+ }
124
+ base_model_pp_plan = {
125
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
126
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
127
+ "norm": (["hidden_states"], ["hidden_states"]),
128
+ }
129
+
130
+ def __init__(
131
+ self,
132
+ vocab_size=32000,
133
+ hidden_size=4096,
134
+ intermediate_size=14336,
135
+ num_hidden_layers=32,
136
+ num_attention_heads=32,
137
+ num_key_value_heads=8,
138
+ head_dim=None,
139
+ hidden_act="silu",
140
+ max_position_embeddings=4096 * 32,
141
+ initializer_range=0.02,
142
+ rms_norm_eps=1e-5,
143
+ use_cache=True,
144
+ pad_token_id=None,
145
+ bos_token_id=1,
146
+ eos_token_id=2,
147
+ tie_word_embeddings=False,
148
+ rope_theta=1e6,
149
+ sliding_window=None,
150
+ attention_dropout=0.0,
151
+ num_experts_per_tok=2,
152
+ num_local_experts=8,
153
+ output_router_logits=False,
154
+ router_aux_loss_coef=0.001,
155
+ router_jitter_noise=0.0,
156
+ **kwargs,
157
+ ):
158
+ self.vocab_size = vocab_size
159
+ self.max_position_embeddings = max_position_embeddings
160
+ self.hidden_size = hidden_size
161
+ self.intermediate_size = intermediate_size
162
+ self.num_hidden_layers = num_hidden_layers
163
+ self.num_attention_heads = num_attention_heads
164
+ self.sliding_window = sliding_window
165
+
166
+ # for backward compatibility
167
+ if num_key_value_heads is None:
168
+ num_key_value_heads = num_attention_heads
169
+
170
+ self.num_key_value_heads = num_key_value_heads
171
+ self.hidden_act = hidden_act
172
+ self.initializer_range = initializer_range
173
+ self.rms_norm_eps = rms_norm_eps
174
+ self.use_cache = use_cache
175
+ self.rope_theta = rope_theta
176
+ self.attention_dropout = attention_dropout
177
+ self.head_dim = head_dim
178
+
179
+ self.num_experts_per_tok = num_experts_per_tok
180
+ self.num_local_experts = num_local_experts
181
+ self.output_router_logits = output_router_logits
182
+ self.router_aux_loss_coef = router_aux_loss_coef
183
+ self.router_jitter_noise = router_jitter_noise
184
+
185
+ self.use_qk_norm = kwargs.pop("use_qk_norm", False)
186
+ self.rotary_dim = kwargs.pop("rotary_dim", self.head_dim)
187
+ self.partial_rotary_factor = kwargs.pop("partial_rotary_factor", 1)
188
+ if self.head_dim is not None:
189
+ self.partial_rotary_factor = self.rotary_dim / self.head_dim
190
+
191
+ super().__init__(
192
+ pad_token_id=pad_token_id,
193
+ bos_token_id=bos_token_id,
194
+ eos_token_id=eos_token_id,
195
+ tie_word_embeddings=tie_word_embeddings,
196
+ **kwargs,
197
+ )
198
+
199
+
200
+ __all__ = ["MiniMaxM2Config"]
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d336e742d563031ded218fec1d2589f8a1f12b31033d5516490074f67d5f5d6
3
+ size 5364248288
model-00002-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:493b31d51bf14a380a669b03908a6b7dce82cbf86dc621f1aae0218e1cdc6b5a
3
+ size 5366832192
model-00003-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:207b3a125458e367f8eca36b63288d55b26d3d90d1bbd6910a4eeff9e31a61ec
3
+ size 5366832128
model-00004-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4d27d7c2ef694ea5369ac3f65998cf9f389e15a5737cc9d274c6b42ec4b1d57
3
+ size 5366832056
model-00005-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c27aa92d6fb3676cdf7ab9efd40d05e549fe414bdafc8fc36b1e3ad81d2d25b5
3
+ size 5366831992
model-00006-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8a3c86e994a8d1b3a126f99a1449636ee5b05dc5f2a71c3df761f3eb681ed08
3
+ size 5366834184
model-00007-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dcff3d0cc0cb265a81dcf115e002cf1150269716417a66e4b1d55e541e9d54c4
3
+ size 5366834120
model-00008-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:842c6896cb9817277d02254f9bcac64ff1705115d88cf2edcead836c83e3cec1
3
+ size 5366834048
model-00009-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e3ba166f56121a1bff51bf09c900138f1e4ad42baacbaa5d0574fba43235fab
3
+ size 5366834024
model-00010-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2bd2b3cfdc8bf6c872763ab8aefae68511d1772d30d2b85122b07463b92771b
3
+ size 5368774432
model-00011-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81c0496fc10a6cdf83f1cc3c41c56a5154f1ffe8175c3a0ae71fe825cb395d3e
3
+ size 5366834992
model-00012-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbb91aabfa535e345288c86cc668a6673601aed55bfca3ad6c7b30ed7484fffa
3
+ size 5366834920
model-00013-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b03a14a136502d818dbac3303070b674ddc68fd240f4d02ff92a8ead708c23e
3
+ size 5366834856
model-00014-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fce8d2242d6649f508f2be917bece30b040634c28af778b199d1b2a4703dd27f
3
+ size 5366834784
model-00015-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc7b76d2d81bb259d078f1e13a5a3e0589715b3c46a9608e36907d113cd97a15
3
+ size 5366834720
model-00016-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c011637a18e9f78fec442418ff6936082d5ef68994a30597300f49c1492eb87
3
+ size 5366834608
model-00017-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f961a296f106c0e8d54fd96d2caf3df49acc7297692593c6011ed1b8b03fda9c
3
+ size 5366834544
model-00018-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9a95786b575b85a6d5257a45193c532935a4e55f03f0ba607d1309527f904e0
3
+ size 5366834496
model-00019-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7b5ab2c1b78cc69f1dc42ebe83584552a3774bd7114b2da7888732027511497
3
+ size 5366834472
model-00020-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cba5af715c8f9a198a7758b295ab3dab1e226ae952b983077a72bbe3971689c1
3
+ size 5366834408
model-00021-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64096614cc10347dfaebedb589bb10dce0cb41a75598cc98c36fa6d7d4b43a7f
3
+ size 5366834336
model-00022-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a4d2e3a98b6d089d863c913b8e4d354876f7320ccde86d8aa05b5daa4fcdb94
3
+ size 5366834272
model-00023-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e71b62b827b5f6f76e97edd1b36c3559ea201c3cf0e245da1f233229736a0f87
3
+ size 5366834200
model-00024-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41b47f096dd586d88839cfb259feda0069ddc969cd5eb81aa040c82c66a96fa2
3
+ size 5366834136
model-00025-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5250405d37efc4604ebabb23d976496d33df7120f8c3b98d20d4671b670dde11
3
+ size 5366834064
model-00026-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8fe5631e92394ad063e10627c92fae60605e91a0832608d83e0aefa87773fe3f
3
+ size 5366834024
model-00027-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a4ed636c585c752afb5edfdb365b15fb8ac90ecacdb5d56f7bca9d24820172a
3
+ size 5368774464
model-00028-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3272fb4e462fb68068fd9b0805fbc9e2e61bec7dab43ae7d50cacf792c3c104
3
+ size 5366835024
model-00029-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d56b452427b918f2bca4d1961917b9a728b2fb10c97958154b5e0f64dcd47b90
3
+ size 5366834936
model-00030-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b355236d28606ef749531090d91eff57cad5be89074a5d200e343ec140ca7cb
3
+ size 5366834872
model-00031-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6805383c4d36ddf29315cc1b74168c8cd3a3bab7e06376edf8ad249ddc82c92
3
+ size 5366834800
model-00032-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c72c0bee0fa9988b84e030c952ab9ed4ae0cd157b8f40eab63ef339d099ff92
3
+ size 5366834736
model-00033-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:993efb909e925aaa2dc91f7bf2c08a47b715e5897f30dfea499e5b09fe359a5d
3
+ size 2064553512
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_minimax_m2.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/minimax_m2/modular_minimax_m2.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_minimax_m2.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 the HuggingFace Team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ from collections.abc import Callable
24
+ from typing import Optional, Union, Unpack
25
+
26
+ import torch
27
+ from torch import nn
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache, DynamicCache
31
+ from transformers.generation import GenerationMixin
32
+ from transformers.integrations import use_kernel_forward_from_hub
33
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
34
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
35
+ from transformers.modeling_layers import (
36
+ GenericForQuestionAnswering,
37
+ GenericForSequenceClassification,
38
+ GenericForTokenClassification,
39
+ GradientCheckpointingLayer,
40
+ )
41
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
42
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
43
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
44
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
45
+ from transformers.utils.deprecation import deprecate_kwarg
46
+ from transformers.utils.generic import OutputRecorder, check_model_inputs
47
+ from .configuration_minimax_m2 import MiniMaxM2Config
48
+
49
+
50
+ class MiniMaxM2MLP(nn.Module):
51
+ def __init__(self, config: MiniMaxM2Config):
52
+ super().__init__()
53
+ self.ffn_dim = config.intermediate_size
54
+ self.hidden_dim = config.hidden_size
55
+
56
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
57
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
58
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
59
+
60
+ self.act_fn = ACT2FN[config.hidden_act]
61
+
62
+ def forward(self, hidden_states):
63
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
64
+ current_hidden_states = self.w2(current_hidden_states)
65
+ return current_hidden_states
66
+
67
+
68
+ class MiniMaxM2Experts(nn.ModuleList):
69
+ """
70
+ ModuleList of experts.
71
+ """
72
+
73
+ def __init__(self, config: MiniMaxM2Config):
74
+ super().__init__()
75
+ self.top_k = config.num_experts_per_tok
76
+ self.num_experts = config.num_local_experts
77
+ for _ in range(self.num_experts):
78
+ self.append(MiniMaxM2MLP(config))
79
+
80
+ def forward(
81
+ self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
82
+ ) -> torch.Tensor:
83
+ """
84
+ Args:
85
+ hidden_states: (batch_size * sequence_length, hidden_dim)
86
+ selected_experts: (batch_size * sequence_length, top_k)
87
+ routing_weights: (batch_size * sequence_length, top_k)
88
+ Returns:
89
+ (batch_size * sequence_length, hidden_dim)
90
+ """
91
+ final_hidden_states = torch.zeros_like(hidden_states)
92
+ expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
93
+
94
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
95
+ for expert_idx in expert_hit:
96
+ idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
97
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
98
+ current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
99
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
100
+ return final_hidden_states
101
+
102
+
103
+ class MiniMaxM2SparseMoeBlock(nn.Module):
104
+ def __init__(self, config):
105
+ super().__init__()
106
+ self.top_k = config.num_experts_per_tok
107
+ self.jitter_noise = config.router_jitter_noise
108
+ self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
109
+ self.experts = MiniMaxM2Experts(config)
110
+ self.register_buffer("e_score_correction_bias", torch.zeros(config.num_local_experts))
111
+
112
+ def route_tokens_to_experts(self, router_logits):
113
+ routing_weights = torch.nn.functional.sigmoid(router_logits.float())
114
+ scores_for_choice = routing_weights + self.e_score_correction_bias
115
+ _, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False)
116
+ top_k_weights = routing_weights.gather(1, top_k_index)
117
+ top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
118
+ return top_k_index, top_k_weights.to(router_logits.dtype)
119
+
120
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
121
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
122
+ if self.training and self.jitter_noise > 0:
123
+ hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
124
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
125
+ router_logits = self.gate(hidden_states)
126
+ top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits)
127
+ hidden_states = self.experts(hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype))
128
+ hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
129
+ return hidden_states, router_logits
130
+
131
+
132
+ @use_kernel_forward_from_hub("RMSNorm")
133
+ class MiniMaxM2RMSNorm(nn.Module):
134
+ def __init__(self, hidden_size, eps=1e-6):
135
+ """
136
+ MiniMaxM2RMSNorm is equivalent to T5LayerNorm
137
+ """
138
+ super().__init__()
139
+ self.weight = nn.Parameter(torch.ones(hidden_size))
140
+ self.variance_epsilon = eps
141
+
142
+ def forward(self, hidden_states):
143
+ input_dtype = hidden_states.dtype
144
+ hidden_states = hidden_states.to(torch.float32)
145
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
146
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
147
+ return self.weight * hidden_states.to(input_dtype)
148
+
149
+ def extra_repr(self):
150
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
151
+
152
+
153
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
154
+ """
155
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
156
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
157
+ """
158
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
159
+ if n_rep == 1:
160
+ return hidden_states
161
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
162
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
163
+
164
+
165
+ def eager_attention_forward(
166
+ module: nn.Module,
167
+ query: torch.Tensor,
168
+ key: torch.Tensor,
169
+ value: torch.Tensor,
170
+ attention_mask: Optional[torch.Tensor],
171
+ scaling: float,
172
+ dropout: float = 0.0,
173
+ **kwargs: Unpack[TransformersKwargs],
174
+ ):
175
+ key_states = repeat_kv(key, module.num_key_value_groups)
176
+ value_states = repeat_kv(value, module.num_key_value_groups)
177
+
178
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
179
+ if attention_mask is not None:
180
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
181
+ attn_weights = attn_weights + causal_mask
182
+
183
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
184
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
185
+ attn_output = torch.matmul(attn_weights, value_states)
186
+ attn_output = attn_output.transpose(1, 2).contiguous()
187
+
188
+ return attn_output, attn_weights
189
+
190
+
191
+ def rotate_half(x):
192
+ """Rotates half the hidden dims of the input."""
193
+ x1 = x[..., : x.shape[-1] // 2]
194
+ x2 = x[..., x.shape[-1] // 2 :]
195
+ return torch.cat((-x2, x1), dim=-1)
196
+
197
+
198
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
199
+ """Applies Rotary Position Embedding to the query and key tensors.
200
+ Args:
201
+ q (`torch.Tensor`): The query tensor.
202
+ k (`torch.Tensor`): The key tensor.
203
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
204
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
205
+ position_ids (`torch.Tensor`, *optional*):
206
+ Deprecated and unused.
207
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
208
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
209
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
210
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
211
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
212
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
213
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
214
+ Returns:
215
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
216
+ """
217
+ cos = cos.unsqueeze(unsqueeze_dim)
218
+ sin = sin.unsqueeze(unsqueeze_dim)
219
+
220
+ # Keep half or full tensor for later concatenation
221
+ rotary_dim = cos.shape[-1]
222
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
223
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
224
+
225
+ # Apply rotary embeddings on the first half or full tensor
226
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
227
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
228
+
229
+ # Concatenate back to full shape
230
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
231
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
232
+ return q_embed, k_embed
233
+
234
+
235
+ class MiniMaxM2Attention(nn.Module):
236
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
237
+
238
+ def __init__(self, config: MiniMaxM2Config, layer_idx: int):
239
+ super().__init__()
240
+ self.config = config
241
+ self.layer_idx = layer_idx
242
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
243
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
244
+ self.scaling = self.head_dim**-0.5
245
+ self.attention_dropout = config.attention_dropout
246
+ self.is_causal = True
247
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
248
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
249
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
250
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
251
+
252
+ self.use_qk_norm = config.use_qk_norm
253
+ if self.use_qk_norm:
254
+ self.q_norm = MiniMaxM2RMSNorm(self.head_dim * config.num_attention_heads, eps=config.rms_norm_eps)
255
+ self.k_norm = MiniMaxM2RMSNorm(self.head_dim * config.num_key_value_heads, eps=config.rms_norm_eps)
256
+
257
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
258
+ def forward(
259
+ self,
260
+ hidden_states: torch.Tensor,
261
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
262
+ attention_mask: Optional[torch.Tensor],
263
+ past_key_values: Optional[Cache] = None,
264
+ cache_position: Optional[torch.LongTensor] = None,
265
+ **kwargs: Unpack[FlashAttentionKwargs],
266
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
267
+ input_shape = hidden_states.shape[:-1]
268
+ hidden_shape = (*input_shape, -1, self.head_dim)
269
+
270
+ query_states = self.q_proj(hidden_states)
271
+ key_states = self.k_proj(hidden_states)
272
+ value_states = self.v_proj(hidden_states)
273
+
274
+ if self.use_qk_norm: # main diff from Llama
275
+ query_states = self.q_norm(query_states)
276
+ key_states = self.k_norm(key_states)
277
+
278
+ key_states = key_states.view(hidden_shape)
279
+ query_states = query_states.view(hidden_shape)
280
+ value_states = value_states.view(hidden_shape)
281
+
282
+ query_states = query_states.transpose(1, 2)
283
+ key_states = key_states.transpose(1, 2)
284
+ value_states = value_states.transpose(1, 2)
285
+
286
+ cos, sin = position_embeddings
287
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
288
+
289
+ if past_key_values is not None:
290
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
291
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
292
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
293
+
294
+ attention_interface: Callable = eager_attention_forward
295
+ if self.config._attn_implementation != "eager":
296
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
297
+
298
+ attn_output, attn_weights = attention_interface(
299
+ self,
300
+ query_states,
301
+ key_states,
302
+ value_states,
303
+ attention_mask,
304
+ dropout=0.0 if not self.training else self.attention_dropout,
305
+ scaling=self.scaling,
306
+ **kwargs,
307
+ )
308
+
309
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
310
+ attn_output = self.o_proj(attn_output)
311
+ return attn_output, attn_weights
312
+
313
+
314
+ class MiniMaxM2DecoderLayer(GradientCheckpointingLayer):
315
+ def __init__(self, config: MiniMaxM2Config, layer_idx: int):
316
+ super().__init__()
317
+ self.hidden_size = config.hidden_size
318
+
319
+ self.self_attn = MiniMaxM2Attention(config, layer_idx)
320
+
321
+ self.block_sparse_moe = MiniMaxM2SparseMoeBlock(config)
322
+ self.input_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
323
+ self.post_attention_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
324
+
325
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
326
+ def forward(
327
+ self,
328
+ hidden_states: torch.Tensor,
329
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
330
+ attention_mask: Optional[torch.Tensor] = None,
331
+ position_ids: Optional[torch.LongTensor] = None,
332
+ past_key_values: Optional[Cache] = None,
333
+ cache_position: Optional[torch.LongTensor] = None,
334
+ **kwargs: Unpack[TransformersKwargs],
335
+ ) -> torch.FloatTensor:
336
+ residual = hidden_states
337
+
338
+ hidden_states = self.input_layernorm(hidden_states)
339
+
340
+ # Self Attention
341
+ hidden_states, _ = self.self_attn(
342
+ hidden_states=hidden_states,
343
+ position_embeddings=position_embeddings,
344
+ attention_mask=attention_mask,
345
+ position_ids=position_ids,
346
+ past_key_values=past_key_values,
347
+ cache_position=cache_position,
348
+ **kwargs,
349
+ )
350
+ hidden_states = residual + hidden_states
351
+
352
+ # Fully Connected
353
+ residual = hidden_states
354
+ hidden_states = self.post_attention_layernorm(hidden_states)
355
+ hidden_states, _ = self.block_sparse_moe(hidden_states)
356
+ hidden_states = residual + hidden_states
357
+
358
+ return hidden_states
359
+
360
+
361
+ class MiniMaxM2RotaryEmbedding(nn.Module):
362
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
363
+
364
+ def __init__(self, config: MiniMaxM2Config, device=None):
365
+ super().__init__()
366
+ # BC: "rope_type" was originally "type"
367
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
368
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
369
+ else:
370
+ self.rope_type = "default"
371
+ self.max_seq_len_cached = config.max_position_embeddings
372
+ self.original_max_seq_len = config.max_position_embeddings
373
+
374
+ self.config = config
375
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
376
+
377
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
378
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
379
+ self.original_inv_freq = self.inv_freq
380
+
381
+ @torch.no_grad()
382
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
383
+ def forward(self, x, position_ids):
384
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
385
+ position_ids_expanded = position_ids[:, None, :].float()
386
+
387
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
388
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
389
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
390
+ emb = torch.cat((freqs, freqs), dim=-1)
391
+ cos = emb.cos() * self.attention_scaling
392
+ sin = emb.sin() * self.attention_scaling
393
+
394
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
395
+
396
+
397
+ @auto_docstring
398
+ class MiniMaxM2PreTrainedModel(PreTrainedModel):
399
+ config: MiniMaxM2Config
400
+ base_model_prefix = "model"
401
+ supports_gradient_checkpointing = True
402
+ _no_split_modules = ["MiniMaxM2DecoderLayer"]
403
+ _skip_keys_device_placement = ["past_key_values"]
404
+ _supports_flash_attn = True
405
+ _supports_sdpa = True
406
+ _supports_flex_attn = True
407
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
408
+ _supports_attention_backend = True
409
+ _can_record_outputs = {
410
+ "router_logits": OutputRecorder(MiniMaxM2SparseMoeBlock, index=1),
411
+ "hidden_states": MiniMaxM2DecoderLayer,
412
+ "attentions": MiniMaxM2Attention,
413
+ }
414
+
415
+
416
+ @auto_docstring
417
+ class MiniMaxM2Model(MiniMaxM2PreTrainedModel):
418
+ def __init__(self, config: MiniMaxM2Config):
419
+ super().__init__(config)
420
+ self.padding_idx = config.pad_token_id
421
+ self.vocab_size = config.vocab_size
422
+
423
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
424
+ self.layers = nn.ModuleList(
425
+ [MiniMaxM2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
426
+ )
427
+ self.norm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
428
+ self.rotary_emb = MiniMaxM2RotaryEmbedding(config=config)
429
+ self.gradient_checkpointing = False
430
+
431
+ # Initialize weights and apply final processing
432
+ self.post_init()
433
+
434
+ @check_model_inputs
435
+ @auto_docstring
436
+ def forward(
437
+ self,
438
+ input_ids: Optional[torch.LongTensor] = None,
439
+ attention_mask: Optional[torch.Tensor] = None,
440
+ position_ids: Optional[torch.LongTensor] = None,
441
+ past_key_values: Optional[Cache] = None,
442
+ inputs_embeds: Optional[torch.FloatTensor] = None,
443
+ use_cache: Optional[bool] = None,
444
+ cache_position: Optional[torch.LongTensor] = None,
445
+ **kwargs: Unpack[TransformersKwargs],
446
+ ) -> MoeModelOutputWithPast:
447
+ if (input_ids is None) ^ (inputs_embeds is not None):
448
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
449
+
450
+ if use_cache and past_key_values is None:
451
+ past_key_values = DynamicCache(config=self.config)
452
+
453
+ if inputs_embeds is None:
454
+ inputs_embeds = self.embed_tokens(input_ids)
455
+
456
+ if cache_position is None:
457
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
458
+ cache_position = torch.arange(
459
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
460
+ )
461
+ if position_ids is None:
462
+ position_ids = cache_position.unsqueeze(0)
463
+
464
+ mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
465
+ causal_mask = mask_function(
466
+ config=self.config,
467
+ input_embeds=inputs_embeds,
468
+ attention_mask=attention_mask,
469
+ cache_position=cache_position,
470
+ past_key_values=past_key_values,
471
+ position_ids=position_ids,
472
+ )
473
+
474
+ hidden_states = inputs_embeds
475
+
476
+ # create position embeddings to be shared across the decoder layers
477
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
478
+
479
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
480
+ hidden_states = decoder_layer(
481
+ hidden_states,
482
+ position_embeddings=position_embeddings,
483
+ attention_mask=causal_mask,
484
+ position_ids=position_ids,
485
+ past_key_values=past_key_values,
486
+ use_cache=use_cache,
487
+ cache_position=cache_position,
488
+ **kwargs,
489
+ )
490
+
491
+ hidden_states = self.norm(hidden_states)
492
+
493
+ return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
494
+ last_hidden_state=hidden_states,
495
+ past_key_values=past_key_values,
496
+ )
497
+
498
+
499
+ def load_balancing_loss_func(
500
+ gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
501
+ num_experts: Optional[int] = None,
502
+ top_k=2,
503
+ attention_mask: Optional[torch.Tensor] = None,
504
+ ) -> Union[torch.Tensor, int]:
505
+ r"""
506
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
507
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
508
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
509
+ experts is too unbalanced.
510
+ Args:
511
+ gate_logits:
512
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
513
+ shape [batch_size X sequence_length, num_experts].
514
+ num_experts:
515
+ Number of experts
516
+ top_k:
517
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
518
+ parameter.
519
+ attention_mask (`torch.Tensor`, *optional*):
520
+ The attention_mask used in forward function
521
+ shape [batch_size X sequence_length] if not None.
522
+ Returns:
523
+ The auxiliary loss.
524
+ """
525
+ if gate_logits is None or not isinstance(gate_logits, tuple):
526
+ return 0
527
+
528
+ if isinstance(gate_logits, tuple):
529
+ compute_device = gate_logits[0].device
530
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
531
+
532
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
533
+
534
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
535
+
536
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
537
+
538
+ if attention_mask is None:
539
+ # Compute the percentage of tokens routed to each experts
540
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
541
+
542
+ # Compute the average probability of routing to these experts
543
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
544
+ else:
545
+ batch_size, sequence_length = attention_mask.shape
546
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
547
+
548
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
549
+ expert_attention_mask = (
550
+ attention_mask[None, :, :, None, None]
551
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
552
+ .reshape(-1, top_k, num_experts)
553
+ .to(compute_device)
554
+ )
555
+
556
+ # Compute the percentage of tokens routed to each experts
557
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
558
+ expert_attention_mask, dim=0
559
+ )
560
+
561
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
562
+ router_per_expert_attention_mask = (
563
+ attention_mask[None, :, :, None]
564
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
565
+ .reshape(-1, num_experts)
566
+ .to(compute_device)
567
+ )
568
+
569
+ # Compute the average probability of routing to these experts
570
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
571
+ router_per_expert_attention_mask, dim=0
572
+ )
573
+
574
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
575
+ return overall_loss * num_experts
576
+
577
+
578
+ @auto_docstring
579
+ class MiniMaxM2ForCausalLM(MiniMaxM2PreTrainedModel, GenerationMixin):
580
+ _tied_weights_keys = ["lm_head.weight"]
581
+ _tp_plan = {"lm_head": "colwise_rep"}
582
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
583
+
584
+ def __init__(self, config):
585
+ super().__init__(config)
586
+ self.model = MiniMaxM2Model(config)
587
+ self.vocab_size = config.vocab_size
588
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
589
+ self.router_aux_loss_coef = config.router_aux_loss_coef
590
+ self.num_experts = config.num_local_experts
591
+ self.num_experts_per_tok = config.num_experts_per_tok
592
+
593
+ # Initialize weights and apply final processing
594
+ self.post_init()
595
+
596
+ @can_return_tuple
597
+ @auto_docstring
598
+ def forward(
599
+ self,
600
+ input_ids: Optional[torch.LongTensor] = None,
601
+ attention_mask: Optional[torch.Tensor] = None,
602
+ position_ids: Optional[torch.LongTensor] = None,
603
+ past_key_values: Optional[Cache] = None,
604
+ inputs_embeds: Optional[torch.FloatTensor] = None,
605
+ labels: Optional[torch.LongTensor] = None,
606
+ use_cache: Optional[bool] = None,
607
+ output_router_logits: Optional[bool] = None,
608
+ cache_position: Optional[torch.LongTensor] = None,
609
+ logits_to_keep: Union[int, torch.Tensor] = 0,
610
+ **kwargs: Unpack[TransformersKwargs],
611
+ ) -> MoeCausalLMOutputWithPast:
612
+ r"""
613
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
614
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
615
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
616
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
617
+ Example:
618
+ ```python
619
+ >>> from transformers import AutoTokenizer, MiniMaxM2ForCausalLM
620
+ >>> model = MiniMaxM2ForCausalLM.from_pretrained("mistralai/MiniMaxM2-8x7B-v0.1")
621
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/MiniMaxM2-8x7B-v0.1")
622
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
623
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
624
+ >>> # Generate
625
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
626
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
627
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
628
+ ```"""
629
+
630
+ output_router_logits = (
631
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
632
+ )
633
+
634
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
635
+ outputs: MoeModelOutputWithPast = self.model(
636
+ input_ids=input_ids,
637
+ attention_mask=attention_mask,
638
+ position_ids=position_ids,
639
+ past_key_values=past_key_values,
640
+ inputs_embeds=inputs_embeds,
641
+ use_cache=use_cache,
642
+ output_router_logits=output_router_logits,
643
+ cache_position=cache_position,
644
+ **kwargs,
645
+ )
646
+
647
+ hidden_states = outputs.last_hidden_state
648
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
649
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
650
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
651
+
652
+ loss = None
653
+ if labels is not None:
654
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
655
+
656
+ aux_loss = None
657
+ if output_router_logits:
658
+ aux_loss = load_balancing_loss_func(
659
+ outputs.router_logits,
660
+ self.num_experts,
661
+ self.num_experts_per_tok,
662
+ attention_mask,
663
+ )
664
+ if labels is not None:
665
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
666
+
667
+ return MoeCausalLMOutputWithPast(
668
+ loss=loss,
669
+ aux_loss=aux_loss,
670
+ logits=logits,
671
+ past_key_values=outputs.past_key_values,
672
+ hidden_states=outputs.hidden_states,
673
+ attentions=outputs.attentions,
674
+ router_logits=outputs.router_logits,
675
+ )
676
+
677
+
678
+ class MiniMaxM2ForSequenceClassification(GenericForSequenceClassification, MiniMaxM2PreTrainedModel):
679
+ pass
680
+
681
+
682
+ class MiniMaxM2ForTokenClassification(GenericForTokenClassification, MiniMaxM2PreTrainedModel):
683
+ pass
684
+
685
+
686
+ class MiniMaxM2ForQuestionAnswering(GenericForQuestionAnswering, MiniMaxM2PreTrainedModel):
687
+ pass
688
+
689
+
690
+ __all__ = [
691
+ "MiniMaxM2ForCausalLM",
692
+ "MiniMaxM2ForQuestionAnswering",
693
+ "MiniMaxM2Model",
694
+ "MiniMaxM2PreTrainedModel",
695
+ "MiniMaxM2ForSequenceClassification",
696
+ "MiniMaxM2ForTokenClassification",
697
+ ]
special_tokens_map.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<code_interpreter>",
4
+ "<commit_after>",
5
+ "<commit_before>",
6
+ "<commit_msg>",
7
+ "<empty_output>",
8
+ "<filename>",
9
+ "<fim_middle>",
10
+ "<fim_pad>",
11
+ "<fim_prefix>",
12
+ "<fim_suffix>",
13
+ "<function_call>",
14
+ "<gh_stars>",
15
+ "]<]speech[>[",
16
+ "]<]image[>[",
17
+ "]<]video[>[",
18
+ "]<]start of speech[>[",
19
+ "]<]end of speech[>[",
20
+ "]<]start of image[>[",
21
+ "]<]end of image[>[",
22
+ "]<]start of video[>[",
23
+ "]<]end of video[>[",
24
+ "]<]vision pad[>[",
25
+ "]~!b[",
26
+ "<issue_closed>",
27
+ "<issue_comment>",
28
+ "<issue_start>",
29
+ "<jupyter_code>",
30
+ "<jupyter_output>",
31
+ "<jupyter_start>",
32
+ "<jupyter_text>",
33
+ "<reponame>",
34
+ "[e~[",
35
+ "]!d~[",
36
+ "]!p~[",
37
+ "]~b]",
38
+ "<jupyter_error>",
39
+ "<add_file>",
40
+ "<delete_file>",
41
+ "<rename_file>",
42
+ "<edit_file>",
43
+ "<commit_message>",
44
+ "<empty_source_file>",
45
+ "<repo_struct>",
46
+ "<code_context>",
47
+ "<file_content>",
48
+ "<source_files>",
49
+ "<pr_start>",
50
+ "<review_comment>",
51
+ "<filepath>",
52
+ "<file_sep>"
53
+ ],
54
+ "bos_token": {
55
+ "content": "]~!b[",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false
60
+ },
61
+ "eos_token": {
62
+ "content": "[e~[",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false
67
+ },
68
+ "unk_token": {
69
+ "content": "]!d~[",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false
74
+ }
75
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7b90ed7f55d905175bc26771d6d7d33b40b46742f073675bc816fedaf482ea1
3
+ size 15522763
tokenizer_config.json ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "200000": {
5
+ "content": "]!p~[",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "200001": {
13
+ "content": "<fim_prefix>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "200002": {
21
+ "content": "<fim_middle>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "200003": {
29
+ "content": "<fim_suffix>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "200004": {
37
+ "content": "<fim_pad>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "200005": {
45
+ "content": "<reponame>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "200006": {
53
+ "content": "<filename>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "200007": {
61
+ "content": "<gh_stars>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "200008": {
69
+ "content": "<issue_start>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "200009": {
77
+ "content": "<issue_comment>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "200010": {
85
+ "content": "<issue_closed>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "200011": {
93
+ "content": "<jupyter_start>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "200012": {
101
+ "content": "<jupyter_text>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "200013": {
109
+ "content": "<jupyter_code>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "200014": {
117
+ "content": "<jupyter_output>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "200015": {
125
+ "content": "<empty_output>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "200016": {
133
+ "content": "<commit_before>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "200017": {
141
+ "content": "<commit_msg>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "200018": {
149
+ "content": "<commit_after>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "200019": {
157
+ "content": "]~b]",
158
+ "lstrip": false,
159
+ "normalized": false,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "200020": {
165
+ "content": "[e~[",
166
+ "lstrip": false,
167
+ "normalized": false,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "200021": {
173
+ "content": "]!d~[",
174
+ "lstrip": false,
175
+ "normalized": false,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "200022": {
181
+ "content": "<function_call>",
182
+ "lstrip": false,
183
+ "normalized": false,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "200023": {
189
+ "content": "<code_interpreter>",
190
+ "lstrip": false,
191
+ "normalized": false,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "200024": {
197
+ "content": "]<]speech[>[",
198
+ "lstrip": false,
199
+ "normalized": false,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "200025": {
205
+ "content": "]<]image[>[",
206
+ "lstrip": false,
207
+ "normalized": false,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "200026": {
213
+ "content": "]<]video[>[",
214
+ "lstrip": false,
215
+ "normalized": false,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "200027": {
221
+ "content": "]<]start of speech[>[",
222
+ "lstrip": false,
223
+ "normalized": false,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "200028": {
229
+ "content": "]<]end of speech[>[",
230
+ "lstrip": false,
231
+ "normalized": false,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "200029": {
237
+ "content": "]<]start of image[>[",
238
+ "lstrip": false,
239
+ "normalized": false,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "200030": {
245
+ "content": "]<]end of image[>[",
246
+ "lstrip": false,
247
+ "normalized": false,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "200031": {
253
+ "content": "]<]start of video[>[",
254
+ "lstrip": false,
255
+ "normalized": false,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "200032": {
261
+ "content": "]<]end of video[>[",
262
+ "lstrip": false,
263
+ "normalized": false,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "200033": {
269
+ "content": "]<]vision pad[>[",
270
+ "lstrip": false,
271
+ "normalized": false,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "200034": {
277
+ "content": "]~!b[",
278
+ "lstrip": false,
279
+ "normalized": false,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "200035": {
285
+ "content": "<jupyter_error>",
286
+ "lstrip": false,
287
+ "normalized": false,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "200036": {
293
+ "content": "<add_file>",
294
+ "lstrip": false,
295
+ "normalized": false,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "200037": {
301
+ "content": "<delete_file>",
302
+ "lstrip": false,
303
+ "normalized": false,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": true
307
+ },
308
+ "200038": {
309
+ "content": "<rename_file>",
310
+ "lstrip": false,
311
+ "normalized": false,
312
+ "rstrip": false,
313
+ "single_word": false,
314
+ "special": true
315
+ },
316
+ "200039": {
317
+ "content": "<edit_file>",
318
+ "lstrip": false,
319
+ "normalized": false,
320
+ "rstrip": false,
321
+ "single_word": false,
322
+ "special": true
323
+ },
324
+ "200040": {
325
+ "content": "<commit_message>",
326
+ "lstrip": false,
327
+ "normalized": false,
328
+ "rstrip": false,
329
+ "single_word": false,
330
+ "special": true
331
+ },
332
+ "200041": {
333
+ "content": "<empty_source_file>",
334
+ "lstrip": false,
335
+ "normalized": false,
336
+ "rstrip": false,
337
+ "single_word": false,
338
+ "special": true
339
+ },
340
+ "200042": {
341
+ "content": "<repo_struct>",
342
+ "lstrip": false,
343
+ "normalized": false,
344
+ "rstrip": false,
345
+ "single_word": false,
346
+ "special": true
347
+ },
348
+ "200043": {
349
+ "content": "<code_context>",
350
+ "lstrip": false,
351
+ "normalized": false,
352
+ "rstrip": false,
353
+ "single_word": false,
354
+ "special": true
355
+ },
356
+ "200044": {
357
+ "content": "<file_content>",
358
+ "lstrip": false,
359
+ "normalized": false,
360
+ "rstrip": false,
361
+ "single_word": false,
362
+ "special": true
363
+ },
364
+ "200045": {
365
+ "content": "<source_files>",
366
+ "lstrip": false,
367
+ "normalized": false,
368
+ "rstrip": false,
369
+ "single_word": false,
370
+ "special": true
371
+ },
372
+ "200046": {
373
+ "content": "<pr_start>",
374
+ "lstrip": false,
375
+ "normalized": false,
376
+ "rstrip": false,
377
+ "single_word": false,
378
+ "special": true
379
+ },
380
+ "200047": {
381
+ "content": "<review_comment>",
382
+ "lstrip": false,
383
+ "normalized": false,
384
+ "rstrip": false,
385
+ "single_word": false,
386
+ "special": true
387
+ },
388
+ "200048": {
389
+ "content": "<filepath>",
390
+ "lstrip": false,
391
+ "normalized": false,
392
+ "rstrip": false,
393
+ "single_word": false,
394
+ "special": true
395
+ },
396
+ "200049": {
397
+ "content": "<file_sep>",
398
+ "lstrip": false,
399
+ "normalized": false,
400
+ "rstrip": false,
401
+ "single_word": false,
402
+ "special": true
403
+ },
404
+ "200050": {
405
+ "content": "<think>",
406
+ "lstrip": false,
407
+ "normalized": false,
408
+ "rstrip": false,
409
+ "single_word": false,
410
+ "special": false
411
+ },
412
+ "200051": {
413
+ "content": "</think>",
414
+ "lstrip": false,
415
+ "normalized": false,
416
+ "rstrip": false,
417
+ "single_word": false,
418
+ "special": false
419
+ },
420
+ "200052": {
421
+ "content": "<minimax:tool_call>",
422
+ "lstrip": false,
423
+ "normalized": false,
424
+ "rstrip": false,
425
+ "single_word": false,
426
+ "special": false
427
+ },
428
+ "200053": {
429
+ "content": "</minimax:tool_call>",
430
+ "lstrip": false,
431
+ "normalized": false,
432
+ "rstrip": false,
433
+ "single_word": false,
434
+ "special": false
435
+ }
436
+ },
437
+ "additional_special_tokens": [
438
+ "<code_interpreter>",
439
+ "<commit_after>",
440
+ "<commit_before>",
441
+ "<commit_msg>",
442
+ "<empty_output>",
443
+ "<filename>",
444
+ "<fim_middle>",
445
+ "<fim_pad>",
446
+ "<fim_prefix>",
447
+ "<fim_suffix>",
448
+ "<function_call>",
449
+ "<gh_stars>",
450
+ "]<]speech[>[",
451
+ "]<]image[>[",
452
+ "]<]video[>[",
453
+ "]<]start of speech[>[",
454
+ "]<]end of speech[>[",
455
+ "]<]start of image[>[",
456
+ "]<]end of image[>[",
457
+ "]<]start of video[>[",
458
+ "]<]end of video[>[",
459
+ "]<]vision pad[>[",
460
+ "]~!b[",
461
+ "<issue_closed>",
462
+ "<issue_comment>",
463
+ "<issue_start>",
464
+ "<jupyter_code>",
465
+ "<jupyter_output>",
466
+ "<jupyter_start>",
467
+ "<jupyter_text>",
468
+ "<reponame>",
469
+ "[e~[",
470
+ "]!d~[",
471
+ "]!p~[",
472
+ "]~b]",
473
+ "<jupyter_error>",
474
+ "<add_file>",
475
+ "<delete_file>",
476
+ "<rename_file>",
477
+ "<edit_file>",
478
+ "<commit_message>",
479
+ "<empty_source_file>",
480
+ "<repo_struct>",
481
+ "<code_context>",
482
+ "<file_content>",
483
+ "<source_files>",
484
+ "<pr_start>",
485
+ "<review_comment>",
486
+ "<filepath>",
487
+ "<file_sep>"
488
+ ],
489
+ "bos_token": "]~!b[",
490
+ "clean_up_tokenization_spaces": false,
491
+ "eos_token": "[e~[",
492
+ "extra_special_tokens": {},
493
+ "model_max_length": 40960000,
494
+ "tokenizer_class": "GPT2Tokenizer",
495
+ "unk_token": "]!d~["
496
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff