File size: 1,401 Bytes
d2c9b66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
import torch


def cfg_skip():
    def decorator(func):
        def wrapper(self, x, *args, **kwargs):
            bs = len(x)
            if bs >= 2 and self.cfg_skip_ratio is not None and self.current_steps >= self.num_inference_steps * (1 - self.cfg_skip_ratio):
                bs_half = int(bs // 2)

                new_x = x[bs_half:]
                
                new_args = []
                for arg in args:
                    if isinstance(arg, (torch.Tensor, list, tuple, np.ndarray)):
                        new_args.append(arg[bs_half:])
                    else:
                        new_args.append(arg)

                new_kwargs = {}
                for key, content in kwargs.items():
                    if isinstance(content, (torch.Tensor, list, tuple, np.ndarray)):
                        new_kwargs[key] = content[bs_half:]
                    else:
                        new_kwargs[key] = content
            else:
                new_x = x
                new_args = args
                new_kwargs = kwargs

            result = func(self, new_x, *new_args, **new_kwargs)

            if bs >= 2 and self.cfg_skip_ratio is not None and self.current_steps >= self.num_inference_steps * (1 - self.cfg_skip_ratio):
                result = torch.cat([result, result], dim=0)

            return result
        return wrapper
    return decorator