ZIT-Controlnet / videox_fun /utils /cfg_optimization.py
Alexander Bagus
initial commit
d2c9b66
raw
history blame
1.4 kB
import numpy as np
import torch
def cfg_skip():
def decorator(func):
def wrapper(self, x, *args, **kwargs):
bs = len(x)
if bs >= 2 and self.cfg_skip_ratio is not None and self.current_steps >= self.num_inference_steps * (1 - self.cfg_skip_ratio):
bs_half = int(bs // 2)
new_x = x[bs_half:]
new_args = []
for arg in args:
if isinstance(arg, (torch.Tensor, list, tuple, np.ndarray)):
new_args.append(arg[bs_half:])
else:
new_args.append(arg)
new_kwargs = {}
for key, content in kwargs.items():
if isinstance(content, (torch.Tensor, list, tuple, np.ndarray)):
new_kwargs[key] = content[bs_half:]
else:
new_kwargs[key] = content
else:
new_x = x
new_args = args
new_kwargs = kwargs
result = func(self, new_x, *new_args, **new_kwargs)
if bs >= 2 and self.cfg_skip_ratio is not None and self.current_steps >= self.num_inference_steps * (1 - self.cfg_skip_ratio):
result = torch.cat([result, result], dim=0)
return result
return wrapper
return decorator