lkongam commited on
Commit
b73818c
·
verified ·
1 Parent(s): 87ef8e4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +71 -1
README.md CHANGED
@@ -3,4 +3,74 @@ base_model:
3
  - Qwen/QwQ-32B
4
  tags:
5
  - code
6
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  - Qwen/QwQ-32B
4
  tags:
5
  - code
6
+ ---
7
+
8
+ # Model Summary
9
+ KernelCoder is trained on a curated dataset of reasoning traces and CUDA kernel pairs.
10
+
11
+ See details in [paper](https://lkongam.github.io/ConCuR/).
12
+
13
+ # Usage
14
+
15
+ ```python
16
+ from vllm import LLM, SamplingParams
17
+ from transformers import AutoTokenizer
18
+ import torch
19
+ import re
20
+ from typing import List, Tuple
21
+ from string import Template
22
+ PROMPT_TEMPLATE = Template('''
23
+ ''')
24
+
25
+ class KernelCoder:
26
+
27
+ def __init__(self, model_name="lkongam/KernelCoder", tensor_parallel_size=1, gpu_memory_utilization=0.9):
28
+
29
+ self.model_name = model_name
30
+
31
+ self.llm = LLM(
32
+ model=model_name,
33
+ tensor_parallel_size=tensor_parallel_size,
34
+ gpu_memory_utilization=gpu_memory_utilization,
35
+ trust_remote_code=True,
36
+ dtype="auto"
37
+ )
38
+
39
+ self.tokenizer = self.llm.get_tokenizer()
40
+ self.device = torch.device("cuda")
41
+
42
+ def generate_raw(self, prompt, temperature=1.0):
43
+ messages = [
44
+ {"role": "user", "content": prompt}
45
+ ]
46
+ text = self.tokenizer.apply_chat_template(
47
+ messages,
48
+ tokenize=False,
49
+ add_generation_prompt=True,
50
+ enable_thinking=True
51
+ )
52
+ return text
53
+
54
+ def extract_last_code_block(text):
55
+ code_blocks = re.findall(r"```(?:python)?\n(.*?)```", text, re.DOTALL)
56
+ if code_blocks:
57
+ return code_blocks[-1].strip()
58
+ match = re.search(r"</think>(.*)", text, re.S)
59
+ after_think = match.group(1).strip() if match else text
60
+ if not after_think:
61
+ return None
62
+ import_match = re.search(r"\bimport\b", after_think)
63
+ if import_match:
64
+ return after_think[import_match.start():].strip()
65
+ return after_think.strip()
66
+
67
+ origin_code = """
68
+ """
69
+
70
+ model = KernelCoder(model_name="lkongam/KernelCoder")
71
+
72
+ prompt = PROMPT_TEMPLATE.substitute(code=origin_code)
73
+ code_output = model.generate_raw(prompt)
74
+ code = extract_last_code_block(code_output)
75
+ print(code)
76
+ ```