autoprogrammer commited on
Commit
3d20c97
·
verified ·
1 Parent(s): 34c9b0b

Update generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +177 -144
generation_utils.py CHANGED
@@ -1,6 +1,17 @@
1
  # coding=utf-8
2
- # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace
3
- # Licensed under the Apache License, Version 2.0
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import warnings
6
  import copy
@@ -23,32 +34,30 @@ def top_p_logits(logits, top_p=None):
23
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
24
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
25
  sorted_indices_to_remove = cumulative_probs > top_p
26
- # keep first token above threshold
27
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
28
  sorted_indices_to_remove[..., 0] = 0
 
29
  mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
30
  mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
31
- return logits.masked_fill(mask, torch.finfo(logits.dtype).min)
 
32
 
33
 
34
  def top_k_logits(logits, top_k=None):
35
  if top_k is None:
36
  return logits
37
- top_k = min(top_k, logits.size(-1))
38
- thresh = torch.topk(logits, top_k)[0][..., -1, None]
39
- indices_to_remove = logits < thresh
40
- return logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
41
-
42
-
43
- def sample_tokens(
44
- logits,
45
- temperature: float = 0.0,
46
- top_p: Optional[float] = None,
47
- top_k: Optional[int] = None,
48
- margin_confidence: bool = False,
49
- neg_entropy: bool = False,
50
- ):
51
- # 保持 dtype 与 logits 一致(包含 bf16/fp16)
52
  if temperature and temperature > 0:
53
  logits = logits / temperature
54
  if top_p is not None and top_p < 1:
