autoprogrammer commited on
Commit
d0ecd79
·
verified ·
1 Parent(s): e350bc1

Update generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +147 -116
generation_utils.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
@@ -22,14 +22,8 @@ import torch
22
  import torch.distributions as dists
23
  from torch.nn import functional as F
24
  from transformers import __version__
25
- from transformers.generation.configuration_utils import (
26
- GenerationConfig
27
- )
28
- from transformers.utils import (
29
- ModelOutput,
30
- is_torchdynamo_compiling,
31
- logging,
32
- )
33
 
34
  logger = logging.get_logger(__name__)
35
 
@@ -47,6 +41,7 @@ def top_p_logits(logits, top_p=None):
47
  logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
48
  return logits
49
 
 
50
  def top_k_logits(logits, top_k=None):
51
  top_k = min(top_k, logits.size(-1)) # Safety check
52
  # Remove all tokens with a probability less than the last token of the top-k
@@ -56,7 +51,6 @@ def top_k_logits(logits, top_k=None):
56
 
57
 
58
  def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
59
-
60
  if temperature > 0:
61
  logits = logits / temperature
62
  if top_p is not None and top_p < 1:
@@ -69,24 +63,22 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confid
69
  try:
70
  x0 = dists.Categorical(probs=probs).sample()
71
  confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
72
- except:
73
  confidence, x0 = probs.max(dim=-1)
74
  else:
75
  confidence, x0 = probs.max(dim=-1)
76
-
77
  if margin_confidence:
78
  sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
79
- # Extract top1 and top2 probabilities
80
- top1_probs = sorted_probs[:, 0]
81
- top2_probs = sorted_probs[:, 1]
82
- # Calculate confidence as top1 - top2
83
- confidence = top1_probs - top2_probs
84
-
85
  if neg_entropy:
86
  epsilon = 1e-10
87
  log_probs = torch.log(probs + epsilon)
88
  confidence = torch.sum(probs * log_probs, dim=-1)
89
-
90
  return confidence, x0
91
 
92
 
@@ -109,6 +101,11 @@ class DreamGenerationConfig(GenerationConfig):
109
  self.alg: str = kwargs.pop("alg", 'origin')
110
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
111
 
 
 
 
 
 
112
  # Parameters that define the output variables of `generate`
113
  self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
114
  self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
@@ -123,16 +120,12 @@ class DreamGenerationConfig(GenerationConfig):
123
  # Wild card
124
  self.generation_kwargs = kwargs.pop("generation_kwargs", {})
125
 
126
- # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
127
- # interface.
128
  self._from_model_config = kwargs.pop("_from_model_config", False)
129
  self._commit_hash = kwargs.pop("_commit_hash", None)
130
  self.transformers_version = kwargs.pop("transformers_version", __version__)
131
 
132
- # Additional attributes without default values
133
  if not self._from_model_config:
134
- # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
135
- # model's default configuration file
136
  for key, value in kwargs.items():
137
  try:
138
  setattr(self, key, value)
@@ -140,12 +133,12 @@ class DreamGenerationConfig(GenerationConfig):
140
  logger.error(f"Can't set {key} with value {value} for {self}")
141
  raise err
142
 
143
- # Validate the values of the attributes
144
  self.validate(is_init=True)
145
 
146
  def validate(self, is_init=False):
147
  pass
148
 
 
149
  class DreamGenerationMixin:
150
  @staticmethod
