autoprogrammer commited on
Commit
fd52594
·
verified ·
1 Parent(s): 24e35c2

Update generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +158 -215
generation_utils.py CHANGED
@@ -1,5 +1,3 @@
1
- # coding=utf-8
2
- # Copyright 2024 The Dream team, HKUNLP Group and...
3
  import warnings
4
  import copy
5
  from dataclasses import dataclass
@@ -15,66 +13,26 @@ from transformers.utils import ModelOutput, is_torchdynamo_compiling, logging
15
  logger = logging.get_logger(__name__)
16
 
17
 
18
- def top_p_logits(logits, top_p=None):
19
- if top_p is None or top_p >= 1:
20
- return logits
21
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
22
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
23
- sorted_indices_to_remove = cumulative_probs > top_p
24
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
25
- sorted_indices_to_remove[..., 0] = 0
26
- mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
27
- mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
28
- logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
29
- return logits
30
-
31
-
32
- def top_k_logits(logits, top_k=None):
33
- if top_k is None:
34
- return logits
35
- top_k = int(min(top_k, logits.size(-1)))
36
- if top_k <= 0:
37
- return logits
38
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
39
- logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
40
- return logits
41
-
42
-
43
- def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
44
- # logits: [N, V]
45
  if temperature and temperature > 0:
46
  logits = logits / temperature
47
  if top_p is not None and top_p < 1:
48
- logits = top_p_logits(logits, top_p)
 
 
 
 
 
 
 
 
49
  if top_k is not None:
50
- logits = top_k_logits(logits, top_k)
51
-
52
- probs = torch.softmax(logits, dim=-1)
53
-
54
- # 采样/贪心
55
- if temperature and temperature > 0:
56
- try:
57
- x0 = dists.Categorical(probs=probs).sample()
58
- confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
59
- except Exception:
60
- confidence, x0 = probs.max(dim=-1)
61
- else:
62
- confidence, x0 = probs.max(dim=-1)
63
-
64
- # 置信度定义切换
65
- if margin_confidence:
66
- sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
67
- top1_probs = sorted_probs[:, 0]
68
- top2_probs = sorted_probs[:, 1]
69
- confidence = top1_probs - top2_probs
70
-
71
- if neg_entropy:
72
- # 负熵(≤0;越接近 0 越“确定”)
73
- epsilon = 1e-10
74
- log_probs = torch.log(probs + epsilon)
75
- confidence = torch.sum(probs * log_probs, dim=-1)
76
-
77
- return confidence, x0
78
 
79
 
80
  @dataclass
@@ -97,11 +55,12 @@ class DreamGenerationConfig(GenerationConfig):
97
  # diffusion specific params
98
  self.eps: float = kwargs.pop("eps", 1e-3)
99
  self.steps: int = kwargs.pop("steps", 512)
100
- self.alg: str = kwargs.pop("alg", 'origin') # 'origin' | 'maskgit_plus' | 'topk_margin' | 'entropy'
101
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
102
 
103
  # RCR
104
  self.rcr: bool = kwargs.pop("rcr", False)
 
105
  self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
106
 
107
  # outputs
@@ -143,7 +102,7 @@ class DreamGenerationMixin:
143
  expand_size: int = 1,
144
  input_ids: Optional[torch.LongTensor] = None,
145
  attention_mask: Optional[torch.LongTensor] = None
146
- ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
147
  if expand_size == 1:
148
  return input_ids, attention_mask
149
  if input_ids is not None:
@@ -152,91 +111,61 @@ class DreamGenerationMixin:
152
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
153
  return input_ids, attention_mask
154
 
155
- # =========================
156
- # 历史置信度 RCR(贴近 vanilla)
157
- # =========================
158
- def _apply_rcr_logic(
159
  self,
160
- x: torch.Tensor,
161
- x0: torch.Tensor,
162
- conf_now: torch.Tensor, # [M] mask 位置的置信度(已为 float32)
163
- mask_index: torch.Tensor, # [B, L] bool
164
- fixed_conf: torch.Tensor, # [B, L] float32(历史 max)
165
- ema_conf: torch.Tensor, # [B, L] float32(EMA)
166
- gen_mask: torch.Tensor, # [B, L] bool(已确认集合)
167
- written_step: torch.Tensor, # [B, L] int32(写入的步骤,-1=未写)
168
- init_mask_count: torch.Tensor, # [B] 初始 mask 数
169
  mask_token_id: int,
170
  step: int,
171
  total_steps: int,
172
  s: torch.Tensor,
173
  t: torch.Tensor,
174
- ema_beta: float = 0.95 # EMA 平滑系数(越大越稳定)
175
  ):