@@ -58,27 +67,27 @@ def sample_tokens(
58
 
59
  probs = torch.softmax(logits, dim=-1)
60
 
 
61
  if temperature and temperature > 0:
62
- # 采样
63
  try:
64
  x0 = dists.Categorical(probs=probs).sample()
65
  confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
66
  except Exception:
67
  confidence, x0 = probs.max(dim=-1)
68
  else:
69
- # 贪心
70
  confidence, x0 = probs.max(dim=-1)
71
 
 
72
  if margin_confidence:
73
  sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
74
- top1_probs = sorted_probs[..., 0]
75
- top2_probs = sorted_probs[..., 1]
76
  confidence = top1_probs - top2_probs
77
 
78
  if neg_entropy:
79
- eps = probs.new_tensor(1e-10)
80
- log_probs = torch.log(probs + eps)
81
- # 负熵(和为负数),数值上越大(绝对值越小)表示不确定;此处直接用于排序
82
  confidence = torch.sum(probs * log_probs, dim=-1)
83
 
84
  return confidence, x0
@@ -92,23 +101,27 @@ class DreamModelOutput(ModelOutput):
92
 
93
  class DreamGenerationConfig(GenerationConfig):
94
  def __init__(self, **kwargs):
 
95
  self.temperature: float = kwargs.pop("temperature", 0.0)
96
  self.top_p: Optional[float] = kwargs.pop("top_p", None)
97
  self.top_k: Optional[int] = kwargs.pop("top_k", None)
 
 
98
  self.max_length = kwargs.pop("max_length", 20)
99
  self.max_new_tokens = kwargs.pop("max_new_tokens", None)
100
 
101
  # diffusion specific params
102
  self.eps: float = kwargs.pop("eps", 1e-3)
103
  self.steps: int = kwargs.pop("steps", 512)
104
- self.alg: str = kwargs.pop("alg", "origin")
105
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
106
 
107
- # RCR 参数(默认不生效)
108
  self.rcr: bool = kwargs.pop("rcr", False)
109
- self.conf_alg: str = kwargs.pop("conf_alg", "maskgit_plus")
 
110
 
111
- # generate 输出控制
112
  self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
113
  self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
114
  self.output_history: bool = kwargs.pop("output_history", False)
@@ -119,9 +132,10 @@ class DreamGenerationConfig(GenerationConfig):
119
  self.bos_token_id = kwargs.pop("bos_token_id", None)
120
  self.eos_token_id = kwargs.pop("eos_token_id", None)
121
 
 
122
  self.generation_kwargs = kwargs.pop("generation_kwargs", {})
123
 
124
- # hub meta
125
  self._from_model_config = kwargs.pop("_from_model_config", False)
126
  self._commit_hash = kwargs.pop("_commit_hash", None)
127
  self.transformers_version = kwargs.pop("transformers_version", __version__)
@@ -137,7 +151,6 @@ class DreamGenerationConfig(GenerationConfig):
137
  self.validate(is_init=True)
138
 
139
  def validate(self, is_init=False):
140
- # 保留空实现,兼容 upstream
141
  pass
142
 
143
 
@@ -146,7 +159,7 @@ class DreamGenerationMixin:
146
  def _expand_inputs_for_generation(
147
  expand_size: int = 1,
148
  input_ids: Optional[torch.LongTensor] = None,
149
- attention_mask: Optional[torch.LongTensor] = None,
150
  ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
151
  if expand_size == 1:
152
  return input_ids, attention_mask
@@ -156,13 +169,16 @@ class DreamGenerationMixin:
156
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
157
  return input_ids, attention_mask
158
 
 
159
  def _apply_rcr_logic(
160
  self,
161
- x: torch.LongTensor,
162
- x0_sel: torch.LongTensor,
163
- conf_sel: torch.Tensor,
164
  mask_index: torch.Tensor,
165
- overtime_confidence: torch.Tensor,
 
 
166
  mask_token_id: int,
167
  step: int,
168
  total_steps: int,
@@ -170,56 +186,68 @@ class DreamGenerationMixin:
170
  t: torch.Tensor,
171
  ):
172
  """
173
- Running Confidence Remasking (RCR)
174
- - Dream 原调度计算每步应转移的 token 数;
175
- - 先把本步最高置信度的若干个位置从 [MASK] 转为预测;
176
- - 再根据“截至本步的目标累计数量”,把最低置信度的多余部分回遮回 [MASK]。
177
- 仅在 rcr=True 时调用。
 
 
 
 
178
  """
179
  device = x.device
180
- dtype = overtime_confidence.dtype # == logits.dtype
181
- B = x.shape[0]
 
 
 
 
182
 
183
- # 当前 batch 平均剩余 mask 数
184
- num_mask_token = mask_index.sum() / mask_index.shape[0]
185
- # 本步的转移数量(与 Dream 调度一致)
186
- number_transfer_tokens = int(num_mask_token * (1 - s / t)) if step < total_steps - 1 else int(num_mask_token)
187
 
188
- # 构造“全长”置信度与候选 token(非 mask 位置分别设为 -inf / mask_token_id)
189
- full_conf = torch.full(x.shape, float("-inf"), device=device, dtype=dtype)
190
- x_temp = torch.full_like(x, fill_value=mask_token_id, dtype=torch.long, device=device)
191
- full_conf[mask_index] = conf_sel
192
- x_temp[mask_index] = x0_sel
193
 
 
194
  for j in range(B):
195
  masked_j = int(mask_index[j].sum().item())
196
- if masked_j == 0:
197
- continue
198
  k_j = min(number_transfer_tokens, masked_j)
199
-
200
  if k_j > 0:
201
- # 选出本步 top-k_j 的位置
202
- _, select_idx = torch.topk(full_conf[j], k=k_j, largest=True)
203
- x[j, select_idx] = x_temp[j, select_idx]
204
- # 记录这些位置的置信度,用于累计与回遮判断
205
- overtime_confidence[j, select_idx] = full_conf[j, select_idx]
206
-
207
- # 目标累计(与原 Dream 线性进度对齐)
 
 
 
 
208
  if step < total_steps - 1:
209
- target_cum = int(num_mask_token * (1 - s / t)) # 累计目标到当前步
210
- gen_mask = overtime_confidence[j] > overtime_confidence.new_tensor(0)
211
- current_gen = int(gen_mask.sum().item())
212
- overflow = max(0, current_gen - target_cum)
213
- if overflow > 0:
214
- gen_indices = torch.where(gen_mask)[0]
215
- if gen_indices.numel() > 0:
216
- gen_conf = overtime_confidence[j, gen_indices]
217
- overflow = min(overflow, int(gen_indices.numel()))
218
- # 选“最低置信度”的 overflow 个位置回遮
219
- _, low_local = torch.topk(gen_conf, k=overflow, largest=False)
220
- low_global = gen_indices[low_local]
221
- x[j, low_global] = mask_token_id
222
- overtime_confidence[j, low_global] = overtime_confidence.new_zeros(low_global.shape)
 
 
 
 
223
 
224
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
225
  if is_torchdynamo_compiling():
@@ -232,9 +260,11 @@ class DreamGenerationMixin:
232
  UserWarning,
233
  )
234
  if input_ids_length >= generation_config.max_length:
 
235
  raise ValueError(
236
- f"Input length is {input_ids_length}, but `max_length` is {generation_config.max_length}. "
237
- "Consider increasing `max_length` or setting `max_new_tokens`."
 
238
  )
239
 
240
  def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
@@ -242,20 +272,20 @@ class DreamGenerationMixin:
242
  if not has_default_max_length and generation_config.max_length is not None:
243
  logger.warning(
244
  f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
245
- f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence."
 
246
  )
247
  generation_config.max_length = generation_config.max_new_tokens + input_ids_length
 
248
  elif has_default_max_length:
249
  if generation_config.max_length == DreamGenerationConfig().max_length:
250
  generation_config.max_length = generation_config.max_length + input_ids_length
251
- mpe = getattr(self.config, "max_position_embeddings", None)
252
- if mpe is not None:
253
- generation_config.max_length = min(generation_config.max_length, mpe)
254
  return generation_config
255
 
256
- def _prepare_generation_config(
257
- self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict
258
- ) -> DreamGenerationConfig:
259
  using_model_generation_config = False
260
  if generation_config is None:
261
  generation_config = DreamGenerationConfig.from_model_config(self.config)
@@ -273,11 +303,10 @@ class DreamGenerationMixin:
273
  generation_config.pad_token_id = self.generation_config.pad_token_id
274
  if generation_config.mask_token_id is None:
275
  generation_config.mask_token_id = self.generation_config.mask_token_id
 
276
  return generation_config
277
 
278
- def _prepare_special_tokens(
279
- self, generation_config: DreamGenerationConfig, device: Optional[Union[torch.device, str]] = None
280
- ):
281
  def _tensor_or_none(token, device=None):
282
  if token is None:
283
  return token
@@ -332,18 +361,22 @@ class DreamGenerationMixin:
332
 
333
  if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
334
  warnings.warn(
335
- "You are calling .generate() with `input_ids` on a device type different than your model's device. "
336
- f"`input_ids` is on {input_ids.device.type}, model is on {self.device.type}.",
 
 
 
 
337
  UserWarning,
338
  )
339
-
340
  if (
341
  hasattr(generation_config, "pad_token_id")
342
  and torch.any(input_ids == generation_config.pad_token_id)
343
  and attention_mask is None
344
  ):
345
  warnings.warn(
346
- "Padding was detected but no attention mask is passed. For correct results, set `attention_mask` when batch-padding inputs.",
 
347
  UserWarning,
348
  )
349
 
@@ -368,9 +401,9 @@ class DreamGenerationMixin:
368
  attention_mask: Optional[torch.LongTensor],
369
  generation_config: DreamGenerationConfig,
370
  generation_tokens_hook_func,
371
- generation_logits_hook_func,
372
  ) -> Union[DreamModelOutput, torch.LongTensor]:
