autoprogrammer commited on
Commit
e350bc1
·
verified ·
1 Parent(s): 8c62663

Update generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +21 -94
generation_utils.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
@@ -85,8 +85,7 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confid
85
  if neg_entropy:
86
  epsilon = 1e-10
87
  log_probs = torch.log(probs + epsilon)
88
- # 注意:这里返回的是“负熵”的相反数(越大越自信)
89
- confidence = -(probs * log_probs).sum(dim=-1)
90
 
91
  return confidence, x0
92
 
@@ -110,10 +109,6 @@ class DreamGenerationConfig(GenerationConfig):
110
  self.alg: str = kwargs.pop("alg", 'origin')
111
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
112
 
113
- # RCR specific parameters
114
- self.rcr: bool = kwargs.pop("rcr", False)
115
- self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
116
-
117
  # Parameters that define the output variables of `generate`
118
  self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
119
  self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
@@ -169,58 +164,6 @@ class DreamGenerationMixin:
169
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
170
  return input_ids, attention_mask
171
 
172
- def _apply_rcr_logic(self, x, x0, confidence, mask_index, overtime_confidence,
173
- mask_token_id, step, total_steps, s, t):
174
- """
175
- RCR:在 Dream 原逻辑上做“最小侵入”改动,使其真正生效。
176
- - 仍采用 Dream 的调度:本步 global k = num_mask_token * (1 - s/t)
177
- - 逐样本 clamp,避免批均值 k 在样本上越界
178
- - 目标累计约束:到本步为止累计应已生成 target_cum = num_mask_token * (1 - s/t)。
179
- 若当前累计 > 目标,按最低置信度反遮盖回 [MASK]。
180
- """
181
- device = x.device
182
- B, L = x.shape
183
-
184
- # 与 Dream 保持一致:使用“批均值”的 num_mask_token 与 (1 - s/t) 调度定义
185
- num_mask_token = (mask_index.sum() / mask_index.shape[0]).item()
186
- k_global = int(num_mask_token * (1 - (s / t).item())) if step < total_steps - 1 else int(num_mask_token)
187
-
188
- # 构造全长置信度和临时候选(非 mask 位置分别置为 -inf / mask_token)
189
- full_conf = torch.full_like(x, -torch.inf, device=device, dtype=confidence.dtype)
190
- x_temp = torch.zeros_like(x, device=device, dtype=torch.long) + mask_token_id
191
- full_conf[mask_index] = confidence
192
- x_temp[mask_index] = x0.clone()
193
-
194
- for j in range(B):
195
- # 逐样本 clamp
196
- masked_count_j = int(mask_index[j].sum().item())
197
- k_j = min(k_global, masked_count_j)
198
- if k_j > 0:
199
- # 只在 mask 内选 topk(非 mask 位置 full_conf 为 -inf,不会被选中)
200
- _, select_idx = torch.topk(full_conf[j], k_j, largest=True)
201
- x[j, select_idx] = x_temp[j, select_idx]
202
- overtime_confidence[j, select_idx] = full_conf[j, select_idx].clone().float()
203
-
204
- # ===== 目标累计约束 + 反遮盖 =====
205
- if step < total_steps - 1:
206
- # Dream 的“到本步为止累计应已生成”的目标数量
207
- target_cum = int(num_mask_token * (1 - (s / t).item()))
208
- # 当前已生成的数量(overtime_confidence>0 的位置视为已确定)
209
- gen_mask = overtime_confidence[j] > 0
210
- current_gen = int(gen_mask.sum().item())
211
-
212
- # 若超过目标,反遮盖(remask)最低置信度的那部分,使当前累计 ≈ 目标累计
213
- to_remask = max(0, current_gen - target_cum)
214
- if to_remask > 0:
215
- gen_indices = torch.where(gen_mask)[0]
216
- if gen_indices.numel() > 0:
217
- gen_conf = overtime_confidence[j, gen_indices]
218
- to_remask = min(to_remask, int(gen_indices.numel()))
219
- _, local_low = torch.topk(gen_conf, k=to_remask, largest=False)
220
- low_global = gen_indices[local_low]
221
- x[j, low_global] = mask_token_id
222
- overtime_confidence[j, low_global] = 0.0
223
-
224
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
225
  """Performs validation related to the resulting generated length"""
226
 
@@ -439,10 +382,6 @@ class DreamGenerationMixin:
439
  top_p = generation_config.top_p
440
  top_k = generation_config.top_k
441
 
442
- # RCR specific values
443
- rcr = generation_config.rcr
444
- conf_alg = generation_config.conf_alg
445
-
446
  histories = [] if (return_dict_in_generate and output_history) else None