176
  """
177
- 策略要点(接近 vanilla):
178
- 1) 当步确认:沿用 vanilla 配额计算,按 conf_now(负熵/概率差等)选 top-k 写入;
179
- 2) 历史维护:fixed_conf 取历史 max;ema_conf 做滑动平均,写入步 recorded;
180
- 3) 超额回遮:若当前已确认数 > 目标累计配额,仅在 gen_mask 内、且不是“本步刚写”的位置,
181
- EMA 最低的 over 个回遮(轻量、稳定)。
 
 
182
  """
183
- device = x.device
184
  B, L = x.shape
185
-
186
- # 1) 配额(与 vanilla 一致)
187
- avg_mask_now = (mask_index.sum().item() / max(1, mask_index.shape[0]))
188
- ratio = (1.0 - (s.item() / t.item())) if step < total_steps - 1 else 1.0
189
- number_transfer_tokens = int(avg_mask_now * ratio)
190
-
191
- # 把当步局部置信度/候选整到全长
192
- full_conf_now = torch.full((B, L), -1e9, dtype=torch.float32, device=device) # 用 -1e9 更稳妥
193
- full_x0 = torch.full((B, L), mask_token_id, dtype=torch.long, device=device)
194
- full_conf_now[mask_index] = conf_now
195
- full_x0[mask_index] = x0
196
-
197
- # 2) 逐样本选择当步 top-k
198
  for j in range(B):
199
- masked_j = int(mask_index[j].sum().item())
200
- k_j = min(number_transfer_tokens, masked_j)
201
- if k_j > 0:
202
- conf_row = full_conf_now[j] # float32
203
- _, sel_idx = torch.topk(conf_row, k=k_j, largest=True)
204
-
205
- # 写入
206
- x[j, sel_idx] = full_x0[j, sel_idx]
207
- gen_mask[j, sel_idx] = True
208
-
209
- # 历史 max & EMA(仅对当步写入位置更新)
210
- fixed_conf[j, sel_idx] = torch.maximum(fixed_conf[j, sel_idx], conf_row[sel_idx])
211
- ema_conf[j, sel_idx] = ema_beta * ema_conf[j, sel_idx] + (1 - ema_beta) * conf_row[sel_idx]
212
- written_step[j, sel_idx] = step
213
-
214
- # 3) 目标累计配额(与 vanilla 同口径)
215
- init_m = int(init_mask_count[j].item())
216
- target_cum = init_m if step >= total_steps - 1 else int(init_m * (1.0 - (s.item() / t.item())))
217
-
218
- current_gen = int(gen_mask[j].sum().item())
219
- over = max(0, current_gen - target_cum)
220
- if over > 0:
221
- # 只能从“非本步写入”的已确认里回遮,避免抖动
222
- gen_idx = torch.where(gen_mask[j])[0]
223
- if gen_idx.numel() > 0:
224
- # 排除刚写入的
225
- not_just_written = written_step[j, gen_idx] < step
226
- candidates = gen_idx[not_just_written]
227
- if candidates.numel() > 0:
228
- over = min(over, int(candidates.numel()))
229
- cand_ema = ema_conf[j, candidates] # float32
230
- _, low_local = torch.topk(cand_ema, k=over, largest=False)
231
- low_global = candidates[low_local]
232
-
233
- # 回遮
234
- x[j, low_global] = mask_token_id
235
- gen_mask[j, low_global] = False
236
- # 适度清理 EMA,max 保留帮助后续稳定
237
- ema_conf[j, low_global] = 0.0
238
- written_step[j, low_global] = -1 # 重置写入步
239
- # fixed_conf 不清零,保留历史峰值作为“锚”信息
240
 
241
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
242
  if is_torchdynamo_compiling():