373
- # 原变量
374
  output_history = generation_config.output_history
375
  return_dict_in_generate = generation_config.return_dict_in_generate
376
  max_length = generation_config.max_length
@@ -383,13 +416,13 @@ class DreamGenerationMixin:
383
  top_p = generation_config.top_p
384
  top_k = generation_config.top_k
385
 
386
- # RCR 控制
387
  rcr = generation_config.rcr
388
  conf_alg = generation_config.conf_alg
389
 
390
  histories = [] if (return_dict_in_generate and output_history) else None
391
 
392
- # pad max_length
393
  x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
394
 
395
  if attention_mask is not None and torch.any(attention_mask == 0.0):
@@ -406,104 +439,104 @@ class DreamGenerationMixin:
406
 
407
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
408
 
409
- # 置信度累计缓冲,延迟到拿到 logits.dtype 后再初始化,避免 dtype 错误
410
- overtime_confidence = None # dtype = logits.dtype(初始化时设置)
 
 
 
 
 
 
 
 
 
411
 
412
- # 允许用户控制中间 tokens
413
  x = generation_tokens_hook_func(None, x, None)
414
 
415
  for i in range(steps):
416
  mask_index = (x == mask_token_id)
417
-
418
  logits = self(x, attention_mask, tok_idx).logits
 
