# coding=utf-8 import warnings import copy from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union import torch import torch.distributions as dists from torch.nn import functional as F from transformers import __version__ from transformers.generation.configuration_utils import GenerationConfig from transformers.utils import ModelOutput, is_torchdynamo_compiling, logging logger = logging.get_logger(__name__) def _apply_top_p_k_temp(logits, temperature=0.0, top_p=None, top_k=None): if temperature and temperature > 0: logits = logits / temperature if top_p is not None and top_p < 1: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) if top_k is not None: top_k = int(min(top_k, logits.size(-1))) if top_k > 0: indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) return logits def _confidence_from_probs( probs: torch.Tensor, # [..., V] chosen_ids: Optional[torch.Tensor], # [...] mode: str # 'entropy' | 'maskgit_plus' | 'topk_margin' ) -> torch.Tensor: """返回“越大越自信”的标量分数,与解码一致。""" if mode == "entropy": eps = 1e-10 logp = torch.log(probs + eps) return -(probs * logp).sum(dim=-1) # -H(p) elif mode == "maskgit_plus": assert chosen_ids is not None, "maskgit_plus 需要 chosen_ids" return torch.gather(probs, -1, chosen_ids.unsqueeze(-1)).squeeze(-1) # p(x0) elif mode == "topk_margin": sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) return sorted_probs[..., 0] - sorted_probs[..., 1] # top1 - top2 else: raise ValueError(f"Unknown conf mode: {mode}") @dataclass class DreamModelOutput(ModelOutput): sequences: torch.LongTensor = None history: Optional[Tuple[torch.FloatTensor]] = None class DreamGenerationConfig(GenerationConfig): def __init__(self, **kwargs): # sampling self.temperature: float = kwargs.pop("temperature", 0.0) self.top_p: Optional[float] = kwargs.pop("top_p", None) self.top_k: Optional[int] = kwargs.pop("top_k", None) # length self.max_length = kwargs.pop("max_length", 20) self.max_new_tokens = kwargs.pop("max_new_tokens", None) # diffusion self.eps: float = kwargs.pop("eps", 1e-3) self.steps: int = kwargs.pop("steps", 512) # vanilla 的打分算法(rcr=False 时使用) self.alg: str = kwargs.pop("alg", 'maskgit_plus') # 'origin' | 'maskgit_plus' | 'topk_margin' | 'entropy' self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None) # === RCR === self.rcr: bool = kwargs.pop("rcr", False) # rcr=True 时用于解码 & 历史分一致的置信度定义 self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus') # 'maskgit_plus' | 'topk_margin' | 'entropy' # 注意:下两项会被 _sample 内部“写死”为 1/4 到 3/4,总是覆盖 self.rcr_start_step: int = kwargs.pop("rcr_start_step", 0) self.rcr_end_step: int = kwargs.pop("rcr_end_step", None) or self.steps # 是否保护“本步刚写”的 token 不被回遮 self.rcr_protect_current_step: bool = kwargs.pop("rcr_protect_current_step", False) # outputs self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1) self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False) self.output_history: bool = kwargs.pop("output_history", False) # special tokens self.mask_token_id = kwargs.pop("mask_token_id", None) self.pad_token_id = kwargs.pop("pad_token_id", None) self.bos_token_id = kwargs.pop("bos_token_id", None) self.eos_token_id = kwargs.pop("eos_token_id", None) # misc self.generation_kwargs = kwargs.pop("generation_kwargs", {}) # bookkeeping self._from_model_config = kwargs.pop("_from_model_config", False) self._commit_hash = kwargs.pop("_commit_hash", None) self.transformers_version = kwargs.pop("transformers_version", __version__) if not self._from_model_config: for key, value in kwargs.items(): try: setattr(self, key, value) except AttributeError as err: logger.error(f"Can't set {key} with value {value} for {self}") raise err self.validate(is_init=True) def validate(self, is_init=False): # 简单边界 self.rcr_start_step = max(0, int(self.rcr_start_step)) self.rcr_end_step = max(self.rcr_start_step, int(self.rcr_end_step)) class DreamGenerationMixin: @staticmethod def _expand_inputs_for_generation( expand_size: int = 1, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None ): if expand_size == 1: return input_ids, attention_mask if input_ids is not None: input_ids = input_ids.repeat_interleave(expand_size, dim=0) if attention_mask is not None: attention_mask = attention_mask.repeat_interleave(expand_size, dim=0) return input_ids, attention_mask def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): if is_torchdynamo_compiling(): return if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: warnings.warn( f"Using default `max_length` (={generation_config.max_length}). Prefer `max_new_tokens`.", UserWarning, ) if input_ids_length >= generation_config.max_length: raise ValueError( f"Input length is {input_ids_length}, but `max_length` is {generation_config.max_length}. " "Increase `max_length` or set `max_new_tokens`." ) def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length): if generation_config.max_new_tokens is not None: if not has_default_max_length and generation_config.max_length is not None: logger.warning("Both `max_new_tokens` and `max_length` are set. `max_new_tokens` takes precedence.") generation_config.max_length = generation_config.max_new_tokens + input_ids_length elif has_default_max_length: if generation_config.max_length == DreamGenerationConfig().max_length: generation_config.max_length = generation_config.max_length + input_ids_length mpe = getattr(self.config, "max_position_embeddings", None) if mpe is not None: generation_config.max_length = min(generation_config.max_length, mpe) return generation_config def _prepare_generation_config(self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict) -> DreamGenerationConfig: using_model_generation_config = False if generation_config is None: generation_config = DreamGenerationConfig.from_model_config(self.config) using_model_generation_config = True if not is_torchdynamo_compiling(): generation_config = copy.deepcopy(generation_config) _ = generation_config.update(**kwargs) if not using_model_generation_config: if generation_config.bos_token_id is None: generation_config.bos_token_id = self.generation_config.bos_token_id if generation_config.eos_token_id is None: generation_config.eos_token_id = self.generation_config.eos_token_id if generation_config.pad_token_id is None: generation_config.pad_token_id = self.generation_config.pad_token_id if generation_config.mask_token_id is None: generation_config.mask_token_id = self.generation_config.mask_token_id return generation_config def _prepare_special_tokens(self, generation_config: DreamGenerationConfig, device=None): def _tensor_or_none(token, device=None): if token is None: return token device = device if device is not None else self.device if isinstance(token, torch.Tensor): return token.to(device) return torch.tensor(token, device=device, dtype=torch.long) bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device) eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device) pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device) mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device) if eos_token_tensor is not None and eos_token_tensor.ndim == 0: eos_token_tensor = eos_token_tensor.unsqueeze(0) if pad_token_tensor is None and eos_token_tensor is not None: pad_token_tensor = eos_token_tensor[0] logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") generation_config._bos_token_tensor = bos_token_tensor generation_config._eos_token_tensor = eos_token_tensor generation_config._pad_token_tensor = pad_token_tensor generation_config._mask_token_tensor = mask_token_tensor @torch.no_grad() def diffusion_generate( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[DreamGenerationConfig] = None, **kwargs, ): generation_config = self._prepare_generation_config(generation_config, **kwargs) generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x) generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits) assert inputs is not None input_ids = inputs device = input_ids.device attention_mask = kwargs.pop("attention_mask", None) self._prepare_special_tokens(generation_config, device=device) input_ids_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None generation_config = self._prepare_generated_length( generation_config=generation_config, has_default_max_length=has_default_max_length, input_ids_length=input_ids_length, ) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type: warnings.warn( "You are calling .generate() with `input_ids` on a device different from the model.", UserWarning, ) if ( hasattr(generation_config, "pad_token_id") and torch.any(input_ids == generation_config.pad_token_id) and attention_mask is None ): warnings.warn( "Padding detected but no attention mask was passed. Set `attention_mask` for correct generation.", UserWarning, ) input_ids, attention_mask = self._expand_inputs_for_generation( expand_size=generation_config.num_return_sequences, input_ids=input_ids, attention_mask=attention_mask, ) return self._sample( input_ids, attention_mask=attention_mask, generation_config=generation_config, generation_tokens_hook_func=generation_tokens_hook_func, generation_logits_hook_func=generation_logits_hook_func, ) def _sample( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor], generation_config: DreamGenerationConfig, generation_tokens_hook_func, generation_logits_hook_func ): output_history = generation_config.output_history return_dict_in_generate = generation_config.return_dict_in_generate max_length = generation_config.max_length mask_token_id = generation_config.mask_token_id steps = generation_config.steps eps = generation_config.eps alg = generation_config.alg alg_temp = generation_config.alg_temp temperature = generation_config.temperature top_p = generation_config.top_p top_k = generation_config.top_k rcr = generation_config.rcr conf_alg = generation_config.conf_alg if rcr else generation_config.alg # === 写死 RCR 生效窗口:总步数的 1/4 到 3/4(左闭右开 [start, end))=== rcr_start = max(0, steps // 4) rcr_end = max(rcr_start, min(steps, (3 * steps) // 4)) protect_cur = bool(generation_config.rcr_protect_current_step) histories = [] if (return_dict_in_generate and output_history) else None # pad input_ids to max_length x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id) if attention_mask is not None and torch.any(attention_mask == 0.0): attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0) tok_idx = attention_mask.long().cumsum(-1) - 1 tok_idx.masked_fill_(attention_mask == 0, 1) attention_mask = torch.logical_and( attention_mask.unsqueeze(1).unsqueeze(-2), attention_mask.unsqueeze(1).unsqueeze(-1), ) else: tok_idx = None attention_mask = "full" timesteps = torch.linspace(1, eps, steps + 1, device=x.device) # ==== RCR 状态 ==== if rcr: init_mask_bool = (x == mask_token_id) # 初始生成区域 init_mask_count = init_mask_bool.sum(dim=1) # [B] hist_conf = torch.zeros_like(x, dtype=torch.float32, device=x.device) # 历史最大置信度 gen_mask = torch.zeros_like(x, dtype=torch.bool, device=x.device) # 已确认位置 written_step = torch.full_like(x, -1, dtype=torch.int32, device=x.device) x = generation_tokens_hook_func(None, x, None) for i in range(steps): mask_index = (x == mask_token_id) # 前向 + Dream 的右移对齐 logits = self(x, attention_mask, tok_idx).logits logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) logits = generation_logits_hook_func(i, x, logits) # 时间步 t = timesteps[i] s = timesteps[i + 1] # —— 仅抽出 mask 位置的 logits 并做过滤 —— mask_logits = logits[mask_index] if mask_logits.numel() == 0: x = generation_tokens_hook_func(i, x, logits) if histories is not None: histories.append(x.clone()) continue mask_logits = _apply_top_p_k_temp(mask_logits, temperature, top_p, top_k) probs = torch.softmax(mask_logits, dim=-1) # 采样 / 贪心拿到 x0 if temperature and temperature > 0: try: x0 = dists.Categorical(probs=probs).sample() except Exception: x0 = probs.argmax(dim=-1) else: x0 = probs.argmax(dim=-1) # 统一置信度(与解码一致) conf_now = _confidence_from_probs( probs=probs, chosen_ids=x0 if conf_alg == "maskgit_plus" else None, mode=conf_alg ).to(torch.float32) # [M] # ====== 计算当步写入配额 k_t(与 vanilla 一致)====== Mt = mask_index.sum().item() ratio = (1.0 - (s.item() / t.item())) if i < steps - 1 else 1.0 k_t = int(Mt * ratio) # —— 写入:top-k_t ——(无论 RCR 窗口与否,先写) full_conf_now = torch.full((x.size(0), x.size(1)), -1e9, dtype=torch.float32, device=x.device) full_x0 = torch.full_like(x, mask_token_id, dtype=torch.long) full_conf_now[mask_index] = conf_now full_x0[mask_index] = x0 for b in range(x.size(0)): masked_b = int(mask_index[b].sum().item()) if masked_b == 0 or k_t <= 0: continue k_b = min(k_t, masked_b) _, sel_idx = torch.topk(full_conf_now[b], k=k_b, largest=True) x[b, sel_idx] = full_x0[b, sel_idx] if rcr: gen_mask[b, sel_idx] = True written_step[b, sel_idx] = i # 更新历史最大置信度(与解码同定义) hist_conf[b, sel_idx] = torch.maximum(hist_conf[b, sel_idx], full_conf_now[b, sel_idx]) # —— RCR 窗口外:不回遮,仅跟踪历史;窗口内:执行回遮到目标累计 —— if rcr and (rcr_start <= i < rcr_end): for b in range(x.size(0)): M0 = int(init_mask_count[b].item()) target_cum = M0 if i >= steps - 1 else int(M0 * (1.0 - (s.item() / t.item()))) # 当前累计确认:初始生成区域内的已确认数 C_t = int((gen_mask[b] & init_mask_bool[b]).sum().item()) over = max(0, C_t - target_cum) if over <= 0: continue # 候选:初始区域 ∧ 已确认(可选:排除本步刚写) cand = torch.where(gen_mask[b] & init_mask_bool[b])[0] if cand.numel() == 0: continue if protect_cur: mask_old = (written_step[b, cand] < i) cand = cand[mask_old] if cand.numel() == 0: # 全是本步写的,且要求保护,则跳过回遮 continue over = min(over, int(cand.numel())) scores = hist_conf[b, cand] # 越大越自信 _, low_local = torch.topk(scores, k=over, largest=False) low_global = cand[low_local] # 回遮 x[b, low_global] = mask_token_id gen_mask[b, low_global] = False # 历史分数与 written_step 保留 x = generation_tokens_hook_func(i, x, logits) if histories is not None: histories.append(x.clone()) if return_dict_in_generate: return DreamModelOutput(sequences=x, history=histories) else: return x