447
 
448
  # pad input_ids to max_length
@@ -465,9 +404,6 @@ class DreamGenerationMixin:
465
 
466
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
467
 
468
- # RCR tracking - initialize overtime confidence tracking
469
- overtime_confidence = torch.zeros_like(x, dtype=torch.float32) if rcr else None
470
-
471
  # this allows user-defined token control of the intermediate steps
472
  x = generation_tokens_hook_func(None, x, None)
473
  for i in range(steps):
@@ -489,38 +425,29 @@ class DreamGenerationMixin:
489
  _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
490
  x[mask_index] = x0.clone()
491
  else:
492
- if alg == 'maskgit_plus' or (rcr and conf_alg == 'maskgit_plus'):
493
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
494
- elif alg == 'topk_margin' or (rcr and conf_alg == 'topk_margin'):
495
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
496
- elif alg == 'entropy' or (rcr and conf_alg == 'entropy'):
497
  confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
498
  else:
499
  raise RuntimeError(f"Unknown alg: {alg}")
500
-
501
- # Apply RCR logic if enabled
502
- if rcr:
503
- print(f"[RCR EXEC] Step {i}: RCR logic executed")
504
- self._apply_rcr_logic(x, x0, confidence, mask_index, overtime_confidence,
505
- mask_token_id, i, steps, s, t)
506
- else:
507
- # Original Dream sampling logic
508
- num_mask_token = mask_index.sum() / mask_index.shape[0]
509
- number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
510
- # --------- 仅此处小修:device 用 x.device,避免跨设备 ----------
511
- full_confidence = torch.full_like(x, -torch.inf, device=x.device, dtype=logits.dtype)
512
- full_confidence[mask_index] = confidence
513
- if number_transfer_tokens > 0:
514
- if alg_temp is None or alg_temp == 0:
515
- _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
516
- else:
517
- full_confidence = full_confidence / alg_temp
518
- full_confidence = F.softmax(full_confidence, dim=-1)
519
- transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
520
- x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
521
- x_[mask_index] = x0.clone()
522
- row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
523
- x[row_indices,transfer_index] = x_[row_indices,transfer_index]
524
 
525
  # this allows user-defined token control of the intermediate steps
526
  x = generation_tokens_hook_func(i, x, logits)
@@ -534,4 +461,4 @@ class DreamGenerationMixin:
534
  history=histories,
535
  )
536
  else:
537
- return x
 
1
  # coding=utf-8
2
+ # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
85
  if neg_entropy:
86
  epsilon = 1e-10
87
  log_probs = torch.log(probs + epsilon)
88
+ confidence = torch.sum(probs * log_probs, dim=-1)
 
89
 
90
  return confidence, x0
91
 
 
109
  self.alg: str = kwargs.pop("alg", 'origin')
110
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
111
 
 
 
 
 
112
  # Parameters that define the output variables of `generate`
113
  self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
114
  self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
 
164
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
165
  return input_ids, attention_mask
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
168
  """Performs validation related to the resulting generated length"""
169
 
 
382
  top_p = generation_config.top_p
383
  top_k = generation_config.top_k
384
 
 
 
 
 
385
  histories = [] if (return_dict_in_generate and output_history) else None
386
 
387
  # pad input_ids to max_length
 
404
 
405
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
406
 
 
 
 
407
  # this allows user-defined token control of the intermediate steps
408
  x = generation_tokens_hook_func(None, x, None)
409
  for i in range(steps):
 
425
  _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
426
  x[mask_index] = x0.clone()
427
  else:
428
+ if alg == 'maskgit_plus':
429
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
430
+ elif alg == 'topk_margin':
431
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
432
+ elif alg == 'entropy':
433
  confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
434
  else:
435
  raise RuntimeError(f"Unknown alg: {alg}")
436
+ num_mask_token = mask_index.sum() / mask_index.shape[0]
437
+ number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
438
+ full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
439
+ full_confidence[mask_index] = confidence
440
+ if number_transfer_tokens > 0:
441
+ if alg_temp is None or alg_temp == 0:
442
+ _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
443
+ else:
444
+ full_confidence = full_confidence / alg_temp
445
+ full_confidence = F.softmax(full_confidence, dim=-1)
446
+ transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
447
+ x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
448
+ x_[mask_index] = x0.clone()
449
+ row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
450
+ x[row_indices,transfer_index] = x_[row_indices,transfer_index]
 
 
 
 
 
 
 
 
 
451
 
452
  # this allows user-defined token control of the intermediate steps
453
  x = generation_tokens_hook_func(i, x, logits)
 
461
  history=histories,
462
  )
463
  else:
464
+ return x