419
  logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
420
 
421
- # 允许用户控制中间 logits
422
  logits = generation_logits_hook_func(i, x, logits)
423
 
424
  mask_logits = logits[mask_index]
425
  t = timesteps[i]
426
  s = timesteps[i + 1]
427
 
428
- # 首次根据 logits.dtype 初始化 overtime_confidence(避免 Float/BFloat16 冲突)
429
- if rcr and overtime_confidence is None:
430
- overtime_confidence = torch.zeros_like(x, dtype=logits.dtype, device=x.device)
431
-
432
- if alg == "origin":
433
- # 原始 Dream 逻辑(不动)
434
  p_transfer = 1 - s / t if i < steps - 1 else 1
435
- x0 = torch.full_like(x[mask_index], fill_value=mask_token_id, dtype=torch.long, device=self.device)
436
  transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
437
  _, x0[transfer_index_t_s] = sample_tokens(
438
  mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k
439
  )
440
  x[mask_index] = x0.clone()
441
-
442
  else:
443
- # 选择置信度算法
444
- use_alg = alg
445
- if rcr:
446
- # rcr=True 时,置信度算法由 conf_alg 决定(不影响 baseline)
447
- use_alg = conf_alg
448
-
449
- if use_alg == "maskgit_plus":
450
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
451
- elif use_alg == "topk_margin":
452
  confidence, x0 = sample_tokens(
453
  mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True
454
  )
455
- elif use_alg == "entropy":
456
  confidence, x0 = sample_tokens(
457
- mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True
458
  )
459
  else:
460
- raise RuntimeError(f"Unknown alg: {alg}")
461
-
462
- # 统一 full_confidence 的 dtype = logits.dtype(避免 int/float 混合)
463
- full_confidence = torch.full(
464
- x.shape, float("-inf"), device=self.device, dtype=logits.dtype
465
- )
466
- full_confidence[mask_index] = confidence
467
 
468
  if rcr:
469
- # === RCR 分支:先转移 top-k,再根据累计目标回遮 ===
470
  self._apply_rcr_logic(
471
  x=x,
472
- x0_sel=x0,
473
- conf_sel=confidence,
474
  mask_index=mask_index,
475
- overtime_confidence=overtime_confidence,
 
 
476
  mask_token_id=mask_token_id,
477
  step=i,
478
  total_steps=steps,
479
- s=s,
480
- t=t,
481
  )
482
  else:
483
- # === baseline 分支:保持 Dream 逻辑不变 ===
484
- num_mask_token = mask_index.sum() / mask_index.shape[0]
485
- number_transfer_tokens = (
486
- int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
487
- )
 
 
 
 
488
  if number_transfer_tokens > 0:
489
  if alg_temp is None or alg_temp == 0:
490
  _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
491
  else:
492
- fc = full_confidence / alg_temp
493
- fc = F.softmax(fc, dim=-1)
494
- transfer_index = torch.multinomial(fc, num_samples=number_transfer_tokens)
495
- x_ = torch.full_like(x, fill_value=mask_token_id, dtype=torch.long, device=self.device)
496
  x_[mask_index] = x0.clone()
497
  row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
498
  x[row_indices, transfer_index] = x_[row_indices, transfer_index]
499
 
500
- # 允许用户控制中间 tokens
501
  x = generation_tokens_hook_func(i, x, logits)
