KernelCoder / README.md
lkongam's picture
Update README.md
46c2f22 verified
metadata
base_model:
  - Qwen/QwQ-32B
tags:
  - code

Model Summary

KernelCoder is trained on a curated dataset of reasoning traces and CUDA kernel pairs.

See details in paper.

Usage

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import torch
import re
from typing import List, Tuple
from string import Template
PROMPT_TEMPLATE = Template('''
''')

class KernelCoder:

def __init__(self, model_name="lkongam/KernelCoder", tensor_parallel_size=1, gpu_memory_utilization=0.9):

self.model_name = model_name

self.llm = LLM(
model=model_name,
tensor_parallel_size=tensor_parallel_size,
gpu_memory_utilization=gpu_memory_utilization,
trust_remote_code=True,
dtype="auto"
)

self.tokenizer = self.llm.get_tokenizer()
self.device = torch.device("cuda")

def generate_raw(self, prompt, temperature=1.0):
messages = [
{"role": "user", "content": prompt}
]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True
)
return text

def extract_last_code_block(text):
code_blocks = re.findall(r"```(?:python)?\n(.*?)```", text, re.DOTALL)
if code_blocks:
return code_blocks[-1].strip()
match = re.search(r"</think>(.*)", text, re.S)
after_think = match.group(1).strip() if match else text
if not after_think:
return None
import_match = re.search(r"\bimport\b", after_think)
if import_match:
return after_think[import_match.start():].strip()
return after_think.strip()

origin_code = """
"""

model = KernelCoder(model_name="lkongam/KernelCoder")

prompt = PROMPT_TEMPLATE.substitute(code=origin_code)
code_output = model.generate_raw(prompt)
code = extract_last_code_block(code_output)
print(code)

Evaluation

sta

Left: Pass@1, Right: Pass@10.