Update generation_utils.py
Browse files- generation_utils.py +21 -94
generation_utils.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# coding=utf-8
|
| 2 |
-
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team.
|
| 3 |
#
|
| 4 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
# you may not use this file except in compliance with the License.
|
|
@@ -85,8 +85,7 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confid
|
|
| 85 |
if neg_entropy:
|
| 86 |
epsilon = 1e-10
|
| 87 |
log_probs = torch.log(probs + epsilon)
|
| 88 |
-
|
| 89 |
-
confidence = -(probs * log_probs).sum(dim=-1)
|
| 90 |
|
| 91 |
return confidence, x0
|
| 92 |
|
|
@@ -110,10 +109,6 @@ class DreamGenerationConfig(GenerationConfig):
|
|
| 110 |
self.alg: str = kwargs.pop("alg", 'origin')
|
| 111 |
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
|
| 112 |
|
| 113 |
-
# RCR specific parameters
|
| 114 |
-
self.rcr: bool = kwargs.pop("rcr", False)
|
| 115 |
-
self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
|
| 116 |
-
|
| 117 |
# Parameters that define the output variables of `generate`
|
| 118 |
self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
|
| 119 |
self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
|
|
@@ -169,58 +164,6 @@ class DreamGenerationMixin:
|
|
| 169 |
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
|
| 170 |
return input_ids, attention_mask
|
| 171 |
|
| 172 |
-
def _apply_rcr_logic(self, x, x0, confidence, mask_index, overtime_confidence,
|
| 173 |
-
mask_token_id, step, total_steps, s, t):
|
| 174 |
-
"""
|
| 175 |
-
RCR:在 Dream 原逻辑上做“最小侵入”改动,使其真正生效。
|
| 176 |
-
- 仍采用 Dream 的调度:本步 global k = num_mask_token * (1 - s/t)
|
| 177 |
-
- 逐样本 clamp,避免批均值 k 在样本上越界
|
| 178 |
-
- 目标累计约束:到本步为止累计应已生成 target_cum = num_mask_token * (1 - s/t)。
|
| 179 |
-
若当前累计 > 目标,按最低置信度反遮盖回 [MASK]。
|
| 180 |
-
"""
|
| 181 |
-
device = x.device
|
| 182 |
-
B, L = x.shape
|
| 183 |
-
|
| 184 |
-
# 与 Dream 保持一致:使用“批均值”的 num_mask_token 与 (1 - s/t) 调度定义
|
| 185 |
-
num_mask_token = (mask_index.sum() / mask_index.shape[0]).item()
|
| 186 |
-
k_global = int(num_mask_token * (1 - (s / t).item())) if step < total_steps - 1 else int(num_mask_token)
|
| 187 |
-
|
| 188 |
-
# 构造全长置信度和临时候选(非 mask 位置分别置为 -inf / mask_token)
|
| 189 |
-
full_conf = torch.full_like(x, -torch.inf, device=device, dtype=confidence.dtype)
|
| 190 |
-
x_temp = torch.zeros_like(x, device=device, dtype=torch.long) + mask_token_id
|
| 191 |
-
full_conf[mask_index] = confidence
|
| 192 |
-
x_temp[mask_index] = x0.clone()
|
| 193 |
-
|
| 194 |
-
for j in range(B):
|
| 195 |
-
# 逐样本 clamp
|
| 196 |
-
masked_count_j = int(mask_index[j].sum().item())
|
| 197 |
-
k_j = min(k_global, masked_count_j)
|
| 198 |
-
if k_j > 0:
|
| 199 |
-
# 只在 mask 内选 topk(非 mask 位置 full_conf 为 -inf,不会被选中)
|
| 200 |
-
_, select_idx = torch.topk(full_conf[j], k_j, largest=True)
|
| 201 |
-
x[j, select_idx] = x_temp[j, select_idx]
|
| 202 |
-
overtime_confidence[j, select_idx] = full_conf[j, select_idx].clone().float()
|
| 203 |
-
|
| 204 |
-
# ===== 目标累计约束 + 反遮盖 =====
|
| 205 |
-
if step < total_steps - 1:
|
| 206 |
-
# Dream 的“到本步为止累计应已生成”的目标数量
|
| 207 |
-
target_cum = int(num_mask_token * (1 - (s / t).item()))
|
| 208 |
-
# 当前已生成的数量(overtime_confidence>0 的位置视为已确定)
|
| 209 |
-
gen_mask = overtime_confidence[j] > 0
|
| 210 |
-
current_gen = int(gen_mask.sum().item())
|
| 211 |
-
|
| 212 |
-
# 若超过目标,反遮盖(remask)最低置信度的那部分,使当前累计 ≈ 目标累计
|
| 213 |
-
to_remask = max(0, current_gen - target_cum)
|
| 214 |
-
if to_remask > 0:
|
| 215 |
-
gen_indices = torch.where(gen_mask)[0]
|
| 216 |
-
if gen_indices.numel() > 0:
|
| 217 |
-
gen_conf = overtime_confidence[j, gen_indices]
|
| 218 |
-
to_remask = min(to_remask, int(gen_indices.numel()))
|
| 219 |
-
_, local_low = torch.topk(gen_conf, k=to_remask, largest=False)
|
| 220 |
-
low_global = gen_indices[local_low]
|
| 221 |
-
x[j, low_global] = mask_token_id
|
| 222 |
-
overtime_confidence[j, low_global] = 0.0
|
| 223 |
-
|
| 224 |
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
|
| 225 |
"""Performs validation related to the resulting generated length"""
|
| 226 |
|
|
@@ -439,10 +382,6 @@ class DreamGenerationMixin:
|
|
| 439 |
top_p = generation_config.top_p
|
| 440 |
top_k = generation_config.top_k
|
| 441 |
|
| 442 |
-
# RCR specific values
|
| 443 |
-
rcr = generation_config.rcr
|
| 444 |
-
conf_alg = generation_config.conf_alg
|
| 445 |
-
|
| 446 |
histories = [] if (return_dict_in_generate and output_history) else None
|
| 447 |
|
| 448 |
# pad input_ids to max_length
|
|
@@ -465,9 +404,6 @@ class DreamGenerationMixin:
|
|
| 465 |
|
| 466 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
| 467 |
|
| 468 |
-
# RCR tracking - initialize overtime confidence tracking
|
| 469 |
-
overtime_confidence = torch.zeros_like(x, dtype=torch.float32) if rcr else None
|
| 470 |
-
|
| 471 |
# this allows user-defined token control of the intermediate steps
|
| 472 |
x = generation_tokens_hook_func(None, x, None)
|
| 473 |
for i in range(steps):
|
|
@@ -489,38 +425,29 @@ class DreamGenerationMixin:
|
|
| 489 |
_, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
|
| 490 |
x[mask_index] = x0.clone()
|
| 491 |
else:
|
| 492 |
-
if alg == 'maskgit_plus'
|
| 493 |
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
| 494 |
-
elif alg == 'topk_margin'
|
| 495 |
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
|
| 496 |
-
elif alg == 'entropy'
|
| 497 |
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
|
| 498 |
else:
|
| 499 |
raise RuntimeError(f"Unknown alg: {alg}")
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
_, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
|
| 516 |
-
else:
|
| 517 |
-
full_confidence = full_confidence / alg_temp
|
| 518 |
-
full_confidence = F.softmax(full_confidence, dim=-1)
|
| 519 |
-
transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
|
| 520 |
-
x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
|
| 521 |
-
x_[mask_index] = x0.clone()
|
| 522 |
-
row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
|
| 523 |
-
x[row_indices,transfer_index] = x_[row_indices,transfer_index]
|
| 524 |
|
| 525 |
# this allows user-defined token control of the intermediate steps
|
| 526 |
x = generation_tokens_hook_func(i, x, logits)
|
|
@@ -534,4 +461,4 @@ class DreamGenerationMixin:
|
|
| 534 |
history=histories,
|
| 535 |
)
|
| 536 |
else:
|
| 537 |
-
return x
|
|
|
|
| 1 |
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
#
|
| 4 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
# you may not use this file except in compliance with the License.
|
|
|
|
| 85 |
if neg_entropy:
|
| 86 |
epsilon = 1e-10
|
| 87 |
log_probs = torch.log(probs + epsilon)
|
| 88 |
+
confidence = torch.sum(probs * log_probs, dim=-1)
|
|
|
|
| 89 |
|
| 90 |
return confidence, x0
|
| 91 |
|
|
|
|
| 109 |
self.alg: str = kwargs.pop("alg", 'origin')
|
| 110 |
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
# Parameters that define the output variables of `generate`
|
| 113 |
self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
|
| 114 |
self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
|
|
|
|
| 164 |
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
|
| 165 |
return input_ids, attention_mask
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
|
| 168 |
"""Performs validation related to the resulting generated length"""
|
| 169 |
|
|
|
|
| 382 |
top_p = generation_config.top_p
|
| 383 |
top_k = generation_config.top_k
|
| 384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
histories = [] if (return_dict_in_generate and output_history) else None
|
| 386 |
|
| 387 |
# pad input_ids to max_length
|
|
|
|
| 404 |
|
| 405 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
| 406 |
|
|
|
|
|
|
|
|
|
|
| 407 |
# this allows user-defined token control of the intermediate steps
|
| 408 |
x = generation_tokens_hook_func(None, x, None)
|
| 409 |
for i in range(steps):
|
|
|
|
| 425 |
_, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
|
| 426 |
x[mask_index] = x0.clone()
|
| 427 |
else:
|
| 428 |
+
if alg == 'maskgit_plus':
|
| 429 |
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
| 430 |
+
elif alg == 'topk_margin':
|
| 431 |
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
|
| 432 |
+
elif alg == 'entropy':
|
| 433 |
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
|
| 434 |
else:
|
| 435 |
raise RuntimeError(f"Unknown alg: {alg}")
|
| 436 |
+
num_mask_token = mask_index.sum() / mask_index.shape[0]
|
| 437 |
+
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
|
| 438 |
+
full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
|
| 439 |
+
full_confidence[mask_index] = confidence
|
| 440 |
+
if number_transfer_tokens > 0:
|
| 441 |
+
if alg_temp is None or alg_temp == 0:
|
| 442 |
+
_, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
|
| 443 |
+
else:
|
| 444 |
+
full_confidence = full_confidence / alg_temp
|
| 445 |
+
full_confidence = F.softmax(full_confidence, dim=-1)
|
| 446 |
+
transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
|
| 447 |
+
x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
|
| 448 |
+
x_[mask_index] = x0.clone()
|
| 449 |
+
row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
|
| 450 |
+
x[row_indices,transfer_index] = x_[row_indices,transfer_index]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
|
| 452 |
# this allows user-defined token control of the intermediate steps
|
| 453 |
x = generation_tokens_hook_func(i, x, logits)
|
|
|
|
| 461 |
history=histories,
|
| 462 |
)
|
| 463 |
else:
|
| 464 |
+
return x
|