502
 
503
  if histories is not None:
504
  histories.append(x.clone())
505
 
506
  if return_dict_in_generate:
507
- return DreamModelOutput(sequences=x, history=histories)
 
 
 
508
  else:
509
  return x
 
1
  # coding=utf-8
2
+ # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # You may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
 
16
  import warnings
17
  import copy
 
34
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
35
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
36
  sorted_indices_to_remove = cumulative_probs > top_p
37
+ # Shift the indices to the right to keep the first token above the threshold
38
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
39
  sorted_indices_to_remove[..., 0] = 0
40
+
41
  mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
42
  mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
43
+ logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
44
+ return logits
45
 
46
 
47
  def top_k_logits(logits, top_k=None):
48
  if top_k is None:
49
  return logits
50
+ top_k = int(min(top_k, logits.size(-1))) # Safety check
51
+ if top_k <= 0:
52
+ return logits
53
+ # Remove all tokens with a probability less than the last token of the top-k
54
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
55
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
56
+ return logits
57
+
58
+
59
+ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
60
+ # logits: [N, V]
 
 
 
 
61
  if temperature and temperature > 0:
62
  logits = logits / temperature
63
  if top_p is not None and top_p < 1:
 
67
 
68
  probs = torch.softmax(logits, dim=-1)
69
 
70
+ # 采样/贪心
71
  if temperature and temperature > 0:
 
72
  try:
73
  x0 = dists.Categorical(probs=probs).sample()
74
  confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
75
  except Exception:
76
  confidence, x0 = probs.max(dim=-1)
77
  else:
 
78
  confidence, x0 = probs.max(dim=-1)
79
 
80
+ # 置信度定义切换
81
  if margin_confidence:
82
  sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
83
+ top1_probs = sorted_probs[:, 0]
84
+ top2_probs = sorted_probs[:, 1]
85
  confidence = top1_probs - top2_probs
86
 
87
  if neg_entropy:
88
+ # 负熵(数值 ≤ 0;越接近 0 越大代表越确定)
89
+ epsilon = 1e-10
90
+ log_probs = torch.log(probs + epsilon)
91
  confidence = torch.sum(probs * log_probs, dim=-1)
92
 
93
  return confidence, x0
 
101
 
102
  class DreamGenerationConfig(GenerationConfig):
103
  def __init__(self, **kwargs):
104
+ # sampling
105
  self.temperature: float = kwargs.pop("temperature", 0.0)
106
  self.top_p: Optional[float] = kwargs.pop("top_p", None)
107
  self.top_k: Optional[int] = kwargs.pop("top_k", None)
108
+
109
+ # length
110
  self.max_length = kwargs.pop("max_length", 20)
111
  self.max_new_tokens = kwargs.pop("max_new_tokens", None)
112
 
113
  # diffusion specific params
114
  self.eps: float = kwargs.pop("eps", 1e-3)
115
  self.steps: int = kwargs.pop("steps", 512)
116
+ self.alg: str = kwargs.pop("alg", 'origin') # 'origin' | 'maskgit_plus' | 'topk_margin' | 'entropy'
117
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
118
 
119
+ # === RCR 参数(新增,默认关闭,不影响原逻辑) ===
120
  self.rcr: bool = kwargs.pop("rcr", False)
121
+ # 仅在 rcr=True 时用于选择置信度算法;rcr=False 不读取它
122
+ self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
123
 
124
+ # generate outputs
125
  self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
126
  self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
127
  self.output_history: bool = kwargs.pop("output_history", False)
 
132
  self.bos_token_id = kwargs.pop("bos_token_id", None)
133
  self.eos_token_id = kwargs.pop("eos_token_id", None)
134
 
135
+ # misc
136
  self.generation_kwargs = kwargs.pop("generation_kwargs", {})
137
 
138
+ # bookkeeping
139
  self._from_model_config = kwargs.pop("_from_model_config", False)
140
  self._commit_hash = kwargs.pop("_commit_hash", None)
141
  self.transformers_version = kwargs.pop("transformers_version", __version__)
 
