File size: 4,664 Bytes
26db3f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147



# math functions for sampling schedule
import math
from typing import Callable, Literal

import torch



class TimestepDistUtils:
    
    @staticmethod
    def t_shift(mu: float, sigma: float, t: torch.Tensor):
        """
            see eq.(12) of https://arxiv.org/abs/2506.15742 Black Forest Labs (2025)
            t' = \frac{e^{\mu}}{e^{\mu} + (1/t - 1)^{\sigma}}
        """
        return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)

    @staticmethod
    def lerp_mu( # qwen params
        seq_len,
        min_seq_len: int = 256,
        max_seq_len: int = 8192,
        min_mu: float = 0.5, 
        max_mu: float = 0.9,
        train_dist: str = "linear",
    ):
        """
        Resolution-dependent shifting of timestep schedules
        from Esser et al. https://arxiv.org/abs/2403.03206 
        updated with default params for Qwen
        """
        m = (max_mu - min_mu) / (max_seq_len - min_seq_len)
        b = min_mu - m * min_seq_len
        mu = seq_len * m + b
        return mu

    @staticmethod
    def logit_normal(t, mu=0.0, sigma=1.0):
        """
            Logit normal PDF, as in logistic(randn(mu, sigma))
        """
        pdf = torch.zeros_like(t)
        z = (torch.logit(t) - mu) / sigma
        coef = 1.0 / (sigma * torch.sqrt(torch.tensor(2.0 * torch.pi)))
        pdf = coef * torch.exp(-0.5 * z**2) / (t * (1.0 - t))
        return pdf
    
    @staticmethod
    def scaled_clipped_gaussian(t):
        """
            Heuristic distribution for gaussian wuth mu = 0.5 and sigma=0.5, 
            clipped to [0,1], with int_0^1dt =1.0
        """
        y = torch.exp(-2 * (t - 0.5) ** 2)
        y = (y - 0.606) * 4.02
        return y
    
    @staticmethod
    def get_seq_len(latents):
        if latents.dim() == 4 or latents.dim() == 5:
            h,w = latents.shape[-2:]
            seq_len = (h//2)*(w//2)
        elif latents.dim() == 3:
            seq_len = latents.shape[1] # [B, L=h*w, C]
        else:
            raise ValueError(f"{latents.dim()=} not in 3,4,5")
        return seq_len

    def __init__(
        self,
        min_seq_len=256,
        max_seq_len=8192,
        min_mu=0.5,
        max_mu=0.9,
        train_dist:Literal["logit-normal", "linear"]="linear",
        train_shift:bool=True,
        inference_dist:Literal["logit-normal", "linear"]="linear",
        inference_shift:bool=True,
        static_mu:float|None=None,
        loss_weight_dist: Literal["scaled_clipped_gaussian", "logit-normal"] | None = None,
    ):
        self.min_seq_len = min_seq_len
        self.max_seq_len = max_seq_len
        self.min_mu = min_mu
        self.max_mu = max_mu
        self.train_dist = train_dist
        self.train_shift = train_shift
        self.inference_dist = inference_dist
        self.inference_shift = inference_shift
        self.static_mu = static_mu
        self.loss_weight_dist = loss_weight_dist

    def lin_t_to_dist(self, t, seq_len=None):
        if self.train_dist == "logit-normal":
            t = self.logit_normal_pdf(t)
        elif self.train_dist == "linear":
            pass
        else:
            raise ValueError()
    
        if self.train_shift:
            if self.static_mu:
                mu = self.static_mu
            elif seq_len:
                mu = self.lerp_mu(seq_len, self.min_seq_len, self.max_seq_len, self.min_mu, self.max_mu)
            else:
                raise ValueError()
            t = self.t_shift(mu, 1.0, t)
        return t

    def get_train_t(self, size, seq_len=None):
        t = torch.rand(size)
        t = self.lin_t_to_dist(t, seq_len=seq_len)
        return t

    def get_loss_weighting(self, t):
        if self.loss_weight_dist == "scaled_clipped_gaussian":
            w = self.scaled_clipped_gaussian(t)
        elif self.loss_weight_dist == "logit-normal":
            w = self.logit_normal_pdf(t)
        elif self.loss_weight_dist is None:
            w = torch.ones_like(t)
        else:
            raise ValueError()
        return w


    def get_inference_t(self, steps, strength=1.0, seq_len=None, clip_by_strength=True):
        if clip_by_strength:
            true_steps = max(1, int(strength * steps)) + 1
        else:
            true_steps = max(1, steps) + 1
        t = torch.linspace(strength, 0.0, true_steps)
        t = self.lin_t_to_dist(t, seq_len=seq_len)
        return t

    def inference_ode_step(self, noise_pred: torch.Tensor, latents: torch.Tensor, index: int, t_schedule: torch.Tensor):
        t = t_schedule[index]
        t_next = t_schedule[index + 1]
        d_t = t_next - t
        latents = latents + d_t * noise_pred
        return latents