Jackmin108 commited on
Commit
45a6ce4
·
1 Parent(s): 98b2028

modeling stuff

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ pipeline_tag: text-generation
4
+ library_name: transformers
5
+ tags:
6
+ - vllm
7
+ ---
8
+
9
+ <p align="center">
10
+ <img alt="gpt-oss-20b" src="https://raw.githubusercontent.com/openai/gpt-oss/main/docs/gpt-oss-20b.svg">
11
+ </p>
12
+
13
+ <p align="center">
14
+ <a href="https://gpt-oss.com"><strong>Try gpt-oss</strong></a> ·
15
+ <a href="https://cookbook.openai.com/topic/gpt-oss"><strong>Guides</strong></a> ·
16
+ <a href="https://arxiv.org/abs/2508.10925"><strong>Model card</strong></a> ·
17
+ <a href="https://openai.com/index/introducing-gpt-oss/"><strong>OpenAI blog</strong></a>
18
+ </p>
19
+
20
+ <br>
21
+
22
+ Welcome to the gpt-oss series, [OpenAI’s open-weight models](https://openai.com/open-models) designed for powerful reasoning, agentic tasks, and versatile developer use cases.
23
+
24
+ We’re releasing two flavors of these open models:
25
+ - `gpt-oss-120b` — for production, general purpose, high reasoning use cases that fit into a single 80GB GPU (like NVIDIA H100 or AMD MI300X) (117B parameters with 5.1B active parameters)
26
+ - `gpt-oss-20b` — for lower latency, and local or specialized use cases (21B parameters with 3.6B active parameters)
27
+
28
+ Both models were trained on our [harmony response format](https://github.com/openai/harmony) and should only be used with the harmony format as it will not work correctly otherwise.
29
+
30
+
31
+ > [!NOTE]
32
+ > This model card is dedicated to the smaller `gpt-oss-20b` model. Check out [`gpt-oss-120b`](https://huggingface.co/openai/gpt-oss-120b) for the larger model.
33
+
34
+ # Highlights
35
+
36
+ * **Permissive Apache 2.0 license:** Build freely without copyleft restrictions or patent risk—ideal for experimentation, customization, and commercial deployment.
37
+ * **Configurable reasoning effort:** Easily adjust the reasoning effort (low, medium, high) based on your specific use case and latency needs.
38
+ * **Full chain-of-thought:** Gain complete access to the model’s reasoning process, facilitating easier debugging and increased trust in outputs. It’s not intended to be shown to end users.
39
+ * **Fine-tunable:** Fully customize models to your specific use case through parameter fine-tuning.
40
+ * **Agentic capabilities:** Use the models’ native capabilities for function calling, [web browsing](https://github.com/openai/gpt-oss/tree/main?tab=readme-ov-file#browser), [Python code execution](https://github.com/openai/gpt-oss/tree/main?tab=readme-ov-file#python), and Structured Outputs.
41
+ * **MXFP4 quantization:** The models were post-trained with MXFP4 quantization of the MoE weights, making `gpt-oss-120b` run on a single 80GB GPU (like NVIDIA H100 or AMD MI300X) and the `gpt-oss-20b` model run within 16GB of memory. All evals were performed with the same MXFP4 quantization.
42
+
43
+ ---
44
+
45
+ # Inference examples
46
+
47
+ ## Transformers
48
+
49
+ You can use `gpt-oss-120b` and `gpt-oss-20b` with Transformers. If you use the Transformers chat template, it will automatically apply the [harmony response format](https://github.com/openai/harmony). If you use `model.generate` directly, you need to apply the harmony format manually using the chat template or use our [openai-harmony](https://github.com/openai/harmony) package.
50
+
51
+ To get started, install the necessary dependencies to setup your environment:
52
+
53
+ ```
54
+ pip install -U transformers kernels torch
55
+ ```
56
+
57
+ Once, setup you can proceed to run the model by running the snippet below:
58
+
59
+ ```py
60
+ from transformers import pipeline
61
+ import torch
62
+
63
+ model_id = "openai/gpt-oss-20b"
64
+
65
+ pipe = pipeline(
66
+ "text-generation",
67
+ model=model_id,
68
+ torch_dtype="auto",
69
+ device_map="auto",
70
+ )
71
+
72
+ messages = [
73
+ {"role": "user", "content": "Explain quantum mechanics clearly and concisely."},
74
+ ]
75
+
76
+ outputs = pipe(
77
+ messages,
78
+ max_new_tokens=256,
79
+ )
80
+ print(outputs[0]["generated_text"][-1])
81
+ ```
82
+
83
+ Alternatively, you can run the model via [`Transformers Serve`](https://huggingface.co/docs/transformers/main/serving) to spin up a OpenAI-compatible webserver:
84
+
85
+ ```
86
+ transformers serve
87
+ transformers chat localhost:8000 --model-name-or-path openai/gpt-oss-20b
88
+ ```
89
+
90
+ [Learn more about how to use gpt-oss with Transformers.](https://cookbook.openai.com/articles/gpt-oss/run-transformers)
91
+
92
+ ## vLLM
93
+
94
+ vLLM recommends using [uv](https://docs.astral.sh/uv/) for Python dependency management. You can use vLLM to spin up an OpenAI-compatible webserver. The following command will automatically download the model and start the server.
95
+
96
+ ```bash
97
+ uv pip install --pre vllm==0.10.1+gptoss \
98
+ --extra-index-url https://wheels.vllm.ai/gpt-oss/ \
99
+ --extra-index-url https://download.pytorch.org/whl/nightly/cu128 \
100
+ --index-strategy unsafe-best-match
101
+
102
+ vllm serve openai/gpt-oss-20b
103
+ ```
104
+
105
+ [Learn more about how to use gpt-oss with vLLM.](https://cookbook.openai.com/articles/gpt-oss/run-vllm)
106
+
107
+ ## PyTorch / Triton
108
+
109
+ To learn about how to use this model with PyTorch and Triton, check out our [reference implementations in the gpt-oss repository](https://github.com/openai/gpt-oss?tab=readme-ov-file#reference-pytorch-implementation).
110
+
111
+ ## Ollama
112
+
113
+ If you are trying to run gpt-oss on consumer hardware, you can use Ollama by running the following commands after [installing Ollama](https://ollama.com/download).
114
+
115
+ ```bash
116
+ # gpt-oss-20b
117
+ ollama pull gpt-oss:20b
118
+ ollama run gpt-oss:20b
119
+ ```
120
+
121
+ [Learn more about how to use gpt-oss with Ollama.](https://cookbook.openai.com/articles/gpt-oss/run-locally-ollama)
122
+
123
+ #### LM Studio
124
+
125
+ If you are using [LM Studio](https://lmstudio.ai/) you can use the following commands to download.
126
+
127
+ ```bash
128
+ # gpt-oss-20b
129
+ lms get openai/gpt-oss-20b
130
+ ```
131
+
132
+ Check out our [awesome list](https://github.com/openai/gpt-oss/blob/main/awesome-gpt-oss.md) for a broader collection of gpt-oss resources and inference partners.
133
+
134
+ ---
135
+
136
+ # Download the model
137
+
138
+ You can download the model weights from the [Hugging Face Hub](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) directly from Hugging Face CLI:
139
+
140
+ ```shell
141
+ # gpt-oss-20b
142
+ huggingface-cli download openai/gpt-oss-20b --include "original/*" --local-dir gpt-oss-20b/
143
+ pip install gpt-oss
144
+ python -m gpt_oss.chat model/
145
+ ```
146
+
147
+ # Reasoning levels
148
+
149
+ You can adjust the reasoning level that suits your task across three levels:
150
+
151
+ * **Low:** Fast responses for general dialogue.
152
+ * **Medium:** Balanced speed and detail.
153
+ * **High:** Deep and detailed analysis.
154
+
155
+ The reasoning level can be set in the system prompts, e.g., "Reasoning: high".
156
+
157
+ # Tool use
158
+
159
+ The gpt-oss models are excellent for:
160
+ * Web browsing (using built-in browsing tools)
161
+ * Function calling with defined schemas
162
+ * Agentic operations like browser tasks
163
+
164
+ # Fine-tuning
165
+
166
+ Both gpt-oss models can be fine-tuned for a variety of specialized use cases.
167
+
168
+ This smaller model `gpt-oss-20b` can be fine-tuned on consumer hardware, whereas the larger [`gpt-oss-120b`](https://huggingface.co/openai/gpt-oss-120b) can be fine-tuned on a single H100 node.
169
+
170
+ # Citation
171
+
172
+ ```bibtex
173
+ @misc{openai2025gptoss120bgptoss20bmodel,
174
+ title={gpt-oss-120b & gpt-oss-20b Model Card},
175
+ author={OpenAI},
176
+ year={2025},
177
+ eprint={2508.10925},
178
+ archivePrefix={arXiv},
179
+ primaryClass={cs.CL},
180
+ url={https://arxiv.org/abs/2508.10925},
181
+ }
182
+ ```
USAGE_POLICY ADDED
@@ -0,0 +1 @@
 
 
1
+ We aim for our tools to be used safely, responsibly, and democratically, while maximizing your control over how you use them. By using OpenAI gpt-oss-20b, you agree to comply with all applicable law.
chat_template.jinja ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {#-
2
+ In addition to the normal inputs of `messages` and `tools`, this template also accepts the
3
+ following kwargs:
4
+ - "builtin_tools": A list, can contain "browser" and/or "python".
5
+ - "model_identity": A string that optionally describes the model identity.
6
+ - "reasoning_effort": A string that describes the reasoning effort, defaults to "medium".
7
+ #}
8
+
9
+ {#- Tool Definition Rendering ============================================== #}
10
+ {%- macro render_typescript_type(param_spec, required_params, is_nullable=false) -%}
11
+ {%- if param_spec.type == "array" -%}
12
+ {%- if param_spec['items'] -%}
13
+ {%- if param_spec['items']['type'] == "string" -%}
14
+ {{- "string[]" }}
15
+ {%- elif param_spec['items']['type'] == "number" -%}
16
+ {{- "number[]" }}
17
+ {%- elif param_spec['items']['type'] == "integer" -%}
18
+ {{- "number[]" }}
19
+ {%- elif param_spec['items']['type'] == "boolean" -%}
20
+ {{- "boolean[]" }}
21
+ {%- else -%}
22
+ {%- set inner_type = render_typescript_type(param_spec['items'], required_params) -%}
23
+ {%- if inner_type == "object | object" or inner_type|length > 50 -%}
24
+ {{- "any[]" }}
25
+ {%- else -%}
26
+ {{- inner_type + "[]" }}
27
+ {%- endif -%}
28
+ {%- endif -%}
29
+ {%- if param_spec.nullable -%}
30
+ {{- " | null" }}
31
+ {%- endif -%}
32
+ {%- else -%}
33
+ {{- "any[]" }}
34
+ {%- if param_spec.nullable -%}
35
+ {{- " | null" }}
36
+ {%- endif -%}
37
+ {%- endif -%}
38
+ {%- elif param_spec.type is defined and param_spec.type is iterable and param_spec.type is not string and param_spec.type is not mapping and param_spec.type[0] is defined -%}
39
+ {#- Handle array of types like ["object", "object"] from Union[dict, list] #}
40
+ {%- if param_spec.type | length > 1 -%}
41
+ {{- param_spec.type | join(" | ") }}
42
+ {%- else -%}
43
+ {{- param_spec.type[0] }}
44
+ {%- endif -%}
45
+ {%- elif param_spec.oneOf -%}
46
+ {#- Handle oneOf schemas - check for complex unions and fallback to any #}
47
+ {%- set has_object_variants = false -%}
48
+ {%- for variant in param_spec.oneOf -%}
49
+ {%- if variant.type == "object" -%}
50
+ {%- set has_object_variants = true -%}
51
+ {%- endif -%}
52
+ {%- endfor -%}
53
+ {%- if has_object_variants and param_spec.oneOf|length > 1 -%}
54
+ {{- "any" }}
55
+ {%- else -%}
56
+ {%- for variant in param_spec.oneOf -%}
57
+ {{- render_typescript_type(variant, required_params) -}}
58
+ {%- if variant.description %}
59
+ {{- "// " + variant.description }}
60
+ {%- endif -%}
61
+ {%- if variant.default is defined %}
62
+ {{ "// default: " + variant.default|tojson }}
63
+ {%- endif -%}
64
+ {%- if not loop.last %}
65
+ {{- " | " }}
66
+ {% endif -%}
67
+ {%- endfor -%}
68
+ {%- endif -%}
69
+ {%- elif param_spec.type == "string" -%}
70
+ {%- if param_spec.enum -%}
71
+ {{- '"' + param_spec.enum|join('" | "') + '"' -}}
72
+ {%- else -%}
73
+ {{- "string" }}
74
+ {%- if param_spec.nullable %}
75
+ {{- " | null" }}
76
+ {%- endif -%}
77
+ {%- endif -%}
78
+ {%- elif param_spec.type == "number" -%}
79
+ {{- "number" }}
80
+ {%- elif param_spec.type == "integer" -%}
81
+ {{- "number" }}
82
+ {%- elif param_spec.type == "boolean" -%}
83
+ {{- "boolean" }}
84
+
85
+ {%- elif param_spec.type == "object" -%}
86
+ {%- if param_spec.properties -%}
87
+ {{- "{\n" }}
88
+ {%- for prop_name, prop_spec in param_spec.properties.items() -%}
89
+ {{- prop_name -}}
90
+ {%- if prop_name not in (param_spec.required or []) -%}
91
+ {{- "?" }}
92
+ {%- endif -%}
93
+ {{- ": " }}
94
+ {{ render_typescript_type(prop_spec, param_spec.required or []) }}
95
+ {%- if not loop.last -%}
96
+ {{-", " }}
97
+ {%- endif -%}
98
+ {%- endfor -%}
99
+ {{- "}" }}
100
+ {%- else -%}
101
+ {{- "object" }}
102
+ {%- endif -%}
103
+ {%- else -%}
104
+ {{- "any" }}
105
+ {%- endif -%}
106
+ {%- endmacro -%}
107
+
108
+ {%- macro render_tool_namespace(namespace_name, tools) -%}
109
+ {{- "## " + namespace_name + "\n\n" }}
110
+ {{- "namespace " + namespace_name + " {\n\n" }}
111
+ {%- for tool in tools %}
112
+ {%- set tool = tool.function %}
113
+ {{- "// " + tool.description + "\n" }}
114
+ {{- "type "+ tool.name + " = " }}
115
+ {%- if tool.parameters and tool.parameters.properties %}
116
+ {{- "(_: {\n" }}
117
+ {%- for param_name, param_spec in tool.parameters.properties.items() %}
118
+ {%- if param_spec.description %}
119
+ {{- "// " + param_spec.description + "\n" }}
120
+ {%- endif %}
121
+ {{- param_name }}
122
+ {%- if param_name not in (tool.parameters.required or []) -%}
123
+ {{- "?" }}
124
+ {%- endif -%}
125
+ {{- ": " }}
126
+ {{- render_typescript_type(param_spec, tool.parameters.required or []) }}
127
+ {%- if param_spec.default is defined -%}
128
+ {%- if param_spec.enum %}
129
+ {{- ", // default: " + param_spec.default }}
130
+ {%- elif param_spec.oneOf %}
131
+ {{- "// default: " + param_spec.default }}
132
+ {%- else %}
133
+ {{- ", // default: " + param_spec.default|tojson }}
134
+ {%- endif -%}
135
+ {%- endif -%}
136
+ {%- if not loop.last %}
137
+ {{- ",\n" }}
138
+ {%- else %}
139
+ {{- ",\n" }}
140
+ {%- endif -%}
141
+ {%- endfor %}
142
+ {{- "}) => any;\n\n" }}
143
+ {%- else -%}
144
+ {{- "() => any;\n\n" }}
145
+ {%- endif -%}
146
+ {%- endfor %}
147
+ {{- "} // namespace " + namespace_name }}
148
+ {%- endmacro -%}
149
+
150
+ {%- macro render_builtin_tools(browser_tool, python_tool) -%}
151
+ {%- if browser_tool %}
152
+ {{- "## browser\n\n" }}
153
+ {{- "// Tool for browsing.\n" }}
154
+ {{- "// The `cursor` appears in brackets before each browsing display: `[{cursor}]`.\n" }}
155
+ {{- "// Cite information from the tool using the following format:\n" }}
156
+ {{- "// `【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.\n" }}
157
+ {{- "// Do not quote more than 10 words directly from the tool output.\n" }}
158
+ {{- "// sources=web (default: web)\n" }}
159
+ {{- "namespace browser {\n\n" }}
160
+ {{- "// Searches for information related to `query` and displays `topn` results.\n" }}
161
+ {{- "type search = (_: {\n" }}
162
+ {{- "query: string,\n" }}
163
+ {{- "topn?: number, // default: 10\n" }}
164
+ {{- "source?: string,\n" }}
165
+ {{- "}) => any;\n\n" }}
166
+ {{- "// Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.\n" }}
167
+ {{- "// Valid link ids are displayed with the formatting: `【{id}†.*】`.\n" }}
168
+ {{- "// If `cursor` is not provided, the most recent page is implied.\n" }}
169
+ {{- "// If `id` is a string, it is treated as a fully qualified URL associated with `source`.\n" }}
170
+ {{- "// If `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.\n" }}
171
+ {{- "// Use this function without `id` to scroll to a new location of an opened page.\n" }}
172
+ {{- "type open = (_: {\n" }}
173
+ {{- "id?: number | string, // default: -1\n" }}
174
+ {{- "cursor?: number, // default: -1\n" }}
175
+ {{- "loc?: number, // default: -1\n" }}
176
+ {{- "num_lines?: number, // default: -1\n" }}
177
+ {{- "view_source?: boolean, // default: false\n" }}
178
+ {{- "source?: string,\n" }}
179
+ {{- "}) => any;\n\n" }}
180
+ {{- "// Finds exact matches of `pattern` in the current page, or the page given by `cursor`.\n" }}
181
+ {{- "type find = (_: {\n" }}
182
+ {{- "pattern: string,\n" }}
183
+ {{- "cursor?: number, // default: -1\n" }}
184
+ {{- "}) => any;\n\n" }}
185
+ {{- "} // namespace browser\n\n" }}
186
+ {%- endif -%}
187
+
188
+ {%- if python_tool %}
189
+ {{- "## python\n\n" }}
190
+ {{- "Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).\n\n" }}
191
+ {{- "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster.\n\n" }}
192
+ {%- endif -%}
193
+ {%- endmacro -%}
194
+
195
+ {#- System Message Construction ============================================ #}
196
+ {%- macro build_system_message() -%}
197
+ {%- if model_identity is not defined %}
198
+ {%- set model_identity = "You are ChatGPT, a large language model trained by OpenAI." %}
199
+ {%- endif %}
200
+ {{- model_identity + "\n" }}
201
+ {{- "Knowledge cutoff: 2024-06\n" }}
202
+ {{- "Current date: " + strftime_now("%Y-%m-%d") + "\n\n" }}
203
+ {%- if reasoning_effort is not defined %}
204
+ {%- set reasoning_effort = "medium" %}
205
+ {%- endif %}
206
+ {{- "Reasoning: " + reasoning_effort + "\n\n" }}
207
+ {%- if builtin_tools %}
208
+ {{- "# Tools\n\n" }}
209
+ {%- set available_builtin_tools = namespace(browser=false, python=false) %}
210
+ {%- for tool in builtin_tools %}
211
+ {%- if tool == "browser" %}
212
+ {%- set available_builtin_tools.browser = true %}
213
+ {%- elif tool == "python" %}
214
+ {%- set available_builtin_tools.python = true %}
215
+ {%- endif %}
216
+ {%- endfor %}
217
+ {{- render_builtin_tools(available_builtin_tools.browser, available_builtin_tools.python) }}
218
+ {%- endif -%}
219
+ {{- "# Valid channels: analysis, commentary, final. Channel must be included for every message." }}
220
+ {%- if tools -%}
221
+ {{- "\nCalls to these tools must go to the commentary channel: 'functions'." }}
222
+ {%- endif -%}
223
+ {%- endmacro -%}
224
+
225
+ {#- Main Template Logic ================================================= #}
226
+ {#- Set defaults #}
227
+
228
+ {#- Render system message #}
229
+ {{- "<|start|>system<|message|>" }}
230
+ {{- build_system_message() }}
231
+ {{- "<|end|>" }}
232
+
233
+ {#- Extract developer message #}
234
+ {%- if messages[0].role == "developer" or messages[0].role == "system" %}
235
+ {%- set developer_message = messages[0].content %}
236
+ {%- set loop_messages = messages[1:] %}
237
+ {%- else %}
238
+ {%- set developer_message = "" %}
239
+ {%- set loop_messages = messages %}
240
+ {%- endif %}
241
+
242
+ {#- Render developer message #}
243
+ {%- if developer_message or tools %}
244
+ {{- "<|start|>developer<|message|>" }}
245
+ {%- if developer_message %}
246
+ {{- "# Instructions\n\n" }}
247
+ {{- developer_message }}
248
+ {{- "\n\n" }}
249
+ {%- endif %}
250
+ {%- if tools -%}
251
+ {{- "# Tools\n\n" }}
252
+ {{- render_tool_namespace("functions", tools) }}
253
+ {%- endif -%}
254
+ {{- "<|end|>" }}
255
+ {%- endif %}
256
+
257
+ {#- Render messages #}
258
+ {%- set last_tool_call = namespace(name=none) %}
259
+ {%- for message in loop_messages -%}
260
+ {#- At this point only assistant/user/tool messages should remain #}
261
+ {%- if message.role == 'assistant' -%}
262
+ {#- Checks to ensure the messages are being passed in the format we expect #}
263
+ {%- if "content" in message %}
264
+ {%- if "<|channel|>analysis<|message|>" in message.content or "<|channel|>final<|message|>" in message.content %}
265
+ {{- raise_exception("You have passed a message containing <|channel|> tags in the content field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.") }}
266
+ {%- endif %}
267
+ {%- endif %}
268
+ {%- if "thinking" in message %}
269
+ {%- if "<|channel|>analysis<|message|>" in message.thinking or "<|channel|>final<|message|>" in message.thinking %}
270
+ {{- raise_exception("You have passed a message containing <|channel|> tags in the thinking field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.") }}
271
+ {%- endif %}
272
+ {%- endif %}
273
+ {%- if "tool_calls" in message %}
274
+ {#- We need very careful handling here - we want to drop the tool call analysis message if the model #}
275
+ {#- has output a later <|final|> message, but otherwise we want to retain it. This is the only case #}
276
+ {#- when we render CoT/analysis messages in inference. #}
277
+ {%- set future_final_message = namespace(found=false) %}
278
+ {%- for future_message in loop_messages[loop.index:] %}
279
+ {%- if future_message.role == 'assistant' and "tool_calls" not in future_message %}
280
+ {%- set future_final_message.found = true %}
281
+ {%- endif %}
282
+ {%- endfor %}
283
+ {#- We assume max 1 tool call per message, and so we infer the tool call name #}
284
+ {#- in "tool" messages from the most recent assistant tool call name #}
285
+ {%- set tool_call = message.tool_calls[0] %}
286
+ {%- if tool_call.function %}
287
+ {%- set tool_call = tool_call.function %}
288
+ {%- endif %}
289
+ {%- if message.content and message.thinking %}
290
+ {{- raise_exception("Cannot pass both content and thinking in an assistant message with tool calls! Put the analysis message in one or the other, but not both.") }}
291
+ {%- elif message.content and not future_final_message.found %}
292
+ {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content + "<|end|>" }}
293
+ {%- elif message.thinking and not future_final_message.found %}
294
+ {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
295
+ {%- endif %}
296
+ {{- "<|start|>assistant to=" }}
297
+ {{- "functions." + tool_call.name + "<|channel|>commentary " }}
298
+ {{- (tool_call.content_type if tool_call.content_type is defined else "json") + "<|message|>" }}
299
+ {{- tool_call.arguments|tojson }}
300
+ {{- "<|call|>" }}
301
+ {%- set last_tool_call.name = tool_call.name %}
302
+ {%- elif loop.last and not add_generation_prompt %}
303
+ {#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #}
304
+ {#- This is a situation that should only occur in training, never in inference. #}
305
+ {%- if "thinking" in message %}
306
+ {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
307
+ {%- endif %}
308
+ {#- <|return|> indicates the end of generation, but <|end|> does not #}
309
+ {#- <|return|> should never be an input to the model, but we include it as the final token #}
310
+ {#- when training, so the model learns to emit it. #}
311
+ {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|return|>" }}
312
+ {%- else %}
313
+ {#- CoT is dropped during all previous turns, so we never render it for inference #}
314
+ {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }}
315
+ {%- set last_tool_call.name = none %}
316
+ {%- endif %}
317
+ {%- elif message.role == 'tool' -%}
318
+ {%- if last_tool_call.name is none %}
319
+ {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }}
320
+ {%- endif %}
321
+ {{- "<|start|>functions." + last_tool_call.name }}
322
+ {{- " to=assistant<|channel|>commentary<|message|>" + message.content|tojson + "<|end|>" }}
323
+ {%- elif message.role == 'user' -%}
324
+ {{- "<|start|>user<|message|>" + message.content + "<|end|>" }}
325
+ {%- endif -%}
326
+ {%- endfor -%}
327
+
328
+ {#- Generation prompt #}
329
+ {%- if add_generation_prompt -%}
330
+ <|start|>assistant
331
+ {%- endif -%}
config.json ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GptOssForCausalLM"
4
+ ],
5
+ "attention_bias": true,
6
+ "attention_dropout": 0.0,
7
+ "eos_token_id": 200002,
8
+ "experts_per_token": 4,
9
+ "head_dim": 64,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 2880,
12
+ "initial_context_length": 4096,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 2880,
15
+ "layer_types": [
16
+ "sliding_attention",
17
+ "full_attention",
18
+ "sliding_attention",
19
+ "full_attention",
20
+ "sliding_attention",
21
+ "full_attention",
22
+ "sliding_attention",
23
+ "full_attention",
24
+ "sliding_attention",
25
+ "full_attention",
26
+ "sliding_attention",
27
+ "full_attention",
28
+ "sliding_attention",
29
+ "full_attention",
30
+ "sliding_attention",
31
+ "full_attention",
32
+ "sliding_attention",
33
+ "full_attention",
34
+ "sliding_attention",
35
+ "full_attention",
36
+ "sliding_attention",
37
+ "full_attention",
38
+ "sliding_attention",
39
+ "full_attention"
40
+ ],
41
+ "max_position_embeddings": 131072,
42
+ "model_type": "gpt_oss",
43
+ "num_attention_heads": 64,
44
+ "num_experts_per_tok": 4,
45
+ "num_hidden_layers": 24,
46
+ "num_key_value_heads": 8,
47
+ "num_local_experts": 32,
48
+ "output_router_logits": false,
49
+ "pad_token_id": 199999,
50
+ "quantization_config": {
51
+ "modules_to_not_convert": [
52
+ "model.layers.*.self_attn",
53
+ "model.layers.*.mlp.router",
54
+ "model.embed_tokens",
55
+ "lm_head"
56
+ ],
57
+ "quant_method": "mxfp4"
58
+ },
59
+ "rms_norm_eps": 1e-05,
60
+ "rope_scaling": {
61
+ "beta_fast": 32.0,
62
+ "beta_slow": 1.0,
63
+ "factor": 32.0,
64
+ "original_max_position_embeddings": 4096,
65
+ "rope_type": "yarn",
66
+ "truncate": false
67
+ },
68
+ "rope_theta": 150000,
69
+ "router_aux_loss_coef": 0.9,
70
+ "sliding_window": 128,
71
+ "swiglu_limit": 7.0,
72
+ "tie_word_embeddings": false,
73
+ "transformers_version": "4.55.0.dev0",
74
+ "use_cache": true,
75
+ "vocab_size": 201088
76
+ }
configuration_gpt_oss.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """openai model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig, layer_type_validation
18
+ from transformers.modeling_rope_utils import rope_config_validation
19
+
20
+
21
+ class GptOssConfig(PretrainedConfig):
22
+ r"""
23
+ This will yield a configuration to that of the BERT
24
+ [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture.
25
+
26
+ """
27
+
28
+ model_type = "gpt_oss"
29
+ base_model_pp_plan = {
30
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
31
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
32
+ "norm": (["hidden_states"], ["hidden_states"]),
33
+ }
34
+ base_model_tp_plan = {
35
+ "layers.*.self_attn.q_proj": "colwise",
36
+ "layers.*.self_attn.k_proj": "colwise",
37
+ "layers.*.self_attn.v_proj": "colwise",
38
+ "layers.*.self_attn.o_proj": "rowwise",
39
+ "layers.*.self_attn.sinks": "local_rowwise",
40
+ "layers.*.mlp.experts": "gather",
41
+ "layers.*.mlp.router": "ep_router",
42
+ "layers.*.mlp.experts.gate_up_proj": "grouped_gemm",
43
+ "layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm",
44
+ "layers.*.mlp.experts.down_proj": "grouped_gemm",
45
+ "layers.*.mlp.experts.down_proj_bias": "grouped_gemm",
46
+ }
47
+
48
+ def __init__(
49
+ self,
50
+ num_hidden_layers: int = 36,
51
+ num_local_experts: int = 128,
52
+ vocab_size: int = 201088,
53
+ hidden_size: int = 2880,
54
+ intermediate_size: int = 2880,
55
+ head_dim: int = 64,
56
+ num_attention_heads: int = 64,
57
+ num_key_value_heads: int = 8,
58
+ sliding_window: int = 128,
59
+ rope_theta: float = 150000.0,
60
+ tie_word_embeddings=False,
61
+ hidden_act: str = "silu",
62
+ initializer_range: float = 0.02,
63
+ max_position_embeddings=131072,
64
+ rms_norm_eps: float = 1e-5,
65
+ rope_scaling={
66
+ "rope_type": "yarn",
67
+ "factor": 32.0,
68
+ "beta_fast": 32.0,
69
+ "beta_slow": 1.0,
70
+ "truncate": False,
71
+ "original_max_position_embeddings": 4096,
72
+ },
73
+ attention_dropout: float = 0.0,
74
+ num_experts_per_tok=4,
75
+ router_aux_loss_coef: float = 0.9,
76
+ output_router_logits=False,
77
+ use_cache=True,
78
+ layer_types=None,
79
+ **kwargs,
80
+ ):
81
+ self.vocab_size = vocab_size
82
+ self.hidden_size = hidden_size
83
+ self.intermediate_size = intermediate_size
84
+ self.num_hidden_layers = num_hidden_layers
85
+ self.num_attention_heads = num_attention_heads
86
+ self.num_local_experts = num_local_experts
87
+ self.sliding_window = sliding_window
88
+ self.num_experts_per_tok = num_experts_per_tok
89
+ # for backward compatibility
90
+ if num_key_value_heads is None:
91
+ num_key_value_heads = num_attention_heads
92
+
93
+ self.num_key_value_heads = num_key_value_heads
94
+ self.hidden_act = hidden_act
95
+ self.initializer_range = initializer_range
96
+ self.rms_norm_eps = rms_norm_eps
97
+ self.rope_theta = rope_theta
98
+ self.rope_scaling = rope_scaling
99
+ self.attention_dropout = attention_dropout
100
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
101
+ self.layer_types = layer_types
102
+ if self.layer_types is None:
103
+ self.layer_types = [
104
+ "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
105
+ ]
106
+ layer_type_validation(self.layer_types)
107
+
108
+ self.attention_bias = True
109
+ self.max_position_embeddings = max_position_embeddings
110
+ self.router_aux_loss_coef = router_aux_loss_coef
111
+ self.output_router_logits = output_router_logits
112
+ self.use_cache = use_cache
113
+
114
+ # Validate the correctness of rotary position embeddings parameters
115
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
116
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
117
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
118
+ rope_config_validation(self)
119
+
120
+ super().__init__(
121
+ tie_word_embeddings=tie_word_embeddings,
122
+ **kwargs,
123
+ )
124
+
125
+
126
+ __all__ = ["GptOssConfig"]
127
+
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 199998,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 200002,
6
+ 199999,
7
+ 200012
8
+ ],
9
+ "pad_token_id": 199999,
10
+ "transformers_version": "4.55.0.dev0"
11
+ }
model.safetensors.index.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"metadata": {"total_size": 13761264768}, "weight_map": {"model.layers.0.input_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.0.self_attn.o_proj.bias": "model-00000-of-00002.safetensors", "model.layers.0.self_attn.o_proj.weight": "model-00000-of-00002.safetensors", "model.layers.0.self_attn.q_proj.bias": "model-00000-of-00002.safetensors", "model.layers.0.self_attn.k_proj.bias": "model-00000-of-00002.safetensors", "model.layers.0.self_attn.v_proj.bias": "model-00000-of-00002.safetensors", "model.layers.0.self_attn.q_proj.weight": "model-00000-of-00002.safetensors", "model.layers.0.self_attn.k_proj.weight": "model-00000-of-00002.safetensors", "model.layers.0.self_attn.v_proj.weight": "model-00000-of-00002.safetensors", "model.layers.0.self_attn.sinks": "model-00000-of-00002.safetensors", "model.layers.0.mlp.router.bias": "model-00000-of-00002.safetensors", "model.layers.0.mlp.router.weight": "model-00000-of-00002.safetensors", "model.layers.0.mlp.experts.gate_up_proj_bias": "model-00000-of-00002.safetensors", "model.layers.0.mlp.experts.gate_up_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.0.mlp.experts.gate_up_proj_scales": "model-00000-of-00002.safetensors", "model.layers.0.mlp.experts.down_proj_bias": "model-00000-of-00002.safetensors", "model.layers.0.mlp.experts.down_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.0.mlp.experts.down_proj_scales": "model-00000-of-00002.safetensors", "model.layers.0.post_attention_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.1.input_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.1.self_attn.o_proj.bias": "model-00000-of-00002.safetensors", "model.layers.1.self_attn.o_proj.weight": "model-00000-of-00002.safetensors", "model.layers.1.self_attn.q_proj.bias": "model-00000-of-00002.safetensors", "model.layers.1.self_attn.k_proj.bias": "model-00000-of-00002.safetensors", "model.layers.1.self_attn.v_proj.bias": "model-00000-of-00002.safetensors", "model.layers.1.self_attn.q_proj.weight": "model-00000-of-00002.safetensors", "model.layers.1.self_attn.k_proj.weight": "model-00000-of-00002.safetensors", "model.layers.1.self_attn.v_proj.weight": "model-00000-of-00002.safetensors", "model.layers.1.self_attn.sinks": "model-00000-of-00002.safetensors", "model.layers.1.mlp.router.bias": "model-00000-of-00002.safetensors", "model.layers.1.mlp.router.weight": "model-00000-of-00002.safetensors", "model.layers.1.mlp.experts.gate_up_proj_bias": "model-00000-of-00002.safetensors", "model.layers.1.mlp.experts.gate_up_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.1.mlp.experts.gate_up_proj_scales": "model-00000-of-00002.safetensors", "model.layers.1.mlp.experts.down_proj_bias": "model-00000-of-00002.safetensors", "model.layers.1.mlp.experts.down_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.1.mlp.experts.down_proj_scales": "model-00000-of-00002.safetensors", "model.layers.1.post_attention_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.10.input_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.10.self_attn.o_proj.bias": "model-00000-of-00002.safetensors", "model.layers.10.self_attn.o_proj.weight": "model-00000-of-00002.safetensors", "model.layers.10.self_attn.q_proj.bias": "model-00000-of-00002.safetensors", "model.layers.10.self_attn.k_proj.bias": "model-00000-of-00002.safetensors", "model.layers.10.self_attn.v_proj.bias": "model-00000-of-00002.safetensors", "model.layers.10.self_attn.q_proj.weight": "model-00000-of-00002.safetensors", "model.layers.10.self_attn.k_proj.weight": "model-00000-of-00002.safetensors", "model.layers.10.self_attn.v_proj.weight": "model-00000-of-00002.safetensors", "model.layers.10.self_attn.sinks": "model-00000-of-00002.safetensors", "model.layers.10.mlp.router.bias": "model-00000-of-00002.safetensors", "model.layers.10.mlp.router.weight": "model-00000-of-00002.safetensors", "model.layers.10.mlp.experts.gate_up_proj_bias": "model-00000-of-00002.safetensors", "model.layers.10.mlp.experts.gate_up_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.10.mlp.experts.gate_up_proj_scales": "model-00000-of-00002.safetensors", "model.layers.10.mlp.experts.down_proj_bias": "model-00000-of-00002.safetensors", "model.layers.10.mlp.experts.down_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.10.mlp.experts.down_proj_scales": "model-00000-of-00002.safetensors", "model.layers.10.post_attention_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.11.input_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.11.self_attn.o_proj.bias": "model-00000-of-00002.safetensors", "model.layers.11.self_attn.o_proj.weight": "model-00000-of-00002.safetensors", "model.layers.11.self_attn.q_proj.bias": "model-00000-of-00002.safetensors", "model.layers.11.self_attn.k_proj.bias": "model-00000-of-00002.safetensors", "model.layers.11.self_attn.v_proj.bias": "model-00000-of-00002.safetensors", "model.layers.11.self_attn.q_proj.weight": "model-00000-of-00002.safetensors", "model.layers.11.self_attn.k_proj.weight": "model-00000-of-00002.safetensors", "model.layers.11.self_attn.v_proj.weight": "model-00000-of-00002.safetensors", "model.layers.11.self_attn.sinks": "model-00000-of-00002.safetensors", "model.layers.11.mlp.router.bias": "model-00000-of-00002.safetensors", "model.layers.11.mlp.router.weight": "model-00000-of-00002.safetensors", "model.layers.11.mlp.experts.gate_up_proj_bias": "model-00000-of-00002.safetensors", "model.layers.11.mlp.experts.gate_up_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.11.mlp.experts.gate_up_proj_scales": "model-00000-of-00002.safetensors", "model.layers.11.mlp.experts.down_proj_bias": "model-00000-of-00002.safetensors", "model.layers.11.mlp.experts.down_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.11.mlp.experts.down_proj_scales": "model-00000-of-00002.safetensors", "model.layers.11.post_attention_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.12.input_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.12.self_attn.o_proj.bias": "model-00000-of-00002.safetensors", "model.layers.12.self_attn.o_proj.weight": "model-00000-of-00002.safetensors", "model.layers.12.self_attn.q_proj.bias": "model-00000-of-00002.safetensors", "model.layers.12.self_attn.k_proj.bias": "model-00000-of-00002.safetensors", "model.layers.12.self_attn.v_proj.bias": "model-00000-of-00002.safetensors", "model.layers.12.self_attn.q_proj.weight": "model-00000-of-00002.safetensors", "model.layers.12.self_attn.k_proj.weight": "model-00000-of-00002.safetensors", "model.layers.12.self_attn.v_proj.weight": "model-00000-of-00002.safetensors", "model.layers.12.self_attn.sinks": "model-00000-of-00002.safetensors", "model.layers.12.mlp.router.bias": "model-00000-of-00002.safetensors", "model.layers.12.mlp.router.weight": "model-00000-of-00002.safetensors", "model.layers.12.mlp.experts.gate_up_proj_bias": "model-00000-of-00002.safetensors", "model.layers.12.mlp.experts.gate_up_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.12.mlp.experts.gate_up_proj_scales": "model-00000-of-00002.safetensors", "model.layers.12.mlp.experts.down_proj_bias": "model-00000-of-00002.safetensors", "model.layers.12.mlp.experts.down_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.12.mlp.experts.down_proj_scales": "model-00000-of-00002.safetensors", "model.layers.12.post_attention_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.13.input_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.13.self_attn.o_proj.bias": "model-00000-of-00002.safetensors", "model.layers.13.self_attn.o_proj.weight": "model-00000-of-00002.safetensors", "model.layers.13.self_attn.q_proj.bias": "model-00000-of-00002.safetensors", "model.layers.13.self_attn.k_proj.bias": "model-00000-of-00002.safetensors", "model.layers.13.self_attn.v_proj.bias": "model-00000-of-00002.safetensors", "model.layers.13.self_attn.q_proj.weight": "model-00000-of-00002.safetensors", "model.layers.13.self_attn.k_proj.weight": "model-00000-of-00002.safetensors", "model.layers.13.self_attn.v_proj.weight": "model-00000-of-00002.safetensors", "model.layers.13.self_attn.sinks": "model-00000-of-00002.safetensors", "model.layers.13.mlp.router.bias": "model-00000-of-00002.safetensors", "model.layers.13.mlp.router.weight": "model-00000-of-00002.safetensors", "model.layers.13.mlp.experts.gate_up_proj_bias": "model-00000-of-00002.safetensors", "model.layers.13.mlp.experts.gate_up_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.13.mlp.experts.gate_up_proj_scales": "model-00000-of-00002.safetensors", "model.layers.13.mlp.experts.down_proj_bias": "model-00000-of-00002.safetensors", "model.layers.13.mlp.experts.down_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.13.mlp.experts.down_proj_scales": "model-00000-of-00002.safetensors", "model.layers.13.post_attention_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.14.input_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.14.self_attn.o_proj.bias": "model-00000-of-00002.safetensors", "model.layers.14.self_attn.o_proj.weight": "model-00000-of-00002.safetensors", "model.layers.14.self_attn.q_proj.bias": "model-00000-of-00002.safetensors", "model.layers.14.self_attn.k_proj.bias": "model-00000-of-00002.safetensors", "model.layers.14.self_attn.v_proj.bias": "model-00000-of-00002.safetensors", "model.layers.14.self_attn.q_proj.weight": "model-00000-of-00002.safetensors", "model.layers.14.self_attn.k_proj.weight": "model-00000-of-00002.safetensors", "model.layers.14.self_attn.v_proj.weight": "model-00000-of-00002.safetensors", "model.layers.14.self_attn.sinks": "model-00000-of-00002.safetensors", "model.layers.14.mlp.router.bias": "model-00000-of-00002.safetensors", "model.layers.14.mlp.router.weight": "model-00000-of-00002.safetensors", "model.layers.14.mlp.experts.gate_up_proj_bias": "model-00000-of-00002.safetensors", "model.layers.14.mlp.experts.gate_up_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.14.mlp.experts.gate_up_proj_scales": "model-00000-of-00002.safetensors", "model.layers.14.mlp.experts.down_proj_bias": "model-00000-of-00002.safetensors", "model.layers.14.mlp.experts.down_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.14.mlp.experts.down_proj_scales": "model-00000-of-00002.safetensors", "model.layers.14.post_attention_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.15.input_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.15.self_attn.o_proj.bias": "model-00000-of-00002.safetensors", "model.layers.15.self_attn.o_proj.weight": "model-00000-of-00002.safetensors", "model.layers.15.self_attn.q_proj.bias": "model-00000-of-00002.safetensors", "model.layers.15.self_attn.k_proj.bias": "model-00000-of-00002.safetensors", "model.layers.15.self_attn.v_proj.bias": "model-00000-of-00002.safetensors", "model.layers.15.self_attn.q_proj.weight": "model-00000-of-00002.safetensors", "model.layers.15.self_attn.k_proj.weight": "model-00000-of-00002.safetensors", "model.layers.15.self_attn.v_proj.weight": "model-00000-of-00002.safetensors", "model.layers.15.self_attn.sinks": "model-00000-of-00002.safetensors", "model.layers.15.mlp.router.bias": "model-00000-of-00002.safetensors", "model.layers.15.mlp.router.weight": "model-00000-of-00002.safetensors", "model.layers.15.mlp.experts.gate_up_proj_bias": "model-00000-of-00002.safetensors", "model.layers.15.mlp.experts.gate_up_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.15.mlp.experts.gate_up_proj_scales": "model-00000-of-00002.safetensors", "model.layers.15.mlp.experts.down_proj_bias": "model-00000-of-00002.safetensors", "model.layers.15.mlp.experts.down_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.15.mlp.experts.down_proj_scales": "model-00000-of-00002.safetensors", "model.layers.15.post_attention_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.16.input_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.16.self_attn.o_proj.bias": "model-00000-of-00002.safetensors", "model.layers.16.self_attn.o_proj.weight": "model-00000-of-00002.safetensors", "model.layers.16.self_attn.q_proj.bias": "model-00000-of-00002.safetensors", "model.layers.16.self_attn.k_proj.bias": "model-00000-of-00002.safetensors", "model.layers.16.self_attn.v_proj.bias": "model-00000-of-00002.safetensors", "model.layers.16.self_attn.q_proj.weight": "model-00000-of-00002.safetensors", "model.layers.16.self_attn.k_proj.weight": "model-00000-of-00002.safetensors", "model.layers.16.self_attn.v_proj.weight": "model-00000-of-00002.safetensors", "model.layers.16.self_attn.sinks": "model-00000-of-00002.safetensors", "model.layers.16.mlp.router.bias": "model-00000-of-00002.safetensors", "model.layers.16.mlp.router.weight": "model-00000-of-00002.safetensors", "model.layers.16.mlp.experts.gate_up_proj_bias": "model-00000-of-00002.safetensors", "model.layers.16.mlp.experts.gate_up_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.16.mlp.experts.gate_up_proj_scales": "model-00000-of-00002.safetensors", "model.layers.16.mlp.experts.down_proj_bias": "model-00000-of-00002.safetensors", "model.layers.16.mlp.experts.down_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.16.mlp.experts.down_proj_scales": "model-00000-of-00002.safetensors", "model.layers.16.post_attention_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.17.input_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.17.self_attn.o_proj.bias": "model-00000-of-00002.safetensors", "model.layers.17.self_attn.o_proj.weight": "model-00000-of-00002.safetensors", "model.layers.17.self_attn.q_proj.bias": "model-00000-of-00002.safetensors", "model.layers.17.self_attn.k_proj.bias": "model-00000-of-00002.safetensors", "model.layers.17.self_attn.v_proj.bias": "model-00000-of-00002.safetensors", "model.layers.17.self_attn.q_proj.weight": "model-00000-of-00002.safetensors", "model.layers.17.self_attn.k_proj.weight": "model-00000-of-00002.safetensors", "model.layers.17.self_attn.v_proj.weight": "model-00000-of-00002.safetensors", "model.layers.17.self_attn.sinks": "model-00000-of-00002.safetensors", "model.layers.17.mlp.router.bias": "model-00000-of-00002.safetensors", "model.layers.17.mlp.router.weight": "model-00000-of-00002.safetensors", "model.layers.17.mlp.experts.gate_up_proj_bias": "model-00000-of-00002.safetensors", "model.layers.17.mlp.experts.gate_up_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.17.mlp.experts.gate_up_proj_scales": "model-00000-of-00002.safetensors", "model.layers.17.mlp.experts.down_proj_bias": "model-00000-of-00002.safetensors", "model.layers.17.mlp.experts.down_proj_blocks": "model-00000-of-00002.safetensors", "model.layers.17.mlp.experts.down_proj_scales": "model-00000-of-00002.safetensors", "model.layers.17.post_attention_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.18.input_layernorm.weight": "model-00000-of-00002.safetensors", "model.layers.18.self_attn.o_proj.bias": "model-00000-of-00002.safetensors", "model.layers.18.self_attn.o_proj.weight": "model-00000-of-00002.safetensors", "model.layers.18.self_attn.q_proj.bias": "model-00000-of-00002.safetensors", "model.layers.18.self_attn.k_proj.bias": "model-00000-of-00002.safetensors", "model.layers.18.self_attn.v_proj.bias": "model-00000-of-00002.safetensors", "model.layers.18.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.18.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.18.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.18.self_attn.sinks": "model-00001-of-00002.safetensors", "model.layers.18.mlp.router.bias": "model-00001-of-00002.safetensors", "model.layers.18.mlp.router.weight": "model-00001-of-00002.safetensors", "model.layers.18.mlp.experts.gate_up_proj_bias": "model-00001-of-00002.safetensors", "model.layers.18.mlp.experts.gate_up_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.18.mlp.experts.gate_up_proj_scales": "model-00001-of-00002.safetensors", "model.layers.18.mlp.experts.down_proj_bias": "model-00001-of-00002.safetensors", "model.layers.18.mlp.experts.down_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.18.mlp.experts.down_proj_scales": "model-00001-of-00002.safetensors", "model.layers.18.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.19.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.19.self_attn.o_proj.bias": "model-00001-of-00002.safetensors", "model.layers.19.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.19.self_attn.q_proj.bias": "model-00001-of-00002.safetensors", "model.layers.19.self_attn.k_proj.bias": "model-00001-of-00002.safetensors", "model.layers.19.self_attn.v_proj.bias": "model-00001-of-00002.safetensors", "model.layers.19.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.19.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.19.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.19.self_attn.sinks": "model-00001-of-00002.safetensors", "model.layers.19.mlp.router.bias": "model-00001-of-00002.safetensors", "model.layers.19.mlp.router.weight": "model-00001-of-00002.safetensors", "model.layers.19.mlp.experts.gate_up_proj_bias": "model-00001-of-00002.safetensors", "model.layers.19.mlp.experts.gate_up_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.19.mlp.experts.gate_up_proj_scales": "model-00001-of-00002.safetensors", "model.layers.19.mlp.experts.down_proj_bias": "model-00001-of-00002.safetensors", "model.layers.19.mlp.experts.down_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.19.mlp.experts.down_proj_scales": "model-00001-of-00002.safetensors", "model.layers.19.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.2.self_attn.o_proj.bias": "model-00001-of-00002.safetensors", "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.2.self_attn.q_proj.bias": "model-00001-of-00002.safetensors", "model.layers.2.self_attn.k_proj.bias": "model-00001-of-00002.safetensors", "model.layers.2.self_attn.v_proj.bias": "model-00001-of-00002.safetensors", "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.2.self_attn.sinks": "model-00001-of-00002.safetensors", "model.layers.2.mlp.router.bias": "model-00001-of-00002.safetensors", "model.layers.2.mlp.router.weight": "model-00001-of-00002.safetensors", "model.layers.2.mlp.experts.gate_up_proj_bias": "model-00001-of-00002.safetensors", "model.layers.2.mlp.experts.gate_up_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.2.mlp.experts.gate_up_proj_scales": "model-00001-of-00002.safetensors", "model.layers.2.mlp.experts.down_proj_bias": "model-00001-of-00002.safetensors", "model.layers.2.mlp.experts.down_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.2.mlp.experts.down_proj_scales": "model-00001-of-00002.safetensors", "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.20.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.20.self_attn.o_proj.bias": "model-00001-of-00002.safetensors", "model.layers.20.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.20.self_attn.q_proj.bias": "model-00001-of-00002.safetensors", "model.layers.20.self_attn.k_proj.bias": "model-00001-of-00002.safetensors", "model.layers.20.self_attn.v_proj.bias": "model-00001-of-00002.safetensors", "model.layers.20.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.20.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.20.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.20.self_attn.sinks": "model-00001-of-00002.safetensors", "model.layers.20.mlp.router.bias": "model-00001-of-00002.safetensors", "model.layers.20.mlp.router.weight": "model-00001-of-00002.safetensors", "model.layers.20.mlp.experts.gate_up_proj_bias": "model-00001-of-00002.safetensors", "model.layers.20.mlp.experts.gate_up_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.20.mlp.experts.gate_up_proj_scales": "model-00001-of-00002.safetensors", "model.layers.20.mlp.experts.down_proj_bias": "model-00001-of-00002.safetensors", "model.layers.20.mlp.experts.down_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.20.mlp.experts.down_proj_scales": "model-00001-of-00002.safetensors", "model.layers.20.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.21.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.21.self_attn.o_proj.bias": "model-00001-of-00002.safetensors", "model.layers.21.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.21.self_attn.q_proj.bias": "model-00001-of-00002.safetensors", "model.layers.21.self_attn.k_proj.bias": "model-00001-of-00002.safetensors", "model.layers.21.self_attn.v_proj.bias": "model-00001-of-00002.safetensors", "model.layers.21.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.21.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.21.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.21.self_attn.sinks": "model-00001-of-00002.safetensors", "model.layers.21.mlp.router.bias": "model-00001-of-00002.safetensors", "model.layers.21.mlp.router.weight": "model-00001-of-00002.safetensors", "model.layers.21.mlp.experts.gate_up_proj_bias": "model-00001-of-00002.safetensors", "model.layers.21.mlp.experts.gate_up_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.21.mlp.experts.gate_up_proj_scales": "model-00001-of-00002.safetensors", "model.layers.21.mlp.experts.down_proj_bias": "model-00001-of-00002.safetensors", "model.layers.21.mlp.experts.down_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.21.mlp.experts.down_proj_scales": "model-00001-of-00002.safetensors", "model.layers.21.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.22.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.22.self_attn.o_proj.bias": "model-00001-of-00002.safetensors", "model.layers.22.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.22.self_attn.q_proj.bias": "model-00001-of-00002.safetensors", "model.layers.22.self_attn.k_proj.bias": "model-00001-of-00002.safetensors", "model.layers.22.self_attn.v_proj.bias": "model-00001-of-00002.safetensors", "model.layers.22.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.22.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.22.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.22.self_attn.sinks": "model-00001-of-00002.safetensors", "model.layers.22.mlp.router.bias": "model-00001-of-00002.safetensors", "model.layers.22.mlp.router.weight": "model-00001-of-00002.safetensors", "model.layers.22.mlp.experts.gate_up_proj_bias": "model-00001-of-00002.safetensors", "model.layers.22.mlp.experts.gate_up_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.22.mlp.experts.gate_up_proj_scales": "model-00001-of-00002.safetensors", "model.layers.22.mlp.experts.down_proj_bias": "model-00001-of-00002.safetensors", "model.layers.22.mlp.experts.down_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.22.mlp.experts.down_proj_scales": "model-00001-of-00002.safetensors", "model.layers.22.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.23.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.23.self_attn.o_proj.bias": "model-00001-of-00002.safetensors", "model.layers.23.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.23.self_attn.q_proj.bias": "model-00001-of-00002.safetensors", "model.layers.23.self_attn.k_proj.bias": "model-00001-of-00002.safetensors", "model.layers.23.self_attn.v_proj.bias": "model-00001-of-00002.safetensors", "model.layers.23.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.23.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.23.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.23.self_attn.sinks": "model-00001-of-00002.safetensors", "model.layers.23.mlp.router.bias": "model-00001-of-00002.safetensors", "model.layers.23.mlp.router.weight": "model-00001-of-00002.safetensors", "model.layers.23.mlp.experts.gate_up_proj_bias": "model-00001-of-00002.safetensors", "model.layers.23.mlp.experts.gate_up_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.23.mlp.experts.gate_up_proj_scales": "model-00001-of-00002.safetensors", "model.layers.23.mlp.experts.down_proj_bias": "model-00001-of-00002.safetensors", "model.layers.23.mlp.experts.down_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.23.mlp.experts.down_proj_scales": "model-00001-of-00002.safetensors", "model.layers.23.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.3.self_attn.o_proj.bias": "model-00001-of-00002.safetensors", "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.3.self_attn.q_proj.bias": "model-00001-of-00002.safetensors", "model.layers.3.self_attn.k_proj.bias": "model-00001-of-00002.safetensors", "model.layers.3.self_attn.v_proj.bias": "model-00001-of-00002.safetensors", "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.3.self_attn.sinks": "model-00001-of-00002.safetensors", "model.layers.3.mlp.router.bias": "model-00001-of-00002.safetensors", "model.layers.3.mlp.router.weight": "model-00001-of-00002.safetensors", "model.layers.3.mlp.experts.gate_up_proj_bias": "model-00001-of-00002.safetensors", "model.layers.3.mlp.experts.gate_up_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.3.mlp.experts.gate_up_proj_scales": "model-00001-of-00002.safetensors", "model.layers.3.mlp.experts.down_proj_bias": "model-00001-of-00002.safetensors", "model.layers.3.mlp.experts.down_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.3.mlp.experts.down_proj_scales": "model-00001-of-00002.safetensors", "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.4.self_attn.o_proj.bias": "model-00001-of-00002.safetensors", "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.4.self_attn.q_proj.bias": "model-00001-of-00002.safetensors", "model.layers.4.self_attn.k_proj.bias": "model-00001-of-00002.safetensors", "model.layers.4.self_attn.v_proj.bias": "model-00001-of-00002.safetensors", "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.4.self_attn.sinks": "model-00001-of-00002.safetensors", "model.layers.4.mlp.router.bias": "model-00001-of-00002.safetensors", "model.layers.4.mlp.router.weight": "model-00001-of-00002.safetensors", "model.layers.4.mlp.experts.gate_up_proj_bias": "model-00001-of-00002.safetensors", "model.layers.4.mlp.experts.gate_up_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.4.mlp.experts.gate_up_proj_scales": "model-00001-of-00002.safetensors", "model.layers.4.mlp.experts.down_proj_bias": "model-00001-of-00002.safetensors", "model.layers.4.mlp.experts.down_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.4.mlp.experts.down_proj_scales": "model-00001-of-00002.safetensors", "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.5.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.5.self_attn.o_proj.bias": "model-00001-of-00002.safetensors", "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.5.self_attn.q_proj.bias": "model-00001-of-00002.safetensors", "model.layers.5.self_attn.k_proj.bias": "model-00001-of-00002.safetensors", "model.layers.5.self_attn.v_proj.bias": "model-00001-of-00002.safetensors", "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.5.self_attn.sinks": "model-00001-of-00002.safetensors", "model.layers.5.mlp.router.bias": "model-00001-of-00002.safetensors", "model.layers.5.mlp.router.weight": "model-00001-of-00002.safetensors", "model.layers.5.mlp.experts.gate_up_proj_bias": "model-00001-of-00002.safetensors", "model.layers.5.mlp.experts.gate_up_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.5.mlp.experts.gate_up_proj_scales": "model-00001-of-00002.safetensors", "model.layers.5.mlp.experts.down_proj_bias": "model-00001-of-00002.safetensors", "model.layers.5.mlp.experts.down_proj_blocks": "model-00001-of-00002.safetensors", "model.layers.5.mlp.experts.down_proj_scales": "model-00001-of-00002.safetensors", "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors", "model.layers.6.self_attn.o_proj.bias": "model-00001-of-00002.safetensors", "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00002.safetensors", "model.layers.6.self_attn.q_proj.bias": "model-00001-of-00002.safetensors", "model.layers.6.self_attn.k_proj.bias": "model-00001-of-00002.safetensors", "model.layers.6.self_attn.v_proj.bias": "model-00001-of-00002.safetensors", "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00002.safetensors", "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00002.safetensors", "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00002.safetensors", "model.layers.6.self_attn.sinks": "model-00001-of-00002.safetensors", "model.layers.6.mlp.router.bias": "model-00001-of-00002.safetensors", "model.layers.6.mlp.router.weight": "model-00001-of-00002.safetensors", "model.layers.6.mlp.experts.gate_up_proj_bias": "model-00001-of-00002.safetensors", "model.layers.6.mlp.experts.gate_up_proj_blocks": "model-00002-of-00002.safetensors", "model.layers.6.mlp.experts.gate_up_proj_scales": "model-00002-of-00002.safetensors", "model.layers.6.mlp.experts.down_proj_bias": "model-00002-of-00002.safetensors", "model.layers.6.mlp.experts.down_proj_blocks": "model-00002-of-00002.safetensors", "model.layers.6.mlp.experts.down_proj_scales": "model-00002-of-00002.safetensors", "model.layers.6.post_attention_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.7.input_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.7.self_attn.o_proj.bias": "model-00002-of-00002.safetensors", "model.layers.7.self_attn.o_proj.weight": "model-00002-of-00002.safetensors", "model.layers.7.self_attn.q_proj.bias": "model-00002-of-00002.safetensors", "model.layers.7.self_attn.k_proj.bias": "model-00002-of-00002.safetensors", "model.layers.7.self_attn.v_proj.bias": "model-00002-of-00002.safetensors", "model.layers.7.self_attn.q_proj.weight": "model-00002-of-00002.safetensors", "model.layers.7.self_attn.k_proj.weight": "model-00002-of-00002.safetensors", "model.layers.7.self_attn.v_proj.weight": "model-00002-of-00002.safetensors", "model.layers.7.self_attn.sinks": "model-00002-of-00002.safetensors", "model.layers.7.mlp.router.bias": "model-00002-of-00002.safetensors", "model.layers.7.mlp.router.weight": "model-00002-of-00002.safetensors", "model.layers.7.mlp.experts.gate_up_proj_bias": "model-00002-of-00002.safetensors", "model.layers.7.mlp.experts.gate_up_proj_blocks": "model-00002-of-00002.safetensors", "model.layers.7.mlp.experts.gate_up_proj_scales": "model-00002-of-00002.safetensors", "model.layers.7.mlp.experts.down_proj_bias": "model-00002-of-00002.safetensors", "model.layers.7.mlp.experts.down_proj_blocks": "model-00002-of-00002.safetensors", "model.layers.7.mlp.experts.down_proj_scales": "model-00002-of-00002.safetensors", "model.layers.7.post_attention_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.8.input_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.8.self_attn.o_proj.bias": "model-00002-of-00002.safetensors", "model.layers.8.self_attn.o_proj.weight": "model-00002-of-00002.safetensors", "model.layers.8.self_attn.q_proj.bias": "model-00002-of-00002.safetensors", "model.layers.8.self_attn.k_proj.bias": "model-00002-of-00002.safetensors", "model.layers.8.self_attn.v_proj.bias": "model-00002-of-00002.safetensors", "model.layers.8.self_attn.q_proj.weight": "model-00002-of-00002.safetensors", "model.layers.8.self_attn.k_proj.weight": "model-00002-of-00002.safetensors", "model.layers.8.self_attn.v_proj.weight": "model-00002-of-00002.safetensors", "model.layers.8.self_attn.sinks": "model-00002-of-00002.safetensors", "model.layers.8.mlp.router.bias": "model-00002-of-00002.safetensors", "model.layers.8.mlp.router.weight": "model-00002-of-00002.safetensors", "model.layers.8.mlp.experts.gate_up_proj_bias": "model-00002-of-00002.safetensors", "model.layers.8.mlp.experts.gate_up_proj_blocks": "model-00002-of-00002.safetensors", "model.layers.8.mlp.experts.gate_up_proj_scales": "model-00002-of-00002.safetensors", "model.layers.8.mlp.experts.down_proj_bias": "model-00002-of-00002.safetensors", "model.layers.8.mlp.experts.down_proj_blocks": "model-00002-of-00002.safetensors", "model.layers.8.mlp.experts.down_proj_scales": "model-00002-of-00002.safetensors", "model.layers.8.post_attention_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.9.input_layernorm.weight": "model-00002-of-00002.safetensors", "model.layers.9.self_attn.o_proj.bias": "model-00002-of-00002.safetensors", "model.layers.9.self_attn.o_proj.weight": "model-00002-of-00002.safetensors", "model.layers.9.self_attn.q_proj.bias": "model-00002-of-00002.safetensors", "model.layers.9.self_attn.k_proj.bias": "model-00002-of-00002.safetensors", "model.layers.9.self_attn.v_proj.bias": "model-00002-of-00002.safetensors", "model.layers.9.self_attn.q_proj.weight": "model-00002-of-00002.safetensors", "model.layers.9.self_attn.k_proj.weight": "model-00002-of-00002.safetensors", "model.layers.9.self_attn.v_proj.weight": "model-00002-of-00002.safetensors", "model.layers.9.self_attn.sinks": "model-00002-of-00002.safetensors", "model.layers.9.mlp.router.bias": "model-00002-of-00002.safetensors", "model.layers.9.mlp.router.weight": "model-00002-of-00002.safetensors", "model.layers.9.mlp.experts.gate_up_proj_bias": "model-00002-of-00002.safetensors", "model.layers.9.mlp.experts.gate_up_proj_blocks": "model-00002-of-00002.safetensors", "model.layers.9.mlp.experts.gate_up_proj_scales": "model-00002-of-00002.safetensors", "model.layers.9.mlp.experts.down_proj_bias": "model-00002-of-00002.safetensors", "model.layers.9.mlp.experts.down_proj_blocks": "model-00002-of-00002.safetensors", "model.layers.9.mlp.experts.down_proj_scales": "model-00002-of-00002.safetensors", "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00002.safetensors", "model.embed_tokens.weight": "model-00002-of-00002.safetensors", "model.norm.weight": "model-00002-of-00002.safetensors", "lm_head.weight": "model-00002-of-00002.safetensors"}}
modeling_gpt_oss.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨��🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨��🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/gpt_oss/modular_gpt_oss.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_gpt_oss.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨��🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨��🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ from typing import Callable, Optional, Union
22
+
23
+ import torch
24
+ from torch import nn
25
+ from torch.nn import functional as F
26
+
27
+ from transformers.cache_utils import Cache, DynamicCache
28
+ from transformers.generation import GenerationMixin
29
+ from transformers.integrations.hub_kernels import use_kernel_forward_from_hub
30
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
31
+ from transformers.modeling_layers import (
32
+ GenericForSequenceClassification,
33
+ GenericForTokenClassification,
34
+ GradientCheckpointingLayer,
35
+ )
36
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
37
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
+ from transformers.processing_utils import Unpack
40
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
41
+ from transformers.utils.deprecation import deprecate_kwarg
42
+ from transformers.utils.generic import OutputRecorder, check_model_inputs
43
+ from .configuration_gpt_oss import GptOssConfig
44
+
45
+
46
+ @use_kernel_forward_from_hub("RMSNorm")
47
+ class GptOssRMSNorm(nn.Module):
48
+ def __init__(self, hidden_size, eps=1e-6):
49
+ """
50
+ GptOssRMSNorm is equivalent to T5LayerNorm
51
+ """
52
+ super().__init__()
53
+ self.weight = nn.Parameter(torch.ones(hidden_size))
54
+ self.variance_epsilon = eps
55
+
56
+ def forward(self, hidden_states):
57
+ input_dtype = hidden_states.dtype
58
+ hidden_states = hidden_states.to(torch.float32)
59
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
60
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
61
+ return (self.weight * hidden_states).to(input_dtype) # main diff with Llama
62
+
63
+ def extra_repr(self):
64
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
65
+
66
+
67
+ class GptOssExperts(nn.Module):
68
+ def __init__(self, config):
69
+ super().__init__()
70
+ self.intermediate_size = config.intermediate_size
71
+ self.num_experts = config.num_local_experts
72
+ self.hidden_size = config.hidden_size
73
+ self.expert_dim = self.intermediate_size
74
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
75
+ self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim))
76
+ self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
77
+ self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size))
78
+ self.alpha = 1.702
79
+ self.limit = 7.0
80
+
81
+ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
82
+ """
83
+ When training it is more efficient to just loop over the experts and compute the output for each expert
84
+ as otherwise the memory would explode.
85
+
86
+ For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
87
+
88
+ Args:
89
+ hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
90
+ selected_experts (torch.Tensor): (batch_size * token_num, top_k)
91
+ routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
92
+ Returns:
93
+ torch.Tensor
94
+ """
95
+ batch_size = hidden_states.shape[0]
96
+ hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
97
+ num_experts = routing_weights.shape[1]
98
+ if hidden_states.device.type == "cpu" or self.training:
99
+ next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
100
+ with torch.no_grad():
101
+ expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
102
+ expert_mask = expert_mask.permute(2, 1, 0)
103
+ # we sum on the top_k and on the sequence length to get which experts
104
+ # are hit this time around
105
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
106
+ for expert_idx in expert_hit[:]:
107
+ # expert_idx only have 1 element, so we can use scale for fast indexing
108
+ expert_idx = expert_idx[0]
109
+ with torch.no_grad():
110
+ _, token_idx = torch.where(expert_mask[expert_idx])
111
+ current_state = hidden_states[token_idx]
112
+ gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
113
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
114
+ gate = gate.clamp(min=None, max=self.limit)
115
+ up = up.clamp(min=-self.limit, max=self.limit)
116
+ glu = gate * torch.sigmoid(gate * self.alpha)
117
+ gated_output = (up + 1) * glu
118
+ out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
119
+ weighted_output = out * routing_weights[token_idx, expert_idx, None]
120
+ next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
121
+ next_states = next_states.view(batch_size, -1, self.hidden_size)
122
+ else:
123
+ hidden_states = hidden_states.repeat(num_experts, 1)
124
+ hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
125
+ gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
126
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
127
+ gate = gate.clamp(min=None, max=self.limit)
128
+ up = up.clamp(min=-self.limit, max=self.limit)
129
+ glu = gate * torch.sigmoid(gate * self.alpha)
130
+ next_states = torch.bmm(((up + 1) * glu), self.down_proj)
131
+ next_states = next_states + self.down_proj_bias[..., None, :]
132
+ next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
133
+ next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
134
+ next_states = next_states.sum(dim=0)
135
+ return next_states
136
+
137
+
138
+ class GptOssTopKRouter(nn.Module):
139
+ def __init__(self, config):
140
+ super().__init__()
141
+ self.top_k = config.num_experts_per_tok
142
+ self.num_experts = config.num_local_experts
143
+ self.hidden_dim = config.hidden_size
144
+ self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
145
+ self.bias = nn.Parameter(torch.empty(self.num_experts))
146
+
147
+ def forward(self, hidden_states):
148
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
149
+ router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
150
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
151
+ router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
152
+ router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
153
+ return router_scores, router_indices
154
+
155
+
156
+ @use_kernel_forward_from_hub("MegaBlocksMoeMLP")
157
+ class GptOssMLP(nn.Module):
158
+ def __init__(self, config):
159
+ super().__init__()
160
+ self.router = GptOssTopKRouter(config)
161
+ self.experts = GptOssExperts(config)
162
+
163
+ def forward(self, hidden_states):
164
+ router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len)
165
+ routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
166
+ return routed_out, router_scores
167
+
168
+
169
+ class GptOssRotaryEmbedding(nn.Module):
170
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
171
+
172
+ def __init__(self, config: GptOssConfig, device=None):
173
+ super().__init__()
174
+ # BC: "rope_type" was originally "type"
175
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
176
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
177
+ else:
178
+ self.rope_type = "default"
179
+ self.max_seq_len_cached = config.max_position_embeddings
180
+ self.original_max_seq_len = config.max_position_embeddings
181
+
182
+ self.config = config
183
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
184
+
185
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
186
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
187
+ self.original_inv_freq = self.inv_freq
188
+
189
+ @torch.no_grad()
190
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
191
+ def forward(self, x, position_ids):
192
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
193
+ position_ids_expanded = position_ids[:, None, :].float()
194
+
195
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
196
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
197
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
198
+ emb = freqs
199
+ cos = emb.cos() * self.attention_scaling
200
+ sin = emb.sin() * self.attention_scaling
201
+
202
+ return cos.to(x.dtype), sin.to(x.dtype)
203
+
204
+
205
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
206
+ """
207
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
208
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
209
+ """
210
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
211
+ if n_rep == 1:
212
+ return hidden_states
213
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
214
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
215
+
216
+
217
+ def _apply_rotary_emb(
218
+ x: torch.Tensor,
219
+ cos: torch.Tensor,
220
+ sin: torch.Tensor,
221
+ ) -> torch.Tensor:
222
+ first_half, second_half = torch.chunk(x, 2, dim=-1)
223
+ first_ = first_half * cos - second_half * sin
224
+ second_ = second_half * cos + first_half * sin
225
+ return torch.cat((first_, second_), dim=-1)
226
+
227
+
228
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
229
+ cos = cos.unsqueeze(unsqueeze_dim)
230
+ sin = sin.unsqueeze(unsqueeze_dim)
231
+ q_embed = _apply_rotary_emb(q, cos, sin)
232
+ k_embed = _apply_rotary_emb(k, cos, sin)
233
+ return q_embed, k_embed
234
+
235
+
236
+ def eager_attention_forward(
237
+ module: nn.Module,
238
+ query: torch.Tensor,
239
+ key: torch.Tensor,
240
+ value: torch.Tensor,
241
+ attention_mask: Optional[torch.Tensor],
242
+ scaling: float,
243
+ dropout: float = 0.0,
244
+ **kwargs,
245
+ ):
246
+ key_states = repeat_kv(key, module.num_key_value_groups)
247
+ value_states = repeat_kv(value, module.num_key_value_groups)
248
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
249
+ if attention_mask is not None:
250
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
251
+ attn_weights = attn_weights + causal_mask
252
+
253
+ sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
254
+ combined_logits = torch.cat([attn_weights, sinks], dim=-1)
255
+
256
+ # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16
257
+ # when training with bsz>1 we clamp max values.
258
+
259
+ combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
260
+ probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
261
+ scores = probs[..., :-1] # we drop the sink here
262
+ attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
263
+ attn_output = torch.matmul(attn_weights, value_states)
264
+ attn_output = attn_output.transpose(1, 2).contiguous()
265
+ return attn_output, attn_weights
266
+
267
+
268
+ class GptOssAttention(nn.Module):
269
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
270
+
271
+ def __init__(self, config: GptOssConfig, layer_idx: int):
272
+ super().__init__()
273
+ self.config = config
274
+ self.layer_idx = layer_idx
275
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
276
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
277
+ self.scaling = self.head_dim**-0.5
278
+ self.attention_dropout = config.attention_dropout
279
+ self.is_causal = True
280
+ self.q_proj = nn.Linear(
281
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
282
+ )
283
+ self.k_proj = nn.Linear(
284
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
285
+ )
286
+ self.v_proj = nn.Linear(
287
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
288
+ )
289
+ self.o_proj = nn.Linear(
290
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
291
+ )
292
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
293
+ self.sinks = nn.Parameter(torch.empty(config.num_attention_heads))
294
+
295
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
296
+ def forward(
297
+ self,
298
+ hidden_states: torch.Tensor,
299
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
300
+ attention_mask: Optional[torch.Tensor],
301
+ past_key_values: Optional[Cache] = None,
302
+ cache_position: Optional[torch.LongTensor] = None,
303
+ **kwargs: Unpack[TransformersKwargs],
304
+ ) -> tuple[torch.Tensor, torch.Tensor]:
305
+ input_shape = hidden_states.shape[:-1]
306
+ hidden_shape = (*input_shape, -1, self.head_dim)
307
+
308
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
309
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
310
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
311
+
312
+ cos, sin = position_embeddings
313
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
314
+
315
+ if past_key_values is not None:
316
+ cache_kwargs = {"cache_position": cache_position}
317
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
318
+
319
+ attention_interface: Callable = eager_attention_forward
320
+ if self.config._attn_implementation != "eager":
321
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
322
+
323
+ attn_output, attn_weights = attention_interface(
324
+ self,
325
+ query_states,
326
+ key_states,
327
+ value_states,
328
+ attention_mask,
329
+ dropout=0.0 if not self.training else self.attention_dropout,
330
+ scaling=self.scaling,
331
+ sliding_window=self.sliding_window,
332
+ s_aux=self.sinks, # diff with Llama
333
+ **kwargs,
334
+ )
335
+
336
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
337
+ attn_output = self.o_proj(attn_output)
338
+ return attn_output, attn_weights
339
+
340
+
341
+ class GptOssDecoderLayer(GradientCheckpointingLayer):
342
+ def __init__(self, config: GptOssConfig, layer_idx: int):
343
+ super().__init__()
344
+ self.hidden_size = config.hidden_size
345
+ self.self_attn = GptOssAttention(config=config, layer_idx=layer_idx)
346
+ self.mlp = GptOssMLP(config)
347
+ self.input_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
348
+ self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
349
+ self.attention_type = config.layer_types[layer_idx]
350
+
351
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
352
+ def forward(
353
+ self,
354
+ hidden_states: torch.Tensor,
355
+ attention_mask: Optional[torch.Tensor] = None,
356
+ position_ids: Optional[torch.LongTensor] = None,
357
+ past_key_values: Optional[Cache] = None,
358
+ use_cache: Optional[bool] = False,
359
+ cache_position: Optional[torch.LongTensor] = None,
360
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
361
+ **kwargs: Unpack[TransformersKwargs],
362
+ ) -> torch.Tensor:
363
+ residual = hidden_states
364
+ hidden_states = self.input_layernorm(hidden_states)
365
+ # Self Attention
366
+ hidden_states, _ = self.self_attn(
367
+ hidden_states=hidden_states,
368
+ attention_mask=attention_mask,
369
+ position_ids=position_ids,
370
+ past_key_values=past_key_values,
371
+ use_cache=use_cache,
372
+ cache_position=cache_position,
373
+ position_embeddings=position_embeddings,
374
+ **kwargs,
375
+ )
376
+ hidden_states = residual + hidden_states
377
+
378
+ # Fully Connected
379
+ residual = hidden_states
380
+ hidden_states = self.post_attention_layernorm(hidden_states)
381
+ hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores
382
+ hidden_states = residual + hidden_states
383
+ return hidden_states
384
+
385
+
386
+ @auto_docstring
387
+ class GptOssPreTrainedModel(PreTrainedModel):
388
+ config: GptOssConfig
389
+ base_model_prefix = "model"
390
+ supports_gradient_checkpointing = True
391
+ _no_split_modules = ["GptOssDecoderLayer"]
392
+ _skip_keys_device_placement = ["past_key_values"]
393
+ _supports_flash_attn = True
394
+ _supports_sdpa = False
395
+ _supports_flex_attn = True
396
+
397
+ _can_compile_fullgraph = True
398
+ _supports_attention_backend = True
399
+ _can_record_outputs = {
400
+ "router_logits": OutputRecorder(GptOssTopKRouter, index=0),
401
+ "hidden_states": GptOssDecoderLayer,
402
+ "attentions": GptOssAttention,
403
+ }
404
+ _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"]
405
+ _supports_flash_attention = False
406
+ _supports_flex_attention = False
407
+
408
+ def _init_weights(self, module):
409
+ std = self.config.initializer_range
410
+ if isinstance(module, nn.Linear):
411
+ module.weight.data.normal_(mean=0.0, std=std)
412
+ if module.bias is not None:
413
+ module.bias.data.zero_()
414
+ elif isinstance(module, nn.Parameter):
415
+ module.data.normal_(mean=0.0, std=std)
416
+ elif isinstance(module, nn.Embedding):
417
+ module.weight.data.normal_(mean=0.0, std=std)
418
+ if module.padding_idx is not None:
419
+ module.weight.data[module.padding_idx].zero_()
420
+ elif isinstance(module, GptOssRMSNorm):
421
+ module.weight.data.fill_(1.0)
422
+ elif isinstance(module, GptOssExperts):
423
+ module.gate_up_proj.data.normal_(mean=0.0, std=std)
424
+ module.gate_up_proj_bias.data.zero_()
425
+ module.down_proj.data.normal_(mean=0.0, std=std)
426
+ module.down_proj_bias.data.zero_()
427
+ elif isinstance(module, GptOssAttention):
428
+ module.sinks.data.normal_(mean=0.0, std=std)
429
+ elif isinstance(module, GptOssTopKRouter):
430
+ module.weight.data.normal_(mean=0.0, std=std)
431
+ module.bias.data.normal_(mean=0.0, std=std)
432
+
433
+
434
+ @auto_docstring
435
+ class GptOssModel(GptOssPreTrainedModel):
436
+ _no_split_modules = ["GptOssDecoderLayer"]
437
+
438
+ def __init__(self, config: GptOssConfig):
439
+ super().__init__(config)
440
+ self.padding_idx = config.pad_token_id
441
+ self.vocab_size = config.vocab_size
442
+
443
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
444
+ self.layers = nn.ModuleList(
445
+ [GptOssDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
446
+ )
447
+ self.norm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
448
+ self.rotary_emb = GptOssRotaryEmbedding(config=config)
449
+ self.gradient_checkpointing = False
450
+
451
+ # Initialize weights and apply final processing
452
+ self.post_init()
453
+
454
+ @check_model_inputs
455
+ @auto_docstring
456
+ def forward(
457
+ self,
458
+ input_ids: Optional[torch.LongTensor] = None,
459
+ attention_mask: Optional[torch.Tensor] = None,
460
+ position_ids: Optional[torch.LongTensor] = None,
461
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
462
+ inputs_embeds: Optional[torch.FloatTensor] = None,
463
+ use_cache: Optional[bool] = None,
464
+ cache_position: Optional[torch.LongTensor] = None,
465
+ **kwargs: Unpack[TransformersKwargs],
466
+ ) -> MoeModelOutputWithPast:
467
+ if (input_ids is None) ^ (inputs_embeds is not None):
468
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
469
+
470
+ if use_cache and past_key_values is None:
471
+ past_key_values = DynamicCache(config=self.config)
472
+
473
+ if inputs_embeds is None:
474
+ inputs_embeds = self.embed_tokens(input_ids)
475
+
476
+ if cache_position is None:
477
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
478
+ cache_position = torch.arange(
479
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
480
+ )
481
+ if position_ids is None:
482
+ position_ids = cache_position.unsqueeze(0)
483
+
484
+ # It may already have been prepared by e.g. `generate`
485
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
486
+ mask_kwargs = {
487
+ "config": self.config,
488
+ "input_embeds": inputs_embeds,
489
+ "attention_mask": attention_mask,
490
+ "cache_position": cache_position,
491
+ "past_key_values": past_key_values,
492
+ }
493
+ causal_mask_mapping = {
494
+ "full_attention": create_causal_mask(**mask_kwargs),
495
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
496
+ }
497
+
498
+ hidden_states = inputs_embeds
499
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
500
+
501
+ for decoder_layer in self.layers:
502
+ hidden_states = decoder_layer(
503
+ hidden_states,
504
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
505
+ position_ids=position_ids,
506
+ past_key_values=past_key_values,
507
+ use_cache=use_cache,
508
+ cache_position=cache_position,
509
+ position_embeddings=position_embeddings,
510
+ **kwargs,
511
+ )
512
+ hidden_states = self.norm(hidden_states)
513
+ return MoeModelOutputWithPast(
514
+ last_hidden_state=hidden_states,
515
+ past_key_values=past_key_values,
516
+ )
517
+
518
+
519
+ def load_balancing_loss_func(
520
+ gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
521
+ num_experts: Optional[int] = None,
522
+ top_k=2,
523
+ attention_mask: Optional[torch.Tensor] = None,
524
+ ) -> Union[torch.Tensor, int]:
525
+ r"""
526
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
527
+
528
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
529
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
530
+ experts is too unbalanced.
531
+
532
+ Args:
533
+ gate_logits:
534
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
535
+ shape [batch_size X sequence_length, num_experts].
536
+ num_experts:
537
+ Number of experts
538
+ top_k:
539
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
540
+ parameter.
541
+ attention_mask (`torch.Tensor`, *optional*):
542
+ The attention_mask used in forward function
543
+ shape [batch_size X sequence_length] if not None.
544
+
545
+ Returns:
546
+ The auxiliary loss.
547
+ """
548
+ if gate_logits is None or not isinstance(gate_logits, tuple):
549
+ return 0
550
+
551
+ if isinstance(gate_logits, tuple):
552
+ compute_device = gate_logits[0].device
553
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
554
+
555
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
556
+
557
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
558
+
559
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
560
+
561
+ if attention_mask is None:
562
+ # Compute the percentage of tokens routed to each experts
563
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
564
+
565
+ # Compute the average probability of routing to these experts
566
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
567
+ else:
568
+ batch_size, sequence_length = attention_mask.shape
569
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
570
+
571
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
572
+ expert_attention_mask = (
573
+ attention_mask[None, :, :, None, None]
574
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
575
+ .reshape(-1, top_k, num_experts)
576
+ .to(compute_device)
577
+ )
578
+
579
+ # Compute the percentage of tokens routed to each experts
580
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
581
+ expert_attention_mask, dim=0
582
+ )
583
+
584
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
585
+ router_per_expert_attention_mask = (
586
+ attention_mask[None, :, :, None]
587
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
588
+ .reshape(-1, num_experts)
589
+ .to(compute_device)
590
+ )
591
+
592
+ # Compute the average probability of routing to these experts
593
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
594
+ router_per_expert_attention_mask, dim=0
595
+ )
596
+
597
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
598
+ return overall_loss * num_experts
599
+
600
+
601
+ @auto_docstring
602
+ class GptOssForCausalLM(GptOssPreTrainedModel, GenerationMixin):
603
+ _tied_weights_keys = ["lm_head.weight"]
604
+ _tp_plan = {"lm_head": "colwise_rep"}
605
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
606
+
607
+ def __init__(self, config):
608
+ super().__init__(config)
609
+ self.model = GptOssModel(config)
610
+ self.vocab_size = config.vocab_size
611
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
612
+ self.router_aux_loss_coef = config.router_aux_loss_coef
613
+ self.num_experts = config.num_local_experts
614
+ self.num_experts_per_tok = config.num_experts_per_tok
615
+
616
+ # Initialize weights and apply final processing
617
+ self.post_init()
618
+
619
+ @can_return_tuple
620
+ @auto_docstring
621
+ def forward(
622
+ self,
623
+ input_ids: Optional[torch.LongTensor] = None,
624
+ attention_mask: Optional[torch.Tensor] = None,
625
+ position_ids: Optional[torch.LongTensor] = None,
626
+ past_key_values: Optional[Cache] = None,
627
+ inputs_embeds: Optional[torch.FloatTensor] = None,
628
+ labels: Optional[torch.LongTensor] = None,
629
+ use_cache: Optional[bool] = None,
630
+ output_router_logits: Optional[bool] = None,
631
+ cache_position: Optional[torch.LongTensor] = None,
632
+ logits_to_keep: Union[int, torch.Tensor] = 0,
633
+ **kwargs: Unpack[TransformersKwargs],
634
+ ) -> MoeCausalLMOutputWithPast:
635
+ r"""
636
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
637
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
638
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
639
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
640
+
641
+ Example:
642
+
643
+ ```python
644
+ >>> from transformers import AutoTokenizer, GptOssForCausalLM
645
+
646
+ >>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1")
647
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1")
648
+
649
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
650
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
651
+
652
+ >>> # Generate
653
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
654
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
655
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
656
+ ```"""
657
+
658
+ output_router_logits = (
659
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
660
+ )
661
+
662
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
663
+ outputs: MoeModelOutputWithPast = self.model(
664
+ input_ids=input_ids,
665
+ attention_mask=attention_mask,
666
+ position_ids=position_ids,
667
+ past_key_values=past_key_values,
668
+ inputs_embeds=inputs_embeds,
669
+ use_cache=use_cache,
670
+ output_router_logits=output_router_logits,
671
+ cache_position=cache_position,
672
+ **kwargs,
673
+ )
674
+
675
+ hidden_states = outputs.last_hidden_state
676
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
677
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
678
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
679
+
680
+ loss = None
681
+ if labels is not None:
682
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
683
+
684
+ aux_loss = None
685
+ if output_router_logits:
686
+ aux_loss = load_balancing_loss_func(
687
+ outputs.router_logits,
688
+ self.num_experts,
689
+ self.num_experts_per_tok,
690
+ attention_mask,
691
+ )
692
+ if labels is not None:
693
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
694
+
695
+ return MoeCausalLMOutputWithPast(
696
+ loss=loss,
697
+ aux_loss=aux_loss,
698
+ logits=logits,
699
+ past_key_values=outputs.past_key_values,
700
+ hidden_states=outputs.hidden_states,
701
+ attentions=outputs.attentions,
702
+ router_logits=outputs.router_logits,
703
+ )
704
+
705
+
706
+ class GptOssForSequenceClassification(GenericForSequenceClassification, GptOssPreTrainedModel):
707
+ pass
708
+
709
+
710
+ class GptOssForTokenClassification(GenericForTokenClassification, GptOssPreTrainedModel):
711
+ pass
712
+
713
+
714
+ __all__ = [
715
+ "GptOssForCausalLM",
716
+ "GptOssForSequenceClassification",
717
+ "GptOssForTokenClassification",
718
+ "GptOssModel",
719
+ "GptOssPreTrainedModel",
720
+ ]
721
+
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|startoftext|>",
3
+ "eos_token": "<|return|>",
4
+ "pad_token": "<|endoftext|>"
5
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0614fe83cadab421296e664e1f48f4261fa8fef6e03e63bb75c20f38e37d07d3
3
+ size 27868174
tokenizer_config.json ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "199998": {
4
+ "content": "<|startoftext|>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "199999": {
12
+ "content": "<|endoftext|>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "200000": {
20
+ "content": "<|reserved_200000|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "200001": {
28
+ "content": "<|reserved_200001|>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "200002": {
36
+ "content": "<|return|>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "200003": {
44
+ "content": "<|constrain|>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "200004": {
52
+ "content": "<|reserved_200004|>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "200005": {
60
+ "content": "<|channel|>",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "200006": {
68
+ "content": "<|start|>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "200007": {
76
+ "content": "<|end|>",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": true
82
+ },
83
+ "200008": {
84
+ "content": "<|message|>",
85
+ "lstrip": false,
86
+ "normalized": false,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": true
90
+ },
91
+ "200009": {
92
+ "content": "<|reserved_200009|>",
93
+ "lstrip": false,
94
+ "normalized": false,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": true
98
+ },
99
+ "200010": {
100
+ "content": "<|reserved_200010|>",
101
+ "lstrip": false,
102
+ "normalized": false,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": true
106
+ },
107
+ "200011": {
108
+ "content": "<|reserved_200011|>",
109
+ "lstrip": false,
110
+ "normalized": false,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": true
114
+ },
115
+ "200012": {
116
+ "content": "<|call|>",
117
+ "lstrip": false,
118
+ "normalized": false,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": true
122
+ },
123
+ "200013": {
124
+ "content": "<|reserved_200013|>",
125
+ "lstrip": false,
126
+ "normalized": false,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": true
130
+ },
131
+ "200014": {
132
+ "content": "<|reserved_200014|>",
133
+ "lstrip": false,
134
+ "normalized": false,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": true
138
+ },
139
+ "200015": {
140
+ "content": "<|reserved_200015|>",
141
+ "lstrip": false,
142
+ "normalized": false,
143
+ "rstrip": false,
144
+ "single_word": false,
145
+ "special": true
146
+ },
147
+ "200016": {
148
+ "content": "<|reserved_200016|>",
149
+ "lstrip": false,
150
+ "normalized": false,
151
+ "rstrip": false,
152
+ "single_word": false,
153
+ "special": true
154
+ },
155
+ "200017": {
156
+ "content": "<|reserved_200017|>",
157
+ "lstrip": false,
158
+ "normalized": false,
159
+ "rstrip": false,
160
+ "single_word": false,
161
+ "special": true
162
+ },
163
+ "200018": {
164
+ "content": "<|endofprompt|>",
165
+ "lstrip": false,
166
+ "normalized": false,
167
+ "rstrip": false,
168
+ "single_word": false,
169
+ "special": true
170
+ }
171
+ },
172
+ "bos_token": "<|startoftext|>",
173
+ "clean_up_tokenization_spaces": false,
174
+ "eos_token": "<|return|>",
175
+ "extra_special_tokens": {},
176
+ "model_input_names": [
177
+ "input_ids",
178
+ "attention_mask"
179
+ ],
180
+ "model_max_length": 1000000000000000019884624838656,
181
+ "pad_token": "<|endoftext|>",
182
+ "tokenizer_class": "PreTrainedTokenizerFast"
183
+ }