|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from .tools.t5 import T5EncoderModel |
|
|
from .tools.wan_model import WanModel |
|
|
|
|
|
|
|
|
class DiffForcingWanModel(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
checkpoint_path="deps/t5_umt5-xxl-enc-bf16/models_t5_umt5-xxl-enc-bf16.pth", |
|
|
tokenizer_path="deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl", |
|
|
input_dim=256, |
|
|
hidden_dim=1024, |
|
|
ffn_dim=2048, |
|
|
freq_dim=256, |
|
|
num_heads=8, |
|
|
num_layers=8, |
|
|
time_embedding_scale=1.0, |
|
|
chunk_size=5, |
|
|
noise_steps=10, |
|
|
use_text_cond=True, |
|
|
text_len=128, |
|
|
drop_out=0.1, |
|
|
cfg_scale=5.0, |
|
|
prediction_type="vel", |
|
|
causal=False, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.input_dim = input_dim |
|
|
self.hidden_dim = hidden_dim |
|
|
self.ffn_dim = ffn_dim |
|
|
self.freq_dim = freq_dim |
|
|
self.num_heads = num_heads |
|
|
self.num_layers = num_layers |
|
|
self.time_embedding_scale = time_embedding_scale |
|
|
self.chunk_size = chunk_size |
|
|
self.noise_steps = noise_steps |
|
|
self.use_text_cond = use_text_cond |
|
|
self.drop_out = drop_out |
|
|
self.cfg_scale = cfg_scale |
|
|
self.prediction_type = prediction_type |
|
|
self.causal = causal |
|
|
|
|
|
self.text_dim = 4096 |
|
|
self.text_len = text_len |
|
|
self.text_encoder = T5EncoderModel( |
|
|
text_len=self.text_len, |
|
|
dtype=torch.bfloat16, |
|
|
device=torch.device("cpu"), |
|
|
checkpoint_path=checkpoint_path, |
|
|
tokenizer_path=tokenizer_path, |
|
|
shard_fn=None, |
|
|
) |
|
|
|
|
|
|
|
|
self.text_cache = {} |
|
|
self.model = WanModel( |
|
|
model_type="t2v", |
|
|
patch_size=(1, 1, 1), |
|
|
text_len=self.text_len, |
|
|
in_dim=self.input_dim, |
|
|
dim=self.hidden_dim, |
|
|
ffn_dim=self.ffn_dim, |
|
|
freq_dim=self.freq_dim, |
|
|
text_dim=self.text_dim, |
|
|
out_dim=self.input_dim, |
|
|
num_heads=self.num_heads, |
|
|
num_layers=self.num_layers, |
|
|
window_size=(-1, -1), |
|
|
qk_norm=True, |
|
|
cross_attn_norm=True, |
|
|
eps=1e-6, |
|
|
causal=self.causal, |
|
|
) |
|
|
self.param_dtype = torch.float32 |
|
|
|
|
|
def encode_text_with_cache(self, text_list, device): |
|
|
"""Encode text using cache |
|
|
Args: |
|
|
text_list: List[str], list of texts |
|
|
device: torch.device |
|
|
Returns: |
|
|
List[Tensor]: List of encoded text features |
|
|
""" |
|
|
text_features = [] |
|
|
indices_to_encode = [] |
|
|
texts_to_encode = [] |
|
|
|
|
|
|
|
|
for i, text in enumerate(text_list): |
|
|
if text in self.text_cache: |
|
|
|
|
|
cached_feature = self.text_cache[text].to(device) |
|
|
text_features.append(cached_feature) |
|
|
else: |
|
|
|
|
|
text_features.append(None) |
|
|
indices_to_encode.append(i) |
|
|
texts_to_encode.append(text) |
|
|
|
|
|
|
|
|
if texts_to_encode: |
|
|
self.text_encoder.model.to(device) |
|
|
encoded = self.text_encoder(texts_to_encode, device) |
|
|
|
|
|
|
|
|
for idx, text, feature in zip(indices_to_encode, texts_to_encode, encoded): |
|
|
|
|
|
self.text_cache[text] = feature.cpu() |
|
|
text_features[idx] = feature |
|
|
|
|
|
return text_features |
|
|
|
|
|
def preprocess(self, x): |
|
|
|
|
|
x = x.permute(0, 2, 1)[:, :, :, None, None] |
|
|
return x |
|
|
|
|
|
def postprocess(self, x): |
|
|
|
|
|
x = x.permute(0, 2, 1, 3, 4).contiguous().view(x.size(0), x.size(2), -1) |
|
|
return x |
|
|
|
|
|
def _get_noise_levels(self, device, seq_len, time_steps): |
|
|
"""Get noise levels""" |
|
|
|
|
|
noise_level = torch.clamp( |
|
|
1 |
|
|
+ torch.arange(seq_len, device=device) / self.chunk_size |
|
|
- time_steps.unsqueeze(1), |
|
|
min=0.0, |
|
|
max=1.0, |
|
|
) |
|
|
return noise_level |
|
|
|
|
|
def add_noise(self, x, noise_level): |
|
|
"""Add noise |
|
|
Args: |
|
|
x: (B, T, D) |
|
|
noise_level: (B, T) |
|
|
""" |
|
|
noise = torch.randn_like(x) |
|
|
|
|
|
noise_level = noise_level.unsqueeze(-1) |
|
|
noisy_x = x * (1 - noise_level) + noise_level * noise |
|
|
return noisy_x, noise |
|
|
|
|
|
def forward(self, x): |
|
|
feature = x["feature"] |
|
|
feature_length = x["feature_length"] |
|
|
batch_size, seq_len, _ = feature.shape |
|
|
device = feature.device |
|
|
|
|
|
|
|
|
time_steps = [] |
|
|
for i in range(batch_size): |
|
|
valid_len = feature_length[i].item() |
|
|
|
|
|
max_time = valid_len / self.chunk_size |
|
|
|
|
|
time_steps.append(torch.FloatTensor(1).uniform_(0, max_time).item()) |
|
|
time_steps = torch.tensor(time_steps, device=device) |
|
|
noise_level = self._get_noise_levels(device, seq_len, time_steps) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
noisy_feature, noise = self.add_noise(feature, noise_level) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
feature = self.preprocess(feature) |
|
|
noisy_feature = self.preprocess(noisy_feature) |
|
|
noise = self.preprocess(noise) |
|
|
|
|
|
feature_ref = [] |
|
|
noise_ref = [] |
|
|
noisy_feature_input = [] |
|
|
for i in range(batch_size): |
|
|
t = time_steps[i].item() |
|
|
end_index = int(self.chunk_size * t) + 1 |
|
|
valid_len = feature_length[i].item() |
|
|
end_index = min(valid_len, end_index) |
|
|
feature_ref.append(feature[i, :, :end_index, ...]) |
|
|
noise_ref.append(noise[i, :, :end_index, ...]) |
|
|
noisy_feature_input.append(noisy_feature[i, :, :end_index, ...]) |
|
|
|
|
|
|
|
|
if self.use_text_cond and "text" in x: |
|
|
text_list = x["text"] |
|
|
if isinstance(text_list[0], list): |
|
|
text_end_list = x["feature_text_end"] |
|
|
all_text_context = [] |
|
|
for single_text_list, single_text_end_list in zip( |
|
|
text_list, text_end_list |
|
|
): |
|
|
if np.random.rand() > self.drop_out: |
|
|
single_text_list = [""] |
|
|
single_text_end_list = [0, seq_len] |
|
|
else: |
|
|
single_text_end_list = [0] + [ |
|
|
min(t, seq_len) for t in single_text_end_list |
|
|
] |
|
|
single_text_length_list = [ |
|
|
t - b |
|
|
for t, b in zip( |
|
|
single_text_end_list[1:], single_text_end_list[:-1] |
|
|
) |
|
|
] |
|
|
single_text_context = self.encode_text_with_cache( |
|
|
single_text_list, device |
|
|
) |
|
|
single_text_context = [ |
|
|
u.to(self.param_dtype) for u in single_text_context |
|
|
] |
|
|
for u, duration in zip( |
|
|
single_text_context, single_text_length_list |
|
|
): |
|
|
all_text_context.extend([u for _ in range(duration)]) |
|
|
all_text_context.extend( |
|
|
[ |
|
|
single_text_context[-1] |
|
|
for _ in range(seq_len - single_text_end_list[-1]) |
|
|
] |
|
|
) |
|
|
else: |
|
|
all_text_context = [ |
|
|
(u if np.random.rand() > self.drop_out else "") for u in text_list |
|
|
] |
|
|
all_text_context = self.encode_text_with_cache(all_text_context, device) |
|
|
all_text_context = [u.to(self.param_dtype) for u in all_text_context] |
|
|
else: |
|
|
all_text_context = [""] * batch_size |
|
|
all_text_context = self.encode_text_with_cache(all_text_context, device) |
|
|
all_text_context = [u.to(self.param_dtype) for u in all_text_context] |
|
|
|
|
|
|
|
|
predicted_result = self.model( |
|
|
noisy_feature_input, |
|
|
noise_level * self.time_embedding_scale, |
|
|
all_text_context, |
|
|
seq_len, |
|
|
y=None, |
|
|
) |
|
|
|
|
|
loss = 0.0 |
|
|
for b in range(batch_size): |
|
|
if self.prediction_type == "vel": |
|
|
vel = feature_ref[b] - noise_ref[b] |
|
|
squared_error = ( |
|
|
predicted_result[b][:, -self.chunk_size :, ...] |
|
|
- vel[:, -self.chunk_size :, ...] |
|
|
) ** 2 |
|
|
elif self.prediction_type == "x0": |
|
|
squared_error = ( |
|
|
predicted_result[b][:, -self.chunk_size :, ...] |
|
|
- feature_ref[b][:, -self.chunk_size :, ...] |
|
|
) ** 2 |
|
|
elif self.prediction_type == "noise": |
|
|
squared_error = ( |
|
|
predicted_result[b][:, -self.chunk_size :, ...] |
|
|
- noise_ref[b][:, -self.chunk_size :, ...] |
|
|
) ** 2 |
|
|
sample_loss = squared_error.sum().mean() |
|
|
loss += sample_loss |
|
|
loss = loss / batch_size |
|
|
|
|
|
loss_dict = {"total": loss, "mse": loss} |
|
|
return loss_dict |
|
|
|
|
|
def generate(self, x, num_denoise_steps=None): |
|
|
""" |
|
|
Generation - Diffusion Forcing inference |
|
|
Uses triangular noise schedule, progressively generating from left to right |
|
|
|
|
|
Generation process: |
|
|
1. Start from t=0, gradually increase t |
|
|
2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle |
|
|
3. After each denoising step, t increases slightly and continues |
|
|
""" |
|
|
feature_length = x["feature_length"] |
|
|
batch_size = len(feature_length) |
|
|
seq_len = max(feature_length).item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if num_denoise_steps is None: |
|
|
num_denoise_steps = self.noise_steps |
|
|
assert num_denoise_steps % self.chunk_size == 0 |
|
|
|
|
|
device = next(self.parameters()).device |
|
|
|
|
|
|
|
|
generated = torch.randn( |
|
|
batch_size, seq_len + self.chunk_size, self.input_dim, device=device |
|
|
) |
|
|
generated = self.preprocess(generated) |
|
|
|
|
|
|
|
|
max_t = 1 + (seq_len - 1) / self.chunk_size |
|
|
|
|
|
|
|
|
dt = 1 / num_denoise_steps |
|
|
total_steps = int(max_t / dt) |
|
|
|
|
|
|
|
|
if self.use_text_cond and "text" in x: |
|
|
text_list = x["text"] |
|
|
if isinstance(text_list[0], list): |
|
|
generated_length = [] |
|
|
text_end_list = x["feature_text_end"] |
|
|
full_text = [] |
|
|
all_text_context = [] |
|
|
for single_text_list, single_text_end_list in zip( |
|
|
text_list, text_end_list |
|
|
): |
|
|
single_text_end_list = [0] + [ |
|
|
min(t, seq_len) for t in single_text_end_list |
|
|
] |
|
|
generated_length.append(single_text_end_list[-1]) |
|
|
single_text_length_list = [ |
|
|
t - b |
|
|
for t, b in zip( |
|
|
single_text_end_list[1:], single_text_end_list[:-1] |
|
|
) |
|
|
] |
|
|
full_text.append( |
|
|
" ////////// ".join( |
|
|
[ |
|
|
f"{u} //dur:{t}" |
|
|
for u, t in zip( |
|
|
single_text_list, single_text_length_list |
|
|
) |
|
|
] |
|
|
) |
|
|
) |
|
|
single_text_context = self.encode_text_with_cache( |
|
|
single_text_list, device |
|
|
) |
|
|
single_text_context = [ |
|
|
u.to(self.param_dtype) for u in single_text_context |
|
|
] |
|
|
for u, duration in zip( |
|
|
single_text_context, single_text_length_list |
|
|
): |
|
|
all_text_context.extend([u for _ in range(duration)]) |
|
|
all_text_context.extend( |
|
|
[ |
|
|
single_text_context[-1] |
|
|
for _ in range( |
|
|
seq_len + self.chunk_size - single_text_end_list[-1] |
|
|
) |
|
|
] |
|
|
) |
|
|
else: |
|
|
generated_length = feature_length |
|
|
full_text = text_list |
|
|
all_text_context = self.encode_text_with_cache(text_list, device) |
|
|
all_text_context = [u.to(self.param_dtype) for u in all_text_context] |
|
|
else: |
|
|
generated_length = feature_length |
|
|
full_text = [""] * batch_size |
|
|
all_text_context = [""] * batch_size |
|
|
all_text_context = self.encode_text_with_cache(all_text_context, device) |
|
|
all_text_context = [u.to(self.param_dtype) for u in all_text_context] |
|
|
|
|
|
|
|
|
text_null_list = [""] * batch_size |
|
|
text_null_context = self.encode_text_with_cache(text_null_list, device) |
|
|
text_null_context = [u.to(self.param_dtype) for u in text_null_context] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for step in range(total_steps): |
|
|
|
|
|
t = step * dt |
|
|
start_index = max(0, int(self.chunk_size * (t - 1)) + 1) |
|
|
end_index = int(self.chunk_size * t) + 1 |
|
|
time_steps = torch.full((batch_size,), t, device=device) |
|
|
|
|
|
|
|
|
noise_level = self._get_noise_levels( |
|
|
device, seq_len + self.chunk_size, time_steps |
|
|
) |
|
|
|
|
|
|
|
|
noisy_input = [] |
|
|
for i in range(batch_size): |
|
|
noisy_input.append(generated[i, :, :end_index, ...]) |
|
|
|
|
|
predicted_result = self.model( |
|
|
noisy_input, |
|
|
noise_level * self.time_embedding_scale, |
|
|
all_text_context, |
|
|
seq_len + self.chunk_size, |
|
|
y=None, |
|
|
) |
|
|
|
|
|
|
|
|
if self.cfg_scale != 1.0: |
|
|
predicted_result_null = self.model( |
|
|
noisy_input, |
|
|
noise_level * self.time_embedding_scale, |
|
|
text_null_context, |
|
|
seq_len + self.chunk_size, |
|
|
y=None, |
|
|
) |
|
|
predicted_result = [ |
|
|
self.cfg_scale * pv - (self.cfg_scale - 1) * pvn |
|
|
for pv, pvn in zip(predicted_result, predicted_result_null) |
|
|
] |
|
|
|
|
|
for i in range(batch_size): |
|
|
predicted_result_i = predicted_result[i] |
|
|
if self.prediction_type == "vel": |
|
|
predicted_vel = predicted_result_i[:, start_index:end_index, ...] |
|
|
generated[i, :, start_index:end_index, ...] += predicted_vel * dt |
|
|
elif self.prediction_type == "x0": |
|
|
predicted_vel = ( |
|
|
predicted_result_i[:, start_index:end_index, ...] |
|
|
- generated[i, :, start_index:end_index, ...] |
|
|
) / ( |
|
|
noise_level[i, start_index:end_index] |
|
|
.unsqueeze(0) |
|
|
.unsqueeze(-1) |
|
|
.unsqueeze(-1) |
|
|
) |
|
|
generated[i, :, start_index:end_index, ...] += predicted_vel * dt |
|
|
elif self.prediction_type == "noise": |
|
|
predicted_vel = ( |
|
|
generated[i, :, start_index:end_index, ...] |
|
|
- predicted_result_i[:, start_index:end_index, ...] |
|
|
) / ( |
|
|
1 |
|
|
+ dt |
|
|
- noise_level[i, start_index:end_index] |
|
|
.unsqueeze(0) |
|
|
.unsqueeze(-1) |
|
|
.unsqueeze(-1) |
|
|
) |
|
|
generated[i, :, start_index:end_index, ...] += predicted_vel * dt |
|
|
|
|
|
generated = self.postprocess(generated) |
|
|
y_hat_out = [] |
|
|
for i in range(batch_size): |
|
|
|
|
|
single_generated = generated[i, : generated_length[i], :] |
|
|
y_hat_out.append(single_generated) |
|
|
out = {} |
|
|
out["generated"] = y_hat_out |
|
|
out["text"] = full_text |
|
|
|
|
|
return out |
|
|
|
|
|
@torch.no_grad() |
|
|
def stream_generate(self, x, num_denoise_steps=None): |
|
|
""" |
|
|
Streaming generation - Diffusion Forcing inference |
|
|
Uses triangular noise schedule, progressively generating from left to right |
|
|
|
|
|
Generation process: |
|
|
1. Start from t=0, gradually increase t |
|
|
2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle |
|
|
3. After each denoising step, t increases slightly and continues |
|
|
""" |
|
|
feature_length = x["feature_length"] |
|
|
batch_size = len(feature_length) |
|
|
seq_len = max(feature_length).item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if num_denoise_steps is None: |
|
|
num_denoise_steps = self.noise_steps |
|
|
assert num_denoise_steps % self.chunk_size == 0 |
|
|
|
|
|
device = next(self.parameters()).device |
|
|
|
|
|
|
|
|
generated = torch.randn( |
|
|
batch_size, seq_len + self.chunk_size, self.input_dim, device=device |
|
|
) |
|
|
generated = self.preprocess(generated) |
|
|
|
|
|
|
|
|
max_t = 1 + (seq_len - 1) / self.chunk_size |
|
|
|
|
|
|
|
|
dt = 1 / num_denoise_steps |
|
|
total_steps = int(max_t / dt) |
|
|
|
|
|
|
|
|
if self.use_text_cond and "text" in x: |
|
|
text_list = x["text"] |
|
|
if isinstance(text_list[0], list): |
|
|
generated_length = [] |
|
|
text_end_list = x["feature_text_end"] |
|
|
full_text = [] |
|
|
all_text_context = [] |
|
|
for single_text_list, single_text_end_list in zip( |
|
|
text_list, text_end_list |
|
|
): |
|
|
single_text_end_list = [0] + [ |
|
|
min(t, seq_len) for t in single_text_end_list |
|
|
] |
|
|
generated_length.append(single_text_end_list[-1]) |
|
|
single_text_length_list = [ |
|
|
t - b |
|
|
for t, b in zip( |
|
|
single_text_end_list[1:], single_text_end_list[:-1] |
|
|
) |
|
|
] |
|
|
full_text.append( |
|
|
" ////////// ".join( |
|
|
[ |
|
|
f"{u} //dur:{t}" |
|
|
for u, t in zip( |
|
|
single_text_list, single_text_length_list |
|
|
) |
|
|
] |
|
|
) |
|
|
) |
|
|
single_text_context = self.encode_text_with_cache( |
|
|
single_text_list, device |
|
|
) |
|
|
single_text_context = [ |
|
|
u.to(self.param_dtype) for u in single_text_context |
|
|
] |
|
|
for u, duration in zip( |
|
|
single_text_context, single_text_length_list |
|
|
): |
|
|
all_text_context.extend([u for _ in range(duration)]) |
|
|
all_text_context.extend( |
|
|
[ |
|
|
single_text_context[-1] |
|
|
for _ in range( |
|
|
seq_len + self.chunk_size - single_text_end_list[-1] |
|
|
) |
|
|
] |
|
|
) |
|
|
else: |
|
|
generated_length = feature_length |
|
|
full_text = text_list |
|
|
all_text_context = self.encode_text_with_cache(text_list, device) |
|
|
all_text_context = [u.to(self.param_dtype) for u in all_text_context] |
|
|
else: |
|
|
generated_length = feature_length |
|
|
full_text = [""] * batch_size |
|
|
all_text_context = [""] * batch_size |
|
|
all_text_context = self.encode_text_with_cache(all_text_context, device) |
|
|
all_text_context = [u.to(self.param_dtype) for u in all_text_context] |
|
|
|
|
|
|
|
|
text_null_list = [""] * batch_size |
|
|
text_null_context = self.encode_text_with_cache(text_null_list, device) |
|
|
text_null_context = [u.to(self.param_dtype) for u in text_null_context] |
|
|
|
|
|
|
|
|
|
|
|
commit_index = 0 |
|
|
|
|
|
for step in range(total_steps): |
|
|
|
|
|
t = step * dt |
|
|
start_index = max(0, int(self.chunk_size * (t - 1)) + 1) |
|
|
end_index = int(self.chunk_size * t) + 1 |
|
|
time_steps = torch.full((batch_size,), t, device=device) |
|
|
|
|
|
|
|
|
noise_level = self._get_noise_levels( |
|
|
device, seq_len + self.chunk_size, time_steps |
|
|
) |
|
|
|
|
|
|
|
|
noisy_input = [] |
|
|
for i in range(batch_size): |
|
|
noisy_input.append(generated[i, :, :end_index, ...]) |
|
|
|
|
|
predicted_result = self.model( |
|
|
noisy_input, |
|
|
noise_level * self.time_embedding_scale, |
|
|
all_text_context, |
|
|
seq_len + self.chunk_size, |
|
|
y=None, |
|
|
) |
|
|
|
|
|
|
|
|
if self.cfg_scale != 1.0: |
|
|
predicted_result_null = self.model( |
|
|
noisy_input, |
|
|
noise_level * self.time_embedding_scale, |
|
|
text_null_context, |
|
|
seq_len + self.chunk_size, |
|
|
y=None, |
|
|
) |
|
|
predicted_result = [ |
|
|
self.cfg_scale * pv - (self.cfg_scale - 1) * pvn |
|
|
for pv, pvn in zip(predicted_result, predicted_result_null) |
|
|
] |
|
|
|
|
|
for i in range(batch_size): |
|
|
predicted_result_i = predicted_result[i] |
|
|
if self.prediction_type == "vel": |
|
|
predicted_vel = predicted_result_i[:, start_index:end_index, ...] |
|
|
generated[i, :, start_index:end_index, ...] += predicted_vel * dt |
|
|
elif self.prediction_type == "x0": |
|
|
predicted_vel = ( |
|
|
predicted_result_i[:, start_index:end_index, ...] |
|
|
- generated[i, :, start_index:end_index, ...] |
|
|
) / ( |
|
|
noise_level[i, start_index:end_index] |
|
|
.unsqueeze(0) |
|
|
.unsqueeze(-1) |
|
|
.unsqueeze(-1) |
|
|
) |
|
|
generated[i, :, start_index:end_index, ...] += predicted_vel * dt |
|
|
elif self.prediction_type == "noise": |
|
|
predicted_vel = ( |
|
|
generated[i, :, start_index:end_index, ...] |
|
|
- predicted_result_i[:, start_index:end_index, ...] |
|
|
) / ( |
|
|
1 |
|
|
+ dt |
|
|
- noise_level[i, start_index:end_index] |
|
|
.unsqueeze(0) |
|
|
.unsqueeze(-1) |
|
|
.unsqueeze(-1) |
|
|
) |
|
|
generated[i, :, start_index:end_index, ...] += predicted_vel * dt |
|
|
|
|
|
if commit_index < start_index: |
|
|
output = generated[:, :, commit_index:start_index, ...] |
|
|
output = self.postprocess(output) |
|
|
y_hat_out = [] |
|
|
for i in range(batch_size): |
|
|
if commit_index < generated_length[i]: |
|
|
y_hat_out.append( |
|
|
output[i, : generated_length[i] - commit_index, ...] |
|
|
) |
|
|
else: |
|
|
y_hat_out.append(None) |
|
|
|
|
|
out = {} |
|
|
out["generated"] = y_hat_out |
|
|
yield out |
|
|
commit_index = start_index |
|
|
|
|
|
output = generated[:, :, commit_index:, ...] |
|
|
output = self.postprocess(output) |
|
|
y_hat_out = [] |
|
|
for i in range(batch_size): |
|
|
if commit_index < generated_length[i]: |
|
|
y_hat_out.append(output[i, : generated_length[i] - commit_index, ...]) |
|
|
else: |
|
|
y_hat_out.append(None) |
|
|
out = {} |
|
|
out["generated"] = y_hat_out |
|
|
yield out |
|
|
|
|
|
def init_generated(self, seq_len, batch_size=1, num_denoise_steps=None): |
|
|
self.seq_len = seq_len |
|
|
self.batch_size = batch_size |
|
|
if num_denoise_steps is None: |
|
|
self.num_denoise_steps = self.noise_steps |
|
|
else: |
|
|
self.num_denoise_steps = num_denoise_steps |
|
|
assert self.num_denoise_steps % self.chunk_size == 0 |
|
|
self.dt = 1 / self.num_denoise_steps |
|
|
self.current_step = 0 |
|
|
self.text_condition_list = [[] for _ in range(self.batch_size)] |
|
|
self.generated = torch.randn( |
|
|
self.batch_size, self.seq_len * 2 + self.chunk_size, self.input_dim |
|
|
) |
|
|
self.generated = self.preprocess(self.generated) |
|
|
self.commit_index = 0 |
|
|
|
|
|
@torch.no_grad() |
|
|
def stream_generate_step(self, x, first_chunk=True): |
|
|
""" |
|
|
Streaming generation step - Diffusion Forcing inference |
|
|
Uses triangular noise schedule, progressively generating from left to right |
|
|
|
|
|
Generation process: |
|
|
1. Start from t=0, gradually increase t |
|
|
2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle |
|
|
3. After each denoising step, t increases slightly and continues |
|
|
""" |
|
|
|
|
|
device = next(self.parameters()).device |
|
|
if first_chunk: |
|
|
self.generated = self.generated.to(device) |
|
|
|
|
|
|
|
|
if self.use_text_cond and "text" in x: |
|
|
text_list = x["text"] |
|
|
new_text_context = self.encode_text_with_cache(text_list, device) |
|
|
new_text_context = [u.to(self.param_dtype) for u in new_text_context] |
|
|
else: |
|
|
new_text_context = [""] * self.batch_size |
|
|
new_text_context = self.encode_text_with_cache(new_text_context, device) |
|
|
new_text_context = [u.to(self.param_dtype) for u in new_text_context] |
|
|
|
|
|
|
|
|
text_null_list = [""] * self.batch_size |
|
|
text_null_context = self.encode_text_with_cache(text_null_list, device) |
|
|
text_null_context = [u.to(self.param_dtype) for u in text_null_context] |
|
|
|
|
|
for i in range(self.batch_size): |
|
|
if first_chunk: |
|
|
self.text_condition_list[i].extend( |
|
|
[new_text_context[i]] * self.chunk_size |
|
|
) |
|
|
else: |
|
|
self.text_condition_list[i].extend([new_text_context[i]]) |
|
|
|
|
|
end_step = ( |
|
|
(self.commit_index + self.chunk_size) |
|
|
* self.num_denoise_steps |
|
|
/ self.chunk_size |
|
|
) |
|
|
while self.current_step < end_step: |
|
|
current_time = self.current_step * self.dt |
|
|
start_index = max(0, int(self.chunk_size * (current_time - 1)) + 1) |
|
|
end_index = int(self.chunk_size * current_time) + 1 |
|
|
time_steps = torch.full((self.batch_size,), current_time, device=device) |
|
|
|
|
|
noise_level = self._get_noise_levels(device, end_index, time_steps)[ |
|
|
:, -self.seq_len : |
|
|
] |
|
|
|
|
|
|
|
|
noisy_input = [] |
|
|
for i in range(self.batch_size): |
|
|
noisy_input.append( |
|
|
self.generated[i, :, :end_index, ...][:, -self.seq_len :] |
|
|
) |
|
|
|
|
|
text_condition = [] |
|
|
for i in range(self.batch_size): |
|
|
text_condition.extend( |
|
|
self.text_condition_list[i][:end_index][-self.seq_len :] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
predicted_result = self.model( |
|
|
noisy_input, |
|
|
noise_level * self.time_embedding_scale, |
|
|
text_condition, |
|
|
min(end_index, self.seq_len), |
|
|
y=None, |
|
|
) |
|
|
|
|
|
|
|
|
if self.cfg_scale != 1.0: |
|
|
predicted_result_null = self.model( |
|
|
noisy_input, |
|
|
noise_level * self.time_embedding_scale, |
|
|
text_null_context, |
|
|
min(end_index, self.seq_len), |
|
|
y=None, |
|
|
) |
|
|
predicted_result = [ |
|
|
self.cfg_scale * pv - (self.cfg_scale - 1) * pvn |
|
|
for pv, pvn in zip(predicted_result, predicted_result_null) |
|
|
] |
|
|
|
|
|
for i in range(self.batch_size): |
|
|
predicted_result_i = predicted_result[i] |
|
|
if end_index > self.seq_len: |
|
|
predicted_result_i = torch.cat( |
|
|
[ |
|
|
torch.zeros( |
|
|
predicted_result_i.shape[0], |
|
|
end_index - self.seq_len, |
|
|
predicted_result_i.shape[2], |
|
|
predicted_result_i.shape[3], |
|
|
device=device, |
|
|
), |
|
|
predicted_result_i, |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
if self.prediction_type == "vel": |
|
|
predicted_vel = predicted_result_i[:, start_index:end_index, ...] |
|
|
self.generated[i, :, start_index:end_index, ...] += ( |
|
|
predicted_vel * self.dt |
|
|
) |
|
|
elif self.prediction_type == "x0": |
|
|
predicted_vel = ( |
|
|
predicted_result_i[:, start_index:end_index, ...] |
|
|
- self.generated[i, :, start_index:end_index, ...] |
|
|
) / ( |
|
|
noise_level[i, start_index:end_index] |
|
|
.unsqueeze(0) |
|
|
.unsqueeze(-1) |
|
|
.unsqueeze(-1) |
|
|
) |
|
|
self.generated[i, :, start_index:end_index, ...] += ( |
|
|
predicted_vel * self.dt |
|
|
) |
|
|
elif self.prediction_type == "noise": |
|
|
predicted_vel = ( |
|
|
self.generated[i, :, start_index:end_index, ...] |
|
|
- predicted_result_i[:, start_index:end_index, ...] |
|
|
) / ( |
|
|
1 |
|
|
+ self.dt |
|
|
- noise_level[i, start_index:end_index] |
|
|
.unsqueeze(0) |
|
|
.unsqueeze(-1) |
|
|
.unsqueeze(-1) |
|
|
) |
|
|
self.generated[i, :, start_index:end_index, ...] += ( |
|
|
predicted_vel * self.dt |
|
|
) |
|
|
self.current_step += 1 |
|
|
output = self.generated[:, :, self.commit_index : self.commit_index + 1, ...] |
|
|
output = self.postprocess(output) |
|
|
out = {} |
|
|
out["generated"] = output |
|
|
self.commit_index += 1 |
|
|
|
|
|
if self.commit_index == self.seq_len * 2: |
|
|
self.generated = torch.cat( |
|
|
[ |
|
|
self.generated[:, :, self.seq_len :, ...], |
|
|
torch.randn( |
|
|
self.batch_size, |
|
|
self.input_dim, |
|
|
self.seq_len, |
|
|
1, |
|
|
1, |
|
|
device=device, |
|
|
), |
|
|
], |
|
|
dim=2, |
|
|
) |
|
|
self.current_step -= self.seq_len * self.num_denoise_steps / self.chunk_size |
|
|
self.commit_index -= self.seq_len |
|
|
for i in range(self.batch_size): |
|
|
self.text_condition_list[i] = self.text_condition_list[i][ |
|
|
self.seq_len : |
|
|
] |
|
|
return out |
|
|
|