import numpy as np import torch def cfg_skip(): def decorator(func): def wrapper(self, x, *args, **kwargs): bs = len(x) if bs >= 2 and self.cfg_skip_ratio is not None and self.current_steps >= self.num_inference_steps * (1 - self.cfg_skip_ratio): bs_half = int(bs // 2) new_x = x[bs_half:] new_args = [] for arg in args: if isinstance(arg, (torch.Tensor, list, tuple, np.ndarray)): new_args.append(arg[bs_half:]) else: new_args.append(arg) new_kwargs = {} for key, content in kwargs.items(): if isinstance(content, (torch.Tensor, list, tuple, np.ndarray)): new_kwargs[key] = content[bs_half:] else: new_kwargs[key] = content else: new_x = x new_args = args new_kwargs = kwargs result = func(self, new_x, *new_args, **new_kwargs) if bs >= 2 and self.cfg_skip_ratio is not None and self.current_steps >= self.num_inference_steps * (1 - self.cfg_skip_ratio): result = torch.cat([result, result], dim=0) return result return wrapper return decorator