Update generation_utils.py
Browse files- generation_utils.py +177 -144
generation_utils.py
CHANGED
|
@@ -1,6 +1,17 @@
|
|
| 1 |
# coding=utf-8
|
| 2 |
-
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace
|
| 3 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
import warnings
|
| 6 |
import copy
|
|
@@ -23,32 +34,30 @@ def top_p_logits(logits, top_p=None):
|
|
| 23 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 24 |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 25 |
sorted_indices_to_remove = cumulative_probs > top_p
|
| 26 |
-
# keep first token above threshold
|
| 27 |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 28 |
sorted_indices_to_remove[..., 0] = 0
|
|
|
|
| 29 |
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
|
| 30 |
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
| 31 |
-
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
def top_k_logits(logits, top_k=None):
|
| 35 |
if top_k is None:
|
| 36 |
return logits
|
| 37 |
-
top_k = min(top_k, logits.size(-1))
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
margin_confidence: bool = False,
|
| 49 |
-
neg_entropy: bool = False,
|
| 50 |
-
):
|
| 51 |
-
# 保持 dtype 与 logits 一致(包含 bf16/fp16)
|
| 52 |
if temperature and temperature > 0:
|
| 53 |
logits = logits / temperature
|
| 54 |
if top_p is not None and top_p < 1:
|
|
@@ -58,27 +67,27 @@ def sample_tokens(
|
|
| 58 |
|
| 59 |
probs = torch.softmax(logits, dim=-1)
|
| 60 |
|
|
|
|
| 61 |
if temperature and temperature > 0:
|
| 62 |
-
# 采样
|
| 63 |
try:
|
| 64 |
x0 = dists.Categorical(probs=probs).sample()
|
| 65 |
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
| 66 |
except Exception:
|
| 67 |
confidence, x0 = probs.max(dim=-1)
|
| 68 |
else:
|
| 69 |
-
# 贪心
|
| 70 |
confidence, x0 = probs.max(dim=-1)
|
| 71 |
|
|
|
|
| 72 |
if margin_confidence:
|
| 73 |
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
| 74 |
-
top1_probs = sorted_probs[
|
| 75 |
-
top2_probs = sorted_probs[
|
| 76 |
confidence = top1_probs - top2_probs
|
| 77 |
|
| 78 |
if neg_entropy:
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
confidence = torch.sum(probs * log_probs, dim=-1)
|
| 83 |
|
| 84 |
return confidence, x0
|
|
@@ -92,23 +101,27 @@ class DreamModelOutput(ModelOutput):
|
|
| 92 |
|
| 93 |
class DreamGenerationConfig(GenerationConfig):
|
| 94 |
def __init__(self, **kwargs):
|
|
|
|
| 95 |
self.temperature: float = kwargs.pop("temperature", 0.0)
|
| 96 |
self.top_p: Optional[float] = kwargs.pop("top_p", None)
|
| 97 |
self.top_k: Optional[int] = kwargs.pop("top_k", None)
|
|
|
|
|
|
|
| 98 |
self.max_length = kwargs.pop("max_length", 20)
|
| 99 |
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
|
| 100 |
|
| 101 |
# diffusion specific params
|
| 102 |
self.eps: float = kwargs.pop("eps", 1e-3)
|
| 103 |
self.steps: int = kwargs.pop("steps", 512)
|
| 104 |
-
self.alg: str = kwargs.pop("alg",
|
| 105 |
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
|
| 106 |
|
| 107 |
-
# RCR
|
| 108 |
self.rcr: bool = kwargs.pop("rcr", False)
|
| 109 |
-
|
|
|
|
| 110 |
|
| 111 |
-
# generate
|
| 112 |
self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
|
| 113 |
self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
|
| 114 |
self.output_history: bool = kwargs.pop("output_history", False)
|
|
@@ -119,9 +132,10 @@ class DreamGenerationConfig(GenerationConfig):
|
|
| 119 |
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
| 120 |
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
| 121 |
|
|
|
|
| 122 |
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
| 123 |
|
| 124 |
-
#
|
| 125 |
self._from_model_config = kwargs.pop("_from_model_config", False)
|
| 126 |
self._commit_hash = kwargs.pop("_commit_hash", None)
|
| 127 |
self.transformers_version = kwargs.pop("transformers_version", __version__)
|
|
@@ -137,7 +151,6 @@ class DreamGenerationConfig(GenerationConfig):
|
|
| 137 |
self.validate(is_init=True)
|
| 138 |
|
| 139 |
def validate(self, is_init=False):
|
| 140 |
-
# 保留空实现,兼容 upstream
|
| 141 |
pass
|
| 142 |
|
| 143 |
|
|
@@ -146,7 +159,7 @@ class DreamGenerationMixin:
|
|
| 146 |
def _expand_inputs_for_generation(
|
| 147 |
expand_size: int = 1,
|
| 148 |
input_ids: Optional[torch.LongTensor] = None,
|
| 149 |
-
attention_mask: Optional[torch.LongTensor] = None
|
| 150 |
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
|
| 151 |
if expand_size == 1:
|
| 152 |
return input_ids, attention_mask
|
|
@@ -156,13 +169,16 @@ class DreamGenerationMixin:
|
|
| 156 |
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
|
| 157 |
return input_ids, attention_mask
|
| 158 |
|
|
|
|
| 159 |
def _apply_rcr_logic(
|
| 160 |
self,
|
| 161 |
-
x: torch.
|
| 162 |
-
|
| 163 |
-
|
| 164 |
mask_index: torch.Tensor,
|
| 165 |
-
|
|
|
|
|
|
|
| 166 |
mask_token_id: int,
|
| 167 |
step: int,
|
| 168 |
total_steps: int,
|
|
@@ -170,56 +186,68 @@ class DreamGenerationMixin:
|
|
| 170 |
t: torch.Tensor,
|
| 171 |
):
|
| 172 |
"""
|
| 173 |
-
Running Confidence Remasking
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
"""
|
| 179 |
device = x.device
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
-
#
|
| 184 |
-
|
| 185 |
-
# 本步的转移数量(与 Dream 调度一致)
|
| 186 |
-
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if step < total_steps - 1 else int(num_mask_token)
|
| 187 |
|
| 188 |
-
#
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
|
|
|
|
| 194 |
for j in range(B):
|
| 195 |
masked_j = int(mask_index[j].sum().item())
|
| 196 |
-
if masked_j == 0:
|
| 197 |
-
continue
|
| 198 |
k_j = min(number_transfer_tokens, masked_j)
|
| 199 |
-
|
| 200 |
if k_j > 0:
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
#
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
if step < total_steps - 1:
|
| 209 |
-
target_cum = int(
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
|
| 225 |
if is_torchdynamo_compiling():
|
|
@@ -232,9 +260,11 @@ class DreamGenerationMixin:
|
|
| 232 |
UserWarning,
|
| 233 |
)
|
| 234 |
if input_ids_length >= generation_config.max_length:
|
|
|
|
| 235 |
raise ValueError(
|
| 236 |
-
f"Input length is {input_ids_length}, but `max_length` is
|
| 237 |
-
"
|
|
|
|
| 238 |
)
|
| 239 |
|
| 240 |
def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
|
|
@@ -242,20 +272,20 @@ class DreamGenerationMixin:
|
|
| 242 |
if not has_default_max_length and generation_config.max_length is not None:
|
| 243 |
logger.warning(
|
| 244 |
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
| 245 |
-
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence."
|
|
|
|
| 246 |
)
|
| 247 |
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
|
|
|
| 248 |
elif has_default_max_length:
|
| 249 |
if generation_config.max_length == DreamGenerationConfig().max_length:
|
| 250 |
generation_config.max_length = generation_config.max_length + input_ids_length
|
| 251 |
-
|
| 252 |
-
if
|
| 253 |
-
generation_config.max_length = min(generation_config.max_length,
|
| 254 |
return generation_config
|
| 255 |
|
| 256 |
-
def _prepare_generation_config(
|
| 257 |
-
self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict
|
| 258 |
-
) -> DreamGenerationConfig:
|
| 259 |
using_model_generation_config = False
|
| 260 |
if generation_config is None:
|
| 261 |
generation_config = DreamGenerationConfig.from_model_config(self.config)
|
|
@@ -273,11 +303,10 @@ class DreamGenerationMixin:
|
|
| 273 |
generation_config.pad_token_id = self.generation_config.pad_token_id
|
| 274 |
if generation_config.mask_token_id is None:
|
| 275 |
generation_config.mask_token_id = self.generation_config.mask_token_id
|
|
|
|
| 276 |
return generation_config
|
| 277 |
|
| 278 |
-
def _prepare_special_tokens(
|
| 279 |
-
self, generation_config: DreamGenerationConfig, device: Optional[Union[torch.device, str]] = None
|
| 280 |
-
):
|
| 281 |
def _tensor_or_none(token, device=None):
|
| 282 |
if token is None:
|
| 283 |
return token
|
|
@@ -332,18 +361,22 @@ class DreamGenerationMixin:
|
|
| 332 |
|
| 333 |
if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
|
| 334 |
warnings.warn(
|
| 335 |
-
"You are calling .generate() with `input_ids` on a device type different
|
| 336 |
-
f"`input_ids` is on {input_ids.device.type},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
UserWarning,
|
| 338 |
)
|
| 339 |
-
|
| 340 |
if (
|
| 341 |
hasattr(generation_config, "pad_token_id")
|
| 342 |
and torch.any(input_ids == generation_config.pad_token_id)
|
| 343 |
and attention_mask is None
|
| 344 |
):
|
| 345 |
warnings.warn(
|
| 346 |
-
"Padding was detected but no attention mask is passed. For correct
|
|
|
|
| 347 |
UserWarning,
|
| 348 |
)
|
| 349 |
|
|
@@ -368,9 +401,9 @@ class DreamGenerationMixin:
|
|
| 368 |
attention_mask: Optional[torch.LongTensor],
|
| 369 |
generation_config: DreamGenerationConfig,
|
| 370 |
generation_tokens_hook_func,
|
| 371 |
-
generation_logits_hook_func
|
| 372 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
| 373 |
-
#
|
| 374 |
output_history = generation_config.output_history
|
| 375 |
return_dict_in_generate = generation_config.return_dict_in_generate
|
| 376 |
max_length = generation_config.max_length
|
|
@@ -383,13 +416,13 @@ class DreamGenerationMixin:
|
|
| 383 |
top_p = generation_config.top_p
|
| 384 |
top_k = generation_config.top_k
|
| 385 |
|
| 386 |
-
# RCR
|
| 387 |
rcr = generation_config.rcr
|
| 388 |
conf_alg = generation_config.conf_alg
|
| 389 |
|
| 390 |
histories = [] if (return_dict_in_generate and output_history) else None
|
| 391 |
|
| 392 |
-
# pad
|
| 393 |
x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
|
| 394 |
|
| 395 |
if attention_mask is not None and torch.any(attention_mask == 0.0):
|
|
@@ -406,104 +439,104 @@ class DreamGenerationMixin:
|
|
| 406 |
|
| 407 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
| 408 |
|
| 409 |
-
#
|
| 410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
|
| 412 |
-
#
|
| 413 |
x = generation_tokens_hook_func(None, x, None)
|
| 414 |
|
| 415 |
for i in range(steps):
|
| 416 |
mask_index = (x == mask_token_id)
|
| 417 |
-
|
| 418 |
logits = self(x, attention_mask, tok_idx).logits
|
|
|
|
| 419 |
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
| 420 |
|
| 421 |
-
# 允许用户控制中间 logits
|
| 422 |
logits = generation_logits_hook_func(i, x, logits)
|
| 423 |
|
| 424 |
mask_logits = logits[mask_index]
|
| 425 |
t = timesteps[i]
|
| 426 |
s = timesteps[i + 1]
|
| 427 |
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
overtime_confidence = torch.zeros_like(x, dtype=logits.dtype, device=x.device)
|
| 431 |
-
|
| 432 |
-
if alg == "origin":
|
| 433 |
-
# 原始 Dream 逻辑(不动)
|
| 434 |
p_transfer = 1 - s / t if i < steps - 1 else 1
|
| 435 |
-
x0 = torch.
|
| 436 |
transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
|
| 437 |
_, x0[transfer_index_t_s] = sample_tokens(
|
| 438 |
mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k
|
| 439 |
)
|
| 440 |
x[mask_index] = x0.clone()
|
| 441 |
-
|
| 442 |
else:
|
| 443 |
-
#
|
| 444 |
-
use_alg = alg
|
| 445 |
-
if
|
| 446 |
-
# rcr=True 时,置信度算法由 conf_alg 决定(不影响 baseline)
|
| 447 |
-
use_alg = conf_alg
|
| 448 |
-
|
| 449 |
-
if use_alg == "maskgit_plus":
|
| 450 |
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
| 451 |
-
elif use_alg ==
|
| 452 |
confidence, x0 = sample_tokens(
|
| 453 |
mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True
|
| 454 |
)
|
| 455 |
-
elif use_alg ==
|
| 456 |
confidence, x0 = sample_tokens(
|
| 457 |
-
mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True
|
| 458 |
)
|
| 459 |
else:
|
| 460 |
-
raise RuntimeError(f"Unknown alg: {
|
| 461 |
-
|
| 462 |
-
# 统一 full_confidence 的 dtype = logits.dtype(避免 int/float 混合)
|
| 463 |
-
full_confidence = torch.full(
|
| 464 |
-
x.shape, float("-inf"), device=self.device, dtype=logits.dtype
|
| 465 |
-
)
|
| 466 |
-
full_confidence[mask_index] = confidence
|
| 467 |
|
| 468 |
if rcr:
|
| 469 |
-
# === RCR
|
| 470 |
self._apply_rcr_logic(
|
| 471 |
x=x,
|
| 472 |
-
|
| 473 |
-
|
| 474 |
mask_index=mask_index,
|
| 475 |
-
|
|
|
|
|
|
|
| 476 |
mask_token_id=mask_token_id,
|
| 477 |
step=i,
|
| 478 |
total_steps=steps,
|
| 479 |
-
s=s,
|
| 480 |
-
t=t,
|
| 481 |
)
|
| 482 |
else:
|
| 483 |
-
# ===
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
if number_transfer_tokens > 0:
|
| 489 |
if alg_temp is None or alg_temp == 0:
|
| 490 |
_, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
|
| 491 |
else:
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
transfer_index = torch.multinomial(
|
| 495 |
-
x_ = torch.
|
| 496 |
x_[mask_index] = x0.clone()
|
| 497 |
row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
|
| 498 |
x[row_indices, transfer_index] = x_[row_indices, transfer_index]
|
| 499 |
|
| 500 |
-
# 允许用户控制中间 tokens
|
| 501 |
x = generation_tokens_hook_func(i, x, logits)
|
| 502 |
|
| 503 |
if histories is not None:
|
| 504 |
histories.append(x.clone())
|
| 505 |
|
| 506 |
if return_dict_in_generate:
|
| 507 |
-
return DreamModelOutput(
|
|
|
|
|
|
|
|
|
|
| 508 |
else:
|
| 509 |
return x
|
|
|
|
| 1 |
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# You may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
|
| 16 |
import warnings
|
| 17 |
import copy
|
|
|
|
| 34 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 35 |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 36 |
sorted_indices_to_remove = cumulative_probs > top_p
|
| 37 |
+
# Shift the indices to the right to keep the first token above the threshold
|
| 38 |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 39 |
sorted_indices_to_remove[..., 0] = 0
|
| 40 |
+
|
| 41 |
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
|
| 42 |
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
| 43 |
+
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
|
| 44 |
+
return logits
|
| 45 |
|
| 46 |
|
| 47 |
def top_k_logits(logits, top_k=None):
|
| 48 |
if top_k is None:
|
| 49 |
return logits
|
| 50 |
+
top_k = int(min(top_k, logits.size(-1))) # Safety check
|
| 51 |
+
if top_k <= 0:
|
| 52 |
+
return logits
|
| 53 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
| 54 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 55 |
+
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
|
| 56 |
+
return logits
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
|
| 60 |
+
# logits: [N, V]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
if temperature and temperature > 0:
|
| 62 |
logits = logits / temperature
|
| 63 |
if top_p is not None and top_p < 1:
|
|
|
|
| 67 |
|
| 68 |
probs = torch.softmax(logits, dim=-1)
|
| 69 |
|
| 70 |
+
# 采样/贪心
|
| 71 |
if temperature and temperature > 0:
|
|
|
|
| 72 |
try:
|
| 73 |
x0 = dists.Categorical(probs=probs).sample()
|
| 74 |
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
| 75 |
except Exception:
|
| 76 |
confidence, x0 = probs.max(dim=-1)
|
| 77 |
else:
|
|
|
|
| 78 |
confidence, x0 = probs.max(dim=-1)
|
| 79 |
|
| 80 |
+
# 置信度定义切换
|
| 81 |
if margin_confidence:
|
| 82 |
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
| 83 |
+
top1_probs = sorted_probs[:, 0]
|
| 84 |
+
top2_probs = sorted_probs[:, 1]
|
| 85 |
confidence = top1_probs - top2_probs
|
| 86 |
|
| 87 |
if neg_entropy:
|
| 88 |
+
# 负熵(数值 ≤ 0;越接近 0 越大代表越确定)
|
| 89 |
+
epsilon = 1e-10
|
| 90 |
+
log_probs = torch.log(probs + epsilon)
|
| 91 |
confidence = torch.sum(probs * log_probs, dim=-1)
|
| 92 |
|
| 93 |
return confidence, x0
|
|
|
|
| 101 |
|
| 102 |
class DreamGenerationConfig(GenerationConfig):
|
| 103 |
def __init__(self, **kwargs):
|
| 104 |
+
# sampling
|
| 105 |
self.temperature: float = kwargs.pop("temperature", 0.0)
|
| 106 |
self.top_p: Optional[float] = kwargs.pop("top_p", None)
|
| 107 |
self.top_k: Optional[int] = kwargs.pop("top_k", None)
|
| 108 |
+
|
| 109 |
+
# length
|
| 110 |
self.max_length = kwargs.pop("max_length", 20)
|
| 111 |
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
|
| 112 |
|
| 113 |
# diffusion specific params
|
| 114 |
self.eps: float = kwargs.pop("eps", 1e-3)
|
| 115 |
self.steps: int = kwargs.pop("steps", 512)
|
| 116 |
+
self.alg: str = kwargs.pop("alg", 'origin') # 'origin' | 'maskgit_plus' | 'topk_margin' | 'entropy'
|
| 117 |
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
|
| 118 |
|
| 119 |
+
# === RCR 参数(新增,默认关闭,不影响原逻辑) ===
|
| 120 |
self.rcr: bool = kwargs.pop("rcr", False)
|
| 121 |
+
# 仅在 rcr=True 时用于选择置信度算法;rcr=False 不读取它
|
| 122 |
+
self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
|
| 123 |
|
| 124 |
+
# generate outputs
|
| 125 |
self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
|
| 126 |
self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
|
| 127 |
self.output_history: bool = kwargs.pop("output_history", False)
|
|
|
|
| 132 |
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
| 133 |
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
| 134 |
|
| 135 |
+
# misc
|
| 136 |
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
| 137 |
|
| 138 |
+
# bookkeeping
|
| 139 |
self._from_model_config = kwargs.pop("_from_model_config", False)
|
| 140 |
self._commit_hash = kwargs.pop("_commit_hash", None)
|
| 141 |
self.transformers_version = kwargs.pop("transformers_version", __version__)
|
|
|
|
| 151 |
self.validate(is_init=True)
|
| 152 |
|
| 153 |
def validate(self, is_init=False):
|
|
|
|
| 154 |
pass
|
| 155 |
|
| 156 |
|
|
|
|
| 159 |
def _expand_inputs_for_generation(
|
| 160 |
expand_size: int = 1,
|
| 161 |
input_ids: Optional[torch.LongTensor] = None,
|
| 162 |
+
attention_mask: Optional[torch.LongTensor] = None
|
| 163 |
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
|
| 164 |
if expand_size == 1:
|
| 165 |
return input_ids, attention_mask
|
|
|
|
| 169 |
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
|
| 170 |
return input_ids, attention_mask
|
| 171 |
|
| 172 |
+
# === 新版:RCR 核心(历史置信度) ===
|
| 173 |
def _apply_rcr_logic(
|
| 174 |
self,
|
| 175 |
+
x: torch.Tensor,
|
| 176 |
+
x0: torch.Tensor,
|
| 177 |
+
conf_now: torch.Tensor,
|
| 178 |
mask_index: torch.Tensor,
|
| 179 |
+
fixed_conf: torch.Tensor,
|
| 180 |
+
gen_mask: torch.Tensor,
|
| 181 |
+
init_mask_count: torch.Tensor,
|
| 182 |
mask_token_id: int,
|
| 183 |
step: int,
|
| 184 |
total_steps: int,
|
|
|
|
| 186 |
t: torch.Tensor,
|
| 187 |
):
|
| 188 |
"""
|
| 189 |
+
Running Confidence Remasking(历史置信度版):
|
| 190 |
+
1) 在 mask 子集内以当步置信度 conf_now 选择 top-k_j 个位置“确认”(写 token);
|
| 191 |
+
2) 更新历史置信度 fixed_conf = max(fixed_conf, conf_now)(仅对新选入位置);
|
| 192 |
+
3) 按“累计允许确认配额” target_cum = init_mask_count * (1 - s/t) 若超额,
|
| 193 |
+
在已确认集合 gen_mask 内按 fixed_conf 最低回遮 over 个位置。
|
| 194 |
+
|
| 195 |
+
说明:
|
| 196 |
+
- conf_now 用 float32 维护,避免与 bfloat16 混写导致 dtype 报错;
|
| 197 |
+
- 对 entropy:conf_now = 负熵(≤0 且越接近 0 越大代表越确定),配合 topk(largest=True) 没问题。
|
| 198 |
"""
|
| 199 |
device = x.device
|
| 200 |
+
B, L = x.shape
|
| 201 |
+
|
| 202 |
+
# 计算“当步”选入规模(与 vanilla 同口径:平均剩余 mask * (1 - s/t))
|
| 203 |
+
avg_mask_now = (mask_index.sum().item() / max(1, mask_index.shape[0]))
|
| 204 |
+
ratio = (1.0 - (s.item() / t.item())) if step < total_steps - 1 else 1.0
|
| 205 |
+
number_transfer_tokens = int(avg_mask_now * ratio)
|
| 206 |
|
| 207 |
+
# 确保当步置信度是 float32
|
| 208 |
+
conf_now = conf_now.to(torch.float32)
|
|
|
|
|
|
|
| 209 |
|
| 210 |
+
# 仅在 mask 处有效的“全长”视图
|
| 211 |
+
full_conf_now = torch.full((B, L), float("-inf"), dtype=torch.float32, device=device)
|
| 212 |
+
full_x0 = torch.full((B, L), mask_token_id, dtype=torch.long, device=device)
|
| 213 |
+
full_conf_now[mask_index] = conf_now
|
| 214 |
+
full_x0[mask_index] = x0
|
| 215 |
|
| 216 |
+
# 逐样本处理
|
| 217 |
for j in range(B):
|
| 218 |
masked_j = int(mask_index[j].sum().item())
|
|
|
|
|
|
|
| 219 |
k_j = min(number_transfer_tokens, masked_j)
|
|
|
|
| 220 |
if k_j > 0:
|
| 221 |
+
conf_row = full_conf_now[j] # float32
|
| 222 |
+
# 选当步 top-k_j
|
| 223 |
+
_, sel_idx = torch.topk(conf_row, k=k_j, largest=True)
|
| 224 |
+
# 写 token & 标记确认
|
| 225 |
+
x[j, sel_idx] = full_x0[j, sel_idx]
|
| 226 |
+
gen_mask[j, sel_idx] = True
|
| 227 |
+
# 历史置信度取 running max
|
| 228 |
+
fixed_conf[j, sel_idx] = torch.maximum(fixed_conf[j, sel_idx], conf_row[sel_idx])
|
| 229 |
+
|
| 230 |
+
# 累计允许确认配额(以初始 mask 为基数)
|
| 231 |
+
init_m = int(init_mask_count[j].item())
|
| 232 |
if step < total_steps - 1:
|
| 233 |
+
target_cum = int(init_m * (1.0 - (s.item() / t.item())))
|
| 234 |
+
else:
|
| 235 |
+
target_cum = init_m # 最后一步允许全确认
|
| 236 |
+
|
| 237 |
+
current_gen = int(gen_mask[j].sum().item())
|
| 238 |
+
over = max(0, current_gen - target_cum)
|
| 239 |
+
if over > 0:
|
| 240 |
+
# 在已确认集合里按历史置信度最低回遮
|
| 241 |
+
gen_idx = torch.where(gen_mask[j])[0]
|
| 242 |
+
if gen_idx.numel() > 0:
|
| 243 |
+
hist_vals = fixed_conf[j, gen_idx] # float32
|
| 244 |
+
over = min(over, int(gen_idx.numel()))
|
| 245 |
+
_, low_local = torch.topk(hist_vals, k=over, largest=False)
|
| 246 |
+
low_global = gen_idx[low_local]
|
| 247 |
+
# 回遮:恢复为 MASK,并撤销确认标记 & 清空历史置信度
|
| 248 |
+
x[j, low_global] = mask_token_id
|
| 249 |
+
gen_mask[j, low_global] = False
|
| 250 |
+
fixed_conf[j, low_global] = float("-inf")
|
| 251 |
|
| 252 |
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
|
| 253 |
if is_torchdynamo_compiling():
|
|
|
|
| 260 |
UserWarning,
|
| 261 |
)
|
| 262 |
if input_ids_length >= generation_config.max_length:
|
| 263 |
+
input_ids_string = "input_ids"
|
| 264 |
raise ValueError(
|
| 265 |
+
f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
|
| 266 |
+
f" {generation_config.max_length}. You should consider increasing `max_length` or, better yet,"
|
| 267 |
+
" setting `max_new_tokens`."
|
| 268 |
)
|
| 269 |
|
| 270 |
def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
|
|
|
|
| 272 |
if not has_default_max_length and generation_config.max_length is not None:
|
| 273 |
logger.warning(
|
| 274 |
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
| 275 |
+
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
| 276 |
+
"Please refer to the documentation for more information."
|
| 277 |
)
|
| 278 |
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
| 279 |
+
|
| 280 |
elif has_default_max_length:
|
| 281 |
if generation_config.max_length == DreamGenerationConfig().max_length:
|
| 282 |
generation_config.max_length = generation_config.max_length + input_ids_length
|
| 283 |
+
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
|
| 284 |
+
if max_position_embeddings is not None:
|
| 285 |
+
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
|
| 286 |
return generation_config
|
| 287 |
|
| 288 |
+
def _prepare_generation_config(self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict) -> DreamGenerationConfig:
|
|
|
|
|
|
|
| 289 |
using_model_generation_config = False
|
| 290 |
if generation_config is None:
|
| 291 |
generation_config = DreamGenerationConfig.from_model_config(self.config)
|
|
|
|
| 303 |
generation_config.pad_token_id = self.generation_config.pad_token_id
|
| 304 |
if generation_config.mask_token_id is None:
|
| 305 |
generation_config.mask_token_id = self.generation_config.mask_token_id
|
| 306 |
+
|
| 307 |
return generation_config
|
| 308 |
|
| 309 |
+
def _prepare_special_tokens(self, generation_config: DreamGenerationConfig, device: Optional[Union[torch.device, str]] = None):
|
|
|
|
|
|
|
| 310 |
def _tensor_or_none(token, device=None):
|
| 311 |
if token is None:
|
| 312 |
return token
|
|
|
|
| 361 |
|
| 362 |
if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
|
| 363 |
warnings.warn(
|
| 364 |
+
"You are calling .generate() with the `input_ids` being on a device type different"
|
| 365 |
+
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
|
| 366 |
+
f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
|
| 367 |
+
" Please make sure that you have put `input_ids` to the"
|
| 368 |
+
f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
|
| 369 |
+
" running `.generate()`.",
|
| 370 |
UserWarning,
|
| 371 |
)
|
|
|
|
| 372 |
if (
|
| 373 |
hasattr(generation_config, "pad_token_id")
|
| 374 |
and torch.any(input_ids == generation_config.pad_token_id)
|
| 375 |
and attention_mask is None
|
| 376 |
):
|
| 377 |
warnings.warn(
|
| 378 |
+
"Padding was detected but no attention mask is passed here. For correct "
|
| 379 |
+
"generation results, please set `attention_mask` when batch-padding inputs.",
|
| 380 |
UserWarning,
|
| 381 |
)
|
| 382 |
|
|
|
|
| 401 |
attention_mask: Optional[torch.LongTensor],
|
| 402 |
generation_config: DreamGenerationConfig,
|
| 403 |
generation_tokens_hook_func,
|
| 404 |
+
generation_logits_hook_func
|
| 405 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
| 406 |
+
# === 基本变量 ===
|
| 407 |
output_history = generation_config.output_history
|
| 408 |
return_dict_in_generate = generation_config.return_dict_in_generate
|
| 409 |
max_length = generation_config.max_length
|
|
|
|
| 416 |
top_p = generation_config.top_p
|
| 417 |
top_k = generation_config.top_k
|
| 418 |
|
| 419 |
+
# === RCR 控制变量 ===
|
| 420 |
rcr = generation_config.rcr
|
| 421 |
conf_alg = generation_config.conf_alg
|
| 422 |
|
| 423 |
histories = [] if (return_dict_in_generate and output_history) else None
|
| 424 |
|
| 425 |
+
# pad input_ids to max_length
|
| 426 |
x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
|
| 427 |
|
| 428 |
if attention_mask is not None and torch.any(attention_mask == 0.0):
|
|
|
|
| 439 |
|
| 440 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
| 441 |
|
| 442 |
+
# === RCR 缓冲(仅 rcr=True 时启用) ===
|
| 443 |
+
if rcr:
|
| 444 |
+
init_mask_count = (x == mask_token_id).sum(dim=1) # [B]
|
| 445 |
+
fixed_conf = torch.full(
|
| 446 |
+
x.shape, float("-inf"), dtype=torch.float32, device=x.device
|
| 447 |
+
) # 历史置信度
|
| 448 |
+
gen_mask = torch.zeros_like(x, dtype=torch.bool) # 已确认集合
|
| 449 |
+
else:
|
| 450 |
+
init_mask_count = None
|
| 451 |
+
fixed_conf = None
|
| 452 |
+
gen_mask = None
|
| 453 |
|
| 454 |
+
# hooks:允许用户中间控制
|
| 455 |
x = generation_tokens_hook_func(None, x, None)
|
| 456 |
|
| 457 |
for i in range(steps):
|
| 458 |
mask_index = (x == mask_token_id)
|
|
|
|
| 459 |
logits = self(x, attention_mask, tok_idx).logits
|
| 460 |
+
# 右移一位(Dream 原实现)
|
| 461 |
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
| 462 |
|
|
|
|
| 463 |
logits = generation_logits_hook_func(i, x, logits)
|
| 464 |
|
| 465 |
mask_logits = logits[mask_index]
|
| 466 |
t = timesteps[i]
|
| 467 |
s = timesteps[i + 1]
|
| 468 |
|
| 469 |
+
if alg == 'origin':
|
| 470 |
+
# === 原版 origin:随机按比例转移(不涉及置信度) ===
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
p_transfer = 1 - s / t if i < steps - 1 else 1
|
| 472 |
+
x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
|
| 473 |
transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
|
| 474 |
_, x0[transfer_index_t_s] = sample_tokens(
|
| 475 |
mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k
|
| 476 |
)
|
| 477 |
x[mask_index] = x0.clone()
|
|
|
|
| 478 |
else:
|
| 479 |
+
# === 置信度算法选择(vanilla 与 RCR 复用此处) ===
|
| 480 |
+
use_alg = conf_alg if rcr else alg
|
| 481 |
+
if use_alg == 'maskgit_plus':
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
| 483 |
+
elif use_alg == 'topk_margin':
|
| 484 |
confidence, x0 = sample_tokens(
|
| 485 |
mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True
|
| 486 |
)
|
| 487 |
+
elif use_alg == 'entropy':
|
| 488 |
confidence, x0 = sample_tokens(
|
| 489 |
+
mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, neg_entropy=True
|
| 490 |
)
|
| 491 |
else:
|
| 492 |
+
raise RuntimeError(f"Unknown alg/conf_alg: {use_alg}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
|
| 494 |
if rcr:
|
| 495 |
+
# === 历史置信度版 RCR ===
|
| 496 |
self._apply_rcr_logic(
|
| 497 |
x=x,
|
| 498 |
+
x0=x0,
|
| 499 |
+
conf_now=confidence,
|
| 500 |
mask_index=mask_index,
|
| 501 |
+
fixed_conf=fixed_conf,
|
| 502 |
+
gen_mask=gen_mask,
|
| 503 |
+
init_mask_count=init_mask_count,
|
| 504 |
mask_token_id=mask_token_id,
|
| 505 |
step=i,
|
| 506 |
total_steps=steps,
|
| 507 |
+
s=s, t=t,
|
|
|
|
| 508 |
)
|
| 509 |
else:
|
| 510 |
+
# === 原版 Dream(vanilla):本步 top-k,永久确认,不回遮 ===
|
| 511 |
+
# number_transfer_tokens 基于“当前平均剩余 mask * (1 - s/t)”
|
| 512 |
+
avg_mask_now = (mask_index.sum().item() / max(1, mask_index.shape[0]))
|
| 513 |
+
ratio = (1.0 - (s.item() / t.item())) if i < steps - 1 else 1.0
|
| 514 |
+
number_transfer_tokens = int(avg_mask_now * ratio)
|
| 515 |
+
|
| 516 |
+
full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
|
| 517 |
+
full_confidence[mask_index] = confidence
|
| 518 |
+
|
| 519 |
if number_transfer_tokens > 0:
|
| 520 |
if alg_temp is None or alg_temp == 0:
|
| 521 |
_, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
|
| 522 |
else:
|
| 523 |
+
full_confidence = full_confidence / alg_temp
|
| 524 |
+
full_confidence = F.softmax(full_confidence, dim=-1)
|
| 525 |
+
transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
|
| 526 |
+
x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
|
| 527 |
x_[mask_index] = x0.clone()
|
| 528 |
row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
|
| 529 |
x[row_indices, transfer_index] = x_[row_indices, transfer_index]
|
| 530 |
|
|
|
|
| 531 |
x = generation_tokens_hook_func(i, x, logits)
|
| 532 |
|
| 533 |
if histories is not None:
|
| 534 |
histories.append(x.clone())
|
| 535 |
|
| 536 |
if return_dict_in_generate:
|
| 537 |
+
return DreamModelOutput(
|
| 538 |
+
sequences=x,
|
| 539 |
+
history=histories,
|
| 540 |
+
)
|
| 541 |
else:
|
| 542 |
return x
|