151
  def _expand_inputs_for_generation(
@@ -153,9 +146,6 @@ class DreamGenerationMixin:
153
  input_ids: Optional[torch.LongTensor] = None,
154
  attention_mask: Optional[torch.LongTensor] = None
155
  ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
156
- """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
157
- # Do not call torch.repeat_interleave if expand_size is 1 because it clones
158
- # the input tensor and thus requires more memory although no change is applied
159
  if expand_size == 1:
160
  return input_ids, attention_mask
161
  if input_ids is not None:
@@ -164,16 +154,63 @@ class DreamGenerationMixin:
164
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
165
  return input_ids, attention_mask
166
 
167
- def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
168
- """Performs validation related to the resulting generated length"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- # Can't throw warnings/exceptions during compilation
171
  if is_torchdynamo_compiling():
172
  return
173
-
174
- # 1. Max length warnings related to poor parameterization
175
  if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
176
- # 20 is the default max_length of the generation config
177
  warnings.warn(
178
  f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
179
  "generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
@@ -188,14 +225,7 @@ class DreamGenerationMixin:
188
  " increasing `max_length` or, better yet, setting `max_new_tokens`."
189
  )
190
 
191
- def _prepare_generated_length(
192
- self,
193
- generation_config,
194
- has_default_max_length,
195
- input_ids_length,
196
- ):
197
- """Prepared max and min length in generation configs to avoid clashes between similar attributes"""
198
-
199
  if generation_config.max_new_tokens is not None:
200
  if not has_default_max_length and generation_config.max_length is not None:
201
  logger.warning(
@@ -212,29 +242,17 @@ class DreamGenerationMixin:
212
  max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
213
  if max_position_embeddings is not None:
214
  generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
215
-
216
  return generation_config
217
 
218
- def _prepare_generation_config(
219
- self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict
220
- ) -> DreamGenerationConfig:
221
- """
222
- Prepares the base generation config, then applies any generation configuration options from kwargs. This
223
- function handles retrocompatibility with respect to configuration files.
224
- """
225
- # priority: `generation_config` argument > `model.generation_config` (the default generation config)
226
  using_model_generation_config = False
227
  if generation_config is None:
228
  generation_config = DreamGenerationConfig.from_model_config(self.config)
229
  using_model_generation_config = True
230
 
231
- # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
232
- # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an
233
- # exception will be raised in `_validate_model_kwargs`
234
  if not is_torchdynamo_compiling():
235
  generation_config = copy.deepcopy(generation_config)
236
  _kwargs = generation_config.update(**kwargs)
237
- # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
238
  if not using_model_generation_config:
239
  if generation_config.bos_token_id is None:
240
  generation_config.bos_token_id = self.generation_config.bos_token_id
@@ -247,25 +265,10 @@ class DreamGenerationMixin:
247
 
248
  return generation_config
249
 
250
- def _prepare_special_tokens(
251
- self,
252
- generation_config: DreamGenerationConfig,
253
- device: Optional[Union[torch.device, str]] = None,
254
- ):
255
- """
256
- Prepares the special tokens for generation, overwriting the generation config with their processed versions
257
- converted to tensor.
258
-
259
- Note that `generation_config` is changed in place and stops being serializable after this method is called.
260
- That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the
261
- function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
262
- """
263
-
264
- # Convert special tokens to tensors
265
  def _tensor_or_none(token, device=None):
266
  if token is None:
267
  return token
268
-
269
  device = device if device is not None else self.device
270
  if isinstance(token, torch.Tensor):
271
  return token.to(device)
@@ -276,19 +279,13 @@ class DreamGenerationMixin:
276
  pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
277
  mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)
278
 
279
- # We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
280
  if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
281
  eos_token_tensor = eos_token_tensor.unsqueeze(0)
282
 
283
- # Set pad token if unset (and there are conditions to do so)
284
  if pad_token_tensor is None and eos_token_tensor is not None:
285
  pad_token_tensor = eos_token_tensor[0]
286
  logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
287
 
288
- # Update generation config with the updated special tokens tensors
289
- # NOTE: this must be written into a different attribute name than the one holding the original special tokens
290
- # (in their non-tensor form), in order to enable end-to-end compilation. See
291
- # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
292
  generation_config._bos_token_tensor = bos_token_tensor
293
  generation_config._eos_token_tensor = eos_token_tensor
294
  generation_config._pad_token_tensor = pad_token_tensor
@@ -301,19 +298,16 @@ class DreamGenerationMixin:
301
  generation_config: Optional[DreamGenerationConfig] = None,
302
  **kwargs,
303
  ) -> Union[DreamModelOutput, torch.LongTensor]:
304
- # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
305
  generation_config = self._prepare_generation_config(generation_config, **kwargs)
306
  generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
307
  generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
308
 
309
- # 2. Define model inputs
310
  assert inputs is not None
311
  input_ids = inputs
