Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| def cfg_skip(): | |
| def decorator(func): | |
| def wrapper(self, x, *args, **kwargs): | |
| bs = len(x) | |
| if bs >= 2 and self.cfg_skip_ratio is not None and self.current_steps >= self.num_inference_steps * (1 - self.cfg_skip_ratio): | |
| bs_half = int(bs // 2) | |
| new_x = x[bs_half:] | |
| new_args = [] | |
| for arg in args: | |
| if isinstance(arg, (torch.Tensor, list, tuple, np.ndarray)): | |
| new_args.append(arg[bs_half:]) | |
| else: | |
| new_args.append(arg) | |
| new_kwargs = {} | |
| for key, content in kwargs.items(): | |
| if isinstance(content, (torch.Tensor, list, tuple, np.ndarray)): | |
| new_kwargs[key] = content[bs_half:] | |
| else: | |
| new_kwargs[key] = content | |
| else: | |
| new_x = x | |
| new_args = args | |
| new_kwargs = kwargs | |
| result = func(self, new_x, *new_args, **new_kwargs) | |
| if bs >= 2 and self.cfg_skip_ratio is not None and self.current_steps >= self.num_inference_steps * (1 - self.cfg_skip_ratio): | |
| result = torch.cat([result, result], dim=0) | |
| return result | |
| return wrapper | |
| return decorator |