Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,401 Bytes
d2c9b66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import numpy as np
import torch
def cfg_skip():
def decorator(func):
def wrapper(self, x, *args, **kwargs):
bs = len(x)
if bs >= 2 and self.cfg_skip_ratio is not None and self.current_steps >= self.num_inference_steps * (1 - self.cfg_skip_ratio):
bs_half = int(bs // 2)
new_x = x[bs_half:]
new_args = []
for arg in args:
if isinstance(arg, (torch.Tensor, list, tuple, np.ndarray)):
new_args.append(arg[bs_half:])
else:
new_args.append(arg)
new_kwargs = {}
for key, content in kwargs.items():
if isinstance(content, (torch.Tensor, list, tuple, np.ndarray)):
new_kwargs[key] = content[bs_half:]
else:
new_kwargs[key] = content
else:
new_x = x
new_args = args
new_kwargs = kwargs
result = func(self, new_x, *new_args, **new_kwargs)
if bs >= 2 and self.cfg_skip_ratio is not None and self.current_steps >= self.num_inference_steps * (1 - self.cfg_skip_ratio):
result = torch.cat([result, result], dim=0)
return result
return wrapper
return decorator |