151
  self.validate(is_init=True)
152
 
153
  def validate(self, is_init=False):
 
154
  pass
155
 
156
 
 
159
  def _expand_inputs_for_generation(
160
  expand_size: int = 1,
161
  input_ids: Optional[torch.LongTensor] = None,
162
+ attention_mask: Optional[torch.LongTensor] = None
163
  ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
164
  if expand_size == 1:
165
  return input_ids, attention_mask
 
169
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
170
  return input_ids, attention_mask
171
 
172
+ # === 新版:RCR 核心(历史置信度) ===
173
  def _apply_rcr_logic(
174
  self,
175
+ x: torch.Tensor,
176
+ x0: torch.Tensor,
177
+ conf_now: torch.Tensor,
178
  mask_index: torch.Tensor,
179
+ fixed_conf: torch.Tensor,
180
+ gen_mask: torch.Tensor,
181
+ init_mask_count: torch.Tensor,
182
  mask_token_id: int,
183
  step: int,
184
  total_steps: int,
 
186
  t: torch.Tensor,
187
  ):
188
  """
189
+ Running Confidence Remasking(历史置信度版):
190
+ 1) mask 子集内以当步置信度 conf_now 选择 top-k_j 个位置“确认”(写 token);
191
+ 2) 更新历史置信度 fixed_conf = max(fixed_conf, conf_now)(仅对新选入位置);
192
+ 3) 按“累计允许确认配额” target_cum = init_mask_count * (1 - s/t) 若超额,
193
+ 在已确认集合 gen_mask 内按 fixed_conf 最低回遮 over 个位置。
194
+
195
+ 说明:
196
+ - conf_now 用 float32 维护,避免与 bfloat16 混写导致 dtype 报错;
197
+ - 对 entropy:conf_now = 负熵(≤0 且越接近 0 越大代表越确定),配合 topk(largest=True) 没问题。
198
  """
199
  device = x.device
200
+ B, L = x.shape
201
+
202
+ # 计算“当步”选入规模(与 vanilla 同口径:平均剩余 mask * (1 - s/t))
203
+ avg_mask_now = (mask_index.sum().item() / max(1, mask_index.shape[0]))
204
+ ratio = (1.0 - (s.item() / t.item())) if step < total_steps - 1 else 1.0
205
+ number_transfer_tokens = int(avg_mask_now * ratio)
206
 
207
+ # 确保当步置信度是 float32
208
+ conf_now = conf_now.to(torch.float32)
 
 
209
 
210
+ # 仅在 mask 处有效的“全长”视图
211
+ full_conf_now = torch.full((B, L), float("-inf"), dtype=torch.float32, device=device)
212
+ full_x0 = torch.full((B, L), mask_token_id, dtype=torch.long, device=device)
213
+ full_conf_now[mask_index] = conf_now
214
+ full_x0[mask_index] = x0
215
 
216
+ # 逐样本处理
217
  for j in range(B):
218
  masked_j = int(mask_index[j].sum().item())
 
 
219
  k_j = min(number_transfer_tokens, masked_j)
 
220
  if k_j > 0:
221
+ conf_row = full_conf_now[j] # float32
222
+ # 选当步 top-k_j
223
+ _, sel_idx = torch.topk(conf_row, k=k_j, largest=True)
224
+ # 写 token & 标记确认
225
+ x[j, sel_idx] = full_x0[j, sel_idx]
226
+ gen_mask[j, sel_idx] = True
227
+ # 历史置信度取 running max
228
+ fixed_conf[j, sel_idx] = torch.maximum(fixed_conf[j, sel_idx], conf_row[sel_idx])
229
+
230
+ # 累计允许确认配额(以初始 mask 为基数)
231
+ init_m = int(init_mask_count[j].item())
232
  if step < total_steps - 1:
233
+ target_cum = int(init_m * (1.0 - (s.item() / t.item())))
234
+ else:
235
+ target_cum = init_m # 最后一步允许全确认
236
+
237
+ current_gen = int(gen_mask[j].sum().item())
238
+ over = max(0, current_gen - target_cum)
239
+ if over > 0:
240
+ # 在已确认集合里按历史置信度最低回遮
241
+ gen_idx = torch.where(gen_mask[j])[0]
242
+ if gen_idx.numel() > 0:
243
+ hist_vals = fixed_conf[j, gen_idx] # float32
244
+ over = min(over, int(gen_idx.numel()))
245
+ _, low_local = torch.topk(hist_vals, k=over, largest=False)
246
+ low_global = gen_idx[low_local]
247
+ # 回遮:恢复为 MASK,并撤销确认标记 & 清空历史置信度
248
+ x[j, low_global] = mask_token_id
249
+ gen_mask[j, low_global] = False
250
+ fixed_conf[j, low_global] = float("-inf")
251
 
252
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
253
  if is_torchdynamo_compiling():
 
260
  UserWarning,
261
  )
