midbee commited on
Commit
c06439b
·
verified ·
1 Parent(s): ada0826

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +467 -0
README.md ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ <h2 align="center" style="line-height: 25px;">
5
+ Unlocking Aha Moments via Reinforcement Learning: Advancing Collaborative Visual Comprehension and Generation
6
+ </h2>
7
+
8
+ <p align="center">
9
+ <a href="https://arxiv.org/abs/2506.01480" style="display: inline-block; margin: 0 5px;">
10
+ <img src="https://img.shields.io/badge/Paper-red?style=flat&logo=arxiv" style="height: 15px;">
11
+ </a>
12
+ <a href="https://janus-pro-r1.github.io/" style="display: inline-block; margin: 0 5px;">
13
+ <img src="https://img.shields.io/badge/Project Page-white?style=flat&logo=google-docs" style="height: 15px;">
14
+ </a>
15
+ <a href="https://github.com/wendell0218/Janus-Pro-R1" style="display: inline-block; margin: 0 5px;">
16
+ <img src="https://img.shields.io/badge/Code-black?style=flat&logo=github" style="height: 15px;">
17
+ </a>
18
+ <a href="https://huggingface.co/midbee/Janus-Pro-R1-7B" style="display: inline-block; margin: 0 5px;">
19
+ <img src="https://img.shields.io/badge/-%F0%9F%A4%97%20Checkpoint-orange?style=flat" style="height: 15px;"/>
20
+ </a>
21
+ </p>
22
+
23
+ <div align="center">
24
+ <span style="font-size: smaller;">
25
+ Kaihang Pan<sup>1*</sup>, Wendong Bu<sup>1*</sup>, Yuruo Wu<sup>1*</sup>, Yang Wu<sup>2</sup>, Kai Shen<sup>1</sup>, Yunfei Li<sup>2</sup>,
26
+ <br>Hang Zhao<sup>2</sup>, Juncheng Li<sup>1&dagger;</sup>, Siliang Tang<sup>1</sup>, Yueting Zhuang<sup>1</sup>
27
+ <br><sup>1</sup>Zhejiang University, <sup>2</sup>Ant Group
28
+ <br>*Equal Contribution, <sup>&dagger;</sup>Corresponding Authors
29
+ </span>
30
+ </div>
31
+
32
+
33
+ ![alt text](https://raw.githubusercontent.com/wendell0218/Janus-Pro-R1/refs/heads/main/assets/intro.png)
34
+
35
+ ## 🚀 Overview
36
+
37
+
38
+ We propose a **two-stage training paradigm** to enable introspective text-to-image generation via genuine reasoning chains (CoT), unlocking what we call **Aha Moments** in visual generation:
39
+
40
+ - **Stage 1 – Supervised Fine-Tuning (SFT):**
41
+ The model learns structured visual reasoning through three subtasks:
42
+ - Text-to-image generation
43
+ - Image-text consistency self-evaluation
44
+ - Image regeneration through reflection
45
+
46
+ - **Stage 2 – Reinforcement Learning (RL):**
47
+ The model is trained using a token-level Markov decision process with bi-level QA-based rewards to encourage spontaneous reasoning and correction, optimizing via GRPO.
48
+
49
+ With self-reflective capabilities, this approach bridges the gap between text-to-image generation and image editing, enabling a unified and coherent visual reasoning process.
50
+
51
+ <div style="text-align: center;">
52
+ <img src="https://janus-pro-r1.github.io/static/images/method.png" width="100%" />
53
+ </div>
54
+
55
+ ## ✨️ Quickstart
56
+
57
+ **1. Prepare Environment**
58
+
59
+ First, the python environment for inference is the same as that for SFT. Specifically, please clone our repo and prepare the python environment. We recommend using Python>=3.10.
60
+
61
+ ```bash
62
+ git clone https://github.com/wendell0218/Janus-Pro-R1.git
63
+ cd Janus-Pro-R1
64
+
65
+ conda create -n janus-pro-r1-sft python=3.11
66
+ conda activate janus-pro-r1-sft
67
+ pip install -r requirements-sft.txt
68
+ ```
69
+
70
+ **2. Prepare Pretrained Model**
71
+
72
+ Janus-Pro-R1-7B utilizes `Janus-Pro-7B` as the pretrained model for subsequent training. You can download the corresponding model using the following command:
73
+ ```bash
74
+ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/deepseek-ai/Janus-Pro-7B
75
+ cd Janus-Pro-7B
76
+ git lfs pull
77
+ ```
78
+
79
+ **3. Start Generating!**
80
+
81
+ ```python
82
+ import os
83
+ import json
84
+ import torch
85
+ import PIL.Image
86
+ import numpy as np
87
+ from typing import List
88
+ from torchvision import transforms
89
+ from transformers import AutoModelForCausalLM
90
+ from models import MultiModalityCausalLM, VLChatProcessor
91
+ from tqdm import tqdm
92
+ import math
93
+
94
+ def center_crop_arr(pil_image, image_size):
95
+ while min(*pil_image.size) >= 2 * image_size:
96
+ pil_image = pil_image.resize(
97
+ tuple(x // 2 for x in pil_image.size), resample=PIL.Image.BOX
98
+ )
99
+
100
+ scale = image_size / min(*pil_image.size)
101
+ pil_image = pil_image.resize(
102
+ tuple(round(x * scale) for x in pil_image.size), resample=PIL.Image.BICUBIC
103
+ )
104
+
105
+ arr = np.array(pil_image)
106
+ crop_y = (arr.shape[0] - image_size) // 2
107
+ crop_x = (arr.shape[1] - image_size) // 2
108
+ return PIL.Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
109
+
110
+ @torch.no_grad()
111
+ def generate_with_refine(
112
+ mmgpt: MultiModalityCausalLM,
113
+ vl_chat_processor: VLChatProcessor,
114
+ input_ids,
115
+ attention_mask,
116
+ temperature: float = 1,
117
+ parallel_size: int = 4,
118
+ cfg_weight: float = 5,
119
+ image_token_num_per_image: int = 576,
120
+ img_size: int = 384,
121
+ patch_size: int = 16,
122
+ img_top_k: int = None,
123
+ img_top_p: float = None,
124
+ txt_top_k: int = None,
125
+ txt_top_p: float = None,
126
+ max_reflect_len: int = 80,
127
+ task_list: List[int] = [1,2,3],
128
+ ):
129
+ prompt = [
130
+ '<end_of_image>\nLet me think Does this image match the prompt...',
131
+ '<|end▁of▁sentence|>\nNext, I will draw a new image<begin_of_image>'
132
+ ]
133
+ all_imgs_1,embeds_1,attention_mask_1 = [],[],[]
134
+ output_text_ids,selfcheck,attention_mask_txt = [],[],[]
135
+ all_imgs_2 = []
136
+ parallel_size = input_ids.shape[0]
137
+ if 1 <= task_list[-1]:
138
+ tokens = torch.repeat_interleave(input_ids,2,dim=0)
139
+ for i in range(tokens.size(0)):
140
+ if i % 2 != 0:
141
+ pad_list = torch.where(tokens[i]==vl_chat_processor.pad_id)[0]
142
+ if pad_list.shape[0]==0:
143
+ st = 1
144
+ else:
145
+ st = pad_list[-1].item()+2
146
+ tokens[i, st:-1] = vl_chat_processor.pad_id
147
+ inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)
148
+ embeds_1 = inputs_embeds
149
+ attention_mask_1 = torch.repeat_interleave(attention_mask, 2, dim=0)
150
+ cur_atten_mask = attention_mask_1
151
+ generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
152
+ for i in tqdm(range(image_token_num_per_image)):
153
+ outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, attention_mask=cur_atten_mask, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
154
+ hidden_states = outputs.last_hidden_state
155
+ logits = mmgpt.gen_head(hidden_states[:, -1, :])
156
+ logit_cond = logits[0::2, :]
157
+ logit_uncond = logits[1::2, :]
158
+ logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
159
+ if img_top_k:
160
+ v, _ = torch.topk(logits, min(img_top_k, logits.size(-1)))
161
+ logits[logits < v[:, [-1]]] = float("-inf")
162
+ probs = torch.softmax(logits / temperature, dim=-1)
163
+ if img_top_p:
164
+ probs_sort, probs_idx = torch.sort(probs,
165
+ dim=-1,
166
+ descending=True)
167
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
168
+ mask = probs_sum - probs_sort > img_top_p
169
+ probs_sort[mask] = 0.0
170
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
171
+ next_token = torch.multinomial(probs_sort, num_samples=1)
172
+ next_token = torch.gather(probs_idx, -1, next_token)
173
+ else:
174
+ next_token = torch.multinomial(probs, num_samples=1)
175
+ generated_tokens[:, i] = next_token.squeeze(dim=-1)
176
+ next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
177
+ img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
178
+ inputs_embeds = img_embeds.unsqueeze(dim=1)
179
+ cur_atten_mask = torch.cat([cur_atten_mask, torch.ones(cur_atten_mask.size(0), 1).to(attention_mask)], dim=1)
180
+ dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
181
+ dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
182
+ dec = np.clip((dec + 1) / 2 * 255, 0, 255)
183
+ visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
184
+ visual_img[:, :, :] = dec
185
+ for i in range(parallel_size):
186
+ all_imgs_1.append(PIL.Image.fromarray(visual_img[i]))
187
+
188
+ if 2 <= task_list[-1]:
189
+ inputs_embeds = embeds_1[::2,:,:]
190
+ under_embeds = torch.zeros((parallel_size, image_token_num_per_image, 4096), dtype=torch.bfloat16).cuda()
191
+ for i in range(parallel_size):
192
+ img_prompt = "<image_placeholder>"
193
+ prepare_inputs = vl_chat_processor(
194
+ prompt=img_prompt, images=[all_imgs_1[i]], force_batchify=True
195
+ ).to(input_ids.device)
196
+ img_embeds = mmgpt.prepare_inputs_embeds(**prepare_inputs)
197
+ img_embeds = img_embeds[:,2:-1,:]
198
+ under_embeds[i,:,:] = img_embeds
199
+ inputs_embeds = torch.cat((inputs_embeds, under_embeds), dim=1)
200
+ selfcheck_ids = vl_chat_processor.tokenizer.encode(prompt[0])[1:]
201
+ selfcheck_ids = torch.LongTensor(selfcheck_ids)
202
+ selfcheck_tokens = torch.zeros((parallel_size, len(selfcheck_ids)), dtype=torch.int).cuda()
203
+ for i in range(parallel_size):
204
+ selfcheck_tokens[i, :] = selfcheck_ids
205
+ selfcheck_embeds = mmgpt.language_model.get_input_embeddings()(selfcheck_tokens)
206
+ inputs_embeds = torch.cat((inputs_embeds, selfcheck_embeds), dim=1)
207
+ reflect_tokens = torch.zeros((parallel_size, max_reflect_len), dtype=torch.int).cuda()
208
+ reflect_len = 0
209
+ eos_list = torch.zeros((parallel_size, 1), dtype=torch.int).cuda()
210
+ add_padding = torch.zeros((parallel_size, 1), dtype=torch.int).cuda()
211
+ eos_token = vl_chat_processor.tokenizer.encode("<|end▁of▁sentence|>")[-1]
212
+ padding_token = vl_chat_processor.tokenizer.encode("<|▁pad▁|>")[-1]
213
+ yes_token = vl_chat_processor.tokenizer.encode("Yes")[-1]
214
+ no_token = vl_chat_processor.tokenizer.encode("No")[-1]
215
+ attn_mask = torch.ones((parallel_size, inputs_embeds.shape[1]), dtype=torch.int).cuda()
216
+ yes_list = torch.zeros((parallel_size), dtype=torch.int).cuda()
217
+ for i in range(max_reflect_len):
218
+ outputs = mmgpt.language_model(inputs_embeds=inputs_embeds, attention_mask=attn_mask, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
219
+ logits = outputs.logits
220
+ logits = logits[:,-1,:]
221
+ if i == 0:
222
+ allowed_tokens = [yes_token, no_token]
223
+ allowed_tokens_logits = logits[:,allowed_tokens]
224
+ logits[:,:] = -math.inf
225
+ logits[:,allowed_tokens] = allowed_tokens_logits
226
+
227
+ if txt_top_k:
228
+ v, _ = torch.topk(logits, min(txt_top_k, logits.size(-1)))
229
+ logits[logits < v[:, [-1]]] = float("-inf")
230
+ probs = torch.softmax(logits / temperature, dim=-1)
231
+ if txt_top_p:
232
+ probs_sort, probs_idx = torch.sort(probs,
233
+ dim=-1,
234
+ descending=True)
235
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
236
+ mask = probs_sum - probs_sort > txt_top_p
237
+ probs_sort[mask] = 0.0
238
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
239
+ next_token = torch.multinomial(probs_sort, num_samples=1)
240
+ next_token = torch.gather(probs_idx, -1, next_token)
241
+ else:
242
+ next_token = torch.multinomial(probs, num_samples=1)
243
+ if i >= 1:
244
+ add_padding = ((reflect_tokens[:, i-1] == eos_token) | (reflect_tokens[:, i-1] == padding_token)).unsqueeze(1).to(torch.int)
245
+ next_token = add_padding*padding_token + (1-add_padding)*next_token
246
+ if i == 0:
247
+ yes_list = (next_token == yes_token)
248
+ reflect_tokens[:, i] = next_token.squeeze(dim=-1)
249
+ is_eos = (next_token == eos_token)
250
+ eos_list = eos_list | is_eos.to(torch.int)
251
+ new_attn = 1-add_padding
252
+ new_attn = new_attn & (~is_eos)
253
+ attn_mask = torch.cat((attn_mask, new_attn), dim=1)
254
+ inputs_embeds = mmgpt.language_model.get_input_embeddings()(next_token)
255
+ reflect_len = i
256
+ if eos_list.all():
257
+ break
258
+ reflect_tokens = reflect_tokens[:,:reflect_len+1]
259
+ max_relect_len = reflect_len+1
260
+ output_text_ids = reflect_tokens
261
+ attention_mask_txt = torch.ones_like(output_text_ids).cuda()
262
+ attention_mask_txt[output_text_ids == padding_token] = 0
263
+ attention_mask_txt[output_text_ids == eos_token] = 0
264
+ selfcheck = yes_list.bool()
265
+
266
+ if 3 <= task_list[-1]:
267
+ tokens = torch.repeat_interleave(input_ids,2,dim=0)
268
+ for i in range(tokens.size(0)):
269
+ if i % 2 != 0:
270
+ pad_list = torch.where(tokens[i]==vl_chat_processor.pad_id)[0]
271
+ if pad_list.shape[0]==0:
272
+ st = 1
273
+ else:
274
+ st = pad_list[-1].item()+2
275
+ tokens[i, st:-1] = vl_chat_processor.pad_id
276
+ inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)
277
+ gen_transform = transforms.Compose([
278
+ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, 384)),
279
+ transforms.ToTensor(),
280
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
281
+ ])
282
+ gen_embeds_list = []
283
+ for i in range(len(all_imgs_1)):
284
+ img = gen_transform(all_imgs_1[i])
285
+ img = img.unsqueeze(0).to(torch.bfloat16).cuda()
286
+ _, _, all_image_ids = mmgpt.gen_vision_model.encode(img)
287
+ image_ids = all_image_ids[2]
288
+ embed = mmgpt.gen_aligner(mmgpt.gen_embed(image_ids))
289
+ gen_embeds_list.append(embed)
290
+ gen_embeds_list.append(embed)
291
+ gen_embeds = torch.cat(gen_embeds_list, dim=0)
292
+ inputs_embeds = torch.cat((inputs_embeds, gen_embeds), dim=1)
293
+ selfcheck_ids = vl_chat_processor.tokenizer.encode(prompt[0])[1:]
294
+ selfcheck_ids = torch.LongTensor(selfcheck_ids)
295
+ selfcheck_tokens = torch.zeros((2*parallel_size, len(selfcheck_ids)), dtype=torch.int).cuda()
296
+ for i in range(2*parallel_size):
297
+ selfcheck_tokens[i, :] = selfcheck_ids
298
+ selfcheck_embeds = mmgpt.language_model.get_input_embeddings()(selfcheck_tokens)
299
+ inputs_embeds = torch.cat((inputs_embeds, selfcheck_embeds), dim=1)
300
+ attn_mask = torch.ones((2*parallel_size, inputs_embeds.shape[1]), dtype=torch.int).cuda()
301
+ reflect_embeds = torch.ones((2*parallel_size, max_relect_len), dtype=torch.int).cuda()
302
+ for i in range(2*parallel_size):
303
+ reflect_embeds[i] = output_text_ids[i//2]
304
+ new_attn = torch.ones((2*parallel_size, max_relect_len), dtype=torch.int).cuda()
305
+ for i in range(2*parallel_size):
306
+ new_attn[i] = attention_mask_txt[i//2]
307
+ reflect_embeds = mmgpt.language_model.get_input_embeddings()(reflect_embeds)
308
+ inputs_embeds = torch.cat((inputs_embeds, reflect_embeds), dim=1)
309
+ attn_mask = torch.cat((attn_mask, new_attn), dim=1)
310
+ regen_ids = vl_chat_processor.tokenizer.encode(prompt[1])[1:]
311
+ regen_ids = torch.LongTensor(regen_ids)
312
+ regen_tokens = torch.zeros((2*parallel_size, len(regen_ids)), dtype=torch.int).cuda()
313
+ for i in range(2*parallel_size):
314
+ regen_tokens[i, :] = regen_ids
315
+ regen_embeds = mmgpt.language_model.get_input_embeddings()(regen_tokens)
316
+ inputs_embeds = torch.cat((inputs_embeds, regen_embeds), dim=1)
317
+ new_attn = torch.ones((2*parallel_size, regen_ids.shape[0]), dtype=torch.int).cuda()
318
+ attn_mask = torch.cat((attn_mask, new_attn), dim=1)
319
+
320
+ new_generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
321
+ for i in tqdm(range(image_token_num_per_image)):
322
+ outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, attention_mask=attn_mask, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
323
+ hidden_states = outputs.last_hidden_state
324
+ new_attn = torch.ones((2*parallel_size, 1), dtype=torch.int).cuda()
325
+ attn_mask = torch.cat((attn_mask, new_attn), dim=1)
326
+ logits = mmgpt.gen_head(hidden_states[:, -1, :])
327
+ logit_cond = logits[0::2, :]
328
+ logit_uncond = logits[1::2, :]
329
+ logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
330
+ if img_top_k:
331
+ v, _ = torch.topk(logits, min(img_top_k, logits.size(-1)))
332
+ logits[logits < v[:, [-1]]] = float("-inf")
333
+ probs = torch.softmax(logits / temperature, dim=-1)
334
+ if img_top_p:
335
+ probs_sort, probs_idx = torch.sort(probs,
336
+ dim=-1,
337
+ descending=True)
338
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
339
+ mask = probs_sum - probs_sort > img_top_p
340
+ probs_sort[mask] = 0.0
341
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
342
+ next_token = torch.multinomial(probs_sort, num_samples=1)
343
+ next_token = torch.gather(probs_idx, -1, next_token)
344
+ else:
345
+ next_token = torch.multinomial(probs, num_samples=1)
346
+ new_generated_tokens[:, i] = next_token.squeeze(dim=-1)
347
+ next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
348
+ img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
349
+ inputs_embeds = img_embeds.unsqueeze(dim=1)
350
+ new_dec = mmgpt.gen_vision_model.decode_code(new_generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
351
+ new_dec = new_dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
352
+ new_dec = np.clip((new_dec + 1) / 2 * 255, 0, 255)
353
+ new_visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
354
+ new_visual_img[:, :, :] = new_dec
355
+ for i in range(parallel_size):
356
+ all_imgs_2.append(PIL.Image.fromarray(new_visual_img[i]))
357
+
358
+ return all_imgs_1, all_imgs_2, (output_text_ids.cpu(), selfcheck.squeeze().cpu())
359
+
360
+
361
+
362
+ if __name__ == "__main__":
363
+ import argparse
364
+ parser = argparse.ArgumentParser()
365
+
366
+ parser.add_argument("--model_path", type=str, default="deepseek-ai/Janus-Pro-7B")
367
+ parser.add_argument("--ckpt_path", type=str, default=None)
368
+ parser.add_argument("--caption", type=str, default="a brown giraffe and a white stop sign")
369
+ parser.add_argument("--gen_path", type=str, default="results/samples")
370
+ parser.add_argument("--reason_path", type=str, default='results/reason.jsonl')
371
+ parser.add_argument("--regen_path", type=str, default='results/regen_samples')
372
+ parser.add_argument("--cfg", type=float, default=5.0)
373
+ parser.add_argument("--parallel_size", type=int, default=4)
374
+
375
+ args = parser.parse_args()
376
+ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(args.model_path)
377
+ vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True)
378
+ if args.ckpt_path is not None:
379
+ state_dict = torch.load(f"{args.ckpt_path}", map_location="cpu")
380
+ vl_gpt.load_state_dict(state_dict)
381
+
382
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
383
+
384
+ # You can flexibly modify the code here to perform batched inference.
385
+ allprompts = []
386
+ # prompt = f'<|User|>: {args.caption}\n\n<|Assistant|>:<begin_of_image>'
387
+ conversation = [
388
+ {
389
+ "role": "<|User|>",
390
+ "content": args.caption,
391
+ },
392
+ {"role": "<|Assistant|>", "content": ""},
393
+ ]
394
+ sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
395
+ conversations=conversation,
396
+ sft_format=vl_chat_processor.sft_format,
397
+ system_prompt="",
398
+ )
399
+ prompt = sft_format + vl_chat_processor.image_start_tag
400
+ allprompts.append(prompt)
401
+
402
+ tokenized_input = vl_chat_processor.tokenizer(
403
+ allprompts,
404
+ return_tensors="pt",
405
+ padding='longest',
406
+ max_length=200, truncation=True
407
+ ).to('cuda')
408
+
409
+ prompt_ids = tokenized_input['input_ids']
410
+ prompt_mask = tokenized_input['attention_mask']
411
+
412
+ images, regen_images, (output_text_ids, selfcheck) = generate_with_refine(
413
+ vl_gpt,
414
+ vl_chat_processor,
415
+ input_ids=prompt_ids, attention_mask=prompt_mask,
416
+ parallel_size = args.parallel_size,
417
+ cfg_weight = args.cfg,
418
+ )
419
+ os.makedirs(args.gen_path, exist_ok=True)
420
+ os.makedirs(args.reason_path, exist_ok=True)
421
+ os.makedirs(args.regen_path, exist_ok=True)
422
+
423
+ for i in range(args.parallel_size):
424
+ img_name = str(i).zfill(4)+".png"
425
+ save_path = os.path.join(args.gen_path, img_name)
426
+ images[i].save(save_path)
427
+
428
+ with open(args.reason_path, 'w') as f:
429
+ for i in range(args.parallel_size):
430
+ reason_data = {"prompt": args.caption}
431
+ img_name = str(i).zfill(4)
432
+ reason_data["filename"] = os.path.join(args.gen_path, f"{img_name}.png")
433
+ reason_data["correct"] = bool(selfcheck[i])
434
+ reason_data["reason"] = vl_chat_processor.tokenizer.decode(output_text_ids[i].cpu().tolist(), skip_special_tokens=True)
435
+ reason_data = json.dumps(reason_data, ensure_ascii=False)
436
+ f.write(reason_data+'\n')
437
+
438
+
439
+ for i in range(args.parallel_size):
440
+ img_name = str(i).zfill(4)+".png"
441
+ save_path = os.path.join(args.regen_path, img_name)
442
+ if selfcheck[i]:
443
+ images[i].save(save_path)
444
+ else:
445
+ regen_images[i].save(save_path)
446
+ ```
447
+
448
+
449
+ ## 🤝 Acknowledgment
450
+
451
+ Our project is developed based on the following repositories:
452
+
453
+ - [Janus-Series](https://github.com/deepseek-ai/Janus): Unified Multimodal Understanding and Generation Models
454
+ - [Open-R1](https://github.com/huggingface/open-r1): Fully open reproduction of DeepSeek-R1
455
+
456
+ ## 📜 Citation
457
+
458
+ If you find this work useful for your research, please cite our paper and star our git repo:
459
+
460
+ ```bibtex
461
+ @article{pan2025unlocking,
462
+ title={Unlocking Aha Moments via Reinforcement Learning: Advancing Collaborative Visual Comprehension and Generation},
463
+ author={Pan, Kaihang and Wu, Yang and Bu, Wendong and Shen, Kai and Li, Juncheng and Wang, Yingting and Li, Yunfei and Tang, Siliang and Xiao, Jun and Wu, Fei and others},
464
+ journal={arXiv preprint arXiv:2506.01480},
465
+ year={2025}
466
+ }
467
+ ```