Add pipeline tag, library name, and additional tags
#1
by
nielsr
HF Staff
- opened
README.md
CHANGED
|
@@ -1,6 +1,15 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
|
|
|
| 4 |
<h2 align="center" style="line-height: 25px;">
|
| 5 |
Unlocking Aha Moments via Reinforcement Learning: Advancing Collaborative Visual Comprehension and Generation
|
| 6 |
</h2>
|
|
@@ -128,8 +137,12 @@ def generate_with_refine(
|
|
| 128 |
task_list: List[int] = [1,2,3],
|
| 129 |
):
|
| 130 |
prompt = [
|
| 131 |
-
'
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
]
|
| 134 |
all_imgs_1,embeds_1,attention_mask_1 = [],[],[]
|
| 135 |
output_text_ids,selfcheck,attention_mask_txt = [],[],[]
|
|
@@ -209,8 +222,8 @@ def generate_with_refine(
|
|
| 209 |
reflect_len = 0
|
| 210 |
eos_list = torch.zeros((parallel_size, 1), dtype=torch.int).cuda()
|
| 211 |
add_padding = torch.zeros((parallel_size, 1), dtype=torch.int).cuda()
|
| 212 |
-
eos_token = vl_chat_processor.tokenizer.encode("<|end
|
| 213 |
-
padding_token = vl_chat_processor.tokenizer.encode("
|
| 214 |
yes_token = vl_chat_processor.tokenizer.encode("Yes")[-1]
|
| 215 |
no_token = vl_chat_processor.tokenizer.encode("No")[-1]
|
| 216 |
attn_mask = torch.ones((parallel_size, inputs_embeds.shape[1]), dtype=torch.int).cuda()
|
|
@@ -234,7 +247,7 @@ def generate_with_refine(
|
|
| 234 |
dim=-1,
|
| 235 |
descending=True)
|
| 236 |
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
| 237 |
-
mask = probs_sum - probs_sort >
|
| 238 |
probs_sort[mask] = 0.0
|
| 239 |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
| 240 |
next_token = torch.multinomial(probs_sort, num_samples=1)
|
|
@@ -384,7 +397,9 @@ if __name__ == "__main__":
|
|
| 384 |
|
| 385 |
# You can flexibly modify the code here to perform batched inference.
|
| 386 |
allprompts = []
|
| 387 |
-
# prompt = f'<|User|>: {args.caption}
|
|
|
|
|
|
|
| 388 |
conversation = [
|
| 389 |
{
|
| 390 |
"role": "<|User|>",
|
|
@@ -434,7 +449,8 @@ if __name__ == "__main__":
|
|
| 434 |
reason_data["correct"] = bool(selfcheck[i])
|
| 435 |
reason_data["reason"] = vl_chat_processor.tokenizer.decode(output_text_ids[i].cpu().tolist(), skip_special_tokens=True)
|
| 436 |
reason_data = json.dumps(reason_data, ensure_ascii=False)
|
| 437 |
-
f.write(reason_data+'
|
|
|
|
| 438 |
|
| 439 |
|
| 440 |
for i in range(args.parallel_size):
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
+
pipeline_tag: any-to-any
|
| 4 |
+
library_name: transformers
|
| 5 |
+
tags:
|
| 6 |
+
- multimodal
|
| 7 |
+
- text-to-image
|
| 8 |
+
- image-to-image
|
| 9 |
+
- image-to-text
|
| 10 |
+
- vqa
|
| 11 |
---
|
| 12 |
+
|
| 13 |
<h2 align="center" style="line-height: 25px;">
|
| 14 |
Unlocking Aha Moments via Reinforcement Learning: Advancing Collaborative Visual Comprehension and Generation
|
| 15 |
</h2>
|
|
|
|
| 137 |
task_list: List[int] = [1,2,3],
|
| 138 |
):
|
| 139 |
prompt = [
|
| 140 |
+
'
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
Let me think Does this image match the prompt...',
|
| 144 |
+
'<|end of sentence|>
|
| 145 |
+
Next, I will draw a new image<begin_of_image>'
|
| 146 |
]
|
| 147 |
all_imgs_1,embeds_1,attention_mask_1 = [],[],[]
|
| 148 |
output_text_ids,selfcheck,attention_mask_txt = [],[],[]
|
|
|
|
| 222 |
reflect_len = 0
|
| 223 |
eos_list = torch.zeros((parallel_size, 1), dtype=torch.int).cuda()
|
| 224 |
add_padding = torch.zeros((parallel_size, 1), dtype=torch.int).cuda()
|
| 225 |
+
eos_token = vl_chat_processor.tokenizer.encode("<|end of sentence|>")[-1]
|
| 226 |
+
padding_token = vl_chat_processor.tokenizer.encode("<| pad |>")[-1]
|
| 227 |
yes_token = vl_chat_processor.tokenizer.encode("Yes")[-1]
|
| 228 |
no_token = vl_chat_processor.tokenizer.encode("No")[-1]
|
| 229 |
attn_mask = torch.ones((parallel_size, inputs_embeds.shape[1]), dtype=torch.int).cuda()
|
|
|
|
| 247 |
dim=-1,
|
| 248 |
descending=True)
|
| 249 |
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
| 250 |
+
mask = probs_sum - probs_sort > img_top_p
|
| 251 |
probs_sort[mask] = 0.0
|
| 252 |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
| 253 |
next_token = torch.multinomial(probs_sort, num_samples=1)
|
|
|
|
| 397 |
|
| 398 |
# You can flexibly modify the code here to perform batched inference.
|
| 399 |
allprompts = []
|
| 400 |
+
# prompt = f'<|User|>: {args.caption}
|
| 401 |
+
|
| 402 |
+
<|Assistant|>:<begin_of_image>'
|
| 403 |
conversation = [
|
| 404 |
{
|
| 405 |
"role": "<|User|>",
|
|
|
|
| 449 |
reason_data["correct"] = bool(selfcheck[i])
|
| 450 |
reason_data["reason"] = vl_chat_processor.tokenizer.decode(output_text_ids[i].cpu().tolist(), skip_special_tokens=True)
|
| 451 |
reason_data = json.dumps(reason_data, ensure_ascii=False)
|
| 452 |
+
f.write(reason_data+'
|
| 453 |
+
')
|
| 454 |
|
| 455 |
|
| 456 |
for i in range(args.parallel_size):
|