@@ -249,29 +178,24 @@ class DreamGenerationMixin:
249
  UserWarning,
250
  )
251
  if input_ids_length >= generation_config.max_length:
252
- input_ids_string = "input_ids"
253
  raise ValueError(
254
- f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
255
- f" {generation_config.max_length}. You should consider increasing `max_length` or, better yet,"
256
- " setting `max_new_tokens`."
257
  )
258
 
259
  def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
260
  if generation_config.max_new_tokens is not None:
261
  if not has_default_max_length and generation_config.max_length is not None:
262
  logger.warning(
263
- f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
264
- f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
265
- "Please refer to the documentation for more information."
266
  )
267
  generation_config.max_length = generation_config.max_new_tokens + input_ids_length
268
-
269
  elif has_default_max_length:
270
  if generation_config.max_length == DreamGenerationConfig().max_length:
271
  generation_config.max_length = generation_config.max_length + input_ids_length
272
- max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
273
- if max_position_embeddings is not None:
274
- generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
275
  return generation_config
276
 
277
  def _prepare_generation_config(self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict) -> DreamGenerationConfig:
@@ -295,7 +219,7 @@ class DreamGenerationMixin:
295
 
296
  return generation_config
297
 
298
- def _prepare_special_tokens(self, generation_config: DreamGenerationConfig, device: Optional[Union[torch.device, str]] = None):
299
  def _tensor_or_none(token, device=None):
300
  if token is None:
301
  return token
@@ -311,7 +235,6 @@ class DreamGenerationMixin:
311
 
312
  if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
313
  eos_token_tensor = eos_token_tensor.unsqueeze(0)
314
-
315
  if pad_token_tensor is None and eos_token_tensor is not None:
316
  pad_token_tensor = eos_token_tensor[0]
317
  logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
@@ -327,7 +250,7 @@ class DreamGenerationMixin:
327
  inputs: Optional[torch.Tensor] = None,
328
  generation_config: Optional[DreamGenerationConfig] = None,
329
  **kwargs,
330
- ) -> Union[DreamModelOutput, torch.LongTensor]:
331
  generation_config = self._prepare_generation_config(generation_config, **kwargs)
332
  generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
333
  generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
@@ -350,9 +273,7 @@ class DreamGenerationMixin:
350
 
351
  if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
352
  warnings.warn(
353
- "You are calling .generate() with the `input_ids` being on a device type different"
354
- f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
355
- f" is on {self.device.type}. You may experience unexpected behaviors or slower generation.",
356
  UserWarning,
357
  )
358
  if (
@@ -361,8 +282,7 @@ class DreamGenerationMixin:
361
  and attention_mask is None
362
  ):
363
  warnings.warn(
364
- "Padding was detected but no attention mask is passed here. For correct "
365
- "generation results, please set `attention_mask` when batch-padding inputs.",
366
  UserWarning,
367
  )
368
 
@@ -372,14 +292,13 @@ class DreamGenerationMixin:
372
  attention_mask=attention_mask,
373
  )
374
 
375
- result = self._sample(
376
  input_ids,
377
  attention_mask=attention_mask,
378
  generation_config=generation_config,
379
  generation_tokens_hook_func=generation_tokens_hook_func,
380
  generation_logits_hook_func=generation_logits_hook_func,
381
  )
382
- return result
383
 
384
  def _sample(
385
  self,
@@ -388,7 +307,7 @@ class DreamGenerationMixin:
388
  generation_config: DreamGenerationConfig,
389
  generation_tokens_hook_func,
390
  generation_logits_hook_func
391
- ) -> Union[DreamModelOutput, torch.LongTensor]:
392
  output_history = generation_config.output_history
393
  return_dict_in_generate = generation_config.return_dict_in_generate
394
  max_length = generation_config.max_length
@@ -401,10 +320,7 @@ class DreamGenerationMixin:
401
  top_p = generation_config.top_p
402
  top_k = generation_config.top_k
403
 
404
- # RCR
405
- rcr = generation_config.rcr
406
- conf_alg = generation_config.conf_alg
407
-
408
  histories = [] if (return_dict_in_generate and output_history) else None
409
 
