Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,664 Bytes
26db3f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# math functions for sampling schedule
import math
from typing import Callable, Literal
import torch
class TimestepDistUtils:
@staticmethod
def t_shift(mu: float, sigma: float, t: torch.Tensor):
"""
see eq.(12) of https://arxiv.org/abs/2506.15742 Black Forest Labs (2025)
t' = \frac{e^{\mu}}{e^{\mu} + (1/t - 1)^{\sigma}}
"""
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
@staticmethod
def lerp_mu( # qwen params
seq_len,
min_seq_len: int = 256,
max_seq_len: int = 8192,
min_mu: float = 0.5,
max_mu: float = 0.9,
train_dist: str = "linear",
):
"""
Resolution-dependent shifting of timestep schedules
from Esser et al. https://arxiv.org/abs/2403.03206
updated with default params for Qwen
"""
m = (max_mu - min_mu) / (max_seq_len - min_seq_len)
b = min_mu - m * min_seq_len
mu = seq_len * m + b
return mu
@staticmethod
def logit_normal(t, mu=0.0, sigma=1.0):
"""
Logit normal PDF, as in logistic(randn(mu, sigma))
"""
pdf = torch.zeros_like(t)
z = (torch.logit(t) - mu) / sigma
coef = 1.0 / (sigma * torch.sqrt(torch.tensor(2.0 * torch.pi)))
pdf = coef * torch.exp(-0.5 * z**2) / (t * (1.0 - t))
return pdf
@staticmethod
def scaled_clipped_gaussian(t):
"""
Heuristic distribution for gaussian wuth mu = 0.5 and sigma=0.5,
clipped to [0,1], with int_0^1dt =1.0
"""
y = torch.exp(-2 * (t - 0.5) ** 2)
y = (y - 0.606) * 4.02
return y
@staticmethod
def get_seq_len(latents):
if latents.dim() == 4 or latents.dim() == 5:
h,w = latents.shape[-2:]
seq_len = (h//2)*(w//2)
elif latents.dim() == 3:
seq_len = latents.shape[1] # [B, L=h*w, C]
else:
raise ValueError(f"{latents.dim()=} not in 3,4,5")
return seq_len
def __init__(
self,
min_seq_len=256,
max_seq_len=8192,
min_mu=0.5,
max_mu=0.9,
train_dist:Literal["logit-normal", "linear"]="linear",
train_shift:bool=True,
inference_dist:Literal["logit-normal", "linear"]="linear",
inference_shift:bool=True,
static_mu:float|None=None,
loss_weight_dist: Literal["scaled_clipped_gaussian", "logit-normal"] | None = None,
):
self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len
self.min_mu = min_mu
self.max_mu = max_mu
self.train_dist = train_dist
self.train_shift = train_shift
self.inference_dist = inference_dist
self.inference_shift = inference_shift
self.static_mu = static_mu
self.loss_weight_dist = loss_weight_dist
def lin_t_to_dist(self, t, seq_len=None):
if self.train_dist == "logit-normal":
t = self.logit_normal_pdf(t)
elif self.train_dist == "linear":
pass
else:
raise ValueError()
if self.train_shift:
if self.static_mu:
mu = self.static_mu
elif seq_len:
mu = self.lerp_mu(seq_len, self.min_seq_len, self.max_seq_len, self.min_mu, self.max_mu)
else:
raise ValueError()
t = self.t_shift(mu, 1.0, t)
return t
def get_train_t(self, size, seq_len=None):
t = torch.rand(size)
t = self.lin_t_to_dist(t, seq_len=seq_len)
return t
def get_loss_weighting(self, t):
if self.loss_weight_dist == "scaled_clipped_gaussian":
w = self.scaled_clipped_gaussian(t)
elif self.loss_weight_dist == "logit-normal":
w = self.logit_normal_pdf(t)
elif self.loss_weight_dist is None:
w = torch.ones_like(t)
else:
raise ValueError()
return w
def get_inference_t(self, steps, strength=1.0, seq_len=None, clip_by_strength=True):
if clip_by_strength:
true_steps = max(1, int(strength * steps)) + 1
else:
true_steps = max(1, steps) + 1
t = torch.linspace(strength, 0.0, true_steps)
t = self.lin_t_to_dist(t, seq_len=seq_len)
return t
def inference_ode_step(self, noise_pred: torch.Tensor, latents: torch.Tensor, index: int, t_schedule: torch.Tensor):
t = t_schedule[index]
t_next = t_schedule[index + 1]
d_t = t_next - t
latents = latents + d_t * noise_pred
return latents
|