Spaces:
Build error
Build error
| import math | |
| import time | |
| import torch | |
| from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor | |
| from torch.nn import Module | |
| import torch.nn.functional as F | |
| import torchode | |
| from torchdiffeq import odeint | |
| from beartype import beartype | |
| from beartype.typing import Tuple, Optional, List, Union | |
| from einops.layers.torch import Rearrange | |
| from einops import rearrange, repeat, reduce, pack, unpack | |
| from modules.audio2motion.cfm.utils import * | |
| from modules.audio2motion.cfm.icl_transformer import InContextTransformerAudio2Motion | |
| # wrapper for the CNF | |
| def is_probably_audio_from_shape(t): | |
| return exists(t) and (t.ndim == 2 or (t.ndim == 3 and t.shape[1] == 1)) | |
| class ConditionalFlowMatcherWrapper(Module): | |
| def __init__( | |
| self, | |
| icl_transformer_model: InContextTransformerAudio2Motion = None, | |
| sigma = 0., | |
| ode_atol = 1e-5, | |
| ode_rtol = 1e-5, | |
| # ode_step_size = 0.0625, | |
| use_torchode = False, | |
| torchdiffeq_ode_method = 'midpoint', # use midpoint for torchdiffeq, as in paper | |
| torchode_method_klass = torchode.Tsit5, # use tsit5 for torchode, as torchode does not have midpoint (recommended by Bryan @b-chiang) | |
| cond_drop_prob = 0. | |
| ): | |
| super().__init__() | |
| self.sigma = sigma | |
| if icl_transformer_model is None: | |
| icl_transformer_model = InContextTransformerAudio2Motion() | |
| self.icl_transformer_model = icl_transformer_model | |
| self.cond_drop_prob = cond_drop_prob | |
| self.use_torchode = use_torchode | |
| self.torchode_method_klass = torchode_method_klass | |
| self.odeint_kwargs = dict( | |
| atol = ode_atol, | |
| rtol = ode_rtol, | |
| method = torchdiffeq_ode_method, | |
| # options = dict(step_size = ode_step_size) | |
| ) | |
| def device(self): | |
| return next(self.parameters()).device | |
| def sample( | |
| self, | |
| *, | |
| cond_audio = None, # [B, T (可以是2倍,会被interpolate到x1的length), C] | |
| cond = None, # random | |
| cond_mask = None, | |
| steps = 3, # flow steps, 3和10都需要0.56s | |
| cond_scale = 1., | |
| ret=None, | |
| self_attn_mask = None, | |
| temperature=1.0, | |
| ): | |
| if ret is None: | |
| ret = {} | |
| cond_target_length = cond_audio.shape[1] // 2 | |
| if exists(cond): | |
| cond = curtail_or_pad(cond, cond_target_length) | |
| else: | |
| cond = torch.zeros((cond_audio.shape[0], cond_target_length, self.dim_cond_emb), device = self.device) | |
| shape = cond.shape | |
| batch = shape[0] | |
| # neural ode | |
| self.icl_transformer_model.eval() | |
| def fn(t, x, *, packed_shape = None): | |
| if exists(packed_shape): | |
| x = unpack_one(x, packed_shape, 'b *') | |
| out = self.icl_transformer_model.forward_with_cond_scale( | |
| x, # rand | |
| times = t, # timestep in DM | |
| cond_audio = cond_audio, | |
| cond = cond, # rand? | |
| cond_scale = cond_scale, | |
| cond_mask = cond_mask, | |
| self_attn_mask = self_attn_mask, | |
| ret=ret, | |
| ) | |
| if exists(packed_shape): | |
| out = rearrange(out, 'b ... -> b (...)') | |
| return out | |
| y0 = torch.randn_like(cond) * float(temperature) | |
| t = torch.linspace(0, 1, steps, device = self.device) | |
| timestamp_before_sampling = time.time() | |
| if not self.use_torchode: | |
| print(f'sampling based on torchdiffeq with flow total_steps={steps}') | |
| trajectory = odeint(fn, y0, t, **self.odeint_kwargs) # 从y0位置出发,fn根据当前位置提供velocity,沿着t进行积分。 | |
| sampled = trajectory[-1] | |
| else: | |
| print(f'sampling based on torchode with flow total_steps={steps}') | |
| t = repeat(t, 'n -> b n', b = batch) | |
| y0, packed_shape = pack_one(y0, 'b *') | |
| fn = partial(fn, packed_shape = packed_shape) | |
| term = to.ODETerm(fn) | |
| step_method = self.torchode_method_klass(term = term) | |
| step_size_controller = to.IntegralController( | |
| atol = self.odeint_kwargs['atol'], | |
| rtol = self.odeint_kwargs['rtol'], | |
| term = term | |
| ) | |
| solver = to.AutoDiffAdjoint(step_method, step_size_controller) | |
| jit_solver = torch.compile(solver) | |
| init_value = to.InitialValueProblem(y0 = y0, t_eval = t) | |
| sol = jit_solver.solve(init_value) | |
| sampled = sol.ys[:, -1] | |
| sampled = unpack_one(sampled, packed_shape, 'b *') | |
| print(f"Flow matching sampling process elapsed in {time.time()-timestamp_before_sampling:.4f} second") | |
| return sampled | |
| def forward( | |
| self, | |
| x1, # gt sample, landmark, [B, T, C] | |
| *, | |
| mask = None, # mask of frames in batch | |
| cond_audio = None, # [B, T (可以是2倍,会被interpolate到x1的length), C] | |
| cond = None, # reference landmark | |
| cond_mask = None, # mask of reference landmark, reference are marked as False, and frames to be predicted are True | |
| ret = None, | |
| ): | |
| """ | |
| training step of Continous Normalizing Flow | |
| following eq (5) (6) in https://arxiv.org/pdf/2306.15687.pdf | |
| """ | |
| if ret is None: | |
| ret = {} | |
| batch, seq_len, dtype, sigma_ = *x1.shape[:2], x1.dtype, self.sigma | |
| # main conditional flow logic is below | |
| # x0 is gaussian noise | |
| x0 = torch.randn_like(x1) | |
| # batch-wise random times with 0~1 | |
| times = torch.rand((batch,), dtype = dtype, device = self.device) | |
| t = rearrange(times, 'b -> b 1 1') | |
| # sample xt within x0=>xt=>x1 (Sec 3.1 in the paper) | |
| # The associated conditional vector field is ut(x | x1) = (x1 − (1 − σmin)*x) / (1 − (1 − σmin)*t), | |
| # and the conditional flow is φt(x | x1) = (1 − (1 − σmin)*t)*x + t * x1. | |
| current_position_in_flows = (1 - (1 - sigma_) * t) * x0 + t * x1 # input of the transformer, noised sample, conditional flow, φt(x | x1) in FlowMatching | |
| optimal_path = x1 - (1 - sigma_) * x0 # target of the transformer, vector field , u_t(x|x1) in FlowMatching | |
| # predict | |
| self.icl_transformer_model.train() | |
| # the ouput of transformer is learnable vector field v_t(x;theta) in FlowMatching | |
| loss = self.icl_transformer_model( | |
| current_position_in_flows, # noised motion sample | |
| cond = cond, | |
| cond_mask = cond_mask, | |
| times = times, | |
| target = optimal_path, # | |
| self_attn_mask = mask, | |
| cond_audio = cond_audio, | |
| cond_drop_prob = self.cond_drop_prob, | |
| ret=ret, | |
| ) | |
| pred_x1_minus_x0 = ret['pred'] # predicted path | |
| pred_x1 = pred_x1_minus_x0 + (1 - sigma_) * x0 | |
| ret['pred'] = pred_x1 | |
| return loss | |
| if __name__ == '__main__': | |
| icl_transformer = InContextTransformerAudio2Motion() | |
| model = ConditionalFlowMatcherWrapper(icl_transformer) | |
| x = torch.randn([2, 125, 64]) | |
| cond = torch.randn([2, 125, 64]) | |
| cond_audio = torch.randn([2, 250, 1024]) | |
| y = model(x, cond=cond, cond_audio=cond_audio) | |
| y = model.sample(cond=cond, cond_audio=cond_audio) | |
| print(y.shape) |