from typing import Callable import torch import torch.nn.functional as F from copy import deepcopy from .transports import Transport from tim.models.utils.funcs import expand_t_like_x def mean_flat(x): """ Take the mean over all non-batch dimensions. """ return torch.mean(x, dim=list(range(1, len(x.size())))) def sum_flat(x): """ Take the mean over all non-batch dimensions. """ return torch.sum(x, dim=list(range(1, len(x.size())))) class TransitionSchedule: def __init__( self, transport: Transport, diffusion_ratio: float = 0.0, consistency_ratio: float = 0.0, derivative_type: str = "dde", differential_epsilon: float = 0.005, weight_t_and_r: bool = True, weight_time_type: str = "constant", weight_time_tangent: bool = False, weight_time_sigmoid: bool = False, ): self.transport = transport self.diffusion_ratio = diffusion_ratio self.consistency_ratio = consistency_ratio self.derivative_type = derivative_type self.differential_epsilon = differential_epsilon self.weight_t_and_r = weight_t_and_r self.weight_time_type = weight_time_type self.weight_time_tangent = weight_time_tangent self.weight_time_sigmoid = weight_time_sigmoid def sample_t_and_r(self, batch_size, dtype, device): t_1 = self.transport.sample_t(batch_size=batch_size, dtype=dtype, device=device) t_2 = self.transport.sample_t(batch_size=batch_size, dtype=dtype, device=device) # t is the larger one, and r is the smaller one t = torch.maximum(t_1, t_2) r = torch.minimum(t_1, t_2) # some samples with t=r, corresponding to diffusion training n_diffusion = round(self.diffusion_ratio * len(t)) r[:n_diffusion] = t[:n_diffusion] # some samples with r=0, corresponding to consistency training n_consistency = round(self.consistency_ratio * len(t)) if n_consistency != 0: r[-n_consistency:] = self.transport.T_min return t, r, n_diffusion def prepare_input(self, batch_size, x, z): # sample timestep according to log-normal distribution of sigmas following EDM t, r, n_diffusion = self.sample_t_and_r( batch_size=batch_size, dtype=x.dtype, device=x.device ) # reshape (B, ) -> (B, 1, 1, 1) t, r = expand_t_like_x(t, x), expand_t_like_x(r, x) # prepere inputs alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.transport.interpolant(t) x_t = alpha_t * x + sigma_t * z v_t = d_alpha_t * x + d_sigma_t * z return x_t, v_t, t, r, n_diffusion def model_forward(self, model, x_t, t, r, model_kwargs, rng_state): # model_input t_input = self.transport.c_noise(t.flatten()) r_input = self.transport.c_noise(r.flatten()) # model_output torch.cuda.set_rng_state(rng_state) model_output = model(x_t, t_input, r_input, **model_kwargs) return model_output @torch.no_grad() def jvp_derivative( self, model, x_t, v_t, t, r, model_kwargs, rng_state, n_diffusion ): if n_diffusion == x_t.size(0): return 0 _dF_dv_dt = torch.zeros_like(x_t) # only calculate the dF_dv_dt when t!=r x_t, v_t, t, r = ( x_t[n_diffusion:], v_t[n_diffusion:], t[n_diffusion:], r[n_diffusion:], ) for k, v in model_kwargs.items(): if type(v) == torch.Tensor: model_kwargs[k] = model_kwargs[k][n_diffusion:] model_kwargs["return_zs"] = False def model_jvp(x_t, t, r): model_kwargs["attn_type"] = "vanilla_attn" model_kwargs["jvp"] = True t_input = self.transport.c_noise(t.flatten()) r_input = self.transport.c_noise(r.flatten()) return model(x_t, t_input, r_input, **model_kwargs) torch.cuda.set_rng_state(rng_state) F_pred, dF_dv_dt = torch.func.jvp( lambda x_t, t, r: model_jvp(x_t, t, r), (x_t, t, r), (v_t, torch.ones_like(t), torch.zeros_like(r)), ) _dF_dv_dt[n_diffusion:] = dF_dv_dt return _dF_dv_dt @torch.no_grad() def dde_derivative(self, model, x, z, t, r, model_kwargs, rng_state, n_diffusion): if n_diffusion == x.size(0): return 0 _dF_dv_dt = torch.zeros_like(x) # only calculate the dF_dv_dt when t!=r x, z, t, r = x[n_diffusion:], z[n_diffusion:], t[n_diffusion:], r[n_diffusion:] for k, v in model_kwargs.items(): if type(v) == torch.Tensor: model_kwargs[k] = model_kwargs[k][n_diffusion:] model_kwargs["return_zs"] = False model_kwargs["jvp"] = True def xfunc(t): alpha_t, sigma_t, _, _ = self.transport.interpolant(t) x_t = alpha_t * x + sigma_t * z return self.model_forward(model, x_t, t, r, model_kwargs, rng_state) epsilon = self.differential_epsilon fc1_dt = 1 / (2 * epsilon) dF_dv_dt = xfunc(t + epsilon) * fc1_dt - xfunc(t - epsilon) * fc1_dt _dF_dv_dt[n_diffusion:] = dF_dv_dt return _dF_dv_dt def get_enhanced_target(self, model, x_t, t, model_kwargs, null_kwargs, rng_state): with torch.no_grad(): t_input = self.transport.c_noise(t.flatten()) if self.transport.w_cond > 0: F_t_cond = self.model_forward( model, x_t, t_input, t_input, model_kwargs, rng_state ) else: F_t_cond = 0 F_t_uncond = self.model_forward( model, x_t, t_input, t_input, null_kwargs, rng_state ) return F_t_cond, F_t_uncond def time_weighting(self, t, r, n_diffusion): if self.weight_time_tangent: t, r = torch.tan(t), torch.tan(r) elif self.weight_time_sigmoid: t, r = t / (1 - t), r / (1 - r) if self.weight_t_and_r: delta_t = (t - r).flatten() else: delta_t = t.flatten() if self.weight_time_type == "constant": weight = torch.ones_like(delta_t) elif self.weight_time_type == "reciprocal": weight = 1 / (delta_t + self.transport.sigma_d) elif self.weight_time_type == "sqrt": weight = 1 / (delta_t + self.transport.sigma_d).sqrt() elif self.weight_time_type == "square": weight = 1 / (delta_t + self.transport.sigma_d) ** 2 elif self.weight_time_type == "Soft-Min-SNR": weight = 1 / (delta_t**2 + self.transport.sigma_d**2) else: raise NotImplementedError weight[:n_diffusion] = 1.0 return weight def adaptive_weighting(self, loss, eps=10e-6): weight = 1 / (loss.detach() + eps) return weight def __call__( self, model, ema_model, unwrapped_model, batch_size, x, z, model_kwargs, use_dir_loss=False, h_target=None, ema_kwargs={}, null_kwargs={}, ): # prepare model input x_t, v_t, t, r, n_diffusion = self.prepare_input(batch_size, x, z) rng_state = torch.cuda.get_rng_state() # get prediction F_pred, h_proj = self.model_forward(model, x_t, t, r, model_kwargs, rng_state) # get target if self.derivative_type == "jvp": dF_dv_dt = self.jvp_derivative( unwrapped_model, x_t, v_t, t, r, model_kwargs, rng_state, n_diffusion ) else: dF_dv_dt = self.dde_derivative( unwrapped_model, x, z, t, r, model_kwargs, rng_state, n_diffusion ) if self.transport.enhance_target: F_t_cond, F_t_uncond = self.get_enhanced_target( ema_model, x_t, t, ema_kwargs, null_kwargs, rng_state ) enhance_target = True else: F_t_cond, F_t_uncond, enhance_target = 0, 0, False F_target = self.transport.target( x_t, v_t, x, z, t, r, dF_dv_dt, F_t_cond, F_t_uncond, enhance_target ) denoising_loss = mean_flat((F_pred - F_target) ** 2) denoising_loss = torch.nan_to_num( denoising_loss, nan=0, posinf=1e5, neginf=-1e5 ) if use_dir_loss: directional_loss = mean_flat( 1 - F.cosine_similarity(F_pred, F_target, dim=1) ) directional_loss = torch.nan_to_num( directional_loss, nan=0, posinf=1e5, neginf=-1e5 ) denoising_loss += directional_loss weight = self.time_weighting(t, r, n_diffusion) * self.adaptive_weighting( denoising_loss ) weighted_loss = weight * denoising_loss weighted_loss = weighted_loss.mean() proj_loss = mean_flat(1 - torch.cosine_similarity(h_proj, h_target, dim=-1)) proj_loss = torch.nan_to_num(proj_loss, nan=0, posinf=1e5, neginf=-1e5) proj_loss = proj_loss.mean() loss_dict = dict( weighted_loss=weighted_loss.detach().item(), denoising_loss=denoising_loss.mean().detach().item(), proj_loss=proj_loss.detach().item(), ) return weighted_loss, proj_loss, loss_dict def forward_with_cfg( self, model, x_t, t, r, y, y_null, cfg_scale, cfg_low, cfg_high ): apply_cfg = cfg_scale > 1.0 and t > cfg_low and t < cfg_high if apply_cfg: x_cur = torch.cat([x_t] * 2, dim=0) y_cur = torch.cat([y, y_null], dim=0) else: x_cur = x_t y_cur = y t_cur = torch.ones(x_cur.size(0)).to(x_cur) * self.transport.c_noise(t) r_cur = torch.ones(x_cur.size(0)).to(x_cur) * self.transport.c_noise(r) F_pred = model(x_cur, t_cur, r_cur, y_cur) if apply_cfg: F_cond, F_uncond = F_pred.chunk(2) F_pred = F_uncond + cfg_scale * (F_cond - F_uncond) return F_pred @torch.no_grad() def sample( self, model, y, y_null, z, T_max, T_min=0.0, num_steps=4, cfg_scale=1.0, cfg_low=0.0, cfg_high=1.0, stochasticity_ratio=0.0, sample_type: str = "transition", # 'transition', diffusion step_callback: Callable[[int], None] | None = None, ): _dtype = z.dtype t_steps = torch.linspace(T_max, T_min, num_steps + 1, dtype=torch.float64).to(z) cfg_low = cfg_low * T_max cfg_high = cfg_high * T_max x_cur = deepcopy(z).to(torch.float64) samples = [z] for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # x_{N} -> x_{N-1} -> ... -> x_{n+1} -> x_{n} -> x_{n-1} -> ... -> x_{1} -> x_{0} if sample_type == "transition": _t_next = t_next elif sample_type == "ddiffusion": _t_next = t_cur else: raise F_pred = self.forward_with_cfg( model, x_cur.to(_dtype), t_cur, _t_next, y, y_null, cfg_scale, cfg_low, cfg_high, ).to(torch.float64) if stochasticity_ratio > 0.0 and t_cur < T_max and _t_next > T_min: s_ratio = stochasticity_ratio else: s_ratio = 0.0 x_next = self.transport.from_x_t_to_x_r( x_cur, t_cur, t_next, F_pred, s_ratio ) samples.append(x_next) x_cur = x_next if step_callback is not None: step_callback(i) return torch.stack(samples, dim=0).to(torch.float32)