File size: 19,797 Bytes
78f65a4
6be7e05
 
 
 
 
 
 
 
 
d0ecd79
 
6be7e05
 
 
 
fd52594
34c9b0b
8c62663
 
fd52594
 
 
 
 
 
 
 
8c62663
fd52594
 
 
 
 
6be7e05
 
78f65a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6be7e05
 
 
 
 
 
 
 
3d20c97
6be7e05
 
 
3d20c97
 
6be7e05
 
34c9b0b
78f65a4
6be7e05
 
78f65a4
 
 
6be7e05
 
78f65a4
d0ecd79
78f65a4
 
 
 
 
 
 
d0ecd79
9384b5d
6be7e05
 
 
 
631ce9b
6be7e05
 
 
 
 
3d20c97
6be7e05
 
3d20c97
6be7e05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78f65a4
 
 
6be7e05
d0ecd79
6be7e05
 
 
 
 
3d20c97
fd52594
6be7e05
 
 
 
 
 
 
 
d0ecd79
6be7e05
 
 
 
78f65a4
6be7e05
 
 
 
fd52594
 
6be7e05
 
d0ecd79
6be7e05
 
78f65a4
6be7e05
 
 
 
fd52594
 
 
6be7e05
 
3d20c97
6be7e05
 
 
 
 
 
 
34c9b0b
6be7e05
 
 
 
 
 
 
 
 
3d20c97
6be7e05
 
fd52594
6be7e05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd52594
6be7e05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34c9b0b
8c62663
d0ecd79
6be7e05
 
78f65a4
34c9b0b
6be7e05
8c62663
d0ecd79
 
 
8c62663
6be7e05
fd52594
34c9b0b
6be7e05
 
 
 
 
d0ecd79
6be7e05
 
fd52594
6be7e05
 
 
 
d0ecd79
6be7e05
 
 
 
 
 
 
 
3d20c97
fd52594
6be7e05
 
 
 
 
 
 
 
 
 
 
 
78f65a4
 
 
 
 
 
 
 
 
6be7e05
 
3d20c97
6be7e05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78f65a4
3d20c97
78f65a4
 
 
 
 
d0ecd79
6be7e05
34c9b0b
6be7e05
8c62663
fd52594
78f65a4
6be7e05
d0ecd79
6be7e05
 
78f65a4
6be7e05
 
d0ecd79
78f65a4
 
 
 
 
 
 
34c9b0b
78f65a4
 
fd52594
78f65a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd52594
6be7e05
 
 
d0ecd79
6be7e05
9384b5d
6be7e05
d0ecd79
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
# coding=utf-8
import warnings
import copy
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.distributions as dists
from torch.nn import functional as F
from transformers import __version__
from transformers.generation.configuration_utils import GenerationConfig
from transformers.utils import ModelOutput, is_torchdynamo_compiling, logging

logger = logging.get_logger(__name__)


def _apply_top_p_k_temp(logits, temperature=0.0, top_p=None, top_k=None):
    if temperature and temperature > 0:
        logits = logits / temperature
    if top_p is not None and top_p < 1:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
        mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
        logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
    if top_k is not None:
        top_k = int(min(top_k, logits.size(-1)))
        if top_k > 0:
            indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
            logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
    return logits


def _confidence_from_probs(
    probs: torch.Tensor,            # [..., V]
    chosen_ids: Optional[torch.Tensor],  # [...]
    mode: str                       # 'entropy' | 'maskgit_plus' | 'topk_margin'
) -> torch.Tensor:
    """返回“越大越自信”的标量分数,与解码一致。"""
    if mode == "entropy":
        eps = 1e-10
        logp = torch.log(probs + eps)
        return -(probs * logp).sum(dim=-1)          # -H(p)
    elif mode == "maskgit_plus":
        assert chosen_ids is not None, "maskgit_plus 需要 chosen_ids"
        return torch.gather(probs, -1, chosen_ids.unsqueeze(-1)).squeeze(-1)  # p(x0)
    elif mode == "topk_margin":
        sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
        return sorted_probs[..., 0] - sorted_probs[..., 1]                    # top1 - top2
    else:
        raise ValueError(f"Unknown conf mode: {mode}")


@dataclass
class DreamModelOutput(ModelOutput):
    sequences: torch.LongTensor = None
    history: Optional[Tuple[torch.FloatTensor]] = None


