Text Generation
Transformers
Safetensors
qwen3
triton
kernel-generation
reinforcement-learning
code
conversational
text-generation-inference
PeterV09 commited on
Commit
9982aff
·
verified ·
1 Parent(s): 054693c

upload ckpt

Browse files
Files changed (1) hide show
  1. README.md +316 -0
README.md ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ pipeline_tag: text-generation
4
+ base_model: Qwen/Qwen3-14B-Base
5
+ tags:
6
+ - qwen3
7
+ - triton
8
+ - kernel-generation
9
+ - reinforcement-learning
10
+ - code
11
+ datasets:
12
+ - hkust-nlp/drkernel-coldstart-8k
13
+ - hkust-nlp/drkernel-rl-data
14
+ - hkust-nlp/drkernel-validation-data
15
+ ---
16
+
17
+ # DR.Kernel-14B
18
+
19
+ [![Model](https://img.shields.io/badge/🤗%20Model-hkust--nlp/drkernel--14b-yellow)](https://huggingface.co/hkust-nlp/drkernel-14b)
20
+ [![Paper](https://img.shields.io/badge/arXiv-2602.05885-b31b1b)](https://arxiv.org/abs/2602.05885)
21
+
22
+ `hkust-nlp/drkernel-14b` is a Qwen3-14B-based model specialized for GPU kernel generation and optimization (especially Triton) in the DR.Kernel framework.
23
+
24
+ It is trained for iterative optimization with execution feedback from KernelGYM, rather than single-shot code generation only.
25
+
26
+ ## Model Summary
27
+
28
+ - Model type: `Qwen3ForCausalLM`
29
+ - Parameter count: `14,768,307,200` (from `model.safetensors.index.json`)
30
+ - Weight dtype: BF16
31
+ - Base model family: Qwen3-14B
32
+ - Main capability: generate and iteratively refine optimized `ModelNew` kernel implementations from PyTorch reference tasks
33
+
34
+ ## Model Specs
35
+
36
+ From `config.json`:
37
+
38
+ - `num_hidden_layers`: 40
39
+ - `hidden_size`: 5120
40
+ - `intermediate_size`: 17408
41
+ - `num_attention_heads`: 40
42
+ - `num_key_value_heads`: 8
43
+ - `max_position_embeddings`: 32768
44
+ - `vocab_size`: 151936
45
+ - `transformers_version`: 4.56.0
46
+
47
+ From `model.safetensors.index.json`:
48
+
49
+ - `total_parameters`: 14,768,307,200
50
+ - `total_size`: 29,536,614,400 bytes
51
+ - Expected checkpoint shards: `model-00001-of-00007.safetensors` ... `model-00007-of-00007.safetensors`
52
+
53
+ ## Training Recipe (DR.Kernel)
54
+
55
+ The 14B model follows the same two-stage DR.Kernel pipeline:
56
+
57
+ 1. Cold-start SFT
58
+ - Dataset: `hkust-nlp/drkernel-coldstart-8k`
59
+ - Multi-turn trajectory warm-up for kernel generation/refinement
60
+ 2. Multi-turn RL
61
+ - Train dataset: `hkust-nlp/drkernel-rl-data`
62
+ - Validation dataset: `hkust-nlp/drkernel-validation-data` (KernelBench Level 2 validation split)
63
+ - Core methods: TRLOO + MRS + PR + PRS
64
+ - Execution environment: KernelGYM with compilation/correctness/performance/profiling feedback
65
+
66
+ Related training scripts in this repo:
67
+
68
+ - `drkernel/kernel/scripts/sft/14b-coldstart.sh`
69
+ - `drkernel/kernel/scripts/rl/14b_trloo_mrs_pr_prs.sh`
70
+
71
+ ## Intended Use
72
+
73
+ - Kernel generation research and benchmarking
74
+ - Triton kernel optimization with iterative feedback
75
+ - Multi-turn agentic code refinement under execution-based reward
76
+
77
+ ## Not Intended Use
78
+
79
+ - Safety-critical production deployment without additional verification
80
+ - General-purpose coding assistant use where kernel-evaluation feedback is unavailable
81
+
82
+ ## Prompting Format
83
+
84
+ The model is trained with kernel-optimization prompts that:
85
+
86
+ - Provide a PyTorch reference architecture (`Model`, `get_inputs`, `get_init_inputs`)
87
+ - Require returning an optimized `ModelNew`
88
+ - In multi-turn settings, append server feedback and request iterative improvement
89
+
90
+ For best behavior, keep the same task style as DR.Kernel datasets and use chat-format messages.
91
+
92
+ ## Quick Start (Transformers)
93
+
94
+ Use the same fixed 1-shot first-turn prompt template as DR.Kernel data (recommended):
95
+
96
+ ````python
97
+ import textwrap
98
+ import torch
99
+ from transformers import AutoModelForCausalLM, AutoTokenizer
100
+
101
+ model_id = "hkust-nlp/drkernel-14b"
102
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
103
+ model = AutoModelForCausalLM.from_pretrained(
104
+ model_id,
105
+ torch_dtype=torch.bfloat16,
106
+ device_map="auto",
107
+ trust_remote_code=True,
108
+ )
109
+
110
+ ref_code = textwrap.dedent(
111
+ """
112
+ import torch
113
+ import torch.nn as nn
114
+
115
+ class Model(nn.Module):
116
+ def __init__(self):
117
+ super().__init__()
118
+
119
+ def forward(self, x):
120
+ x = torch.abs(x)
121
+ x = x - 1.0
122
+ return x
123
+
124
+ def get_inputs():
125
+ return [torch.randn(64, 128)]
126
+
127
+ def get_init_inputs():
128
+ return []
129
+ """
130
+ ).strip()
131
+
132
+ example_ref_code = textwrap.dedent(
133
+ """
134
+ import torch
135
+ import torch.nn as nn
136
+ import torch.nn.functional as F
137
+
138
+ class Model(nn.Module):
139
+ def __init__(self) -> None:
140
+ super().__init__()
141
+
142
+ def forward(self, a, b):
143
+ return a + b
144
+
145
+ def get_inputs():
146
+ # randomly generate input tensors based on the model architecture
147
+ a = torch.randn(1, 128).cuda()
148
+ b = torch.randn(1, 128).cuda()
149
+ return [a, b]
150
+
151
+ def get_init_inputs():
152
+ # randomly generate tensors required for initialization based on the model architecture
153
+ return []
154
+ """
155
+ ).strip()
156
+
157
+ example_kernel_code = textwrap.dedent(
158
+ '''
159
+ import torch
160
+ import torch.nn as nn
161
+ import torch.nn.functional as F
162
+ import triton
163
+ import triton.language as tl
164
+
165
+ @triton.jit
166
+ def add_kernel(
167
+ x_ptr, # Pointer to first input
168
+ y_ptr, # Pointer to second input
169
+ out_ptr, # Pointer to output
170
+ n_elements, # Total number of elements in input/output
171
+ BLOCK_SIZE: tl.constexpr,
172
+ ):
173
+ # Each program handles a contiguous block of data of size BLOCK_SIZE
174
+ block_start = tl.program_id(0) * BLOCK_SIZE
175
+ # Create a range of offsets [0..BLOCK_SIZE-1]
176
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
177
+ # Mask to ensure we don't go out of bounds
178
+ mask = offsets < n_elements
179
+ # Load input values
180
+ x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
181
+ y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
182
+ # Perform the elementwise addition
183
+ out = x + y
184
+ # Store the result
185
+ tl.store(out_ptr + offsets, out, mask=mask)
186
+
187
+ def triton_add(x: torch.Tensor, y: torch.Tensor):
188
+ """
189
+ This function wraps the Triton kernel call. It:
190
+ 1. Ensures the inputs are contiguous on GPU.
191
+ 2. Calculates the grid (blocks) needed.
192
+ 3. Launches the Triton kernel.
193
+ """
194
+ assert x.is_cuda and y.is_cuda, "Tensors must be on CUDA."
195
+ x = x.contiguous()
196
+ y = y.contiguous()
197
+
198
+ # Prepare output tensor
199
+ out = torch.empty_like(x)
200
+
201
+ # Number of elements in the tensor
202
+ n_elements = x.numel()
203
+ BLOCK_SIZE = 128 # Tunable parameter for block size
204
+
205
+ # Determine the number of blocks needed
206
+ grid = lambda meta: ((n_elements + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],)
207
+
208
+ # Launch the Triton kernel
209
+ add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=BLOCK_SIZE)
210
+ return out
211
+
212
+ class ModelNew(nn.Module):
213
+ def __init__(self) -> None:
214
+ super().__init__()
215
+
216
+ def forward(self, a, b):
217
+ # Instead of "return a + b", call our Triton-based addition
218
+ return triton_add(a, b)
219
+ '''
220
+ ).strip()
221
+
222
+ prompt_template = textwrap.dedent(
223
+ """\
224
+ You write custom Triton kernels to replace the pytorch operators in the given architecture to get speedups.
225
+
226
+ You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom Triton kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.
227
+
228
+ Here's an example to show you the syntax of inline embedding custom Triton kernels in torch: The example given architecture is:
229
+
230
+ ```python
231
+ {example_ref_code}
232
+ ```
233
+
234
+ The example new arch with custom Triton kernels looks like this:
235
+
236
+ ```python
237
+ {example_kernel_code}
238
+ ```
239
+
240
+ You are given the following architecture:
241
+ ```python
242
+ {ref_code}
243
+ ```
244
+
245
+ Optimize the architecture named Model with custom Triton operators! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Let's think step by step.
246
+ """
247
+ ).strip()
248
+
249
+ prompt = prompt_template.format(
250
+ example_ref_code=example_ref_code,
251
+ example_kernel_code=example_kernel_code,
252
+ ref_code=ref_code,
253
+ )
254
+ messages = [{"role": "user", "content": prompt}]
255
+
256
+ inputs = tokenizer.apply_chat_template(
257
+ messages,
258
+ add_generation_prompt=True,
259
+ return_tensors="pt",
260
+ ).to(model.device)
261
+
262
+ with torch.no_grad():
263
+ outputs = model.generate(
264
+ inputs,
265
+ max_new_tokens=2048,
266
+ do_sample=True,
267
+ temperature=1.0,
268
+ top_p=1.0,
269
+ )
270
+
271
+ # Only print newly generated tokens
272
+ print(tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=False))
273
+ ````
274
+
275
+ ## Evaluation
276
+
277
+ Use KernelGYM-based evaluation scripts in this repo:
278
+
279
+ - `drkernel/kernel/scripts/eval/drkernel-14b-maxturns3.sh`
280
+ - `drkernel/kernel/scripts/eval/grading_common.sh` for custom evaluation runs
281
+
282
+ Validation data:
283
+
284
+ - `hkust-nlp/drkernel-validation-data` (KernelBench Level 2 validation tasks)
285
+
286
+ ## Data and Attribution
287
+
288
+ - Query/task source includes:
289
+ - [ByteDance-Seed/cudaLLM-data](https://huggingface.co/datasets/ByteDance-Seed/cudaLLM-data)
290
+ - SFT cold-start trajectories:
291
+ - [hkust-nlp/drkernel-coldstart-8k](https://huggingface.co/datasets/hkust-nlp/drkernel-coldstart-8k)
292
+ - RL train data:
293
+ - [hkust-nlp/drkernel-rl-data](https://huggingface.co/datasets/hkust-nlp/drkernel-rl-data)
294
+ - Validation/eval data:
295
+ - [hkust-nlp/drkernel-validation-data](https://huggingface.co/datasets/hkust-nlp/drkernel-validation-data)
296
+ - Benchmark source:
297
+ - [KernelBench](https://github.com/ScalingIntelligence/KernelBench)
298
+
299
+ Please acknowledge original dataset/benchmark authors when using this model.
300
+
301
+ ## Related Resources
302
+
303
+ - Paper: [Dr.Kernel: Reinforcement Learning Done Right for Triton Kernel Generations](https://arxiv.org/abs/2602.05885)
304
+ - Codebase: [KernelGYM](https://github.com/hkust-nlp/KernelGYM)
305
+ - Training docs: `drkernel/README.md`
306
+
307
+ ## Citation
308
+
309
+ ```bibtex
310
+ @article{liuetal2026,
311
+ title={Dr.Kernel: Reinforcement Learning Done Right for Triton Kernel Generations},
312
+ author={Wei Liu, Jiawei Xu, Yingru Li, Longtao Zheng, Tianjian Li, Qian Liu, Junxian He},
313
+ journal={arXiv:2602.05885},
314
+ year={2026}
315
+ }
316
+ ```