import numpy as np import torch import torch.nn as nn from .tools.t5 import T5EncoderModel from .tools.wan_model import WanModel class DiffForcingWanModel(nn.Module): def __init__( self, checkpoint_path="deps/t5_umt5-xxl-enc-bf16/models_t5_umt5-xxl-enc-bf16.pth", tokenizer_path="deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl", input_dim=256, hidden_dim=1024, ffn_dim=2048, freq_dim=256, num_heads=8, num_layers=8, time_embedding_scale=1.0, chunk_size=5, noise_steps=10, use_text_cond=True, text_len=128, drop_out=0.1, cfg_scale=5.0, prediction_type="vel", # "vel", "x0", "noise" causal=False, ): super().__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.ffn_dim = ffn_dim self.freq_dim = freq_dim self.num_heads = num_heads self.num_layers = num_layers self.time_embedding_scale = time_embedding_scale self.chunk_size = chunk_size self.noise_steps = noise_steps self.use_text_cond = use_text_cond self.drop_out = drop_out self.cfg_scale = cfg_scale self.prediction_type = prediction_type self.causal = causal self.text_dim = 4096 self.text_len = text_len self.text_encoder = T5EncoderModel( text_len=self.text_len, dtype=torch.bfloat16, device=torch.device("cpu"), checkpoint_path=checkpoint_path, tokenizer_path=tokenizer_path, shard_fn=None, ) # Text encoding cache self.text_cache = {} self.model = WanModel( model_type="t2v", patch_size=(1, 1, 1), text_len=self.text_len, in_dim=self.input_dim, dim=self.hidden_dim, ffn_dim=self.ffn_dim, freq_dim=self.freq_dim, text_dim=self.text_dim, out_dim=self.input_dim, num_heads=self.num_heads, num_layers=self.num_layers, window_size=(-1, -1), qk_norm=True, cross_attn_norm=True, eps=1e-6, causal=self.causal, ) self.param_dtype = torch.float32 def encode_text_with_cache(self, text_list, device): """Encode text using cache Args: text_list: List[str], list of texts device: torch.device Returns: List[Tensor]: List of encoded text features """ text_features = [] indices_to_encode = [] texts_to_encode = [] # Check cache for i, text in enumerate(text_list): if text in self.text_cache: # Get from cache and move to correct device cached_feature = self.text_cache[text].to(device) text_features.append(cached_feature) else: # Need to encode text_features.append(None) indices_to_encode.append(i) texts_to_encode.append(text) # Batch encode uncached texts if texts_to_encode: self.text_encoder.model.to(device) encoded = self.text_encoder(texts_to_encode, device) # Store in cache and update results for idx, text, feature in zip(indices_to_encode, texts_to_encode, encoded): # Cache to CPU to save GPU memory self.text_cache[text] = feature.cpu() text_features[idx] = feature return text_features def preprocess(self, x): # (bs, T, C) -> (bs, C, T, 1, 1) x = x.permute(0, 2, 1)[:, :, :, None, None] return x def postprocess(self, x): # (bs, C, T, 1, 1) -> (bs, T, C) x = x.permute(0, 2, 1, 3, 4).contiguous().view(x.size(0), x.size(2), -1) return x def _get_noise_levels(self, device, seq_len, time_steps): """Get noise levels""" # noise_level[i] = clip(1 + i / chunk_size - time_steps, 0, 1) noise_level = torch.clamp( 1 + torch.arange(seq_len, device=device) / self.chunk_size - time_steps.unsqueeze(1), min=0.0, max=1.0, ) return noise_level def add_noise(self, x, noise_level): """Add noise Args: x: (B, T, D) noise_level: (B, T) """ noise = torch.randn_like(x) # noise_level: (B, T) -> (B, T, 1) noise_level = noise_level.unsqueeze(-1) noisy_x = x * (1 - noise_level) + noise_level * noise return noisy_x, noise def forward(self, x): feature = x["feature"] # (B, T, C) feature_length = x["feature_length"] # (B,) batch_size, seq_len, _ = feature.shape device = feature.device # Randomly use a time step time_steps = [] for i in range(batch_size): valid_len = feature_length[i].item() # Random float from 0 to valid_len/chunk_size, not an integer max_time = valid_len / self.chunk_size # max_time = valid_len / self.chunk_size + 1 time_steps.append(torch.FloatTensor(1).uniform_(0, max_time).item()) time_steps = torch.tensor(time_steps, device=device) # (B,) noise_level = self._get_noise_levels(device, seq_len, time_steps) # (B, T) # # Debug: Print noise levels # print("Time steps and corresponding noise levels:") # for i in range(batch_size): # t = time_steps[i].item() # # Get noise level at each position # start_idx = int(self.chunk_size * (t - 1)) # end_idx = int(self.chunk_size * t) + 2 # # Limit to valid range # start_idx = max(0, start_idx) # end_idx = min(seq_len, end_idx) # print(time_steps[i]) # print(noise_level[i, start_idx:end_idx]) # Add noise to entire sequence noisy_feature, noise = self.add_noise(feature, noise_level) # (B, T, D) # Debug: Print noise addition information # print("Added noise levels at chunk positions:") # for i in range(batch_size): # t = time_steps[i].item() # start_idx = int(self.chunk_size * (t - 1)) # end_idx = int(self.chunk_size * t) + 2 # # Limit to valid range # start_idx = max(0, start_idx) # end_idx = min(seq_len, end_idx) # test1 = ( # feature[i, start_idx:end_idx, :] - noisy_feature[i, start_idx:end_idx, :] # ) # test2 = ( # noise[i, start_idx:end_idx, :] - noisy_feature[i, start_idx:end_idx, :] # ) # # Compute length on last dimension # print(test1.norm(dim=-1)) # print(test2.norm(dim=-1)) feature = self.preprocess(feature) # (B, C, T, 1, 1) noisy_feature = self.preprocess(noisy_feature) # (B, C, T, 1, 1) noise = self.preprocess(noise) # (B, C, T, 1, 1) feature_ref = [] noise_ref = [] noisy_feature_input = [] for i in range(batch_size): t = time_steps[i].item() end_index = int(self.chunk_size * t) + 1 valid_len = feature_length[i].item() end_index = min(valid_len, end_index) feature_ref.append(feature[i, :, :end_index, ...]) noise_ref.append(noise[i, :, :end_index, ...]) noisy_feature_input.append(noisy_feature[i, :, :end_index, ...]) # Encode text condition (using cache) if self.use_text_cond and "text" in x: text_list = x["text"] # List[str] or List[List[str]] if isinstance(text_list[0], list): text_end_list = x["feature_text_end"] all_text_context = [] for single_text_list, single_text_end_list in zip( text_list, text_end_list ): if np.random.rand() > self.drop_out: single_text_list = [""] single_text_end_list = [0, seq_len] else: single_text_end_list = [0] + [ min(t, seq_len) for t in single_text_end_list ] single_text_length_list = [ t - b for t, b in zip( single_text_end_list[1:], single_text_end_list[:-1] ) ] single_text_context = self.encode_text_with_cache( single_text_list, device ) single_text_context = [ u.to(self.param_dtype) for u in single_text_context ] for u, duration in zip( single_text_context, single_text_length_list ): all_text_context.extend([u for _ in range(duration)]) all_text_context.extend( [ single_text_context[-1] for _ in range(seq_len - single_text_end_list[-1]) ] ) else: all_text_context = [ (u if np.random.rand() > self.drop_out else "") for u in text_list ] all_text_context = self.encode_text_with_cache(all_text_context, device) all_text_context = [u.to(self.param_dtype) for u in all_text_context] else: all_text_context = [""] * batch_size all_text_context = self.encode_text_with_cache(all_text_context, device) all_text_context = [u.to(self.param_dtype) for u in all_text_context] # Through WanModel predicted_result = self.model( noisy_feature_input, noise_level * self.time_embedding_scale, all_text_context, seq_len, y=None, ) # (B, C, T, 1, 1) loss = 0.0 for b in range(batch_size): if self.prediction_type == "vel": vel = feature_ref[b] - noise_ref[b] # (C, input_length, 1, 1) squared_error = ( predicted_result[b][:, -self.chunk_size :, ...] - vel[:, -self.chunk_size :, ...] ) ** 2 elif self.prediction_type == "x0": squared_error = ( predicted_result[b][:, -self.chunk_size :, ...] - feature_ref[b][:, -self.chunk_size :, ...] ) ** 2 elif self.prediction_type == "noise": squared_error = ( predicted_result[b][:, -self.chunk_size :, ...] - noise_ref[b][:, -self.chunk_size :, ...] ) ** 2 sample_loss = squared_error.sum().mean() loss += sample_loss loss = loss / batch_size loss_dict = {"total": loss, "mse": loss} return loss_dict def generate(self, x, num_denoise_steps=None): """ Generation - Diffusion Forcing inference Uses triangular noise schedule, progressively generating from left to right Generation process: 1. Start from t=0, gradually increase t 2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle 3. After each denoising step, t increases slightly and continues """ feature_length = x["feature_length"] batch_size = len(feature_length) seq_len = max(feature_length).item() # # debug # x["text"] = [["walk forward.", "sit down.", "stand up."] for _ in range(batch_size)] # x["feature_text_end"] = [[1, 2, 3] for _ in range(batch_size)] # text = x["text"] # text_end = x["feature_text_end"] # print(text) # print(text_end) # print(batch_size, seq_len, self.chunk_size) if num_denoise_steps is None: num_denoise_steps = self.noise_steps assert num_denoise_steps % self.chunk_size == 0 device = next(self.parameters()).device # Initialize entire sequence as pure noise generated = torch.randn( batch_size, seq_len + self.chunk_size, self.input_dim, device=device ) generated = self.preprocess(generated) # (B, C, T, 1, 1) # Calculate total number of time steps needed max_t = 1 + (seq_len - 1) / self.chunk_size # Step size for each advancement dt = 1 / num_denoise_steps total_steps = int(max_t / dt) # Encode text condition (using cache) if self.use_text_cond and "text" in x: text_list = x["text"] # List[str] or List[List[str]] if isinstance(text_list[0], list): generated_length = [] text_end_list = x["feature_text_end"] full_text = [] all_text_context = [] for single_text_list, single_text_end_list in zip( text_list, text_end_list ): single_text_end_list = [0] + [ min(t, seq_len) for t in single_text_end_list ] generated_length.append(single_text_end_list[-1]) single_text_length_list = [ t - b for t, b in zip( single_text_end_list[1:], single_text_end_list[:-1] ) ] full_text.append( " ////////// ".join( [ f"{u} //dur:{t}" for u, t in zip( single_text_list, single_text_length_list ) ] ) ) single_text_context = self.encode_text_with_cache( single_text_list, device ) single_text_context = [ u.to(self.param_dtype) for u in single_text_context ] for u, duration in zip( single_text_context, single_text_length_list ): all_text_context.extend([u for _ in range(duration)]) all_text_context.extend( [ single_text_context[-1] for _ in range( seq_len + self.chunk_size - single_text_end_list[-1] ) ] ) else: generated_length = feature_length full_text = text_list all_text_context = self.encode_text_with_cache(text_list, device) all_text_context = [u.to(self.param_dtype) for u in all_text_context] else: generated_length = feature_length full_text = [""] * batch_size all_text_context = [""] * batch_size all_text_context = self.encode_text_with_cache(all_text_context, device) all_text_context = [u.to(self.param_dtype) for u in all_text_context] # Get empty text condition encoding (for CFG) text_null_list = [""] * batch_size text_null_context = self.encode_text_with_cache(text_null_list, device) text_null_context = [u.to(self.param_dtype) for u in text_null_context] # print(len(all_text_context), len(text_null_context)) # Progressively advance from t=0 to t=max_t for step in range(total_steps): # Current time step t = step * dt start_index = max(0, int(self.chunk_size * (t - 1)) + 1) end_index = int(self.chunk_size * t) + 1 time_steps = torch.full((batch_size,), t, device=device) # Calculate current noise schedule noise_level = self._get_noise_levels( device, seq_len + self.chunk_size, time_steps ) # (B, T) # Predict noise through WanModel noisy_input = [] for i in range(batch_size): noisy_input.append(generated[i, :, :end_index, ...]) predicted_result = self.model( noisy_input, noise_level * self.time_embedding_scale, all_text_context, seq_len + self.chunk_size, y=None, ) # (B, C, T, 1, 1) # Adjust using CFG if self.cfg_scale != 1.0: predicted_result_null = self.model( noisy_input, noise_level * self.time_embedding_scale, text_null_context, seq_len + self.chunk_size, y=None, ) # (B, C, T, 1, 1) predicted_result = [ self.cfg_scale * pv - (self.cfg_scale - 1) * pvn for pv, pvn in zip(predicted_result, predicted_result_null) ] for i in range(batch_size): predicted_result_i = predicted_result[i] # (C, input_length, 1, 1) if self.prediction_type == "vel": predicted_vel = predicted_result_i[:, start_index:end_index, ...] generated[i, :, start_index:end_index, ...] += predicted_vel * dt elif self.prediction_type == "x0": predicted_vel = ( predicted_result_i[:, start_index:end_index, ...] - generated[i, :, start_index:end_index, ...] ) / ( noise_level[i, start_index:end_index] .unsqueeze(0) .unsqueeze(-1) .unsqueeze(-1) ) generated[i, :, start_index:end_index, ...] += predicted_vel * dt elif self.prediction_type == "noise": predicted_vel = ( generated[i, :, start_index:end_index, ...] - predicted_result_i[:, start_index:end_index, ...] ) / ( 1 + dt - noise_level[i, start_index:end_index] .unsqueeze(0) .unsqueeze(-1) .unsqueeze(-1) ) generated[i, :, start_index:end_index, ...] += predicted_vel * dt generated = self.postprocess(generated) # (B, T, C) y_hat_out = [] for i in range(batch_size): # cut off the padding single_generated = generated[i, : generated_length[i], :] y_hat_out.append(single_generated) out = {} out["generated"] = y_hat_out out["text"] = full_text return out @torch.no_grad() def stream_generate(self, x, num_denoise_steps=None): """ Streaming generation - Diffusion Forcing inference Uses triangular noise schedule, progressively generating from left to right Generation process: 1. Start from t=0, gradually increase t 2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle 3. After each denoising step, t increases slightly and continues """ feature_length = x["feature_length"] batch_size = len(feature_length) seq_len = max(feature_length).item() # # debug # x["text"] = [["walk forward.", "sit down.", "stand up."] for _ in range(batch_size)] # x["feature_text_end"] = [[1, 2, 3] for _ in range(batch_size)] # text = x["text"] # text_end = x["feature_text_end"] # print(text) # print(text_end) # print(batch_size, seq_len, self.chunk_size) if num_denoise_steps is None: num_denoise_steps = self.noise_steps assert num_denoise_steps % self.chunk_size == 0 device = next(self.parameters()).device # Initialize entire sequence as pure noise generated = torch.randn( batch_size, seq_len + self.chunk_size, self.input_dim, device=device ) generated = self.preprocess(generated) # (B, C, T, 1, 1) # Calculate total number of time steps needed max_t = 1 + (seq_len - 1) / self.chunk_size # Step size for each advancement dt = 1 / num_denoise_steps total_steps = int(max_t / dt) # Encode text condition (using cache) if self.use_text_cond and "text" in x: text_list = x["text"] # List[str] or List[List[str]] if isinstance(text_list[0], list): generated_length = [] text_end_list = x["feature_text_end"] full_text = [] all_text_context = [] for single_text_list, single_text_end_list in zip( text_list, text_end_list ): single_text_end_list = [0] + [ min(t, seq_len) for t in single_text_end_list ] generated_length.append(single_text_end_list[-1]) single_text_length_list = [ t - b for t, b in zip( single_text_end_list[1:], single_text_end_list[:-1] ) ] full_text.append( " ////////// ".join( [ f"{u} //dur:{t}" for u, t in zip( single_text_list, single_text_length_list ) ] ) ) single_text_context = self.encode_text_with_cache( single_text_list, device ) single_text_context = [ u.to(self.param_dtype) for u in single_text_context ] for u, duration in zip( single_text_context, single_text_length_list ): all_text_context.extend([u for _ in range(duration)]) all_text_context.extend( [ single_text_context[-1] for _ in range( seq_len + self.chunk_size - single_text_end_list[-1] ) ] ) else: generated_length = feature_length full_text = text_list all_text_context = self.encode_text_with_cache(text_list, device) all_text_context = [u.to(self.param_dtype) for u in all_text_context] else: generated_length = feature_length full_text = [""] * batch_size all_text_context = [""] * batch_size all_text_context = self.encode_text_with_cache(all_text_context, device) all_text_context = [u.to(self.param_dtype) for u in all_text_context] # Get empty text condition encoding (for CFG) text_null_list = [""] * batch_size text_null_context = self.encode_text_with_cache(text_null_list, device) text_null_context = [u.to(self.param_dtype) for u in text_null_context] # print(len(all_text_context), len(text_null_context)) commit_index = 0 # Progressively advance from t=0 to t=max_t for step in range(total_steps): # Current time step t = step * dt start_index = max(0, int(self.chunk_size * (t - 1)) + 1) end_index = int(self.chunk_size * t) + 1 time_steps = torch.full((batch_size,), t, device=device) # Calculate current noise schedule noise_level = self._get_noise_levels( device, seq_len + self.chunk_size, time_steps ) # (B, T) # Predict noise through WanModel noisy_input = [] for i in range(batch_size): noisy_input.append(generated[i, :, :end_index, ...]) predicted_result = self.model( noisy_input, noise_level * self.time_embedding_scale, all_text_context, seq_len + self.chunk_size, y=None, ) # (B, C, T, 1, 1) # Adjust using CFG if self.cfg_scale != 1.0: predicted_result_null = self.model( noisy_input, noise_level * self.time_embedding_scale, text_null_context, seq_len + self.chunk_size, y=None, ) # (B, C, T, 1, 1) predicted_result = [ self.cfg_scale * pv - (self.cfg_scale - 1) * pvn for pv, pvn in zip(predicted_result, predicted_result_null) ] for i in range(batch_size): predicted_result_i = predicted_result[i] # (C, input_length, 1, 1) if self.prediction_type == "vel": predicted_vel = predicted_result_i[:, start_index:end_index, ...] generated[i, :, start_index:end_index, ...] += predicted_vel * dt elif self.prediction_type == "x0": predicted_vel = ( predicted_result_i[:, start_index:end_index, ...] - generated[i, :, start_index:end_index, ...] ) / ( noise_level[i, start_index:end_index] .unsqueeze(0) .unsqueeze(-1) .unsqueeze(-1) ) generated[i, :, start_index:end_index, ...] += predicted_vel * dt elif self.prediction_type == "noise": predicted_vel = ( generated[i, :, start_index:end_index, ...] - predicted_result_i[:, start_index:end_index, ...] ) / ( 1 + dt - noise_level[i, start_index:end_index] .unsqueeze(0) .unsqueeze(-1) .unsqueeze(-1) ) generated[i, :, start_index:end_index, ...] += predicted_vel * dt if commit_index < start_index: output = generated[:, :, commit_index:start_index, ...] output = self.postprocess(output) # (B, T, C) y_hat_out = [] for i in range(batch_size): if commit_index < generated_length[i]: y_hat_out.append( output[i, : generated_length[i] - commit_index, ...] ) else: y_hat_out.append(None) out = {} out["generated"] = y_hat_out yield out commit_index = start_index output = generated[:, :, commit_index:, ...] output = self.postprocess(output) # (B, T_remain, C) y_hat_out = [] for i in range(batch_size): if commit_index < generated_length[i]: y_hat_out.append(output[i, : generated_length[i] - commit_index, ...]) else: y_hat_out.append(None) out = {} out["generated"] = y_hat_out yield out def init_generated(self, seq_len, batch_size=1, num_denoise_steps=None): self.seq_len = seq_len self.batch_size = batch_size if num_denoise_steps is None: self.num_denoise_steps = self.noise_steps else: self.num_denoise_steps = num_denoise_steps assert self.num_denoise_steps % self.chunk_size == 0 self.dt = 1 / self.num_denoise_steps self.current_step = 0 self.text_condition_list = [[] for _ in range(self.batch_size)] self.generated = torch.randn( self.batch_size, self.seq_len * 2 + self.chunk_size, self.input_dim ) self.generated = self.preprocess(self.generated) # (B, C, T, 1, 1) self.commit_index = 0 @torch.no_grad() def stream_generate_step(self, x, first_chunk=True): """ Streaming generation step - Diffusion Forcing inference Uses triangular noise schedule, progressively generating from left to right Generation process: 1. Start from t=0, gradually increase t 2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle 3. After each denoising step, t increases slightly and continues """ device = next(self.parameters()).device if first_chunk: self.generated = self.generated.to(device) # Encode text condition (using cache) if self.use_text_cond and "text" in x: text_list = x["text"] # List[str] new_text_context = self.encode_text_with_cache(text_list, device) new_text_context = [u.to(self.param_dtype) for u in new_text_context] else: new_text_context = [""] * self.batch_size new_text_context = self.encode_text_with_cache(new_text_context, device) new_text_context = [u.to(self.param_dtype) for u in new_text_context] # Get empty text condition encoding (for CFG) text_null_list = [""] * self.batch_size text_null_context = self.encode_text_with_cache(text_null_list, device) text_null_context = [u.to(self.param_dtype) for u in text_null_context] for i in range(self.batch_size): if first_chunk: self.text_condition_list[i].extend( [new_text_context[i]] * self.chunk_size ) else: self.text_condition_list[i].extend([new_text_context[i]]) end_step = ( (self.commit_index + self.chunk_size) * self.num_denoise_steps / self.chunk_size ) while self.current_step < end_step: current_time = self.current_step * self.dt start_index = max(0, int(self.chunk_size * (current_time - 1)) + 1) end_index = int(self.chunk_size * current_time) + 1 time_steps = torch.full((self.batch_size,), current_time, device=device) noise_level = self._get_noise_levels(device, end_index, time_steps)[ :, -self.seq_len : ] # (B, T) # Predict noise through WanModel noisy_input = [] for i in range(self.batch_size): noisy_input.append( self.generated[i, :, :end_index, ...][:, -self.seq_len :] ) # (C, T, 1, 1) text_condition = [] for i in range(self.batch_size): text_condition.extend( self.text_condition_list[i][:end_index][-self.seq_len :] ) # (T, D, 4096) # print("////////////////////") # print("current step: ", self.current_step) # print("chunk size: ", self.chunk_size) # print("start_index: ", start_index) # print("end_index: ", end_index) # print("noisy_input shape: ", noisy_input[0].shape) # print("noise_level: ", noise_level[0, start_index:end_index]) # print("text_condition shape: ", len(text_condition)) # print("commit_index: ", self.commit_index) # print("////////////////////") predicted_result = self.model( noisy_input, noise_level * self.time_embedding_scale, text_condition, min(end_index, self.seq_len), y=None, ) # (B, C, T, 1, 1) # Adjust using CFG if self.cfg_scale != 1.0: predicted_result_null = self.model( noisy_input, noise_level * self.time_embedding_scale, text_null_context, min(end_index, self.seq_len), y=None, ) # (B, C, T, 1, 1) predicted_result = [ self.cfg_scale * pv - (self.cfg_scale - 1) * pvn for pv, pvn in zip(predicted_result, predicted_result_null) ] for i in range(self.batch_size): predicted_result_i = predicted_result[i] # (C, input_length, 1, 1) if end_index > self.seq_len: predicted_result_i = torch.cat( [ torch.zeros( predicted_result_i.shape[0], end_index - self.seq_len, predicted_result_i.shape[2], predicted_result_i.shape[3], device=device, ), predicted_result_i, ], dim=1, ) if self.prediction_type == "vel": predicted_vel = predicted_result_i[:, start_index:end_index, ...] self.generated[i, :, start_index:end_index, ...] += ( predicted_vel * self.dt ) elif self.prediction_type == "x0": predicted_vel = ( predicted_result_i[:, start_index:end_index, ...] - self.generated[i, :, start_index:end_index, ...] ) / ( noise_level[i, start_index:end_index] .unsqueeze(0) .unsqueeze(-1) .unsqueeze(-1) ) self.generated[i, :, start_index:end_index, ...] += ( predicted_vel * self.dt ) elif self.prediction_type == "noise": predicted_vel = ( self.generated[i, :, start_index:end_index, ...] - predicted_result_i[:, start_index:end_index, ...] ) / ( 1 + self.dt - noise_level[i, start_index:end_index] .unsqueeze(0) .unsqueeze(-1) .unsqueeze(-1) ) self.generated[i, :, start_index:end_index, ...] += ( predicted_vel * self.dt ) self.current_step += 1 output = self.generated[:, :, self.commit_index : self.commit_index + 1, ...] output = self.postprocess(output) # (B, 1, C) out = {} out["generated"] = output self.commit_index += 1 if self.commit_index == self.seq_len * 2: self.generated = torch.cat( [ self.generated[:, :, self.seq_len :, ...], torch.randn( self.batch_size, self.input_dim, self.seq_len, 1, 1, device=device, ), ], dim=2, ) self.current_step -= self.seq_len * self.num_denoise_steps / self.chunk_size self.commit_index -= self.seq_len for i in range(self.batch_size): self.text_condition_list[i] = self.text_condition_list[i][ self.seq_len : ] return out