File size: 19,797 Bytes
78f65a4 6be7e05 d0ecd79 6be7e05 fd52594 34c9b0b 8c62663 fd52594 8c62663 fd52594 6be7e05 78f65a4 6be7e05 3d20c97 6be7e05 3d20c97 6be7e05 34c9b0b 78f65a4 6be7e05 78f65a4 6be7e05 78f65a4 d0ecd79 78f65a4 d0ecd79 9384b5d 6be7e05 631ce9b 6be7e05 3d20c97 6be7e05 3d20c97 6be7e05 78f65a4 6be7e05 d0ecd79 6be7e05 3d20c97 fd52594 6be7e05 d0ecd79 6be7e05 78f65a4 6be7e05 fd52594 6be7e05 d0ecd79 6be7e05 78f65a4 6be7e05 fd52594 6be7e05 3d20c97 6be7e05 34c9b0b 6be7e05 3d20c97 6be7e05 fd52594 6be7e05 fd52594 6be7e05 34c9b0b 8c62663 d0ecd79 6be7e05 78f65a4 34c9b0b 6be7e05 8c62663 d0ecd79 8c62663 6be7e05 fd52594 34c9b0b 6be7e05 d0ecd79 6be7e05 fd52594 6be7e05 d0ecd79 6be7e05 3d20c97 fd52594 6be7e05 78f65a4 6be7e05 3d20c97 6be7e05 78f65a4 3d20c97 78f65a4 d0ecd79 6be7e05 34c9b0b 6be7e05 8c62663 fd52594 78f65a4 6be7e05 d0ecd79 6be7e05 78f65a4 6be7e05 d0ecd79 78f65a4 34c9b0b 78f65a4 fd52594 78f65a4 fd52594 6be7e05 d0ecd79 6be7e05 9384b5d 6be7e05 d0ecd79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 |
# coding=utf-8
import warnings
import copy
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.distributions as dists
from torch.nn import functional as F
from transformers import __version__
from transformers.generation.configuration_utils import GenerationConfig
from transformers.utils import ModelOutput, is_torchdynamo_compiling, logging
logger = logging.get_logger(__name__)
def _apply_top_p_k_temp(logits, temperature=0.0, top_p=None, top_k=None):
if temperature and temperature > 0:
logits = logits / temperature
if top_p is not None and top_p < 1:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
if top_k is not None:
top_k = int(min(top_k, logits.size(-1)))
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
return logits
def _confidence_from_probs(
probs: torch.Tensor, # [..., V]
chosen_ids: Optional[torch.Tensor], # [...]
mode: str # 'entropy' | 'maskgit_plus' | 'topk_margin'
) -> torch.Tensor:
"""返回“越大越自信”的标量分数,与解码一致。"""
if mode == "entropy":
eps = 1e-10
logp = torch.log(probs + eps)
return -(probs * logp).sum(dim=-1) # -H(p)
elif mode == "maskgit_plus":
assert chosen_ids is not None, "maskgit_plus 需要 chosen_ids"
return torch.gather(probs, -1, chosen_ids.unsqueeze(-1)).squeeze(-1) # p(x0)
elif mode == "topk_margin":
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
return sorted_probs[..., 0] - sorted_probs[..., 1] # top1 - top2
else:
raise ValueError(f"Unknown conf mode: {mode}")
@dataclass
class DreamModelOutput(ModelOutput):
sequences: torch.LongTensor = None
history: Optional[Tuple[torch.FloatTensor]] = None
class DreamGenerationConfig(GenerationConfig):
def __init__(self, **kwargs):
# sampling
self.temperature: float = kwargs.pop("temperature", 0.0)
self.top_p: Optional[float] = kwargs.pop("top_p", None)
self.top_k: Optional[int] = kwargs.pop("top_k", None)
# length
self.max_length = kwargs.pop("max_length", 20)
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
# diffusion
self.eps: float = kwargs.pop("eps", 1e-3)
self.steps: int = kwargs.pop("steps", 512)
# vanilla 的打分算法(rcr=False 时使用)
self.alg: str = kwargs.pop("alg", 'maskgit_plus') # 'origin' | 'maskgit_plus' | 'topk_margin' | 'entropy'
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
# === RCR ===
self.rcr: bool = kwargs.pop("rcr", False)
# rcr=True 时用于解码 & 历史分一致的置信度定义
self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus') # 'maskgit_plus' | 'topk_margin' | 'entropy'
# 注意:下两项会被 _sample 内部“写死”为 1/4 到 3/4,总是覆盖
self.rcr_start_step: int = kwargs.pop("rcr_start_step", 0)
self.rcr_end_step: int = kwargs.pop("rcr_end_step", None) or self.steps
# 是否保护“本步刚写”的 token 不被回遮
self.rcr_protect_current_step: bool = kwargs.pop("rcr_protect_current_step", False)
# outputs
self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
self.output_history: bool = kwargs.pop("output_history", False)
# special tokens
self.mask_token_id = kwargs.pop("mask_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
# misc
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
# bookkeeping
self._from_model_config = kwargs.pop("_from_model_config", False)
self._commit_hash = kwargs.pop("_commit_hash", None)
self.transformers_version = kwargs.pop("transformers_version", __version__)
if not self._from_model_config:
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
self.validate(is_init=True)
def validate(self, is_init=False):
# 简单边界
self.rcr_start_step = max(0, int(self.rcr_start_step))
self.rcr_end_step = max(self.rcr_start_step, int(self.rcr_end_step))
class DreamGenerationMixin:
@staticmethod
def _expand_inputs_for_generation(
expand_size: int = 1,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None
):
if expand_size == 1:
return input_ids, attention_mask
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
if attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
return input_ids, attention_mask
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
if is_torchdynamo_compiling():
return
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
warnings.warn(
f"Using default `max_length` (={generation_config.max_length}). Prefer `max_new_tokens`.",
UserWarning,
)
if input_ids_length >= generation_config.max_length:
raise ValueError(
f"Input length is {input_ids_length}, but `max_length` is {generation_config.max_length}. "
"Increase `max_length` or set `max_new_tokens`."
)
def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
if generation_config.max_new_tokens is not None:
if not has_default_max_length and generation_config.max_length is not None:
logger.warning("Both `max_new_tokens` and `max_length` are set. `max_new_tokens` takes precedence.")
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
elif has_default_max_length:
if generation_config.max_length == DreamGenerationConfig().max_length:
generation_config.max_length = generation_config.max_length + input_ids_length
mpe = getattr(self.config, "max_position_embeddings", None)
if mpe is not None:
generation_config.max_length = min(generation_config.max_length, mpe)
return generation_config
def _prepare_generation_config(self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict) -> DreamGenerationConfig:
using_model_generation_config = False
if generation_config is None:
generation_config = DreamGenerationConfig.from_model_config(self.config)
using_model_generation_config = True
if not is_torchdynamo_compiling():
generation_config = copy.deepcopy(generation_config)
_ = generation_config.update(**kwargs)
if not using_model_generation_config:
if generation_config.bos_token_id is None:
generation_config.bos_token_id = self.generation_config.bos_token_id
if generation_config.eos_token_id is None:
generation_config.eos_token_id = self.generation_config.eos_token_id
if generation_config.pad_token_id is None:
generation_config.pad_token_id = self.generation_config.pad_token_id
if generation_config.mask_token_id is None:
generation_config.mask_token_id = self.generation_config.mask_token_id
return generation_config
def _prepare_special_tokens(self, generation_config: DreamGenerationConfig, device=None):
def _tensor_or_none(token, device=None):
if token is None:
return token
device = device if device is not None else self.device
if isinstance(token, torch.Tensor):
return token.to(device)
return torch.tensor(token, device=device, dtype=torch.long)
bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)
if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
eos_token_tensor = eos_token_tensor.unsqueeze(0)
if pad_token_tensor is None and eos_token_tensor is not None:
pad_token_tensor = eos_token_tensor[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
generation_config._bos_token_tensor = bos_token_tensor
generation_config._eos_token_tensor = eos_token_tensor
generation_config._pad_token_tensor = pad_token_tensor
generation_config._mask_token_tensor = mask_token_tensor
@torch.no_grad()
def diffusion_generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[DreamGenerationConfig] = None,
**kwargs,
):
generation_config = self._prepare_generation_config(generation_config, **kwargs)
generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
assert inputs is not None
input_ids = inputs
device = input_ids.device
attention_mask = kwargs.pop("attention_mask", None)
self._prepare_special_tokens(generation_config, device=device)
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
input_ids_length=input_ids_length,
)
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
warnings.warn(
"You are calling .generate() with `input_ids` on a device different from the model.",
UserWarning,
)
if (
hasattr(generation_config, "pad_token_id")
and torch.any(input_ids == generation_config.pad_token_id)
and attention_mask is None
):
warnings.warn(
"Padding detected but no attention mask was passed. Set `attention_mask` for correct generation.",
UserWarning,
)
input_ids, attention_mask = self._expand_inputs_for_generation(
expand_size=generation_config.num_return_sequences,
input_ids=input_ids,
attention_mask=attention_mask,
)
return self._sample(
input_ids,
attention_mask=attention_mask,
generation_config=generation_config,
generation_tokens_hook_func=generation_tokens_hook_func,
generation_logits_hook_func=generation_logits_hook_func,
)
def _sample(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor],
generation_config: DreamGenerationConfig,
generation_tokens_hook_func,
generation_logits_hook_func
):
output_history = generation_config.output_history
return_dict_in_generate = generation_config.return_dict_in_generate
max_length = generation_config.max_length
mask_token_id = generation_config.mask_token_id
steps = generation_config.steps
eps = generation_config.eps
alg = generation_config.alg
alg_temp = generation_config.alg_temp
temperature = generation_config.temperature
top_p = generation_config.top_p
top_k = generation_config.top_k
rcr = generation_config.rcr
conf_alg = generation_config.conf_alg if rcr else generation_config.alg
# === 写死 RCR 生效窗口:总步数的 1/4 到 3/4(左闭右开 [start, end))===
rcr_start = max(0, steps // 4)
rcr_end = max(rcr_start, min(steps, (3 * steps) // 4))
protect_cur = bool(generation_config.rcr_protect_current_step)
histories = [] if (return_dict_in_generate and output_history) else None
# pad input_ids to max_length
x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
if attention_mask is not None and torch.any(attention_mask == 0.0):
attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
tok_idx = attention_mask.long().cumsum(-1) - 1
tok_idx.masked_fill_(attention_mask == 0, 1)
attention_mask = torch.logical_and(
attention_mask.unsqueeze(1).unsqueeze(-2),
attention_mask.unsqueeze(1).unsqueeze(-1),
)
else:
tok_idx = None
attention_mask = "full"
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
# ==== RCR 状态 ====
if rcr:
init_mask_bool = (x == mask_token_id) # 初始生成区域
init_mask_count = init_mask_bool.sum(dim=1) # [B]
hist_conf = torch.zeros_like(x, dtype=torch.float32, device=x.device) # 历史最大置信度
gen_mask = torch.zeros_like(x, dtype=torch.bool, device=x.device) # 已确认位置
written_step = torch.full_like(x, -1, dtype=torch.int32, device=x.device)
x = generation_tokens_hook_func(None, x, None)
for i in range(steps):
mask_index = (x == mask_token_id)
# 前向 + Dream 的右移对齐
logits = self(x, attention_mask, tok_idx).logits
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
logits = generation_logits_hook_func(i, x, logits)
# 时间步
t = timesteps[i]
s = timesteps[i + 1]
# —— 仅抽出 mask 位置的 logits 并做过滤 ——
mask_logits = logits[mask_index]
if mask_logits.numel() == 0:
x = generation_tokens_hook_func(i, x, logits)
if histories is not None:
histories.append(x.clone())
continue
mask_logits = _apply_top_p_k_temp(mask_logits, temperature, top_p, top_k)
probs = torch.softmax(mask_logits, dim=-1)
# 采样 / 贪心拿到 x0
if temperature and temperature > 0:
try:
x0 = dists.Categorical(probs=probs).sample()
except Exception:
x0 = probs.argmax(dim=-1)
else:
x0 = probs.argmax(dim=-1)
# 统一置信度(与解码一致)
conf_now = _confidence_from_probs(
probs=probs,
chosen_ids=x0 if conf_alg == "maskgit_plus" else None,
mode=conf_alg
).to(torch.float32) # [M]
# ====== 计算当步写入配额 k_t(与 vanilla 一致)======
Mt = mask_index.sum().item()
ratio = (1.0 - (s.item() / t.item())) if i < steps - 1 else 1.0
k_t = int(Mt * ratio)
# —— 写入:top-k_t ——(无论 RCR 窗口与否,先写)
full_conf_now = torch.full((x.size(0), x.size(1)), -1e9, dtype=torch.float32, device=x.device)
full_x0 = torch.full_like(x, mask_token_id, dtype=torch.long)
full_conf_now[mask_index] = conf_now
full_x0[mask_index] = x0
for b in range(x.size(0)):
masked_b = int(mask_index[b].sum().item())
if masked_b == 0 or k_t <= 0:
continue
k_b = min(k_t, masked_b)
_, sel_idx = torch.topk(full_conf_now[b], k=k_b, largest=True)
x[b, sel_idx] = full_x0[b, sel_idx]
if rcr:
gen_mask[b, sel_idx] = True
written_step[b, sel_idx] = i
# 更新历史最大置信度(与解码同定义)
hist_conf[b, sel_idx] = torch.maximum(hist_conf[b, sel_idx], full_conf_now[b, sel_idx])
# —— RCR 窗口外:不回遮,仅跟踪历史;窗口内:执行回遮到目标累计 ——
if rcr and (rcr_start <= i < rcr_end):
for b in range(x.size(0)):
M0 = int(init_mask_count[b].item())
target_cum = M0 if i >= steps - 1 else int(M0 * (1.0 - (s.item() / t.item())))
# 当前累计确认:初始生成区域内的已确认数
C_t = int((gen_mask[b] & init_mask_bool[b]).sum().item())
over = max(0, C_t - target_cum)
if over <= 0:
continue
# 候选:初始区域 ∧ 已确认(可选:排除本步刚写)
cand = torch.where(gen_mask[b] & init_mask_bool[b])[0]
if cand.numel() == 0:
continue
if protect_cur:
mask_old = (written_step[b, cand] < i)
cand = cand[mask_old]
if cand.numel() == 0:
# 全是本步写的,且要求保护,则跳过回遮
continue
over = min(over, int(cand.numel()))
scores = hist_conf[b, cand] # 越大越自信
_, low_local = torch.topk(scores, k=over, largest=False)
low_global = cand[low_local]
# 回遮
x[b, low_global] = mask_token_id
gen_mask[b, low_global] = False
# 历史分数与 written_step 保留
x = generation_tokens_hook_func(i, x, logits)
if histories is not None:
histories.append(x.clone())
if return_dict_in_generate:
return DreamModelOutput(sequences=x, history=histories)
else:
return x
|