410
  # pad input_ids to max_length
@@ -424,75 +340,60 @@ class DreamGenerationMixin:
424
 
425
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
426
 
427
- # ===== RCR 缓冲初始化(关键:float32,避免 dtype 冲突) =====
428
  if rcr:
429
- init_mask_count = (x == mask_token_id).sum(dim=1) # [B]
430
- fixed_conf = torch.full(x.shape, -1e9, dtype=torch.float32, device=x.device) # 历史 max
431
- ema_conf = torch.zeros_like(fixed_conf, dtype=torch.float32) # EMA
432
- gen_mask = torch.zeros_like(x, dtype=torch.bool) # 已确认集合
433
- written_step = torch.full(x.shape, -1, dtype=torch.int32, device=x.device) # 写入步
434
- else:
435
- init_mask_count = None
436
- fixed_conf = None
437
- ema_conf = None
438
- gen_mask = None
439
- written_step = None
440
 
441
  x = generation_tokens_hook_func(None, x, None)
442
 
443
  for i in range(steps):
444
  mask_index = (x == mask_token_id)
 
 
445
  logits = self(x, attention_mask, tok_idx).logits
446
  logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
447
  logits = generation_logits_hook_func(i, x, logits)
448
 
449
- mask_logits = logits[mask_index]
450
  t = timesteps[i]
451
  s = timesteps[i + 1]
452
 
453
- if alg == 'origin':
454
- p_transfer = 1 - s / t if i < steps - 1 else 1
455
- x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
456
- transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
457
- _, x0[transfer_index_t_s] = sample_tokens(
458
- mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k
459
- )
460
- x[mask_index] = x0.clone()
461
- else:
462
- use_alg = conf_alg if rcr else alg
463
- if use_alg == 'maskgit_plus':
464
- confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
465
- elif use_alg == 'topk_margin':
466
- confidence, x0 = sample_tokens(
467
- mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True
468
- )
469
- elif use_alg == 'entropy':
470
- confidence, x0 = sample_tokens(
471
- mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, neg_entropy=True
472
- )
473
  else:
474
- raise RuntimeError(f"Unknown alg/conf_alg: {use_alg}")
475
-
476
- if rcr:
477
- # —— 贴近 vanilla 的历史置信度 RCR ——
478
- self._apply_rcr_logic(
479
- x=x,
480
- x0=x0,
481
- conf_now=confidence.to(torch.float32),
482
- mask_index=mask_index,
483
- fixed_conf=fixed_conf,
484
- ema_conf=ema_conf,
485
- gen_mask=gen_mask,
486
- written_step=written_step,
487
- init_mask_count=init_mask_count,
488
- mask_token_id=mask_token_id,
489
- step=i,
490
- total_steps=steps,
491
- s=s, t=t,
492
- ema_beta=0.8,
493
- )
494
- else:
495
- # —— vanilla:本步 top-k 永久确认 ——
496
  avg_mask_now = (mask_index.sum().item() / max(1, mask_index.shape[0]))
497
  ratio = (1.0 - (s.item() / t.item())) if i < steps - 1 else 1.0
498
  number_transfer_tokens = int(avg_mask_now * ratio)
@@ -512,6 +413,48 @@ class DreamGenerationMixin:
512
  row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
513
  x[row_indices, transfer_index] = x_[row_indices, transfer_index]
514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  x = generation_tokens_hook_func(i, x, logits)
516
  if histories is not None:
517
  histories.append(x.clone())
 
 
 
1
  import warnings
2
  import copy
3
  from dataclasses import dataclass
 
13
  logger = logging.get_logger(__name__)
14
 
15
 
