DeepBeepMeep
commited on
Commit
·
03085c8
1
Parent(s):
e420cd0
optimization for i2v with CausVid
Browse files- hyvideo/modules/models.py +1 -2
- wan/image2video.py +23 -24
- wan/modules/model.py +3 -4
hyvideo/modules/models.py
CHANGED
|
@@ -492,8 +492,7 @@ class MMSingleStreamBlock(nn.Module):
|
|
| 492 |
return img, txt
|
| 493 |
|
| 494 |
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
| 495 |
-
|
| 496 |
-
def preprocess_loras(model_filename, sd):
|
| 497 |
if not "i2v" in model_filename:
|
| 498 |
return sd
|
| 499 |
new_sd = {}
|
|
|
|
| 492 |
return img, txt
|
| 493 |
|
| 494 |
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
| 495 |
+
def preprocess_loras(self, model_filename, sd):
|
|
|
|
| 496 |
if not "i2v" in model_filename:
|
| 497 |
return sd
|
| 498 |
new_sd = {}
|
wan/image2video.py
CHANGED
|
@@ -330,8 +330,11 @@ class WanI2V:
|
|
| 330 |
'current_step' :i,
|
| 331 |
})
|
| 332 |
|
| 333 |
-
|
| 334 |
-
|
|
|
|
|
|
|
|
|
|
| 335 |
if audio_proj == None:
|
| 336 |
noise_pred_cond, noise_pred_uncond = self.model(
|
| 337 |
[latent_model_input, latent_model_input],
|
|
@@ -347,13 +350,7 @@ class WanI2V:
|
|
| 347 |
if self._interrupt:
|
| 348 |
return None
|
| 349 |
else:
|
| 350 |
-
noise_pred_cond = self.model(
|
| 351 |
-
[latent_model_input],
|
| 352 |
-
context=[context],
|
| 353 |
-
audio_scale = None if audio_scale == None else [audio_scale],
|
| 354 |
-
x_id=0,
|
| 355 |
-
**kwargs,
|
| 356 |
-
)[0]
|
| 357 |
if self._interrupt:
|
| 358 |
return None
|
| 359 |
|
|
@@ -377,22 +374,24 @@ class WanI2V:
|
|
| 377 |
return None
|
| 378 |
del latent_model_input
|
| 379 |
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
else:
|
| 391 |
-
noise_pred_uncond
|
| 392 |
-
|
| 393 |
-
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
|
| 394 |
-
else:
|
| 395 |
-
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
|
| 396 |
noise_pred_uncond, noise_pred_noaudio = None, None
|
| 397 |
temp_x0 = sample_scheduler.step(
|
| 398 |
noise_pred.unsqueeze(0),
|
|
|
|
| 330 |
'current_step' :i,
|
| 331 |
})
|
| 332 |
|
| 333 |
+
if guide_scale == 1:
|
| 334 |
+
noise_pred = self.model( [latent_model_input], context=[context], audio_scale = None if audio_scale == None else [audio_scale], x_id=0, **kwargs, )[0]
|
| 335 |
+
if self._interrupt:
|
| 336 |
+
return None
|
| 337 |
+
elif joint_pass:
|
| 338 |
if audio_proj == None:
|
| 339 |
noise_pred_cond, noise_pred_uncond = self.model(
|
| 340 |
[latent_model_input, latent_model_input],
|
|
|
|
| 350 |
if self._interrupt:
|
| 351 |
return None
|
| 352 |
else:
|
| 353 |
+
noise_pred_cond = self.model( [latent_model_input], context=[context], audio_scale = None if audio_scale == None else [audio_scale], x_id=0, **kwargs, )[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
if self._interrupt:
|
| 355 |
return None
|
| 356 |
|
|
|
|
| 374 |
return None
|
| 375 |
del latent_model_input
|
| 376 |
|
| 377 |
+
if guide_scale > 1:
|
| 378 |
+
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
| 379 |
+
if cfg_star_switch:
|
| 380 |
+
positive_flat = noise_pred_cond.view(batch_size, -1)
|
| 381 |
+
negative_flat = noise_pred_uncond.view(batch_size, -1)
|
| 382 |
+
|
| 383 |
+
alpha = optimized_scale(positive_flat,negative_flat)
|
| 384 |
+
alpha = alpha.view(batch_size, 1, 1, 1)
|
| 385 |
+
|
| 386 |
+
if (i <= cfg_zero_step):
|
| 387 |
+
noise_pred = noise_pred_cond*0. # it would be faster not to compute noise_pred...
|
| 388 |
+
else:
|
| 389 |
+
noise_pred_uncond *= alpha
|
| 390 |
+
if audio_scale == None:
|
| 391 |
+
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
|
| 392 |
else:
|
| 393 |
+
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
|
| 394 |
+
|
|
|
|
|
|
|
|
|
|
| 395 |
noise_pred_uncond, noise_pred_noaudio = None, None
|
| 396 |
temp_x0 = sample_scheduler.step(
|
| 397 |
noise_pred.unsqueeze(0),
|
wan/modules/model.py
CHANGED
|
@@ -589,8 +589,7 @@ class MLPProj(torch.nn.Module):
|
|
| 589 |
|
| 590 |
|
| 591 |
class WanModel(ModelMixin, ConfigMixin):
|
| 592 |
-
|
| 593 |
-
def preprocess_loras(model_filename, sd):
|
| 594 |
|
| 595 |
first = next(iter(sd), None)
|
| 596 |
if first == None:
|
|
@@ -634,8 +633,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 634 |
print(f"Lora alpha'{alpha_key}' is missing")
|
| 635 |
new_sd.update(new_alphas)
|
| 636 |
sd = new_sd
|
| 637 |
-
|
| 638 |
-
if
|
| 639 |
new_sd = {}
|
| 640 |
# convert loras for i2v to t2v
|
| 641 |
for k,v in sd.items():
|
|
|
|
| 589 |
|
| 590 |
|
| 591 |
class WanModel(ModelMixin, ConfigMixin):
|
| 592 |
+
def preprocess_loras(self, model_filename, sd):
|
|
|
|
| 593 |
|
| 594 |
first = next(iter(sd), None)
|
| 595 |
if first == None:
|
|
|
|
| 633 |
print(f"Lora alpha'{alpha_key}' is missing")
|
| 634 |
new_sd.update(new_alphas)
|
| 635 |
sd = new_sd
|
| 636 |
+
from wgp import test_class_i2v
|
| 637 |
+
if not test_class_i2v(model_filename):
|
| 638 |
new_sd = {}
|
| 639 |
# convert loras for i2v to t2v
|
| 640 |
for k,v in sd.items():
|