262
  if input_ids_length >= generation_config.max_length:
263
+ input_ids_string = "input_ids"
264
  raise ValueError(
265
+ f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
266
+ f" {generation_config.max_length}. You should consider increasing `max_length` or, better yet,"
267
+ " setting `max_new_tokens`."
268
  )
269
 
270
  def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
 
272
  if not has_default_max_length and generation_config.max_length is not None:
273
  logger.warning(
274
  f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
275
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
276
+ "Please refer to the documentation for more information."
277
  )
278
  generation_config.max_length = generation_config.max_new_tokens + input_ids_length
279
+
280
  elif has_default_max_length:
281
  if generation_config.max_length == DreamGenerationConfig().max_length:
282
  generation_config.max_length = generation_config.max_length + input_ids_length
283
+ max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
284
+ if max_position_embeddings is not None:
285
+ generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
286
  return generation_config
287
 
288
+ def _prepare_generation_config(self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict) -> DreamGenerationConfig:
 
 
289
  using_model_generation_config = False
290
  if generation_config is None:
291
  generation_config = DreamGenerationConfig.from_model_config(self.config)
 
303
  generation_config.pad_token_id = self.generation_config.pad_token_id
304
  if generation_config.mask_token_id is None:
305
  generation_config.mask_token_id = self.generation_config.mask_token_id
306
+
307
  return generation_config
308
 
309
+ def _prepare_special_tokens(self, generation_config: DreamGenerationConfig, device: Optional[Union[torch.device, str]] = None):
 
 
310
  def _tensor_or_none(token, device=None):
311
  if token is None:
312
  return token
 
361
 
362
  if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
363
  warnings.warn(
364
+ "You are calling .generate() with the `input_ids` being on a device type different"
365
+ f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
366
+ f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
367
+ " Please make sure that you have put `input_ids` to the"
368
+ f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
369
+ " running `.generate()`.",
370
  UserWarning,
371
  )
 
372
  if (
373
  hasattr(generation_config, "pad_token_id")
374
  and torch.any(input_ids == generation_config.pad_token_id)
375
  and attention_mask is None
376
  ):
377
  warnings.warn(
378
+ "Padding was detected but no attention mask is passed here. For correct "
379
+ "generation results, please set `attention_mask` when batch-padding inputs.",
380
  UserWarning,
381
  )
382
 
 
401
  attention_mask: Optional[torch.LongTensor],
402
  generation_config: DreamGenerationConfig,
403
  generation_tokens_hook_func,
404
+ generation_logits_hook_func
405
  ) -> Union[DreamModelOutput, torch.LongTensor]:
406
+ # === 基本变量 ===
407
  output_history = generation_config.output_history
408
  return_dict_in_generate = generation_config.return_dict_in_generate
409
  max_length = generation_config.max_length
 
416
  top_p = generation_config.top_p
417
  top_k = generation_config.top_k
418
 
419
+ # === RCR 控制变量 ===
420
  rcr = generation_config.rcr
421
  conf_alg = generation_config.conf_alg
422
 
423
  histories = [] if (return_dict_in_generate and output_history) else None
424
 
425
+ # pad input_ids to max_length
426
  x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
427
 
428
  if attention_mask is not None and torch.any(attention_mask == 0.0):
 