16
+ def _apply_top_p_k_temp(logits, temperature=0.0, top_p=None, top_k=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  if temperature and temperature > 0:
18
  logits = logits / temperature
19
  if top_p is not None and top_p < 1:
20
+ # top-p
21
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
22
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
23
+ sorted_indices_to_remove = cumulative_probs > top_p
24
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
25
+ sorted_indices_to_remove[..., 0] = 0
26
+ mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
27
+ mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
28
+ logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
29
  if top_k is not None:
30
+ # top-k
31
+ top_k = int(min(top_k, logits.size(-1)))
32
+ if top_k > 0:
33
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
34
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
35
+ return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  @dataclass
 
55
  # diffusion specific params
56
  self.eps: float = kwargs.pop("eps", 1e-3)
57
  self.steps: int = kwargs.pop("steps", 512)
58
+ self.alg: str = kwargs.pop("alg", 'origin') # vanilla 使用
59
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
60
 
61
  # RCR
62
  self.rcr: bool = kwargs.pop("rcr", False)
63
+ # 注意:论文版 RCR 会忽略这里的 conf_alg,并统一用“选中 token 概率”做 running max
64
  self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
65
 
66
  # outputs
 
102
  expand_size: int = 1,
103
  input_ids: Optional[torch.LongTensor] = None,
104
  attention_mask: Optional[torch.LongTensor] = None
105
+ ):
106
  if expand_size == 1:
107
  return input_ids, attention_mask
108
  if input_ids is not None:
 
111
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
112
  return input_ids, attention_mask
113
 
114
+ # =============== 论文版 RCR:运行最大置信度 + 直接选 n_t 回遮 ===============
115
+ def _apply_rcr_logic_paper(
 
 
116
  self,
117
+ x: torch.Tensor, # [B, L]
118
+ rmax_conf: torch.Tensor, # [B, L], float32, running max of selected-token prob
119
+ init_mask_bool: torch.Tensor, # [B, L], 初始生成区域(最开始是 MASK 的位置)
120
+ init_mask_count: torch.Tensor, # [B], 初始 MASK 数 M0
 
 
 
 
 
121
  mask_token_id: int,
122
  step: int,
123
  total_steps: int,
124
  s: torch.Tensor,
125
  t: torch.Tensor,
 
126
  ):
127
  """
128
+ 目标:在“初始生成区域”(init_mask_bool) 内,让“已确认个数”符合 vanilla 的线性进度;
129
+ 但位置选择依据“历史最大置信度 rmax_conf”——每步保留 rmax_conf 高的,回遮 rmax_conf 低的。
130
+
131
+ 做法:
132
+ target_cum = floor(M0 * (1 - s/t)) # 最后一步 = M0
133
+ 在 init_mask_bool[j] 内按 rmax_conf[j] 降序选 target_cum 个 => 保持已确认(不 mask)
134
+ 其余位置设为 mask_token_id
135
  """
 
136
  B, L = x.shape
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  for j in range(B):
138
+ M0 = int(init_mask_count[j].item())
139
+ if step < total_steps - 1:
140
+ target_cum = int(M0 * (1.0 - (s.item() / t.item())))
141
+ else:
142
+ target_cum = M0
143
+
144
+ # 在初始生成区域内排序
145
+ region_idx = torch.where(init_mask_bool[j])[0]
146
+ if region_idx.numel() == 0:
147
+ continue
148
+
149
+ # rmax_conf 越大越稳,保留前 target_cum
150
+ scores = rmax_conf[j, region_idx] # float32
151
+ # 防御:若还没更新过,rmax_conf 初始 0.0,会被优先回遮(符合“历史没自信过”的直觉)
152
+ target_cum = min(target_cum, int(region_idx.numel()))
153
+ if target_cum <= 0:
154
+ # 全部保持 mask
155
+ x[j, region_idx] = mask_token_id
156
+ continue
157
+
158
+ _, keep_local = torch.topk(scores, k=target_cum, largest=True)
159
+ keep_global = region_idx[keep_local]
160
+
161
+ # 其余回遮
162
+ mask_global = torch.ones_like(region_idx, dtype=torch.bool, device=x.device)
163
+ mask_global[keep_local] = False
164
+ remask_idx = region_idx[mask_global]
165
+
166
+ if remask_idx.numel() > 0:
167
+ x[j, remask_idx] = mask_token_id
168
+ # keep_global 上保持当前写入的 token,不动
 
 
 
 
 
 
 
 
 
 
169
 
170
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
171
  if is_torchdynamo_compiling():
 
178
  UserWarning,
179
  )
180
  if input_ids_length >= generation_config.max_length:
 