class DreamGenerationConfig(GenerationConfig):
    def __init__(self, **kwargs):
        # sampling
        self.temperature: float = kwargs.pop("temperature", 0.0)
        self.top_p: Optional[float] = kwargs.pop("top_p", None)
        self.top_k: Optional[int] = kwargs.pop("top_k", None)

        # length
        self.max_length = kwargs.pop("max_length", 20)
        self.max_new_tokens = kwargs.pop("max_new_tokens", None)

        # diffusion
        self.eps: float = kwargs.pop("eps", 1e-3)
        self.steps: int = kwargs.pop("steps", 512)

        # vanilla 的打分算法(rcr=False 时使用)
        self.alg: str = kwargs.pop("alg", 'maskgit_plus')  # 'origin' | 'maskgit_plus' | 'topk_margin' | 'entropy'
        self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)

        # === RCR ===
        self.rcr: bool = kwargs.pop("rcr", False)
        # rcr=True 时用于解码 & 历史分一致的置信度定义
        self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')  # 'maskgit_plus' | 'topk_margin' | 'entropy'
        # 注意:下两项会被 _sample 内部“写死”为 1/4 到 3/4,总是覆盖
        self.rcr_start_step: int = kwargs.pop("rcr_start_step", 0)
        self.rcr_end_step: int = kwargs.pop("rcr_end_step", None) or self.steps
        # 是否保护“本步刚写”的 token 不被回遮
        self.rcr_protect_current_step: bool = kwargs.pop("rcr_protect_current_step", False)

        # outputs
        self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
        self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
        self.output_history: bool = kwargs.pop("output_history", False)

        # special tokens
        self.mask_token_id = kwargs.pop("mask_token_id", None)
        self.pad_token_id = kwargs.pop("pad_token_id", None)
        self.bos_token_id = kwargs.pop("bos_token_id", None)
        self.eos_token_id = kwargs.pop("eos_token_id", None)

        # misc
        self.generation_kwargs = kwargs.pop("generation_kwargs", {})

        # bookkeeping
        self._from_model_config = kwargs.pop("_from_model_config", False)
        self._commit_hash = kwargs.pop("_commit_hash", None)
        self.transformers_version = kwargs.pop("transformers_version", __version__)

        if not self._from_model_config:
            for key, value in kwargs.items():
                try:
                    setattr(self, key, value)
                except AttributeError as err:
                    logger.error(f"Can't set {key} with value {value} for {self}")
                    raise err

        self.validate(is_init=True)

    def validate(self, is_init=False):
        # 简单边界
        self.rcr_start_step = max(0, int(self.rcr_start_step))
        self.rcr_end_step = max(self.rcr_start_step, int(self.rcr_end_step))