312
  device = input_ids.device
313
  attention_mask = kwargs.pop("attention_mask", None)
314
  self._prepare_special_tokens(generation_config, device=device)
315
 
316
- # 3. Prepare `max_length`.
317
  input_ids_length = input_ids.shape[-1]
318
  has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
319
  generation_config = self._prepare_generated_length(
@@ -323,8 +317,7 @@ class DreamGenerationMixin:
323
  )
324
 
325
  self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
326
-
327
- # 4. Check input_ids
328
  if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
329
  warnings.warn(
330
  "You are calling .generate() with the `input_ids` being on a device type different"
@@ -336,9 +329,9 @@ class DreamGenerationMixin:
336
  UserWarning,
337
  )
338
  if (
339
- hasattr(generation_config, "pad_token_id") and
340
- torch.any(input_ids == generation_config.pad_token_id) and
341
- attention_mask is None
342
  ):
343
  warnings.warn(
344
  "Padding was detected but no attention mask is passed here. For correct "
@@ -349,7 +342,7 @@ class DreamGenerationMixin:
349
  input_ids, attention_mask = self._expand_inputs_for_generation(
350
  expand_size=generation_config.num_return_sequences,
351
  input_ids=input_ids,
352
- attention_mask=attention_mask
353
  )
354
 
355
  result = self._sample(
@@ -357,7 +350,7 @@ class DreamGenerationMixin:
357
  attention_mask=attention_mask,
358
  generation_config=generation_config,
359
  generation_tokens_hook_func=generation_tokens_hook_func,
360
- generation_logits_hook_func=generation_logits_hook_func
361
  )
362
  return result
363
 
@@ -369,7 +362,7 @@ class DreamGenerationMixin:
369
  generation_tokens_hook_func,
370
  generation_logits_hook_func
371
  ) -> Union[DreamModelOutput, torch.LongTensor]:
372
- # init values
373
  output_history = generation_config.output_history
374
  return_dict_in_generate = generation_config.return_dict_in_generate
375
  max_length = generation_config.max_length
@@ -382,18 +375,19 @@ class DreamGenerationMixin:
382
  top_p = generation_config.top_p
383
  top_k = generation_config.top_k
384
 
 
 
 
 
385
  histories = [] if (return_dict_in_generate and output_history) else None
386
 
387
  # pad input_ids to max_length
388
  x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
389
 
390
  if attention_mask is not None and torch.any(attention_mask == 0.0):
391
- # we do not mask the [MASK] tokens so value = 1.0
392
  attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
393
  tok_idx = attention_mask.long().cumsum(-1) - 1
394
  tok_idx.masked_fill_(attention_mask == 0, 1)