181
  raise ValueError(
182
+ f"Input length is {input_ids_length}, but `max_length` is {generation_config.max_length}. "
183
+ "Increase `max_length` or set `max_new_tokens`."
 
184
  )
185
 
186
  def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
187
  if generation_config.max_new_tokens is not None:
188
  if not has_default_max_length and generation_config.max_length is not None:
189
  logger.warning(
190
+ f"Both `max_new_tokens` and `max_length` are set. `max_new_tokens` takes precedence."
 
 
191
  )
192
  generation_config.max_length = generation_config.max_new_tokens + input_ids_length
 
193
  elif has_default_max_length:
194
  if generation_config.max_length == DreamGenerationConfig().max_length:
195
  generation_config.max_length = generation_config.max_length + input_ids_length
196
+ mpe = getattr(self.config, "max_position_embeddings", None)
197
+ if mpe is not None:
198
+ generation_config.max_length = min(generation_config.max_length, mpe)
199
  return generation_config
200
 
201
  def _prepare_generation_config(self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict) -> DreamGenerationConfig:
 
219
 
220
  return generation_config
221
 
222
+ def _prepare_special_tokens(self, generation_config: DreamGenerationConfig, device=None):
223
  def _tensor_or_none(token, device=None):
224
  if token is None:
225
  return token
 
235
 
236
  if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
237
  eos_token_tensor = eos_token_tensor.unsqueeze(0)
 
238
  if pad_token_tensor is None and eos_token_tensor is not None:
239
  pad_token_tensor = eos_token_tensor[0]
240
  logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
 
250
  inputs: Optional[torch.Tensor] = None,
251
  generation_config: Optional[DreamGenerationConfig] = None,
252
  **kwargs,
253
+ ):
254
  generation_config = self._prepare_generation_config(generation_config, **kwargs)
255
  generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
256
  generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
 
273
 
274
  if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
275
  warnings.warn(
276
+ "You are calling .generate() with `input_ids` on a different device than the model.",
 
 
277
  UserWarning,
278
  )
279
  if (
 
282
  and attention_mask is None
283
  ):
284
  warnings.warn(
285
+ "Padding detected but no attention mask was passed. Set `attention_mask` for correct generation.",
 
286
  UserWarning,
287
  )
288
 
 
292
  attention_mask=attention_mask,
293
  )
294
 
295
+ return self._sample(
296
  input_ids,
297
  attention_mask=attention_mask,
298
  generation_config=generation_config,
299
  generation_tokens_hook_func=generation_tokens_hook_func,
300
  generation_logits_hook_func=generation_logits_hook_func,
301
  )
 
302
 
303
  def _sample(
304
  self,
 
307
  generation_config: DreamGenerationConfig,
308
  generation_tokens_hook_func,
309
  generation_logits_hook_func
310
+ ):
311
  output_history = generation_config.output_history
312
  return_dict_in_generate = generation_config.return_dict_in_generate
313
  max_length = generation_config.max_length
 
320
  top_p = generation_config.top_p
321
  top_k = generation_config.top_k
322
 
323
+ rcr = generation_config.rcr # 打开则走论文版 RCR(历史最大 top-1 概率)
 
 
 
324
  histories = [] if (return_dict_in_generate and output_history) else None
325
 
326
  # pad input_ids to max_length
 
340
 
341
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
342
 
 
343
  if rcr:
344
+ # 初始生成区域(prompt 之外扩展出来的那一段)
345
+ init_mask_bool = (x == mask_token_id) # [B, L]
346
+ init_mask_count = init_mask_bool.sum(dim=1) # [B]
347
+ # 历史最大“被选 token 概率”(float32)
348
+ rmax_conf = torch.zeros_like(x, dtype=torch.float32, device=x.device)
349
+ logger.warning(
350
+ "[RCR] Using PAPER version: running-max of SELECTED-TOKEN PROB; "
351
+ "this overrides `conf_alg` (e.g., entropy) for remasking decisions."
352
+ )
 
 
353
 
354
  x = generation_tokens_hook_func(None, x, None)
355
 
356
  for i in range(steps):
357
  mask_index = (x == mask_token_id)
358
+
359
+ # 前向
360
  logits = self(x, attention_mask, tok_idx).logits
361
  logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