439
 
440
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
441
 
442
+ # === RCR 缓冲(仅 rcr=True 时启用) ===
443
+ if rcr:
444
+ init_mask_count = (x == mask_token_id).sum(dim=1) # [B]
445
+ fixed_conf = torch.full(
446
+ x.shape, float("-inf"), dtype=torch.float32, device=x.device
447
+ ) # 历史置信度
448
+ gen_mask = torch.zeros_like(x, dtype=torch.bool) # 已确认集合
449
+ else:
450
+ init_mask_count = None
451
+ fixed_conf = None
452
+ gen_mask = None
453
 
454
+ # hooks:允许用户中间控制
455
  x = generation_tokens_hook_func(None, x, None)
456
 
457
  for i in range(steps):
458
  mask_index = (x == mask_token_id)
 
459
  logits = self(x, attention_mask, tok_idx).logits
460
+ # 右移一位(Dream 原实现)
461
  logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
462
 
 
463
  logits = generation_logits_hook_func(i, x, logits)
464
 
465
  mask_logits = logits[mask_index]
466
  t = timesteps[i]
467
  s = timesteps[i + 1]
468
 
469
+ if alg == 'origin':
470
+ # === 原版 origin:随机按比例转移(不涉及置信度) ===
 
 
 
 
471
  p_transfer = 1 - s / t if i < steps - 1 else 1
472
+ x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
473
  transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
474
  _, x0[transfer_index_t_s] = sample_tokens(
475
  mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k
476
  )
477
  x[mask_index] = x0.clone()
 
478
  else:
479
+ # === 置信度算法选择(vanilla 与 RCR 复用此处) ===
480
+ use_alg = conf_alg if rcr else alg
481
+ if use_alg == 'maskgit_plus':
 
 
 
 
482
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
483
+ elif use_alg == 'topk_margin':
484
  confidence, x0 = sample_tokens(
485
  mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True
486
  )
487
+ elif use_alg == 'entropy':
488
  confidence, x0 = sample_tokens(
489
+ mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, neg_entropy=True
490
  )
491
  else:
492
+ raise RuntimeError(f"Unknown alg/conf_alg: {use_alg}")
 
 
 
 
 
 
493
 
494
  if rcr:
495
+ # === 历史置信度版 RCR ===
496
  self._apply_rcr_logic(
497
  x=x,
498
+ x0=x0,
499
+ conf_now=confidence,
500
  mask_index=mask_index,
501
+ fixed_conf=fixed_conf,
502
+ gen_mask=gen_mask,
503
+ init_mask_count=init_mask_count,
504
  mask_token_id=mask_token_id,
505
  step=i,
506
  total_steps=steps,
507
+ s=s, t=t,
 
508
  )
509
  else:
510
+ # === 原版 Dream(vanilla):本步 top-k,永久确认,不回遮 ===
511
+ # number_transfer_tokens 基于“当前平均剩余 mask * (1 - s/t)”
512
+ avg_mask_now = (mask_index.sum().item() / max(1, mask_index.shape[0]))
513
+ ratio = (1.0 - (s.item() / t.item())) if i < steps - 1 else 1.0
514
+ number_transfer_tokens = int(avg_mask_now * ratio)
515
+
516
+ full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
517
+ full_confidence[mask_index] = confidence
518
+
519
  if number_transfer_tokens > 0:
520
  if alg_temp is None or alg_temp == 0:
521
  _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
522
  else:
523
+ full_confidence = full_confidence / alg_temp
524
+ full_confidence = F.softmax(full_confidence, dim=-1)
525
+ transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
526
+ x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
527
  x_[mask_index] = x0.clone()
528
  row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
529
  x[row_indices, transfer_index] = x_[row_indices, transfer_index]
530
 
 
531
  x = generation_tokens_hook_func(i, x, logits)
532
 
533
  if histories is not None:
534
  histories.append(x.clone())
535
 
536
  if return_dict_in_generate:
537
+ return DreamModelOutput(
538
+ sequences=x,
539
+ history=histories,
540
+ )
541
  else:
542
  return x