Spaces:
Running
on
Zero
Running
on
Zero
| # Modified from https://github.com/Fantasy-AMAP/fantasy-talking/blob/main/diffsynth/models | |
| # Copyright Alibaba Inc. All Rights Reserved. | |
| import math | |
| import os | |
| from typing import Any, Dict | |
| import numpy as np | |
| import torch | |
| import torch.cuda.amp as amp | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusers.configuration_utils import register_to_config | |
| from diffusers.utils import is_torch_version | |
| from ..dist import sequence_parallel_all_gather, sequence_parallel_chunk | |
| from ..utils import cfg_skip | |
| from .attention_utils import attention | |
| from .wan_transformer3d import (WanAttentionBlock, WanLayerNorm, WanRMSNorm, | |
| WanSelfAttention, WanTransformer3DModel, | |
| sinusoidal_embedding_1d) | |
| class AudioProjModel(nn.Module): | |
| def __init__(self, audio_in_dim=1024, cross_attention_dim=1024): | |
| super().__init__() | |
| self.cross_attention_dim = cross_attention_dim | |
| self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False) | |
| self.norm = torch.nn.LayerNorm(cross_attention_dim) | |
| def forward(self, audio_embeds): | |
| context_tokens = self.proj(audio_embeds) | |
| context_tokens = self.norm(context_tokens) | |
| return context_tokens # [B,L,C] | |
| class AudioCrossAttentionProcessor(nn.Module): | |
| def __init__(self, context_dim, hidden_dim): | |
| super().__init__() | |
| self.context_dim = context_dim | |
| self.hidden_dim = hidden_dim | |
| self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False) | |
| self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False) | |
| nn.init.zeros_(self.k_proj.weight) | |
| nn.init.zeros_(self.v_proj.weight) | |
| self.sp_world_size = 1 | |
| self.sp_world_rank = 0 | |
| self.all_gather = None | |
| def __call__( | |
| self, | |
| attn: nn.Module, | |
| x: torch.Tensor, | |
| context: torch.Tensor, | |
| context_lens: torch.Tensor, | |
| audio_proj: torch.Tensor, | |
| audio_context_lens: torch.Tensor, | |
| latents_num_frames: int = 21, | |
| audio_scale: float = 1.0, | |
| ) -> torch.Tensor: | |
| """ | |
| x: [B, L1, C]. | |
| context: [B, L2, C]. | |
| context_lens: [B]. | |
| audio_proj: [B, 21, L3, C] | |
| audio_context_lens: [B*21]. | |
| """ | |
| context_img = context[:, :257] | |
| context = context[:, 257:] | |
| b, n, d = x.size(0), attn.num_heads, attn.head_dim | |
| # Compute query, key, value | |
| q = attn.norm_q(attn.q(x)).view(b, -1, n, d) | |
| k = attn.norm_k(attn.k(context)).view(b, -1, n, d) | |
| v = attn.v(context).view(b, -1, n, d) | |
| k_img = attn.norm_k_img(attn.k_img(context_img)).view(b, -1, n, d) | |
| v_img = attn.v_img(context_img).view(b, -1, n, d) | |
| img_x = attention(q, k_img, v_img, k_lens=None) | |
| # Compute attention | |
| x = attention(q, k, v, k_lens=context_lens) | |
| x = x.flatten(2) | |
| img_x = img_x.flatten(2) | |
| if len(audio_proj.shape) == 4: | |
| if self.sp_world_size > 1: | |
| q = self.all_gather(q, dim=1) | |
| length = int(np.floor(q.size()[1] / latents_num_frames) * latents_num_frames) | |
| origin_length = q.size()[1] | |
| if origin_length > length: | |
| q_pad = q[:, length:] | |
| q = q[:, :length] | |
| audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d] | |
| ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d) | |
| ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d) | |
| audio_x = attention( | |
| audio_q, ip_key, ip_value, k_lens=audio_context_lens, attention_type="NORMAL" | |
| ) | |
| audio_x = audio_x.view(b, q.size(1), n, d) | |
| if self.sp_world_size > 1: | |
| if origin_length > length: | |
| audio_x = torch.cat([audio_x, q_pad], dim=1) | |
| audio_x = torch.chunk(audio_x, self.sp_world_size, dim=1)[self.sp_world_rank] | |
| audio_x = audio_x.flatten(2) | |
| elif len(audio_proj.shape) == 3: | |
| ip_key = self.k_proj(audio_proj).view(b, -1, n, d) | |
| ip_value = self.v_proj(audio_proj).view(b, -1, n, d) | |
| audio_x = attention(q, ip_key, ip_value, k_lens=audio_context_lens, attention_type="NORMAL") | |
| audio_x = audio_x.flatten(2) | |
| # Output | |
| if isinstance(audio_scale, torch.Tensor): | |
| audio_scale = audio_scale[:, None, None] | |
| x = x + img_x + audio_x * audio_scale | |
| x = attn.o(x) | |
| # print(audio_scale) | |
| return x | |
| class AudioCrossAttention(WanSelfAttention): | |
| def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): | |
| super().__init__(dim, num_heads, window_size, qk_norm, eps) | |
| self.k_img = nn.Linear(dim, dim) | |
| self.v_img = nn.Linear(dim, dim) | |
| self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() | |
| self.processor = AudioCrossAttentionProcessor(2048, dim) | |
| def forward( | |
| self, | |
| x, | |
| context, | |
| context_lens, | |
| audio_proj, | |
| audio_context_lens, | |
| latents_num_frames, | |
| audio_scale: float = 1.0, | |
| **kwargs, | |
| ): | |
| """ | |
| x: [B, L1, C]. | |
| context: [B, L2, C]. | |
| context_lens: [B]. | |
| """ | |
| if audio_proj is None: | |
| return self.processor(self, x, context, context_lens) | |
| else: | |
| return self.processor( | |
| self, | |
| x, | |
| context, | |
| context_lens, | |
| audio_proj, | |
| audio_context_lens, | |
| latents_num_frames, | |
| audio_scale, | |
| ) | |
| class AudioAttentionBlock(nn.Module): | |
| def __init__( | |
| self, | |
| cross_attn_type, # Useless | |
| dim, | |
| ffn_dim, | |
| num_heads, | |
| window_size=(-1, -1), | |
| qk_norm=True, | |
| cross_attn_norm=False, | |
| eps=1e-6, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.ffn_dim = ffn_dim | |
| self.num_heads = num_heads | |
| self.window_size = window_size | |
| self.qk_norm = qk_norm | |
| self.cross_attn_norm = cross_attn_norm | |
| self.eps = eps | |
| # Layers | |
| self.norm1 = WanLayerNorm(dim, eps) | |
| self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps) | |
| self.norm3 = ( | |
| WanLayerNorm(dim, eps, elementwise_affine=True) | |
| if cross_attn_norm | |
| else nn.Identity() | |
| ) | |
| self.cross_attn = AudioCrossAttention( | |
| dim, num_heads, (-1, -1), qk_norm, eps | |
| ) | |
| self.norm2 = WanLayerNorm(dim, eps) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(dim, ffn_dim), | |
| nn.GELU(approximate="tanh"), | |
| nn.Linear(ffn_dim, dim), | |
| ) | |
| # Modulation | |
| self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) | |
| def forward( | |
| self, | |
| x, | |
| e, | |
| seq_lens, | |
| grid_sizes, | |
| freqs, | |
| context, | |
| context_lens, | |
| audio_proj=None, | |
| audio_context_lens=None, | |
| audio_scale=1, | |
| dtype=torch.bfloat16, | |
| t=0, | |
| ): | |
| assert e.dtype == torch.float32 | |
| with amp.autocast(dtype=torch.float32): | |
| e = (self.modulation.to(dtype=e.dtype, device=e.device) + e).chunk(6, dim=1) | |
| assert e[0].dtype == torch.float32 | |
| # self-attention | |
| y = self.self_attn( | |
| self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs, dtype, t=t | |
| ) | |
| with amp.autocast(dtype=torch.float32): | |
| x = x + y * e[2] | |
| # Cross-attention & FFN function | |
| def cross_attn_ffn(x, context, context_lens, e): | |
| x = x + self.cross_attn( | |
| self.norm3(x), context, context_lens, dtype=dtype, t=t, | |
| audio_proj=audio_proj, audio_context_lens=audio_context_lens, audio_scale=audio_scale, | |
| latents_num_frames=grid_sizes[0][0], | |
| ) | |
| y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) | |
| with amp.autocast(dtype=torch.float32): | |
| x = x + y * e[5] | |
| return x | |
| x = cross_attn_ffn(x, context, context_lens, e) | |
| return x | |
| class FantasyTalkingTransformer3DModel(WanTransformer3DModel): | |
| def __init__(self, | |
| model_type='i2v', | |
| patch_size=(1, 2, 2), | |
| text_len=512, | |
| in_dim=16, | |
| dim=2048, | |
| ffn_dim=8192, | |
| freq_dim=256, | |
| text_dim=4096, | |
| out_dim=16, | |
| num_heads=16, | |
| num_layers=32, | |
| window_size=(-1, -1), | |
| qk_norm=True, | |
| cross_attn_norm=True, | |
| eps=1e-6, | |
| cross_attn_type=None, | |
| audio_in_dim=768): | |
| super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim, | |
| num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps) | |
| if cross_attn_type is None: | |
| cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' | |
| self.blocks = nn.ModuleList([ | |
| AudioAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, | |
| window_size, qk_norm, cross_attn_norm, eps) | |
| for _ in range(num_layers) | |
| ]) | |
| for layer_idx, block in enumerate(self.blocks): | |
| block.self_attn.layer_idx = layer_idx | |
| block.self_attn.num_layers = self.num_layers | |
| self.proj_model = AudioProjModel(audio_in_dim, 2048) | |
| def split_audio_sequence(self, audio_proj_length, num_frames=81): | |
| """ | |
| Map the audio feature sequence to corresponding latent frame slices. | |
| Args: | |
| audio_proj_length (int): The total length of the audio feature sequence | |
| (e.g., 173 in audio_proj[1, 173, 768]). | |
| num_frames (int): The number of video frames in the training data (default: 81). | |
| Returns: | |
| list: A list of [start_idx, end_idx] pairs. Each pair represents the index range | |
| (within the audio feature sequence) corresponding to a latent frame. | |
| """ | |
| # Average number of tokens per original video frame | |
| tokens_per_frame = audio_proj_length / num_frames | |
| # Each latent frame covers 4 video frames, and we want the center | |
| tokens_per_latent_frame = tokens_per_frame * 4 | |
| half_tokens = int(tokens_per_latent_frame / 2) | |
| pos_indices = [] | |
| for i in range(int((num_frames - 1) / 4) + 1): | |
| if i == 0: | |
| pos_indices.append(0) | |
| else: | |
| start_token = tokens_per_frame * ((i - 1) * 4 + 1) | |
| end_token = tokens_per_frame * (i * 4 + 1) | |
| center_token = int((start_token + end_token) / 2) - 1 | |
| pos_indices.append(center_token) | |
| # Build index ranges centered around each position | |
| pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices] | |
| # Adjust the first range to avoid negative start index | |
| pos_idx_ranges[0] = [ | |
| -(half_tokens * 2 - pos_idx_ranges[1][0]), | |
| pos_idx_ranges[1][0], | |
| ] | |
| return pos_idx_ranges | |
| def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0): | |
| """ | |
| Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding | |
| if the range exceeds the input boundaries. | |
| Args: | |
| input_tensor (Tensor): Input audio tensor of shape [1, L, 768]. | |
| pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]]. | |
| expand_length (int): Number of tokens to expand on both sides of each subsequence. | |
| Returns: | |
| sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding. | |
| Each element is a padded subsequence. | |
| k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence. | |
| Useful for ignoring padding tokens in attention masks. | |
| """ | |
| pos_idx_ranges = [ | |
| [idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges | |
| ] | |
| sub_sequences = [] | |
| seq_len = input_tensor.size(1) # 173 | |
| max_valid_idx = seq_len - 1 # 172 | |
| k_lens_list = [] | |
| for start, end in pos_idx_ranges: | |
| # Calculate the fill amount | |
| pad_front = max(-start, 0) | |
| pad_back = max(end - max_valid_idx, 0) | |
| # Calculate the start and end indices of the valid part | |
| valid_start = max(start, 0) | |
| valid_end = min(end, max_valid_idx) | |
| # Extract the valid part | |
| if valid_start <= valid_end: | |
| valid_part = input_tensor[:, valid_start : valid_end + 1, :] | |
| else: | |
| valid_part = input_tensor.new_zeros((1, 0, input_tensor.size(2))) | |
| # In the sequence dimension (the 1st dimension) perform padding | |
| padded_subseq = F.pad( | |
| valid_part, | |
| (0, 0, 0, pad_back + pad_front, 0, 0), | |
| mode="constant", | |
| value=0, | |
| ) | |
| k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front) | |
| sub_sequences.append(padded_subseq) | |
| return torch.stack(sub_sequences, dim=1), torch.tensor( | |
| k_lens_list, dtype=torch.long | |
| ) | |
| def enable_multi_gpus_inference(self,): | |
| super().enable_multi_gpus_inference() | |
| for name, module in self.named_modules(): | |
| if module.__class__.__name__ == 'AudioCrossAttentionProcessor': | |
| module.sp_world_size = self.sp_world_size | |
| module.sp_world_rank = self.sp_world_rank | |
| module.all_gather = self.all_gather | |
| def forward( | |
| self, | |
| x, | |
| t, | |
| context, | |
| seq_len, | |
| audio_wav2vec_fea=None, | |
| clip_fea=None, | |
| y=None, | |
| audio_scale=1, | |
| cond_flag=True | |
| ): | |
| r""" | |
| Forward pass through the diffusion model | |
| Args: | |
| x (List[Tensor]): | |
| List of input video tensors, each with shape [C_in, F, H, W] | |
| t (Tensor): | |
| Diffusion timesteps tensor of shape [B] | |
| context (List[Tensor]): | |
| List of text embeddings each with shape [L, C] | |
| seq_len (`int`): | |
| Maximum sequence length for positional encoding | |
| clip_fea (Tensor, *optional*): | |
| CLIP image features for image-to-video mode | |
| y (List[Tensor], *optional*): | |
| Conditional video inputs for image-to-video mode, same shape as x | |
| Returns: | |
| List[Tensor]: | |
| List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] | |
| """ | |
| # Wan2.2 don't need a clip. | |
| # if self.model_type == 'i2v': | |
| # assert clip_fea is not None and y is not None | |
| # params | |
| device = self.patch_embedding.weight.device | |
| dtype = x.dtype | |
| if self.freqs.device != device and torch.device(type="meta") != device: | |
| self.freqs = self.freqs.to(device) | |
| if y is not None: | |
| x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] | |
| # embeddings | |
| x = [self.patch_embedding(u.unsqueeze(0)) for u in x] | |
| grid_sizes = torch.stack( | |
| [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) | |
| x = [u.flatten(2).transpose(1, 2) for u in x] | |
| seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) | |
| if self.sp_world_size > 1: | |
| seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size | |
| assert seq_lens.max() <= seq_len | |
| x = torch.cat([ | |
| torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], | |
| dim=1) for u in x | |
| ]) | |
| # time embeddings | |
| with amp.autocast(dtype=torch.float32): | |
| if t.dim() != 1: | |
| if t.size(1) < seq_len: | |
| pad_size = seq_len - t.size(1) | |
| last_elements = t[:, -1].unsqueeze(1) | |
| padding = last_elements.repeat(1, pad_size) | |
| t = torch.cat([t, padding], dim=1) | |
| bt = t.size(0) | |
| ft = t.flatten() | |
| e = self.time_embedding( | |
| sinusoidal_embedding_1d(self.freq_dim, | |
| ft).unflatten(0, (bt, seq_len)).float()) | |
| e0 = self.time_projection(e).unflatten(2, (6, self.dim)) | |
| else: | |
| e = self.time_embedding( | |
| sinusoidal_embedding_1d(self.freq_dim, t).float()) | |
| e0 = self.time_projection(e).unflatten(1, (6, self.dim)) | |
| # assert e.dtype == torch.float32 and e0.dtype == torch.float32 | |
| # e0 = e0.to(dtype) | |
| # e = e.to(dtype) | |
| # context | |
| context_lens = None | |
| context = self.text_embedding( | |
| torch.stack([ | |
| torch.cat( | |
| [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) | |
| for u in context | |
| ])) | |
| if clip_fea is not None: | |
| context_clip = self.img_emb(clip_fea) # bs x 257 x dim | |
| context = torch.concat([context_clip, context], dim=1) | |
| num_frames = (grid_sizes[0][0] - 1) * 4 + 1 | |
| audio_proj_fea = self.proj_model(audio_wav2vec_fea) | |
| pos_idx_ranges = self.split_audio_sequence(audio_proj_fea.size(1), num_frames=num_frames) | |
| audio_proj, audio_context_lens = self.split_tensor_with_padding( | |
| audio_proj_fea, pos_idx_ranges, expand_length=4 | |
| ) | |
| # Context Parallel | |
| if self.sp_world_size > 1: | |
| x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank] | |
| if t.dim() != 1: | |
| e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank] | |
| e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank] | |
| # TeaCache | |
| if self.teacache is not None: | |
| if cond_flag: | |
| if t.dim() != 1: | |
| modulated_inp = e0[:, -1, :] | |
| else: | |
| modulated_inp = e0 | |
| skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps | |
| if skip_flag: | |
| self.should_calc = True | |
| self.teacache.accumulated_rel_l1_distance = 0 | |
| else: | |
| if cond_flag: | |
| rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp) | |
| self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance) | |
| if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh: | |
| self.should_calc = False | |
| else: | |
| self.should_calc = True | |
| self.teacache.accumulated_rel_l1_distance = 0 | |
| self.teacache.previous_modulated_input = modulated_inp | |
| self.teacache.should_calc = self.should_calc | |
| else: | |
| self.should_calc = self.teacache.should_calc | |
| # TeaCache | |
| if self.teacache is not None: | |
| if not self.should_calc: | |
| previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond | |
| x = x + previous_residual.to(x.device)[-x.size()[0]:,] | |
| else: | |
| ori_x = x.clone().cpu() if self.teacache.offload else x.clone() | |
| for block in self.blocks: | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| x, | |
| e0, | |
| seq_lens, | |
| grid_sizes, | |
| self.freqs, | |
| context, | |
| context_lens, | |
| audio_proj, | |
| audio_context_lens, | |
| audio_scale, | |
| dtype, | |
| t, | |
| **ckpt_kwargs, | |
| ) | |
| else: | |
| # arguments | |
| kwargs = dict( | |
| e=e0, | |
| seq_lens=seq_lens, | |
| grid_sizes=grid_sizes, | |
| freqs=self.freqs, | |
| context=context, | |
| context_lens=context_lens, | |
| audio_proj=audio_proj, | |
| audio_context_lens=audio_context_lens, | |
| audio_scale=audio_scale, | |
| dtype=dtype, | |
| t=t | |
| ) | |
| x = block(x, **kwargs) | |
| if cond_flag: | |
| self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x | |
| else: | |
| self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x | |
| else: | |
| for block in self.blocks: | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| x, | |
| e0, | |
| seq_lens, | |
| grid_sizes, | |
| self.freqs, | |
| context, | |
| context_lens, | |
| audio_proj, | |
| audio_context_lens, | |
| audio_scale, | |
| dtype, | |
| t, | |
| **ckpt_kwargs, | |
| ) | |
| else: | |
| # arguments | |
| kwargs = dict( | |
| e=e0, | |
| seq_lens=seq_lens, | |
| grid_sizes=grid_sizes, | |
| freqs=self.freqs, | |
| context=context, | |
| context_lens=context_lens, | |
| audio_proj=audio_proj, | |
| audio_context_lens=audio_context_lens, | |
| audio_scale=audio_scale, | |
| dtype=dtype, | |
| t=t | |
| ) | |
| x = block(x, **kwargs) | |
| # head | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
| x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs) | |
| else: | |
| x = self.head(x, e) | |
| if self.sp_world_size > 1: | |
| x = self.all_gather(x, dim=1) | |
| # Unpatchify | |
| x = self.unpatchify(x, grid_sizes) | |
| x = torch.stack(x) | |
| if self.teacache is not None and cond_flag: | |
| self.teacache.cnt += 1 | |
| if self.teacache.cnt == self.teacache.num_steps: | |
| self.teacache.reset() | |
| return x |