362
  logits = generation_logits_hook_func(i, x, logits)
363
 
 
364
  t = timesteps[i]
365
  s = timesteps[i + 1]
366
 
367
+ if not rcr:
368
+ # ===== vanilla 路径(保持你原来的实现)=====
369
+ mask_logits = logits[mask_index]
370
+ if alg == 'origin':
371
+ p_transfer = 1 - s / t if i < steps - 1 else 1
372
+ x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
373
+ transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
374
+ if transfer_index_t_s.any():
375
+ logits_sub = mask_logits[transfer_index_t_s]
376
+ logits_sub = _apply_top_p_k_temp(logits_sub, temperature, top_p, top_k)
377
+ probs_sub = torch.softmax(logits_sub, dim=-1)
378
+ try:
379
+ x0_sel = dists.Categorical(probs=probs_sub).sample()
380
+ except Exception:
381
+ x0_sel = probs_sub.argmax(dim=-1)
382
+ x0[transfer_index_t_s] = x0_sel
383
+ x[mask_index] = x0.clone()
 
 
 
384
  else:
385
+ # 按你 vanilla 的 top-k / alg_temp 逻辑
386
+ mask_logits = _apply_top_p_k_temp(logits[mask_index], temperature, top_p, top_k)
387
+ probs = torch.softmax(mask_logits, dim=-1)
388
+ if temperature and temperature > 0:
389
+ try:
390
+ x0 = dists.Categorical(probs=probs).sample()
391
+ confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
392
+ except Exception:
393
+ confidence, x0 = probs.max(dim=-1)
394
+ else:
395
+ confidence, x0 = probs.max(dim=-1)
396
+
 
 
 
 
 
 
 
 
 
 
397
  avg_mask_now = (mask_index.sum().item() / max(1, mask_index.shape[0]))
398
  ratio = (1.0 - (s.item() / t.item())) if i < steps - 1 else 1.0
399
  number_transfer_tokens = int(avg_mask_now * ratio)
 
413
  row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
414
  x[row_indices, transfer_index] = x_[row_indices, transfer_index]
415
 
416
+ else:
417
+ # ===== 论文版 RCR =====
418
+ # 1) 仅对当前 mask 的位置,做 top_p/top_k/temperature 过滤后采样(或贪心)
419
+ mask_logits = logits[mask_index]
420
+ mask_logits = _apply_top_p_k_temp(mask_logits, temperature, top_p, top_k)
421
+ probs = torch.softmax(mask_logits, dim=-1)
422
+
423
+ # 采样 / 贪心
424
+ if temperature and temperature > 0:
425
+ try:
426
+ x0 = dists.Categorical(probs=probs).sample()
427
+ except Exception:
428
+ x0 = probs.argmax(dim=-1)
429
+ else:
430
+ x0 = probs.argmax(dim=-1)
431
+
432
+ # 被选 token 的概率 p_sel(论文要求用这个做“历史置信度”)
433
+ p_sel = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) # [M], float32
434
+
435
+ # 写入选中的 token
436
+ x_maskwrite = torch.full_like(x, mask_token_id, dtype=torch.long)
437
+ x_maskwrite[mask_index] = x0
438
+ x = torch.where(mask_index, x_maskwrite, x)
439
+
440
+ # 更新 running-max 置信度(float32)
441
+ # 先铺到全长
442
+ full_p_sel = torch.zeros_like(x, dtype=torch.float32)
443
+ full_p_sel[mask_index] = p_sel.to(torch.float32)
444
+ rmax_conf = torch.maximum(rmax_conf, full_p_sel)
445
+
446
+ # 2) 基于 rmax_conf 直接确定“下一步要保留的已确认个数”,其余全部回遮
447
+ self._apply_rcr_logic_paper(
448
+ x=x,
449
+ rmax_conf=rmax_conf,
450
+ init_mask_bool=init_mask_bool,
451
+ init_mask_count=init_mask_count,
452
+ mask_token_id=mask_token_id,
453
+ step=i,
454
+ total_steps=steps,
455
+ s=s, t=t,
456
+ )
457
+
458
  x = generation_tokens_hook_func(i, x, logits)
459
  if histories is not None:
460
  histories.append(x.clone())