class DreamGenerationMixin:
    @staticmethod
    def _expand_inputs_for_generation(
        expand_size: int = 1,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None
    ):
        if expand_size == 1:
            return input_ids, attention_mask
        if input_ids is not None:
            input_ids = input_ids.repeat_interleave(expand_size, dim=0)
        if attention_mask is not None:
            attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
        return input_ids, attention_mask

    def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
        if is_torchdynamo_compiling():
            return
        if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
            warnings.warn(
                f"Using default `max_length` (={generation_config.max_length}). Prefer `max_new_tokens`.",
                UserWarning,
            )
        if input_ids_length >= generation_config.max_length:
            raise ValueError(
                f"Input length is {input_ids_length}, but `max_length` is {generation_config.max_length}. "
                "Increase `max_length` or set `max_new_tokens`."
            )

    def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
        if generation_config.max_new_tokens is not None:
            if not has_default_max_length and generation_config.max_length is not None:
                logger.warning("Both `max_new_tokens` and `max_length` are set. `max_new_tokens` takes precedence.")
            generation_config.max_length = generation_config.max_new_tokens + input_ids_length
        elif has_default_max_length:
            if generation_config.max_length == DreamGenerationConfig().max_length:
                generation_config.max_length = generation_config.max_length + input_ids_length
                mpe = getattr(self.config, "max_position_embeddings", None)
                if mpe is not None:
                    generation_config.max_length = min(generation_config.max_length, mpe)
        return generation_config

    def _prepare_generation_config(self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict) -> DreamGenerationConfig:
        using_model_generation_config = False
        if generation_config is None:
            generation_config = DreamGenerationConfig.from_model_config(self.config)
            using_model_generation_config = True

        if not is_torchdynamo_compiling():
            generation_config = copy.deepcopy(generation_config)
            _ = generation_config.update(**kwargs)
            if not using_model_generation_config:
                if generation_config.bos_token_id is None:
                    generation_config.bos_token_id = self.generation_config.bos_token_id
                if generation_config.eos_token_id is None:
                    generation_config.eos_token_id = self.generation_config.eos_token_id
                if generation_config.pad_token_id is None:
                    generation_config.pad_token_id = self.generation_config.pad_token_id
                if generation_config.mask_token_id is None:
                    generation_config.mask_token_id = self.generation_config.mask_token_id

        return generation_config

    def _prepare_special_tokens(self, generation_config: DreamGenerationConfig, device=None):
        def _tensor_or_none(token, device=None):
            if token is None:
                return token
            device = device if device is not None else self.device
            if isinstance(token, torch.Tensor):
                return token.to(device)
            return torch.tensor(token, device=device, dtype=torch.long)

        bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
        eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
        pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
        mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)

        if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
            eos_token_tensor = eos_token_tensor.unsqueeze(0)
        if pad_token_tensor is None and eos_token_tensor is not None:
            pad_token_tensor = eos_token_tensor[0]
            logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")

        generation_config._bos_token_tensor = bos_token_tensor
        generation_config._eos_token_tensor = eos_token_tensor
        generation_config._pad_token_tensor = pad_token_tensor
        generation_config._mask_token_tensor = mask_token_tensor

    @torch.no_grad()
    def diffusion_generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        generation_config: Optional[DreamGenerationConfig] = None,
        **kwargs,
    ):
        generation_config = self._prepare_generation_config(generation_config, **kwargs)
        generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
        generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)

        assert inputs is not None
        input_ids = inputs
        device = input_ids.device
        attention_mask = kwargs.pop("attention_mask", None)
        self._prepare_special_tokens(generation_config, device=device)

        input_ids_length = input_ids.shape[-1]
        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
        generation_config = self._prepare_generated_length(
            generation_config=generation_config,
            has_default_max_length=has_default_max_length,
            input_ids_length=input_ids_length,
        )

        self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

        if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
            warnings.warn(
                "You are calling .generate() with `input_ids` on a device different from the model.",
                UserWarning,
            )
        if (
            hasattr(generation_config, "pad_token_id")
            and torch.any(input_ids == generation_config.pad_token_id)
            and attention_mask is None
        ):
            warnings.warn(
                "Padding detected but no attention mask was passed. Set `attention_mask` for correct generation.",
                UserWarning,
            )

        input_ids, attention_mask = self._expand_inputs_for_generation(
            expand_size=generation_config.num_return_sequences,
            input_ids=input_ids,
            attention_mask=attention_mask,
        )

        return self._sample(
            input_ids,
            attention_mask=attention_mask,
            generation_config=generation_config,
            generation_tokens_hook_func=generation_tokens_hook_func,
            generation_logits_hook_func=generation_logits_hook_func,
        )

    def _sample(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.LongTensor],
        generation_config: DreamGenerationConfig,
        generation_tokens_hook_func,
        generation_logits_hook_func
    ):
        output_history = generation_config.output_history
        return_dict_in_generate = generation_config.return_dict_in_generate
        max_length = generation_config.max_length
        mask_token_id = generation_config.mask_token_id
        steps = generation_config.steps
        eps = generation_config.eps
        alg = generation_config.alg
        alg_temp = generation_config.alg_temp
        temperature = generation_config.temperature
        top_p = generation_config.top_p
        top_k = generation_config.top_k

        rcr = generation_config.rcr
        conf_alg = generation_config.conf_alg if rcr else generation_config.alg

        # === 写死 RCR 生效窗口:总步数的 1/4 到 3/4(左闭右开 [start, end))===
        rcr_start = max(0, steps // 4)
        rcr_end = max(rcr_start, min(steps, (3 * steps) // 4))

        protect_cur = bool(generation_config.rcr_protect_current_step)

        histories = [] if (return_dict_in_generate and output_history) else None

        # pad input_ids to max_length
        x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)

        if attention_mask is not None and torch.any(attention_mask == 0.0):
            attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
            tok_idx = attention_mask.long().cumsum(-1) - 1
            tok_idx.masked_fill_(attention_mask == 0, 1)
            attention_mask = torch.logical_and(
                attention_mask.unsqueeze(1).unsqueeze(-2),
                attention_mask.unsqueeze(1).unsqueeze(-1),
            )
        else:
            tok_idx = None
            attention_mask = "full"

        timesteps = torch.linspace(1, eps, steps + 1, device=x.device)

        # ==== RCR 状态 ====
        if rcr:
            init_mask_bool = (x == mask_token_id)               # 初始生成区域
            init_mask_count = init_mask_bool.sum(dim=1)         # [B]
            hist_conf = torch.zeros_like(x, dtype=torch.float32, device=x.device)  # 历史最大置信度
            gen_mask = torch.zeros_like(x, dtype=torch.bool, device=x.device)      # 已确认位置
            written_step = torch.full_like(x, -1, dtype=torch.int32, device=x.device)

        x = generation_tokens_hook_func(None, x, None)

        for i in range(steps):
            mask_index = (x == mask_token_id)

            # 前向 + Dream 的右移对齐
            logits = self(x, attention_mask, tok_idx).logits
            logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
            logits = generation_logits_hook_func(i, x, logits)

            # 时间步
            t = timesteps[i]
            s = timesteps[i + 1]

            # —— 仅抽出 mask 位置的 logits 并做过滤 ——
            mask_logits = logits[mask_index]
            if mask_logits.numel() == 0:
                x = generation_tokens_hook_func(i, x, logits)
                if histories is not None:
                    histories.append(x.clone())
                continue

            mask_logits = _apply_top_p_k_temp(mask_logits, temperature, top_p, top_k)
            probs = torch.softmax(mask_logits, dim=-1)

            # 采样 / 贪心拿到 x0
            if temperature and temperature > 0:
                try:
                    x0 = dists.Categorical(probs=probs).sample()
                except Exception:
                    x0 = probs.argmax(dim=-1)
            else:
                x0 = probs.argmax(dim=-1)

            # 统一置信度(与解码一致)
            conf_now = _confidence_from_probs(
                probs=probs,
                chosen_ids=x0 if conf_alg == "maskgit_plus" else None,
                mode=conf_alg
            ).to(torch.float32)  # [M]

            # ====== 计算当步写入配额 k_t(与 vanilla 一致)======
            Mt = mask_index.sum().item()
            ratio = (1.0 - (s.item() / t.item())) if i < steps - 1 else 1.0
            k_t = int(Mt * ratio)

            # —— 写入:top-k_t ——(无论 RCR 窗口与否,先写)
            full_conf_now = torch.full((x.size(0), x.size(1)), -1e9, dtype=torch.float32, device=x.device)
            full_x0 = torch.full_like(x, mask_token_id, dtype=torch.long)
            full_conf_now[mask_index] = conf_now
            full_x0[mask_index] = x0

            for b in range(x.size(0)):
                masked_b = int(mask_index[b].sum().item())
                if masked_b == 0 or k_t <= 0:
                    continue
                k_b = min(k_t, masked_b)
                _, sel_idx = torch.topk(full_conf_now[b], k=k_b, largest=True)
                x[b, sel_idx] = full_x0[b, sel_idx]
                if rcr:
                    gen_mask[b, sel_idx] = True
                    written_step[b, sel_idx] = i
                    # 更新历史最大置信度(与解码同定义)
                    hist_conf[b, sel_idx] = torch.maximum(hist_conf[b, sel_idx], full_conf_now[b, sel_idx])

            # —— RCR 窗口外:不回遮,仅跟踪历史;窗口内:执行回遮到目标累计 ——
            if rcr and (rcr_start <= i < rcr_end):
                for b in range(x.size(0)):
                    M0 = int(init_mask_count[b].item())
                    target_cum = M0 if i >= steps - 1 else int(M0 * (1.0 - (s.item() / t.item())))
                    # 当前累计确认:初始生成区域内的已确认数
                    C_t = int((gen_mask[b] & init_mask_bool[b]).sum().item())
                    over = max(0, C_t - target_cum)
                    if over <= 0:
                        continue

                    # 候选:初始区域 ∧ 已确认(可选:排除本步刚写)
                    cand = torch.where(gen_mask[b] & init_mask_bool[b])[0]
                    if cand.numel() == 0:
                        continue
                    if protect_cur:
                        mask_old = (written_step[b, cand] < i)
                        cand = cand[mask_old]
                        if cand.numel() == 0:
                            # 全是本步写的,且要求保护,则跳过回遮
                            continue

                    over = min(over, int(cand.numel()))
                    scores = hist_conf[b, cand]  # 越大越自信
                    _, low_local = torch.topk(scores, k=over, largest=False)
                    low_global = cand[low_local]

                    # 回遮
                    x[b, low_global] = mask_token_id
                    gen_mask[b, low_global] = False
                    # 历史分数与 written_step 保留

            x = generation_tokens_hook_func(i, x, logits)
            if histories is not None:
                histories.append(x.clone())

        if return_dict_in_generate:
            return DreamModelOutput(sequences=x, history=histories)
        else:
            return x