395
- # attention_mask is of shape [B, N]
396
- # broadcast to [B, 1, N, N]
397
  attention_mask = torch.logical_and(
398
  attention_mask.unsqueeze(1).unsqueeze(-2),
399
  attention_mask.unsqueeze(1).unsqueeze(-1),
@@ -404,12 +398,15 @@ class DreamGenerationMixin:
404
 
405
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
406
 
 
 
 
407
  # this allows user-defined token control of the intermediate steps
408
  x = generation_tokens_hook_func(None, x, None)
409
  for i in range(steps):
410
  mask_index = (x == mask_token_id)
411
  logits = self(x, attention_mask, tok_idx).logits
412
- logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
413
 
414
  # this allows user-defined logits control of the intermediate steps
415
  logits = generation_logits_hook_func(i, x, logits)
@@ -417,48 +414,82 @@ class DreamGenerationMixin:
417
  mask_logits = logits[mask_index]
418
  t = timesteps[i]
419
  s = timesteps[i + 1]
420
-
421
  if alg == 'origin':
 
422
  p_transfer = 1 - s / t if i < steps - 1 else 1
423
  x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
424
  transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
425
- _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
 
 
426
  x[mask_index] = x0.clone()
427
  else:
428
- if alg == 'maskgit_plus':
 
 
 
429
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
430
- elif alg == 'topk_margin':
431
- confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
432
- elif alg == 'entropy':
433
- confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
 
 
 
 
434
  else:
435
- raise RuntimeError(f"Unknown alg: {alg}")
436
- num_mask_token = mask_index.sum() / mask_index.shape[0]
437
- number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
438
- full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
439
- full_confidence[mask_index] = confidence
440
- if number_transfer_tokens > 0:
441
- if alg_temp is None or alg_temp == 0:
442
- _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
 
 
 
 
 
 
443
  else:
444
- full_confidence = full_confidence / alg_temp
445
- full_confidence = F.softmax(full_confidence, dim=-1)
446
- transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
447
- x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
448
- x_[mask_index] = x0.clone()
449
- row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
450
- x[row_indices,transfer_index] = x_[row_indices,transfer_index]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
 
452
  # this allows user-defined token control of the intermediate steps
453
  x = generation_tokens_hook_func(i, x, logits)
454
 
455
  if histories is not None:
456
  histories.append(x.clone())
457
-
458
  if return_dict_in_generate:
459
  return DreamModelOutput(
460
  sequences=x,
461
  history=histories,
462
  )
463
  else:
464
- return x
 
1
  # coding=utf-8
2
+ # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
22
  import torch.distributions as dists
23
  from torch.nn import functional as F
24
  from transformers import __version__
25
+ from transformers.generation.configuration_utils import GenerationConfig
26
+ from transformers.utils import ModelOutput, is_torchdynamo_compiling, logging
 
 
 
 
 
 
27
 
28
  logger = logging.get_logger(__name__)
29
 
 
41
  logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
42
  return logits
43
 
44
+
45
  def top_k_logits(logits, top_k=None):
46
  top_k = min(top_k, logits.size(-1)) # Safety check
47
  # Remove all tokens with a probability less than the last token of the top-k
 
51
 
52
 
53
  def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
 
54
  if temperature > 0:
55
  logits = logits / temperature
56
  if top_p is not None and top_p < 1:
 
63
  try:
64
  x0 = dists.Categorical(probs=probs).sample()
65
  confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
66
+ except Exception:
67
  confidence, x0 = probs.max(dim=-1)
68
  else:
69
  confidence, x0 = probs.max(dim=-1)
70
+
71
  if margin_confidence:
72
  sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
73
+ top1_probs = sorted_probs[:, 0]
74
+ top2_probs = sorted_probs[:, 1]
75
+ confidence = top1_probs - top2_probs
76
+
 
 
77
  if neg_entropy:
78
  epsilon = 1e-10
79
  log_probs = torch.log(probs + epsilon)
80
  confidence = torch.sum(probs * log_probs, dim=-1)
81
+
82
  return confidence, x0
83
 
84
 
 
101
  self.alg: str = kwargs.pop("alg", 'origin')
102
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
103
 
104
+ # === RCR 相关参数(新增;默认不影响原逻辑) ===
105
+ self.rcr: bool = kwargs.pop("rcr", False)
106
+ # 仅在 rcr=True 时用于选择置信度算法;rcr=False 不读取它
107
+ self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
108
+
109
  # Parameters that define the output variables of `generate`
110
  self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
111
  self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
 
120
  # Wild card
121
  self.generation_kwargs = kwargs.pop("generation_kwargs", {})
122
 
123
+ # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub interface.
 
124
  self._from_model_config = kwargs.pop("_from_model_config", False)
125
  self._commit_hash = kwargs.pop("_commit_hash", None)
126
  self.transformers_version = kwargs.pop("transformers_version", __version__)
127
 
 
128
  if not self._from_model_config:
 
 
129
  for key, value in kwargs.items():
130
  try:
131
  setattr(self, key, value)
 
133
  logger.error(f"Can't set {key} with value {value} for {self}")
134
  raise err
135
 
 
136
  self.validate(is_init=True)
137
 
138
  def validate(self, is_init=False):
139
  pass
140
 
141
+
142
  class DreamGenerationMixin:
143
  @staticmethod
144
  def _expand_inputs_for_generation(
 
146
  input_ids: Optional[torch.LongTensor] = None,
147
  attention_mask: Optional[torch.LongTensor] = None
148
  ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
 
 
 
149
  if expand_size == 1:
150
  return input_ids, attention_mask
151
  if input_ids is not None:
 
154
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
155
  return input_ids, attention_mask
156
 
157
+ # === 新增:RCR 逻辑,仅在 rcr=True 时被调用;不改动非 RCR 分支 ===
158
+ def _apply_rcr_logic(self, x, x0, confidence, mask_index, overtime_confidence,
159
+ mask_token_id, step, total_steps, s, t):
160
+ """
161
+ 在 Dream 的“maskgit”采样骨架上,执行 Running Confidence Remasking:
162
+ - 本步采用 Dream 原调度:global_k = num_mask_token * (1 - s/t)
163
+ - 先以当前置信度将 top-k token 从 [MASK] 转为预测 token,并累计它们的置信度
164
+ - 再施加“目标累计”约束:截至本步应累计生成 target_cum = num_mask_token * (1 - s/t)
165
+ 若当前累计 > 目标,则把最低置信度的那些 token 反遮盖回 [MASK]
166
+ 说明:只影响 rcr=True 的路径;rcr=False 时完全不调用本函数。
167
+ """
168
+ device = x.device
169
+ B = x.shape[0]
170
+
171
+ # 与 Dream 一致的 num_mask_token(按 batch 平均)
172
+ num_mask_token = mask_index.sum() / mask_index.shape[0]
173
+ # 本步的转移数量(按 Dream 调度)
174
+ number_transfer_tokens = int(num_mask_token * (1 - s / t)) if step < total_steps - 1 else int(num_mask_token)
175
+
176
+ # 构造全长置信度和候选值(非 mask 位置分别设为 -inf / mask_token_id)
177
+ full_conf = torch.full_like(x, -torch.inf, device=device, dtype=confidence.dtype)
178
+ x_temp = torch.zeros_like(x, device=device, dtype=torch.long) + mask_token_id
179
+ full_conf[mask_index] = confidence
180
+ x_temp[mask_index] = x0.clone()
181
+
182
+ for j in range(B):
183
+ # 逐样本 clamp,避免 batch 均值带来越界
184
+ masked_j = int(mask_index[j].sum().item())
185
+ k_j = min(number_transfer_tokens, masked_j)
186
+
187
+ # 先按置信度选出本步 top-k_j
188
+ if k_j > 0:
189
+ _, select_idx = torch.topk(full_conf[j], k=k_j, largest=True)
190
+ x[j, select_idx] = x_temp[j, select_idx]
191
+ overtime_confidence[j, select_idx] = full_conf[j, select_idx].clone().float()
192
+
193
+ # 目标累计约束:截至本步应累计的生成数
194
+ if step < total_steps - 1:
195
+ target_cum = int(num_mask_token * (1 - s / t))
196
+ gen_mask = overtime_confidence[j] > 0
197
+ current_gen = int(gen_mask.sum().item())
198
+ # 若超额,则按最低置信度回遮
199
+ to_remask = max(0, current_gen - target_cum)
200
+ if to_remask > 0:
201
+ gen_indices = torch.where(gen_mask)[0]
202
+ if gen_indices.numel() > 0:
203
+ gen_conf = overtime_confidence[j, gen_indices]
204
+ to_remask = min(to_remask, int(gen_indices.numel()))
205
+ _, local_low = torch.topk(gen_conf, k=to_remask, largest=False)
206
+ low_global = gen_indices[local_low]
207
+ x[j, low_global] = mask_token_id
208
+ overtime_confidence[j, low_global] = 0.0
209
 
210
+ def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
211
  if is_torchdynamo_compiling():
212
  return
 
 
213
  if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
 
214
  warnings.warn(
215
  f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
216
  "generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
 
225
  " increasing `max_length` or, better yet, setting `max_new_tokens`."
226
  )
227
 
228
+ def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
 
 
 
 
 
 
 
229
  if generation_config.max_new_tokens is not None:
230
  if not has_default_max_length and generation_config.max_length is not None:
231
  logger.warning(
 
242
  max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
243
  if max_position_embeddings is not None:
244
  generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
 
245
  return generation_config
246
 
247
+ def _prepare_generation_config(self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict) -> DreamGenerationConfig:
 
 
 
 
 
 
 
248
  using_model_generation_config = False
249
  if generation_config is None:
250
  generation_config = DreamGenerationConfig.from_model_config(self.config)
251
  using_model_generation_config = True
252
 
 
 
 
253
  if not is_torchdynamo_compiling():
254
  generation_config = copy.deepcopy(generation_config)
255
  _kwargs = generation_config.update(**kwargs)
 
256
  if not using_model_generation_config:
257
  if generation_config.bos_token_id is None:
258
  generation_config.bos_token_id = self.generation_config.bos_token_id
 
265
 
266
  return generation_config
267
 
268
+ def _prepare_special_tokens(self, generation_config: DreamGenerationConfig, device: Optional[Union[torch.device, str]] = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  def _tensor_or_none(token, device=None):
270
  if token is None:
271
  return token
 
272
  device = device if device is not None else self.device
273
  if isinstance(token, torch.Tensor):
274
  return token.to(device)
 
279
  pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
280
  mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)
281
 
 
282
  if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
283
  eos_token_tensor = eos_token_tensor.unsqueeze(0)
284
 
 
285
  if pad_token_tensor is None and eos_token_tensor is not None:
286
  pad_token_tensor = eos_token_tensor[0]
287
  logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
288
 
 
 
 
 
289
  generation_config._bos_token_tensor = bos_token_tensor
290
  generation_config._eos_token_tensor = eos_token_tensor
291
  generation_config._pad_token_tensor = pad_token_tensor
 
298
  generation_config: Optional[DreamGenerationConfig] = None,
299
  **kwargs,
300
  ) -> Union[DreamModelOutput, torch.LongTensor]:
 
301
  generation_config = self._prepare_generation_config(generation_config, **kwargs)
302
  generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
303
  generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
304
 
 
305
  assert inputs is not None
306
  input_ids = inputs
307
  device = input_ids.device
308
  attention_mask = kwargs.pop("attention_mask", None)
309
  self._prepare_special_tokens(generation_config, device=device)
310
 
 
311
  input_ids_length = input_ids.shape[-1]
312
  has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
313
  generation_config = self._prepare_generated_length(
 
317
  )
318
 
319
  self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
320
+
 
321
  if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
322
  warnings.warn(
323
  "You are calling .generate() with the `input_ids` being on a device type different"
 
329
  UserWarning,
330
  )
331
  if (
332
+ hasattr(generation_config, "pad_token_id")
333
+ and torch.any(input_ids == generation_config.pad_token_id)
334
+ and attention_mask is None
335
  ):
336
  warnings.warn(
337
  "Padding was detected but no attention mask is passed here. For correct "
 
342
  input_ids, attention_mask = self._expand_inputs_for_generation(
343
  expand_size=generation_config.num_return_sequences,
344
  input_ids=input_ids,
345
+ attention_mask=attention_mask,
346
  )
347
 
348
  result = self._sample(
 
350
  attention_mask=attention_mask,
351
  generation_config=generation_config,
352
  generation_tokens_hook_func=generation_tokens_hook_func,
353
+ generation_logits_hook_func=generation_logits_hook_func,
354
  )
355
  return result
356
 
 
362
  generation_tokens_hook_func,
363
  generation_logits_hook_func
364
  ) -> Union[DreamModelOutput, torch.LongTensor]:
365
+ # === 原变量 ===
366
  output_history = generation_config.output_history
367
  return_dict_in_generate = generation_config.return_dict_in_generate
368
  max_length = generation_config.max_length
 
375
  top_p = generation_config.top_p
376
  top_k = generation_config.top_k
377
 
378
+ # === 新增:RCR 控制变量(不会影响 rcr=False 的路径) ===
379
+ rcr = generation_config.rcr
380
+ conf_alg = generation_config.conf_alg
381
+
382
  histories = [] if (return_dict_in_generate and output_history) else None
383
 
384
  # pad input_ids to max_length
385
  x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
386
 
387
  if attention_mask is not None and torch.any(attention_mask == 0.0):
 
388
  attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
389
  tok_idx = attention_mask.long().cumsum(-1) - 1
390
  tok_idx.masked_fill_(attention_mask == 0, 1)
 
 
391
  attention_mask = torch.logical_and(
392
  attention_mask.unsqueeze(1).unsqueeze(-2),
393
  attention_mask.unsqueeze(1).unsqueeze(-1),
 
398
 
399
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
400
 
401
+ # === 仅在 rcr=True 时分配 Overtime Confidence(不影响 baseline) ===
402
+ overtime_confidence = torch.zeros_like(x, dtype=torch.float32) if rcr else None
403
+
404
  # this allows user-defined token control of the intermediate steps
405
  x = generation_tokens_hook_func(None, x, None)
406
  for i in range(steps):
407
  mask_index = (x == mask_token_id)
408
  logits = self(x, attention_mask, tok_idx).logits
409
+ logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
410
 
411
  # this allows user-defined logits control of the intermediate steps
412
  logits = generation_logits_hook_func(i, x, logits)
 
414
  mask_logits = logits[mask_index]
415
  t = timesteps[i]
416
  s = timesteps[i + 1]
417
+
418
  if alg == 'origin':
419
+ # === 原版 origin 分支:保持不变 ===
420
  p_transfer = 1 - s / t if i < steps - 1 else 1
421
  x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
422
  transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
423
+ _, x0[transfer_index_t_s] = sample_tokens(
424
+ mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k
425
+ )
426
  x[mask_index] = x0.clone()
427
  else:
428
+ # === origin 分支 ===
429
+ # rcr=False:保持原有使用 alg 的置信度算法
430
+ # rcr=True :使用 conf_alg 指定的置信度算法(不改变 rcr=False 的行为)
431
+ if (not rcr and alg == 'maskgit_plus') or (rcr and conf_alg == 'maskgit_plus'):
432
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
433
+ elif (not rcr and alg == 'topk_margin') or (rcr and conf_alg == 'topk_margin'):
434
+ confidence, x0 = sample_tokens(
435
+ mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True
436
+ )
437
+ elif (not rcr and alg == 'entropy') or (rcr and conf_alg == 'entropy'):
438
+ confidence, x0 = sample_tokens(
439
+ mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True
440
+ )
441
  else:
442
+ # 兼容:如果 rcr=True 但 conf_alg 非上述三者,回退到 alg 指定
443
+ if rcr:
444
+ if alg == 'maskgit_plus':
445
+ confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
446
+ elif alg == 'topk_margin':
447
+ confidence, x0 = sample_tokens(
448
+ mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True
449
+ )
450
+ elif alg == 'entropy':
451
+ confidence, x0 = sample_tokens(
452
+ mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True
453
+ )
454
+ else:
455
+ raise RuntimeError(f"Unknown alg: {alg}")
456
  else:
457
+ raise RuntimeError(f"Unknown alg: {alg}")
458
+
459
+ if rcr:
460
+ # === 仅在 rcr=True 时:应用 RCR;不会触碰 baseline 分支实现 ===
461
+ self._apply_rcr_logic(
462
+ x, x0, confidence, mask_index, overtime_confidence,
463
+ mask_token_id, i, steps, s, t
464
+ )
465
+ else:
466
+ # === 原版 Dream 逻辑:保持不变(包括 device=self.device 等细节) ===
467
+ num_mask_token = mask_index.sum() / mask_index.shape[0]
468
+ number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
469
+ full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
470
+ full_confidence[mask_index] = confidence
471
+ if number_transfer_tokens > 0:
472
+ if alg_temp is None or alg_temp == 0:
473
+ _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
474
+ else:
475
+ full_confidence = full_confidence / alg_temp
476
+ full_confidence = F.softmax(full_confidence, dim=-1)
477
+ transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
478
+ x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
479
+ x_[mask_index] = x0.clone()
480
+ row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
481
+ x[row_indices, transfer_index] = x_[row_indices, transfer_index]
482
 
483
  # this allows user-defined token control of the intermediate steps
484
  x = generation_tokens_hook_func(i, x, logits)
485
 
486
  if histories is not None:
487
  histories.append(x.clone())
488
+
489
  if return_dict_in_generate:
490
  return DreamModelOutput(
491
  sequences=x,
492
  history=histories,
493
  )
494
  else:
495
+ return x