dikdimon commited on
Commit
cb4d7e0
verified
1 Parent(s): c10aebf

Upload sd-webui-smea using SD-Hub

Browse files
.gitattributes CHANGED
@@ -144,3 +144,5 @@ DWPose/resources/jay_pose.jpg filter=lfs diff=lfs merge=lfs -text
144
  DWPose/resources/lalaland.gif filter=lfs diff=lfs merge=lfs -text
145
  sd-webui-hires-i2i/img/off.jpg filter=lfs diff=lfs merge=lfs -text
146
  sd-webui-hires-i2i/img/on.jpg filter=lfs diff=lfs merge=lfs -text
 
 
 
144
  DWPose/resources/lalaland.gif filter=lfs diff=lfs merge=lfs -text
145
  sd-webui-hires-i2i/img/off.jpg filter=lfs diff=lfs merge=lfs -text
146
  sd-webui-hires-i2i/img/on.jpg filter=lfs diff=lfs merge=lfs -text
147
+ sd-webui-smea/sample.jpg filter=lfs diff=lfs merge=lfs -text
148
+ sd-webui-smea/sample2.jpg filter=lfs diff=lfs merge=lfs -text
sd-webui-smea/README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sd-webui-smea
2
+ smea sampler experiments for a1111 webui
3
+ These sampler has nothing to do with NAI's sampler or Euler sampler, I'm just suck at naming them.
4
+ (smea here stands for "Shovel More Extra Artifacts")
5
+ originally created by [Koishi-Star](https://github.com/Koishi-Star/Euler-Smea-Dyn-Sampler) and [ananosleep](https://github.com/ananosleep/advanced_euler_sampler_extension)
6
+ TCD sampler from [dfl](https://github.com/dfl/comfyui-tcd-scheduler)
7
+
8
+ ![sample2](https://github.com/AG-w/sd-webui-smea/blob/main/sample2.jpg?raw=true)
9
+ ![sample](https://github.com/AG-w/sd-webui-smea/blob/main/sample.jpg?raw=true)
10
+
11
+ **RECOMMEND: Use: Smea mbs2 (\#) / Smea mds2 (\#) / h max (#) / Max(2b/3c/4b), they add details (or artifacts) more reliably**
12
+
13
+ Also check [Dynamic Thresholding](https://github.com/mcmonkeyprojects/sd-dynamic-thresholding), you can add more details
14
+
15
+ Euler Dy: og Euler Dy with DPM2 tweak, toggle on/off every step
16
+ Euler Smea: og Euler Smea Dy with DPM2 tweak, use smea sampling only, toggle on/off every step
17
+ Euler Smea Dy: og Euler Smea Dy with DPM2 tweak, loopping scale up > folded to 1/2 size > normal >... every step
18
+ Euler Smea dyn a: Euler Smea with DPM2 tweak (less sigma), toggle on/off (scale up) every step every step
19
+ Euler Smea dyn b: Euler Smea with DPM2 tweak (less sigma), loopping scale down > up > normal >... every step
20
+ Euler Smea dyn c: Euler Smea with DPM2 tweak (less sigma), toggle on/off (scale down) every step every step
21
+ Euler Smea md: Euler Smea with DPM2 tweak (less sigma), start with Smea mc then toggle Smea mb on/off every step, ended with Smea ma
22
+ all sampler above stopped smea / dy sampling at 1/3 total steps
23
+
24
+ Euler Smea Max: Euler Smea with adjusted cosine wave scaling
25
+ Euler Smea Max s: Euler Smea Max with smoothed latent in process
26
+ Euler Smea ma: Euler Smea with DPM2 tweak (less sigma), combine scaled up latent image with normal one
27
+ Euler Smea mb: Euler Smea with DPM2 tweak (less sigma), combine scaled up and scaled down latent image with normal one
28
+ Euler Smea mc: Euler Smea with DPM2 tweak (less sigma), combine scaled down latent image with normal one
29
+ Euler Smea mas: Euler Smea ma tweaked
30
+ Euler Smea mbs: Euler Smea mb tweaked
31
+ Euler Smea mcs: Euler Smea mc tweaked
32
+ Euler Smea mds: Euler Smea md tweaked
33
+ Euler Smea mbs2: Euler Smea mds with tweaked sigma
34
+ Euler Smea mds2: Euler Smea mds with tweaked sigma
35
+ Euler Smea mbs2 s: Euler Smea mbs2 with smoothed latent in process
36
+ Euler Smea mds2 s: Euler Smea mds2 with smoothed latent in process
37
+ Euler Smea mds2 max: Euler Smea mds2 with adjusted cosine wave
38
+ Euler Smea mds2 s max: Euler Smea mds2 s with adjusted cosine wave
39
+ all sampler above stopped smea sampling at 1/6 total steps
40
+
41
+ Euler Max: from ananosleep's repo
42
+ Euler h max (\#): Euler Max with adjusted cosine wave
43
+ Euler Max(#): Euler Max with adjusted cosine wave
44
+ Euler Dy koishi-star: og Euler Dy made by koishi-star
45
+ Euler Smea Dy koishi-star: og Euler Smea Dy made by koishi-star
46
+ TCD / TCD Euler a: from dfl's repo
47
+
48
+ ### Explanation:
49
+ The reason of many experiments is due to og sampler tends to blurred the background or overfry the image,
50
+ so I checked DPM2 sampler and experiment if it's worth to tweak it
51
+ What Smea sampling do is scaling latent image > denoise > scale it back to original size
52
+ What dy sampling do is shrinking latent image to 1/2 size > denoise > extend it to original size
53
+ since what they did is bascially just scaling latent image, I use smea sampling only
54
+ What all these samplers do is bascailly trying to combine different scaled latent image to denoise image to generate better details (artifacts)
sd-webui-smea/__pycache__/sd_webui_smea.cpython-310.pyc ADDED
Binary file (45.7 kB). View file
 
sd-webui-smea/sample.jpg ADDED

Git LFS Details

  • SHA256: 640b932885e490d67494529fa136c59e10b6f63d50616fd2530378799594ddf4
  • Pointer size: 132 Bytes
  • Size of remote file: 2.57 MB
sd-webui-smea/sample2.jpg ADDED

Git LFS Details

  • SHA256: 3b414d8876c254956fdb6ab9031f9e83b712fd238b6611fecea3fba2a9ddce08
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
sd-webui-smea/scripts/__pycache__/sd-webui-smea.cpython-310.pyc ADDED
Binary file (46 kB). View file
 
sd-webui-smea/scripts/sd-webui-smea.py ADDED
@@ -0,0 +1,1672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import k_diffusion.sampling
4
+
5
+ from k_diffusion.sampling import to_d, BrownianTreeNoiseSampler
6
+ from tqdm.auto import trange
7
+ from modules import scripts
8
+ from modules import sd_samplers_kdiffusion, sd_samplers_common, sd_samplers
9
+ from modules.sd_samplers_kdiffusion import KDiffusionSampler
10
+
11
+ class _Rescaler:
12
+ def __init__(self, model, x, mode, **extra_args):
13
+ self.model = model
14
+ self.x = x
15
+ self.mode = mode
16
+ self.extra_args = extra_args
17
+ self.init_latent, self.mask, self.nmask = model.init_latent, model.mask, model.nmask
18
+
19
+ def __enter__(self):
20
+ if self.init_latent is not None:
21
+ self.model.init_latent = torch.nn.functional.interpolate(input=self.init_latent, size=self.x.shape[2:4], mode=self.mode)
22
+ if self.mask is not None:
23
+ self.model.mask = torch.nn.functional.interpolate(input=self.mask.unsqueeze(0), size=self.x.shape[2:4], mode=self.mode).squeeze(0)
24
+ if self.nmask is not None:
25
+ self.model.nmask = torch.nn.functional.interpolate(input=self.nmask.unsqueeze(0), size=self.x.shape[2:4], mode=self.mode).squeeze(0)
26
+ return self
27
+
28
+ def __exit__(self, type, value, traceback):
29
+ del self.model.init_latent, self.model.mask, self.model.nmask
30
+ self.model.init_latent, self.model.mask, self.model.nmask = self.init_latent, self.mask, self.nmask
31
+
32
+ class Smea(scripts.Script):
33
+
34
+ def title(self):
35
+ init() # <- 袩械褉械薪芯褋 褋褞写邪
36
+ return "Euler Smea Dy sampler"
37
+
38
+ def show(self, is_img2img):
39
+ return scripts.AlwaysVisible
40
+
41
+
42
+ def init():
43
+ for i in sd_samplers.all_samplers:
44
+ if "Euler Max" in i.name:
45
+ return
46
+
47
+ samplers_smea = [
48
+ ('Euler Max', sample_euler_max, ['k_euler'], {}),
49
+ ('Euler Max1b', sample_euler_max1b, ['k_euler'], {}),
50
+ ('Euler Max1c', sample_euler_max1c, ['k_euler'], {}),
51
+ ('Euler Max1d', sample_euler_max1d, ['k_euler'], {}),
52
+ ('Euler Max2', sample_euler_max2, ['k_euler'], {}),
53
+ ('Euler Max2b', sample_euler_max2b, ['k_euler'], {}),
54
+ ('Euler Max2c', sample_euler_max2c, ['k_euler'], {}),
55
+ ('Euler Max2d', sample_euler_max2d, ['k_euler'], {}),
56
+ ('Euler Max3', sample_euler_max3, ['k_euler'], {}),
57
+ ('Euler Max3b', sample_euler_max3b, ['k_euler'], {}),
58
+ ('Euler Max3c', sample_euler_max3c, ['k_euler'], {}),
59
+ ('Euler Max4', sample_euler_max4, ['k_euler'], {}),
60
+ ('Euler Max4b', sample_euler_max4b, ['k_euler'], {}),
61
+ ('Euler Max4c', sample_euler_max4c, ['k_euler'], {}),
62
+ ('Euler Max4d', sample_euler_max4d, ['k_euler'], {}),
63
+ ('Euler Max4e', sample_euler_max4e, ['k_euler'], {}),
64
+ ('Euler Max4f', sample_euler_max4f, ['k_euler'], {}),
65
+ ('Euler Dy', sample_euler_dy, ['k_euler'], {}),
66
+ ('Euler Smea', sample_euler_smea, ['k_euler'], {}),
67
+ ('Euler Smea Dy', sample_euler_smea_dy, ['k_euler'], {}),
68
+ ('Euler Smea Max', sample_euler_smea_max, ['k_euler'], {}),
69
+ ('Euler Smea Max s', sample_euler_smea_max_s, ['k_euler'], {}),
70
+ ('Euler Smea dyn a', sample_euler_smea_dyn_a, ['k_euler'], {}),
71
+ ('Euler Smea dyn b', sample_euler_smea_dyn_b, ['k_euler'], {}),
72
+ ('Euler Smea dyn c', sample_euler_smea_dyn_c, ['k_euler'], {}),
73
+ ('Euler Smea ma', sample_euler_smea_multi_a, ['k_euler'], {}),
74
+ ('Euler Smea mb', sample_euler_smea_multi_b, ['k_euler'], {}),
75
+ ('Euler Smea mc', sample_euler_smea_multi_c, ['k_euler'], {}),
76
+ ('Euler Smea md', sample_euler_smea_multi_d, ['k_euler'], {}),
77
+ ('Euler Smea mas', sample_euler_smea_multi_as, ['k_euler'], {}),
78
+ ('Euler Smea mbs', sample_euler_smea_multi_bs, ['k_euler'], {}),
79
+ ('Euler Smea mcs', sample_euler_smea_multi_cs, ['k_euler'], {}),
80
+ ('Euler Smea mds', sample_euler_smea_multi_ds, ['k_euler'], {}),
81
+ ('Euler Smea mbs2', sample_euler_smea_multi_bs2, ['k_euler'], {}),
82
+ ('Euler Smea mds2', sample_euler_smea_multi_ds2, ['k_euler'], {}),
83
+ ('Euler Smea mds2 max', sample_euler_smea_multi_ds2_m, ['k_euler'], {}),
84
+ ('Euler Smea mds2 s max', sample_euler_smea_multi_ds2_s_m, ['k_euler'], {}),
85
+ ('Euler Smea mbs2 s', sample_euler_smea_multi_bs2_s, ['k_euler'], {}),
86
+ ('Euler Smea mds2 s', sample_euler_smea_multi_ds2_s, ['k_euler'], {}),
87
+ ('Euler h max', sample_euler_h_m, ['k_euler'], {"brownian_noise": True}),
88
+ ('Euler h max b', sample_euler_h_m_b, ['k_euler'], {"brownian_noise": True}),
89
+ ('Euler h max c', sample_euler_h_m_c, ['k_euler'], {"brownian_noise": True}),
90
+ ('Euler h max d', sample_euler_h_m_d, ['k_euler'], {"brownian_noise": True}),
91
+ ('Euler h max e', sample_euler_h_m_e, ['k_euler'], {"brownian_noise": True}),
92
+ ('Euler h max f', sample_euler_h_m_f, ['k_euler'], {"brownian_noise": True}),
93
+ ('Euler h max g', sample_euler_h_m_g, ['k_euler'], {"brownian_noise": True}),
94
+ ('Euler h max b c', sample_euler_h_m_b_c, ['k_euler'], {"brownian_noise": True}),
95
+ ('Euler h max b c CFG++', sample_euler_h_m_b_c_pp, ['k_euler'], {"brownian_noise": True, "cfgpp": True}),
96
+ ('Euler Dy koishi-star', sample_euler_dy_og, ['k_euler'], {}),
97
+ ('Euler Smea Dy koishi-star', sample_euler_smea_dy_og, ['k_euler'], {}),
98
+ ('TCD Euler a', sample_tcd_euler_a, ['tcd_euler_a'], {}),
99
+ ('TCD', sample_tcd, ['tcd'], {}),
100
+ ]
101
+
102
+ samplers_data_smea = [
103
+ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
104
+ for label, funcname, aliases, options in samplers_smea
105
+ if callable(funcname)
106
+ ]
107
+
108
+ sampler_exparams_smea = {
109
+ sample_euler_max: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
110
+ sample_euler_max1b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
111
+ sample_euler_max1c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
112
+ sample_euler_max1d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
113
+ sample_euler_max2: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
114
+ sample_euler_max2b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
115
+ sample_euler_max2c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
116
+ sample_euler_max2d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
117
+ sample_euler_max3: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
118
+ sample_euler_max3b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
119
+ sample_euler_max3c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
120
+ sample_euler_max4: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
121
+ sample_euler_max4b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
122
+ sample_euler_max4c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
123
+ sample_euler_max4d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
124
+ sample_euler_max4e: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
125
+ sample_euler_max4f: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
126
+ sample_euler_dy: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
127
+ sample_euler_smea: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
128
+ sample_euler_smea_dy: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
129
+ sample_euler_smea_max: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
130
+ sample_euler_smea_max_s: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
131
+ sample_euler_smea_dyn_a: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
132
+ sample_euler_smea_dyn_b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
133
+ sample_euler_smea_dyn_c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
134
+ sample_euler_smea_multi_a: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
135
+ sample_euler_smea_multi_b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
136
+ sample_euler_smea_multi_c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
137
+ sample_euler_smea_multi_d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
138
+ sample_euler_smea_multi_as: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
139
+ sample_euler_smea_multi_bs: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
140
+ sample_euler_smea_multi_cs: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
141
+ sample_euler_smea_multi_ds: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
142
+ sample_euler_smea_multi_bs2: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
143
+ sample_euler_smea_multi_ds2: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
144
+ sample_euler_smea_multi_ds2_m: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
145
+ sample_euler_smea_multi_ds2_s_m: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
146
+ sample_euler_smea_multi_bs2_s: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
147
+ sample_euler_smea_multi_ds2_s: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
148
+ sample_euler_h_m: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
149
+ sample_euler_h_m_b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
150
+ sample_euler_h_m_c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
151
+ sample_euler_h_m_d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
152
+ sample_euler_h_m_e: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
153
+ sample_euler_h_m_f: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
154
+ sample_euler_h_m_g: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
155
+ sample_euler_h_m_b_c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
156
+ sample_euler_h_m_b_c_pp: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
157
+ sample_euler_dy_og: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
158
+ sample_euler_smea_dy_og: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
159
+ }
160
+ sd_samplers_kdiffusion.sampler_extra_params = {**sd_samplers_kdiffusion.sampler_extra_params, **sampler_exparams_smea}
161
+
162
+ samplers_map_smea = {x.name: x for x in samplers_data_smea}
163
+ sd_samplers_kdiffusion.k_diffusion_samplers_map = {**sd_samplers_kdiffusion.k_diffusion_samplers_map, **samplers_map_smea}
164
+
165
+ for i, item in enumerate(sd_samplers.all_samplers):
166
+ if "Euler" in item.name:
167
+ sd_samplers.all_samplers = sd_samplers.all_samplers[:i + 1] + [*samplers_data_smea] + sd_samplers.all_samplers[i + 1:]
168
+ break
169
+ sd_samplers.all_samplers_map = {x.name: x for x in sd_samplers.all_samplers}
170
+ sd_samplers.set_samplers()
171
+
172
+ return
173
+
174
+ def default_noise_sampler(x):
175
+ return lambda sigma, sigma_next: k_diffusion.sampling.torch.randn_like(x)
176
+
177
+ @torch.no_grad()
178
+ def dy_sampling_step(x, model, dt, sigma_hat, **extra_args):
179
+ original_shape = x.shape
180
+ batch_size, channels, m, n = original_shape[0], original_shape[1], original_shape[2] // 2, original_shape[3] // 2
181
+ extra_row = x.shape[2] % 2 == 1
182
+ extra_col = x.shape[3] % 2 == 1
183
+
184
+ if extra_row:
185
+ extra_row_content = x[:, :, -1:, :]
186
+ x = x[:, :, :-1, :]
187
+ if extra_col:
188
+ extra_col_content = x[:, :, :, -1:]
189
+ x = x[:, :, :, :-1]
190
+
191
+ a_list = x.unfold(2, 2, 2).unfold(3, 2, 2).contiguous().view(batch_size, channels, m * n, 2, 2)
192
+ c = a_list[:, :, :, 1, 1].view(batch_size, channels, m, n)
193
+
194
+ with _Rescaler(model, c, 'nearest-exact', **extra_args) as rescaler:
195
+ denoised = model(c, sigma_hat * c.new_ones([c.shape[0]]), **rescaler.extra_args)
196
+ d = to_d(c, sigma_hat, denoised)
197
+ c = c + d * dt
198
+
199
+ d_list = c.view(batch_size, channels, m * n, 1, 1)
200
+ a_list[:, :, :, 1, 1] = d_list[:, :, :, 0, 0]
201
+ x = a_list.view(batch_size, channels, m, n, 2, 2).permute(0, 1, 2, 4, 3, 5).reshape(batch_size, channels, 2 * m, 2 * n)
202
+
203
+ if extra_row or extra_col:
204
+ x_expanded = torch.zeros(original_shape, dtype=x.dtype, device=x.device)
205
+ x_expanded[:, :, :2 * m, :2 * n] = x
206
+ if extra_row:
207
+ x_expanded[:, :, -1:, :2 * n + 1] = extra_row_content
208
+ if extra_col:
209
+ x_expanded[:, :, :2 * m, -1:] = extra_col_content
210
+ if extra_row and extra_col:
211
+ x_expanded[:, :, -1:, -1:] = extra_col_content[:, :, -1:, :]
212
+ x = x_expanded
213
+
214
+ return x
215
+
216
+ @torch.no_grad()
217
+ def smea_sampling_step(x, model, dt, sigma_hat, **extra_args):
218
+ m, n = x.shape[2], x.shape[3]
219
+ x = torch.nn.functional.interpolate(input=x, size=None, scale_factor=(1.25, 1.25), mode='nearest-exact', align_corners=None, recompute_scale_factor=None)
220
+ with _Rescaler(model, x, 'nearest-exact', **extra_args) as rescaler:
221
+ denoised = model(x, sigma_hat * x.new_ones([x.shape[0]]), **rescaler.extra_args)
222
+ d = to_d(x, sigma_hat, denoised)
223
+ x = x + d * dt
224
+ x = torch.nn.functional.interpolate(input=x, size=(m,n), scale_factor=None, mode='nearest-exact', align_corners=None, recompute_scale_factor=None)
225
+ return x
226
+
227
+ @torch.no_grad()
228
+ def smea_sampling_step_denoised(x, model, sigma_hat, scale=1.25, smooth=False, **extra_args):
229
+ m, n = x.shape[2], x.shape[3]
230
+ filter = 'nearest-exact' if not smooth else 'bilinear'
231
+ x = torch.nn.functional.interpolate(input=x, scale_factor=(scale, scale), mode=filter)
232
+ with _Rescaler(model, x, filter, **extra_args) as rescaler:
233
+ denoised = model(x, sigma_hat * x.new_ones([x.shape[0]]), **rescaler.extra_args)
234
+ x = denoised
235
+ x = torch.nn.functional.interpolate(input=x, size=(m,n), mode='nearest-exact')
236
+ return x
237
+
238
+ @torch.no_grad()
239
+ def sample_euler_max(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
240
+ extra_args = {} if extra_args is None else extra_args
241
+ s_in = x.new_ones([x.shape[0]])
242
+ for i in trange(len(sigmas) - 1, disable=disable):
243
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
244
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
245
+ sigma_hat = sigmas[i] * (gamma + 1)
246
+ if gamma > 0:
247
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
248
+ denoised = model(x, sigma_hat * s_in, **extra_args)
249
+ d = to_d(x, sigma_hat, denoised)
250
+ if callback is not None:
251
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
252
+ dt = sigmas[i + 1] - sigma_hat
253
+ # Euler method
254
+ x = x + (math.cos(i + 1)/(i + 1) + 1) * d * dt
255
+ return x
256
+
257
+
258
+ @torch.no_grad()
259
+ def sample_euler_max1b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
260
+ extra_args = {} if extra_args is None else extra_args
261
+ s_in = x.new_ones([x.shape[0]])
262
+ for i in trange(len(sigmas) - 1, disable=disable):
263
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
264
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
265
+ sigma_hat = sigmas[i] * (gamma + 1)
266
+ if gamma > 0:
267
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
268
+ denoised = model(x, sigma_hat * s_in, **extra_args)
269
+ d = to_d(x, sigma_hat, denoised)
270
+ if callback is not None:
271
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
272
+ dt = sigmas[i + 1] - sigma_hat
273
+ # Euler method
274
+ x = x + (math.cos(1.05 * i + 1)/(1.1 * i + 1.5) + 1) * d * dt
275
+ return x
276
+
277
+ @torch.no_grad()
278
+ def sample_euler_max1c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
279
+ extra_args = {} if extra_args is None else extra_args
280
+ s_in = x.new_ones([x.shape[0]])
281
+ for i in trange(len(sigmas) - 1, disable=disable):
282
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
283
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
284
+ sigma_hat = sigmas[i] * (gamma + 1)
285
+ if gamma > 0:
286
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
287
+ denoised = model(x, sigma_hat * s_in, **extra_args)
288
+ d = to_d(x, sigma_hat, denoised)
289
+ if callback is not None:
290
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
291
+ dt = sigmas[i + 1] - sigma_hat
292
+ # Euler method
293
+ x = x + (math.cos(1.05 * i + 1.1)/(1.25 * i + 1.5) + 1) * d * dt
294
+ return x
295
+
296
+ @torch.no_grad()
297
+ def sample_euler_max1d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
298
+ extra_args = {} if extra_args is None else extra_args
299
+ s_in = x.new_ones([x.shape[0]])
300
+ for i in trange(len(sigmas) - 1, disable=disable):
301
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
302
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
303
+ sigma_hat = sigmas[i] * (gamma + 1)
304
+ if gamma > 0:
305
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
306
+ denoised = model(x, sigma_hat * s_in, **extra_args)
307
+ d = to_d(x, sigma_hat, denoised)
308
+ if callback is not None:
309
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
310
+ dt = sigmas[i + 1] - sigma_hat
311
+ # Euler method
312
+ x = x + (math.cos(math.pi * 0.333 * i + 0.9)/(0.5 * i + 1.5) + 1) * d * dt
313
+ return x
314
+
315
+ @torch.no_grad()
316
+ def sample_euler_max2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
317
+ extra_args = {} if extra_args is None else extra_args
318
+ s_in = x.new_ones([x.shape[0]])
319
+ for i in trange(len(sigmas) - 1, disable=disable):
320
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
321
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
322
+ sigma_hat = sigmas[i] * (gamma + 1)
323
+ if gamma > 0:
324
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
325
+ denoised = model(x, sigma_hat * s_in, **extra_args)
326
+ d = to_d(x, sigma_hat, denoised)
327
+ if callback is not None:
328
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
329
+ dt = sigmas[i + 1] - sigma_hat
330
+ # Euler method
331
+ x = x + (math.cos(math.pi * 0.333 * i - 0.1)/(0.5 * i + 1.5) + 1) * d * dt
332
+ return x
333
+
334
+ @torch.no_grad()
335
+ def sample_euler_max2b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
336
+ extra_args = {} if extra_args is None else extra_args
337
+ s_in = x.new_ones([x.shape[0]])
338
+ for i in trange(len(sigmas) - 1, disable=disable):
339
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
340
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
341
+ sigma_hat = sigmas[i] * (gamma + 1)
342
+ if gamma > 0:
343
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
344
+ denoised = model(x, sigma_hat * s_in, **extra_args)
345
+ d = to_d(x, sigma_hat, denoised)
346
+ if callback is not None:
347
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
348
+ dt = sigmas[i + 1] - sigma_hat
349
+ # Euler method
350
+ x = x + (math.cos(math.pi * 0.5 * i - 0.0)/(0.5 * i + 1.5) + 1) * d * dt
351
+ return x
352
+
353
+ @torch.no_grad()
354
+ def sample_euler_max2c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
355
+ extra_args = {} if extra_args is None else extra_args
356
+ s_in = x.new_ones([x.shape[0]])
357
+ for i in trange(len(sigmas) - 1, disable=disable):
358
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
359
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
360
+ sigma_hat = sigmas[i] * (gamma + 1)
361
+ if gamma > 0:
362
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
363
+ denoised = model(x, sigma_hat * s_in, **extra_args)
364
+ d = to_d(x, sigma_hat, denoised)
365
+ if callback is not None:
366
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
367
+ dt = sigmas[i + 1] - sigma_hat
368
+ # Euler method
369
+ x = x + (math.cos(math.pi * 0.5 * i)/(i + 2) + 1) * d * dt
370
+ return x
371
+
372
+ @torch.no_grad()
373
+ def sample_euler_max2d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
374
+ extra_args = {} if extra_args is None else extra_args
375
+ s_in = x.new_ones([x.shape[0]])
376
+ for i in trange(len(sigmas) - 1, disable=disable):
377
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
378
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
379
+ sigma_hat = sigmas[i] * (gamma + 1)
380
+ if gamma > 0:
381
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
382
+ denoised = model(x, sigma_hat * s_in, **extra_args)
383
+ d = to_d(x, sigma_hat, denoised)
384
+ if callback is not None:
385
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
386
+ dt = sigmas[i + 1] - sigma_hat
387
+ # Euler method
388
+ x = x + (math.cos(math.pi * 0.5 * i)/(0.75 * i + 1.75) + 1) * d * dt
389
+ return x
390
+
391
+ @torch.no_grad()
392
+ def sample_euler_max3b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
393
+ extra_args = {} if extra_args is None else extra_args
394
+ s_in = x.new_ones([x.shape[0]])
395
+ for i in trange(len(sigmas) - 1, disable=disable):
396
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
397
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
398
+ sigma_hat = sigmas[i] * (gamma + 1)
399
+ if gamma > 0:
400
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
401
+ denoised = model(x, sigma_hat * s_in, **extra_args)
402
+ d = to_d(x, sigma_hat, denoised)
403
+ if callback is not None:
404
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
405
+ dt = sigmas[i + 1] - sigma_hat
406
+ # Euler method
407
+ x = x + (math.cos(2 * i + 0.5)/(2 * i + 1.5) + 1) * d * dt
408
+ return x
409
+
410
+ @torch.no_grad()
411
+ def sample_euler_max3c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
412
+ extra_args = {} if extra_args is None else extra_args
413
+ s_in = x.new_ones([x.shape[0]])
414
+ for i in trange(len(sigmas) - 1, disable=disable):
415
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
416
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
417
+ sigma_hat = sigmas[i] * (gamma + 1)
418
+ if gamma > 0:
419
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
420
+ denoised = model(x, sigma_hat * s_in, **extra_args)
421
+ d = to_d(x, sigma_hat, denoised)
422
+ if callback is not None:
423
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
424
+ dt = sigmas[i + 1] - sigma_hat
425
+ # Euler method
426
+ x = x + (math.cos(2 * i + 0.5)/(1.5 * i + 2.7) + 1) * d * dt
427
+ return x
428
+
429
+ @torch.no_grad()
430
+ def sample_euler_max3(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
431
+ extra_args = {} if extra_args is None else extra_args
432
+ s_in = x.new_ones([x.shape[0]])
433
+ for i in trange(len(sigmas) - 1, disable=disable):
434
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
435
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
436
+ sigma_hat = sigmas[i] * (gamma + 1)
437
+ if gamma > 0:
438
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
439
+ denoised = model(x, sigma_hat * s_in, **extra_args)
440
+ d = to_d(x, sigma_hat, denoised)
441
+ if callback is not None:
442
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
443
+ dt = sigmas[i + 1] - sigma_hat
444
+ # Euler method
445
+ x = x + (math.cos(2 * i + 1)/(2 * i + 1) + 1) * d * dt
446
+ return x
447
+
448
+ @torch.no_grad()
449
+ def sample_euler_max4b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
450
+ extra_args = {} if extra_args is None else extra_args
451
+ s_in = x.new_ones([x.shape[0]])
452
+ for i in trange(len(sigmas) - 1, disable=disable):
453
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
454
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
455
+ sigma_hat = sigmas[i] * (gamma + 1)
456
+ if gamma > 0:
457
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
458
+ denoised = model(x, sigma_hat * s_in, **extra_args)
459
+ d = to_d(x, sigma_hat, denoised)
460
+ if callback is not None:
461
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
462
+ dt = sigmas[i + 1] - sigma_hat
463
+ # Euler method
464
+ x = x + (math.cos(math.pi * i - 0.1)/(2 * i + 2) + 1) * d * dt
465
+ return x
466
+
467
+ @torch.no_grad()
468
+ def sample_euler_max4c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
469
+ extra_args = {} if extra_args is None else extra_args
470
+ s_in = x.new_ones([x.shape[0]])
471
+ for i in trange(len(sigmas) - 1, disable=disable):
472
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
473
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
474
+ sigma_hat = sigmas[i] * (gamma + 1)
475
+ if gamma > 0:
476
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
477
+ denoised = model(x, sigma_hat * s_in, **extra_args)
478
+ d = to_d(x, sigma_hat, denoised)
479
+ if callback is not None:
480
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
481
+ dt = sigmas[i + 1] - sigma_hat
482
+ # Euler method
483
+ x = x + (math.cos(math.pi * i - 0.1)/(2 * i + 1.5) + 1) * d * dt
484
+ return x
485
+
486
+ @torch.no_grad()
487
+ def sample_euler_max4d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
488
+ extra_args = {} if extra_args is None else extra_args
489
+ s_in = x.new_ones([x.shape[0]])
490
+ for i in trange(len(sigmas) - 1, disable=disable):
491
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
492
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
493
+ sigma_hat = sigmas[i] * (gamma + 1)
494
+ if gamma > 0:
495
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
496
+ denoised = model(x, sigma_hat * s_in, **extra_args)
497
+ d = to_d(x, sigma_hat, denoised)
498
+ if callback is not None:
499
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
500
+ dt = sigmas[i + 1] - sigma_hat
501
+ # Euler method
502
+ x = x + (math.cos(math.pi * i - 0.1)/(i + 1.5) + 1) * d * dt
503
+ return x
504
+
505
+ @torch.no_grad()
506
+ def sample_euler_max4e(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
507
+ extra_args = {} if extra_args is None else extra_args
508
+ s_in = x.new_ones([x.shape[0]])
509
+ for i in trange(len(sigmas) - 1, disable=disable):
510
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
511
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
512
+ sigma_hat = sigmas[i] * (gamma + 1)
513
+ if gamma > 0:
514
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
515
+ denoised = model(x, sigma_hat * s_in, **extra_args)
516
+ d = to_d(x, sigma_hat, denoised)
517
+ if callback is not None:
518
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
519
+ dt = sigmas[i + 1] - sigma_hat
520
+ # Euler method
521
+ x = x + (math.cos(math.pi * i - 0.1)/(i + 1) + 1) * d * dt
522
+ return x
523
+
524
+ @torch.no_grad()
525
+ def sample_euler_max4f(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
526
+ extra_args = {} if extra_args is None else extra_args
527
+ s_in = x.new_ones([x.shape[0]])
528
+ for i in trange(len(sigmas) - 1, disable=disable):
529
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
530
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
531
+ sigma_hat = sigmas[i] * (gamma + 1)
532
+ if gamma > 0:
533
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
534
+ denoised = model(x, sigma_hat * s_in, **extra_args)
535
+ d = to_d(x, sigma_hat, denoised)
536
+ if callback is not None:
537
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
538
+ dt = sigmas[i + 1] - sigma_hat
539
+ # Euler method
540
+ x = x + (math.cos(math.pi * i - 0.1)/(i + 2) + 1) * d * dt
541
+ return x
542
+
543
+ @torch.no_grad()
544
+ def sample_euler_max4(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
545
+ extra_args = {} if extra_args is None else extra_args
546
+ s_in = x.new_ones([x.shape[0]])
547
+ for i in trange(len(sigmas) - 1, disable=disable):
548
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
549
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
550
+ sigma_hat = sigmas[i] * (gamma + 1)
551
+ if gamma > 0:
552
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
553
+ denoised = model(x, sigma_hat * s_in, **extra_args)
554
+ d = to_d(x, sigma_hat, denoised)
555
+ if callback is not None:
556
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
557
+ dt = sigmas[i + 1] - sigma_hat
558
+ # Euler method
559
+ x = x + (math.cos(math.pi * i - 0.1)/(math.pi * 0.5 * i + math.pi * 0.5) + 1) * d * dt
560
+ return x
561
+
562
+ @torch.no_grad()
563
+ def sample_euler_dy(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
564
+ extra_args = {} if extra_args is None else extra_args
565
+ s_in = x.new_ones([x.shape[0]])
566
+ for i in trange(len(sigmas) - 1, disable=disable):
567
+ # print(i)
568
+ # i绗竴姝ヤ负0
569
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
570
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
571
+ sigma_hat = sigmas[i] * (gamma + 1)
572
+ # print(sigma_hat)
573
+ dt = sigmas[i + 1] - sigma_hat
574
+ if gamma > 0:
575
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
576
+ denoised = model(x, sigma_hat * s_in, **extra_args)
577
+ d = to_d(x, sigma_hat, denoised)
578
+ if callback is not None:
579
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
580
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2 and i % 2 == 0:
581
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
582
+ dt_1 = sigma_mid - sigmas[i]
583
+ dt_2 = sigmas[i + 1] - sigmas[i]
584
+ x_2 = x + d * dt_1
585
+ x_temp = dy_sampling_step(x_2, model, dt_2, sigma_mid, **extra_args)
586
+ x = x_temp - d * dt_1
587
+ # Euler method
588
+ x = x + d * dt
589
+ return x
590
+
591
+ @torch.no_grad()
592
+ def sample_euler_smea_dyn_a(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
593
+ extra_args = {} if extra_args is None else extra_args
594
+ s_in = x.new_ones([x.shape[0]])
595
+ for i in trange(len(sigmas) - 1, disable=disable):
596
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
597
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
598
+ sigma_hat = sigmas[i] * (gamma + 1)
599
+ if gamma > 0:
600
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
601
+ denoised = model(x, sigma_hat * s_in, **extra_args)
602
+ d = to_d(x, sigma_hat, denoised)
603
+ if callback is not None:
604
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
605
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2:
606
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
607
+ dt_1 = sigma_mid - sigma_hat
608
+ dt_2 = sigmas[i + 1] - sigma_hat
609
+ x_2 = x + d * dt_1
610
+ #scale = (sigma_mid / sigmas[0]) * 0.25
611
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2 * 0.15
612
+ #scale = scale.item()
613
+ if i % 2 == 0:
614
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale, **extra_args)
615
+ #denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + sigma_mid.item() * 0.01, **extra_args)
616
+ else:
617
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
618
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
619
+ x = x + d_2 * dt_2
620
+ else:
621
+ dt = sigmas[i + 1] - sigma_hat
622
+ # Euler method
623
+ x = x + d * dt
624
+ return x
625
+
626
+ @torch.no_grad()
627
+ def sample_euler_smea_dyn_b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
628
+ extra_args = {} if extra_args is None else extra_args
629
+ s_in = x.new_ones([x.shape[0]])
630
+ for i in trange(len(sigmas) - 1, disable=disable):
631
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
632
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
633
+ sigma_hat = sigmas[i] * (gamma + 1)
634
+ if gamma > 0:
635
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
636
+ denoised = model(x, sigma_hat * s_in, **extra_args)
637
+ d = to_d(x, sigma_hat, denoised)
638
+ if callback is not None:
639
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
640
+ if sigmas[i + 1] > 0 and (i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 3 or i < 3):
641
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
642
+ dt_1 = sigma_mid - sigma_hat
643
+ dt_2 = sigmas[i + 1] - sigma_hat
644
+ x_2 = x + d * dt_1
645
+ #scale = (sigma_mid / sigmas[0]) * 0.25
646
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2 * 0.2
647
+ #scale = scale.item()
648
+ if i % 4 == 0:
649
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale, **extra_args)
650
+ #denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - sigma_mid.item() * 0.01, **extra_args)
651
+ elif i % 4 == 2:
652
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale, **extra_args)
653
+ #denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + sigma_mid.item() * 0.01, **extra_args)
654
+ else:
655
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
656
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
657
+ x = x + d_2 * dt_2
658
+ else:
659
+ dt = sigmas[i + 1] - sigma_hat
660
+ # Euler method
661
+ x = x + d * dt
662
+ return x
663
+
664
+ @torch.no_grad()
665
+ def sample_euler_smea_dyn_c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
666
+ extra_args = {} if extra_args is None else extra_args
667
+ s_in = x.new_ones([x.shape[0]])
668
+ for i in trange(len(sigmas) - 1, disable=disable):
669
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
670
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
671
+ sigma_hat = sigmas[i] * (gamma + 1)
672
+ if gamma > 0:
673
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
674
+ denoised = model(x, sigma_hat * s_in, **extra_args)
675
+ d = to_d(x, sigma_hat, denoised)
676
+ if callback is not None:
677
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
678
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2:
679
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
680
+ dt_1 = sigma_mid - sigma_hat
681
+ dt_2 = sigmas[i + 1] - sigma_hat
682
+ x_2 = x + d * dt_1
683
+ #scale = (sigma_mid / sigmas[0]) * 0.25
684
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2 * 0.25
685
+ #scale = scale.item()
686
+ if i % 2 == 0:
687
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale, **extra_args)
688
+ #denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + sigma_mid.item() * 0.01, **extra_args)
689
+ else:
690
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
691
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
692
+ x = x + d_2 * dt_2
693
+ else:
694
+ dt = sigmas[i + 1] - sigma_hat
695
+ # Euler method
696
+ x = x + d * dt
697
+ return x
698
+
699
+ @torch.no_grad()
700
+ def sample_euler_smea(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
701
+ extra_args = {} if extra_args is None else extra_args
702
+ s_in = x.new_ones([x.shape[0]])
703
+ for i in trange(len(sigmas) - 1, disable=disable):
704
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
705
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
706
+ sigma_hat = sigmas[i] * (gamma + 1)
707
+ if gamma > 0:
708
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
709
+ denoised = model(x, sigma_hat * s_in, **extra_args)
710
+ d = to_d(x, sigma_hat, denoised)
711
+ if callback is not None:
712
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
713
+ dt = sigmas[i + 1] - sigma_hat
714
+ # Euler method
715
+ x = x + d * dt
716
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2 and i % 2 == 0:
717
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
718
+ dt_1 = sigma_mid - sigmas[i]
719
+ dt_2 = sigmas[i + 1] - sigmas[i]
720
+ #print(dt_1, "#", dt_2, "#", dt_3, "#", dt_4)
721
+ x_2 = x + d * dt_1
722
+ x_temp = smea_sampling_step(x, model, dt_2, sigma_mid, **extra_args)
723
+ x = x_temp - d * dt_1
724
+ return x
725
+
726
+ @torch.no_grad()
727
+ def sample_euler_smea_dy(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
728
+ extra_args = {} if extra_args is None else extra_args
729
+ s_in = x.new_ones([x.shape[0]])
730
+ for i in trange(len(sigmas) - 1, disable=disable):
731
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
732
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
733
+ sigma_hat = sigmas[i] * (gamma + 1)
734
+ if gamma > 0:
735
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
736
+ denoised = model(x, sigma_hat * s_in, **extra_args)
737
+ d = to_d(x, sigma_hat, denoised)
738
+ if callback is not None:
739
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
740
+ dt = sigmas[i + 1] - sigma_hat
741
+ # Euler method
742
+ x = x + d * dt
743
+ if sigmas[i + 1] > 0 and (i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2 or i < 3) and i % 3 != 2:
744
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
745
+ dt_1 = sigma_mid - sigmas[i]
746
+ dt_2 = sigmas[i + 1] - sigmas[i]
747
+ #print(dt_1, "#", dt_2, "#", dt_3, "#", dt_4)
748
+ x_2 = x + d * dt_1
749
+ if i % 3 == 1:
750
+ x_temp = dy_sampling_step(x, model, dt_2, sigma_mid, **extra_args)
751
+ elif i % 3 == 0:
752
+ x_temp = smea_sampling_step(x, model, dt_2, sigma_mid, **extra_args)
753
+ x = x_temp - d * dt_1
754
+ return x
755
+
756
+ @torch.no_grad()
757
+ def sample_euler_smea_multi_d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
758
+ extra_args = {} if extra_args is None else extra_args
759
+ s_in = x.new_ones([x.shape[0]])
760
+ for i in trange(len(sigmas) - 1, disable=disable):
761
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
762
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
763
+ sigma_hat = sigmas[i] * (gamma + 1)
764
+ if gamma > 0:
765
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
766
+ denoised = model(x, sigma_hat * s_in, **extra_args)
767
+ d = to_d(x, sigma_hat, denoised)
768
+ if callback is not None:
769
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
770
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 + 2 and i % 2 == 0:
771
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
772
+ dt_1 = sigma_mid - sigma_hat
773
+ dt_2 = sigmas[i + 1] - sigma_hat
774
+ x_2 = x + d * dt_1
775
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
776
+ if i == 0:
777
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale * 0.15, **extra_args)
778
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
779
+ denoised_2 = (denoised_2a + denoised_2c) / 2
780
+ elif i < len(sigmas) * 0.334:
781
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale * 0.25, **extra_args)
782
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale * 0.15, **extra_args)
783
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
784
+ denoised_2 = (denoised_2a + denoised_2b + denoised_2c) / 3
785
+ else:
786
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale * 0.03, True, **extra_args)
787
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
788
+ denoised_2 = (denoised_2b + denoised_2c) / 2
789
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
790
+ x = x + d_2 * dt_2
791
+ else:
792
+ dt = sigmas[i + 1] - sigma_hat
793
+ # Euler method
794
+ x = x + d * dt
795
+ return x
796
+
797
+ @torch.no_grad()
798
+ def sample_euler_smea_multi_b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
799
+ extra_args = {} if extra_args is None else extra_args
800
+ s_in = x.new_ones([x.shape[0]])
801
+ for i in trange(len(sigmas) - 1, disable=disable):
802
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
803
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
804
+ sigma_hat = sigmas[i] * (gamma + 1)
805
+ if gamma > 0:
806
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
807
+ denoised = model(x, sigma_hat * s_in, **extra_args)
808
+ d = to_d(x, sigma_hat, denoised)
809
+ if callback is not None:
810
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
811
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
812
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
813
+ dt_1 = sigma_mid - sigma_hat
814
+ dt_2 = sigmas[i + 1] - sigma_hat
815
+ x_2 = x + d * dt_1
816
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
817
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale * 0.25, **extra_args)
818
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale * 0.15, **extra_args)
819
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
820
+ denoised_2 = (denoised_2a + denoised_2b + denoised_2c) / 3
821
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
822
+ x = x + d_2 * dt_2
823
+ else:
824
+ dt = sigmas[i + 1] - sigma_hat
825
+ # Euler method
826
+ x = x + d * dt
827
+ return x
828
+
829
+ @torch.no_grad()
830
+ def sample_euler_smea_multi_c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
831
+ extra_args = {} if extra_args is None else extra_args
832
+ s_in = x.new_ones([x.shape[0]])
833
+ for i in trange(len(sigmas) - 1, disable=disable):
834
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
835
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
836
+ sigma_hat = sigmas[i] * (gamma + 1)
837
+ if gamma > 0:
838
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
839
+ denoised = model(x, sigma_hat * s_in, **extra_args)
840
+ d = to_d(x, sigma_hat, denoised)
841
+ if callback is not None:
842
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
843
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
844
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
845
+ dt_1 = sigma_mid - sigma_hat
846
+ dt_2 = sigmas[i + 1] - sigma_hat
847
+ x_2 = x + d * dt_1
848
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
849
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale * 0.25, **extra_args)
850
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
851
+ denoised_2 = (denoised_2a + denoised_2c) / 2
852
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
853
+ x = x + d_2 * dt_2
854
+ else:
855
+ dt = sigmas[i + 1] - sigma_hat
856
+ # Euler method
857
+ x = x + d * dt
858
+ return x
859
+
860
+ @torch.no_grad()
861
+ def sample_euler_smea_multi_a(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
862
+ extra_args = {} if extra_args is None else extra_args
863
+ s_in = x.new_ones([x.shape[0]])
864
+ for i in trange(len(sigmas) - 1, disable=disable):
865
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
866
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
867
+ sigma_hat = sigmas[i] * (gamma + 1)
868
+ if gamma > 0:
869
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
870
+ denoised = model(x, sigma_hat * s_in, **extra_args)
871
+ d = to_d(x, sigma_hat, denoised)
872
+ if callback is not None:
873
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
874
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
875
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
876
+ dt_1 = sigma_mid - sigma_hat
877
+ dt_2 = sigmas[i + 1] - sigma_hat
878
+ x_2 = x + d * dt_1
879
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
880
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale * 0.15, **extra_args)
881
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
882
+ denoised_2 = (denoised_2b + denoised_2c) / 2
883
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
884
+ x = x + d_2 * dt_2
885
+ else:
886
+ dt = sigmas[i + 1] - sigma_hat
887
+ # Euler method
888
+ x = x + d * dt
889
+ return x
890
+
891
+
892
+ @torch.no_grad()
893
+ def sample_euler_smea_multi_ds(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
894
+ extra_args = {} if extra_args is None else extra_args
895
+ s_in = x.new_ones([x.shape[0]])
896
+ for i in trange(len(sigmas) - 1, disable=disable):
897
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
898
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
899
+ sigma_hat = sigmas[i] * (gamma + 1)
900
+ if gamma > 0:
901
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
902
+ denoised = model(x, sigma_hat * s_in, **extra_args)
903
+ d = to_d(x, sigma_hat, denoised)
904
+ if callback is not None:
905
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
906
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167 + 1: # and i % 2 == 0:
907
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
908
+ dt_1 = sigma_mid - sigma_hat
909
+ dt_2 = sigmas[i + 1] - sigma_hat
910
+ x_2 = x + d * dt_1
911
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
912
+ if i == 0:
913
+ sa = 1 - scale * 0.15
914
+ sb = 1 + scale * 0.09
915
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
916
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, sb, **extra_args)
917
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.625 + denoised_2b * (sb ** 2) * 0.375) / (0.97**2)
918
+ elif i < len(sigmas) * 0.167:
919
+ sa = 1 - scale * 0.25
920
+ sb = 1 + scale * 0.15
921
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
922
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, sb , **extra_args)
923
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.625 + denoised_2b * (sb ** 2) * 0.375) / (0.95**2)
924
+ else:
925
+ sb = 1 + scale * 0.06
926
+ sc = 1 - scale * 0.1
927
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, sb, True, **extra_args)
928
+ denoised_2c = smea_sampling_step_denoised(x_2, model, sigma_mid, sc, **extra_args)
929
+ denoised_2 = (denoised_2b * (sb ** 2) * 0.375 + denoised_2c * (sc ** 2) * 0.625) / (0.98**2)
930
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
931
+ x = x + d_2 * dt_2
932
+ else:
933
+ dt = sigmas[i + 1] - sigma_hat
934
+ # Euler method
935
+ x = x + d * dt
936
+ return x
937
+
938
+ @torch.no_grad()
939
+ def sample_euler_smea_multi_ds2_s(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
940
+ sample = sample_euler_smea_multi_ds2(model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise, smooth=True)
941
+ return sample
942
+
943
+ @torch.no_grad()
944
+ def sample_euler_smea_multi_ds2_s_m(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
945
+ sample = sample_euler_smea_multi_ds2_m(model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise, smooth=True)
946
+ return sample
947
+
948
+ @torch.no_grad()
949
+ def sample_euler_smea_multi_ds2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., smooth=False):
950
+ extra_args = {} if extra_args is None else extra_args
951
+ s_in = x.new_ones([x.shape[0]])
952
+ for i in trange(len(sigmas) - 1, disable=disable):
953
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
954
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
955
+ sigma_hat = sigmas[i] * (gamma + 1)
956
+ if gamma > 0:
957
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
958
+ denoised = model(x, sigma_hat * s_in, **extra_args)
959
+ d = to_d(x, sigma_hat, denoised)
960
+ if callback is not None:
961
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
962
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167 + 1: # and i % 2 == 0:
963
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
964
+ dt_1 = sigma_mid - sigma_hat
965
+ dt_2 = sigmas[i + 1] - sigma_hat
966
+ x_2 = x + d * dt_1
967
+ scale = (sigmas[i] / sigmas[0]) ** 2
968
+ scale = scale.item()
969
+ if i == 0:
970
+ sa = 1 - scale * 0.15
971
+ sb = 1 + scale * 0.09
972
+ sigA = sigma_mid / (sa ** 2)
973
+ sigB = sigma_mid / (sb ** 2)
974
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
975
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
976
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2) #/ (0.97**2) # 1 - (sa * sb ) / 2 + 1
977
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
978
+ elif i < len(sigmas) * 0.167:
979
+ sa = 1 - scale * 0.25
980
+ sb = 1 + scale * 0.15
981
+ sigA = sigma_mid / (sa ** 2)
982
+ sigB = sigma_mid / (sb ** 2)
983
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
984
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
985
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2) #/ (0.95**2)
986
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
987
+ else:
988
+ sb = 1 + scale * 0.06
989
+ sc = 1 - scale * 0.1
990
+ sigB = sigma_mid / (sb ** 2)
991
+ sigC = sigma_mid / (sc ** 2)
992
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
993
+ denoised_2c = smea_sampling_step_denoised(x_2, model, sigC, sc, smooth, **extra_args)
994
+ denoised_2 = (denoised_2b * (sb ** 2) * 0.5 * sc ** 2 + denoised_2c * (sc ** 2) * 0.5 * sb ** 2) #/ (0.98**2)
995
+ d_2 = to_d(x_2, sigB * 0.5 * sc ** 2 + sigC * 0.5 * sb ** 2, denoised_2)
996
+ x = x + d_2 * dt_2
997
+ else:
998
+ dt = sigmas[i + 1] - sigma_hat
999
+ # Euler method
1000
+ x = x + d * dt
1001
+ return x
1002
+
1003
+ @torch.no_grad()
1004
+ def sample_euler_smea_multi_ds2_m(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., smooth=False):
1005
+ extra_args = {} if extra_args is None else extra_args
1006
+ s_in = x.new_ones([x.shape[0]])
1007
+ for i in trange(len(sigmas) - 1, disable=disable):
1008
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1009
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1010
+ sigma_hat = sigmas[i] * (gamma + 1)
1011
+ if gamma > 0:
1012
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1013
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1014
+ d = to_d(x, sigma_hat, denoised)
1015
+ if callback is not None:
1016
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1017
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167 + 1: # and i % 2 == 0:
1018
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1019
+ dt_1 = sigma_mid - sigma_hat
1020
+ dt_2 = sigmas[i + 1] - sigma_hat
1021
+ x_2 = x + d * dt_1
1022
+ scale = (sigmas[i] / sigmas[0]) ** 2
1023
+ #scale = dt_1 ** 2 * 0.01
1024
+ scale = scale.item()
1025
+ if i == 0:
1026
+ sa = 1 - scale * 0.15 #15
1027
+ sb = 1 + scale * 0.09 #09
1028
+ sigA = sigma_mid / (sa ** 2)
1029
+ sigB = sigma_mid / (sb ** 2)
1030
+ #delta = sa * sb
1031
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
1032
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
1033
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2) #/ (0.97**2) # 1 - (sa * sb ) / 2 + 1
1034
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
1035
+ elif i < len(sigmas) * 0.167:
1036
+ sa = 1 - scale * 0.25 #25
1037
+ sb = 1 + scale * 0.15 #15
1038
+ sigA = sigma_mid / (sa ** 2)
1039
+ sigB = sigma_mid / (sb ** 2)
1040
+ #delta = sa * sb
1041
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
1042
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
1043
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2) #/ (0.95**2)
1044
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
1045
+ else:
1046
+ sb = 1 + scale * 0.06
1047
+ sc = 1 - scale * 0.1
1048
+ sigB = sigma_mid / (sb ** 2)
1049
+ sigC = sigma_mid / (sc ** 2)
1050
+ #delta = sb * sc
1051
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
1052
+ denoised_2c = smea_sampling_step_denoised(x_2, model, sigC, sc, smooth, **extra_args)
1053
+ denoised_2 = (denoised_2b * (sb ** 2) * 0.5 * sc ** 2+ denoised_2c * (sc ** 2) * 0.5 * sb ** 2) #/ (0.98**2)
1054
+ d_2 = to_d(x_2, sigB * 0.5 * sc ** 2 + sigC * 0.5 * sb ** 2, denoised_2)
1055
+ x = x + (math.cos(1.05 * i + 1.1)/(1.25 * i + 1.5) + 1) * d_2 * dt_2
1056
+ else:
1057
+ dt = sigmas[i + 1] - sigma_hat
1058
+ # Euler method
1059
+ x = x + (math.cos(1.05 * i + 1.1)/(1.25 * i + 1.5) + 1) * d * dt
1060
+ return x
1061
+
1062
+ @torch.no_grad()
1063
+ def sample_euler_h_m(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1064
+ extra_args = {} if extra_args is None else extra_args
1065
+ s_in = x.new_ones([x.shape[0]])
1066
+ for i in trange(len(sigmas) - 1, disable=disable):
1067
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1068
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1069
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1070
+ gamma = min((2 ** 0.5 - 1) - wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1071
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler == None else noise_sampler
1072
+ sigma_hat = sigmas[i] * (gamma + 1)
1073
+ if gamma > 0:
1074
+ x = x - eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1075
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1076
+ d = to_d(x, sigma_hat, denoised)
1077
+ dt = sigmas[i + 1] - sigma_hat
1078
+ if callback is not None:
1079
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1080
+ if sigmas[i + 1] > 0:
1081
+ x_2 = x + d * dt
1082
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1083
+ d_prime = d * 0.5 + d_2 * 0.5
1084
+ x = x + d_prime * dt
1085
+ else:
1086
+ # Euler method
1087
+ x = x + d * dt
1088
+ return x
1089
+
1090
+ @torch.no_grad()
1091
+ def sample_euler_h_m_b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1092
+ extra_args = {} if extra_args is None else extra_args
1093
+ s_in = x.new_ones([x.shape[0]])
1094
+ for i in trange(len(sigmas) - 1, disable=disable):
1095
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1096
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1097
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1098
+ gamma = min(wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1099
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1100
+ sigma_hat = sigmas[i] * (gamma + 1)
1101
+ if gamma > 0:
1102
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1103
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1104
+ d = to_d(x, sigma_hat, denoised)
1105
+ dt = sigmas[i + 1] - sigma_hat
1106
+ if callback is not None:
1107
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1108
+ if sigmas[i + 1] > 0:
1109
+ x_2 = x + d * dt
1110
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1111
+ d_prime = d * 0.5 + d_2 * 0.5
1112
+ x = x + d_prime * dt
1113
+ else:
1114
+ # Euler method
1115
+ x = x + d * dt
1116
+ return x
1117
+
1118
+ @torch.no_grad()
1119
+ def sample_euler_h_m_c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1120
+ extra_args = {} if extra_args is None else extra_args
1121
+ s_in = x.new_ones([x.shape[0]])
1122
+ for i in trange(len(sigmas) - 1, disable=disable):
1123
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1124
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1125
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1126
+ gamma = max((2 ** 0.5 - 1) + wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1127
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1128
+ sigma_hat = sigmas[i] * (gamma + 1)
1129
+ if gamma > 0:
1130
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1131
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1132
+ d = to_d(x, sigma_hat, denoised)
1133
+ dt = sigmas[i + 1] - sigma_hat
1134
+ if callback is not None:
1135
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1136
+ if sigmas[i + 1] > 0:
1137
+ x_2 = x + d * dt
1138
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1139
+ d_prime = d * 0.5 + d_2 * 0.5
1140
+ x = x + d_prime * dt
1141
+ else:
1142
+ # Euler method
1143
+ x = x + d * dt
1144
+ return x
1145
+
1146
+ @torch.no_grad()
1147
+ def sample_euler_h_m_d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1148
+ extra_args = {} if extra_args is None else extra_args
1149
+ s_in = x.new_ones([x.shape[0]])
1150
+ for i in trange(len(sigmas) - 1, disable=disable):
1151
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1152
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1153
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1154
+ gamma = min((2 ** 0.5 - 1) - wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1155
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1156
+ sigma_hat = sigmas[i] * (gamma + 1)
1157
+ if gamma > 0:
1158
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1159
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1160
+ d = to_d(x, sigma_hat, denoised)
1161
+ dt = sigmas[i + 1] - sigma_hat
1162
+ if callback is not None:
1163
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1164
+ if sigmas[i + 1] > 0:
1165
+ x_2 = x + d * dt
1166
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1167
+ d_prime = d * 0.5 + d_2 * 0.5
1168
+ x = x + d_prime * dt
1169
+ else:
1170
+ # Euler method
1171
+ x = x + d * dt
1172
+ return x
1173
+
1174
+ @torch.no_grad()
1175
+ def sample_euler_h_m_e(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1176
+ extra_args = {} if extra_args is None else extra_args
1177
+ s_in = x.new_ones([x.shape[0]])
1178
+ for i in trange(len(sigmas) - 1, disable=disable):
1179
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1180
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1181
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1182
+ gamma = max((2 ** 0.5 - 1) + wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1183
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1184
+ sigma_hat = sigmas[i] * (gamma + 1)
1185
+ if gamma > 0:
1186
+ x = x - eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1187
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1188
+ d = to_d(x, sigma_hat, denoised)
1189
+ dt = sigmas[i + 1] - sigma_hat
1190
+ if callback is not None:
1191
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1192
+ if sigmas[i + 1] > 0:
1193
+ x_2 = x + d * dt
1194
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1195
+ d_prime = d * 0.5 + d_2 * 0.5
1196
+ x = x + d_prime * dt
1197
+ else:
1198
+ # Euler method
1199
+ x = x + d * dt
1200
+ return x
1201
+
1202
+ @torch.no_grad()
1203
+ def sample_euler_h_m_f(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1204
+ extra_args = {} if extra_args is None else extra_args
1205
+ s_in = x.new_ones([x.shape[0]])
1206
+ for i in trange(len(sigmas) - 1, disable=disable):
1207
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1208
+ wave_max = math.cos(0)/1.5 + 1
1209
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1210
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1211
+ gamma = min((wave_max - wave) * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1212
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1213
+ sigma_hat = sigmas[i] * (gamma + 1)
1214
+ if gamma > 0:
1215
+ x = x - eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1216
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1217
+ d = to_d(x, sigma_hat, denoised)
1218
+ dt = sigmas[i + 1] - sigma_hat
1219
+ if callback is not None:
1220
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1221
+ if sigmas[i + 1] > 0:
1222
+ x_2 = x + d * dt
1223
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1224
+ d_prime = d * 0.5 + d_2 * 0.5
1225
+ x = x + d_prime * dt
1226
+ else:
1227
+ # Euler method
1228
+ x = x + d * dt
1229
+ return x
1230
+
1231
+ @torch.no_grad()
1232
+ def sample_euler_h_m_g(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1233
+ extra_args = {} if extra_args is None else extra_args
1234
+ s_in = x.new_ones([x.shape[0]])
1235
+ for i in trange(len(sigmas) - 1, disable=disable):
1236
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1237
+ wave_max = math.cos(0)/1.5 + 1
1238
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1239
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1240
+ gamma = min((wave_max - wave) * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1241
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1242
+ sigma_hat = sigmas[i] * (gamma + 1)
1243
+ if gamma > 0:
1244
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1245
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1246
+ d = to_d(x, sigma_hat, denoised)
1247
+ dt = sigmas[i + 1] - sigma_hat
1248
+ if callback is not None:
1249
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1250
+ if sigmas[i + 1] > 0:
1251
+ x_2 = x + d * dt
1252
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1253
+ d_prime = d * 0.5 + d_2 * 0.5
1254
+ x = x + d_prime * dt
1255
+ else:
1256
+ # Euler method
1257
+ x = x + d * dt
1258
+ return x
1259
+
1260
+ @torch.no_grad()
1261
+ def sample_euler_h_m_b_c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1262
+ extra_args = {} if extra_args is None else extra_args
1263
+ s_in = x.new_ones([x.shape[0]])
1264
+ for i in trange(len(sigmas) - 1, disable=disable):
1265
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1266
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1267
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1268
+ gamma = min(wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1269
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1270
+ gammaup = gamma + 1
1271
+ sigma_hat = sigmas[i] * gammaup
1272
+ if gamma > 0:
1273
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1274
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1275
+ last_noise_uncond = model.last_noise_uncond
1276
+ d = to_d(x, sigma_hat, denoised)
1277
+ dt = sigmas[i + 1] - sigma_hat
1278
+ if callback is not None:
1279
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1280
+ if i == 0:
1281
+ x = x + d * dt
1282
+ elif i <= len(sigmas) - 4:
1283
+ x_2 = x + d * dt
1284
+ d_2 = to_d(x_2, sigmas[i + 1] * gammaup, denoised)
1285
+ x_3 = x_2 + d_2 * dt
1286
+ d_3 = to_d(x_3, sigmas[i + 2] * gammaup, denoised)
1287
+ d_prime = d * 0.5 + d_2 * 0.375 + d_3 * 0.125
1288
+ x = x + d_prime * dt
1289
+ elif sigmas[i + 1] > 0:
1290
+ x_2 = x + d * dt
1291
+ d_2 = to_d(x_2, sigmas[i + 1] * gammaup, denoised)
1292
+ d_prime = d * 0.5 + d_2 * 0.5
1293
+ x = x + d_prime * dt
1294
+ else:
1295
+ # Euler method
1296
+ x = x + d * dt
1297
+ return x
1298
+
1299
+ @torch.no_grad()
1300
+ def sample_euler_h_m_b_c_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1301
+ extra_args = {} if extra_args is None else extra_args
1302
+ s_in = x.new_ones([x.shape[0]])
1303
+ for i in trange(len(sigmas) - 1, disable=disable):
1304
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1305
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1306
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1307
+ gamma = min(wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1308
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1309
+ gammaup = gamma + 1
1310
+ sigma_hat = sigmas[i] * gammaup
1311
+ if gamma > 0:
1312
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1313
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1314
+ last_noise_uncond = model.last_noise_uncond
1315
+ d = to_d(x, sigma_hat, denoised)
1316
+ dt = sigmas[i + 1] - sigma_hat
1317
+ if callback is not None:
1318
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1319
+ if i == 0:
1320
+ x = x + d * dt
1321
+ elif i <= len(sigmas) - 4:
1322
+ x_2 = x + d * dt
1323
+ d_2 = to_d(x_2, sigmas[i + 1] * gammaup, denoised)
1324
+ x_3 = x_2 + d_2 * dt
1325
+ d_3 = to_d(x_3, sigmas[i + 2] * gammaup, last_noise_uncond)
1326
+ d_prime = d * 0.5 + d_2 * 0.375 + d_3 * 0.125
1327
+ x = x + d_prime * dt
1328
+ elif sigmas[i + 1] > 0:
1329
+ x_2 = x + d * dt
1330
+ d_2 = to_d(x_2, sigmas[i + 1] * gammaup, denoised)
1331
+ d_prime = d * 0.5 + d_2 * 0.5
1332
+ x = x + d_prime * dt
1333
+ else:
1334
+ # Euler method
1335
+ x = x + d * dt
1336
+ return x
1337
+
1338
+ @torch.no_grad()
1339
+ def sample_euler_smea_max(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., smooth=False):
1340
+ extra_args = {} if extra_args is None else extra_args
1341
+ s_in = x.new_ones([x.shape[0]])
1342
+ for i in trange(len(sigmas) - 1, disable=disable):
1343
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1344
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1345
+ sigma_hat = sigmas[i] * (gamma + 1)
1346
+ sa = math.cos(i + 1)/(1.5 * i + 1.75) + 1
1347
+ if gamma > 0:
1348
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1349
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1350
+ d = to_d(x, sigma_hat, denoised)
1351
+ if callback is not None:
1352
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1353
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167 + 1: # and i % 2 == 0:
1354
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1355
+ dt_1 = sigma_mid - sigma_hat
1356
+ dt_2 = sigmas[i + 1] - sigma_hat
1357
+ x_2 = x + d * dt_1
1358
+ sigA = sigma_mid / (sa ** 2)
1359
+ sigB = sigma_mid
1360
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
1361
+ denoised_2b = model(x_2, sigma_mid * s_in, **extra_args)
1362
+ denoised_2 = (denoised_2a * 0.5 * (sa ** 2) + denoised_2b * 0.5 / (sa ** 2))
1363
+ d_2 = to_d(x_2, sigA * 0.5 * (sa ** 2) + sigB * 0.5 / (sa ** 2), denoised_2)
1364
+ x = x + d_2 * dt_2
1365
+ else:
1366
+ dt = sigmas[i + 1] - sigma_hat
1367
+ # Euler method
1368
+ x = x + sa * d * dt
1369
+ return x
1370
+
1371
+
1372
+ @torch.no_grad()
1373
+ def sample_euler_smea_max_s(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1374
+ sample = sample_euler_smea_max(model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise, smooth=True)
1375
+ return sample
1376
+
1377
+ @torch.no_grad()
1378
+ def sample_euler_smea_multi_bs(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1379
+ extra_args = {} if extra_args is None else extra_args
1380
+ s_in = x.new_ones([x.shape[0]])
1381
+ for i in trange(len(sigmas) - 1, disable=disable):
1382
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1383
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1384
+ sigma_hat = sigmas[i] * (gamma + 1)
1385
+ if gamma > 0:
1386
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1387
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1388
+ d = to_d(x, sigma_hat, denoised)
1389
+ if callback is not None:
1390
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1391
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
1392
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1393
+ dt_1 = sigma_mid - sigma_hat
1394
+ dt_2 = sigmas[i + 1] - sigma_hat
1395
+ x_2 = x + d * dt_1
1396
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
1397
+ sa = 1 - scale * 0.25
1398
+ sb = 1 + scale * 0.15
1399
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
1400
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, sb, **extra_args)
1401
+ denoised_2 = denoised_2a * (sa ** 2) * 0.625 + denoised_2b * (sb ** 2) * 0.375 / (0.95**2)
1402
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
1403
+ x = x + d_2 * dt_2
1404
+ else:
1405
+ dt = sigmas[i + 1] - sigma_hat
1406
+ # Euler method
1407
+ x = x + d * dt
1408
+ return x
1409
+
1410
+ @torch.no_grad()
1411
+ def sample_euler_smea_multi_bs2_s(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1412
+ sample = sample_euler_smea_multi_bs2(model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise, smooth=True)
1413
+ return sample
1414
+
1415
+ @torch.no_grad()
1416
+ def sample_euler_smea_multi_bs2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., smooth=False):
1417
+ extra_args = {} if extra_args is None else extra_args
1418
+ s_in = x.new_ones([x.shape[0]])
1419
+ for i in trange(len(sigmas) - 1, disable=disable):
1420
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1421
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1422
+ sigma_hat = sigmas[i] * (gamma + 1)
1423
+ if gamma > 0:
1424
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1425
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1426
+ d = to_d(x, sigma_hat, denoised)
1427
+ if callback is not None:
1428
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1429
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
1430
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1431
+ dt_1 = sigma_mid - sigma_hat
1432
+ dt_2 = sigmas[i + 1] - sigma_hat
1433
+ x_2 = x + d * dt_1
1434
+ scale = (sigmas[i] / sigmas[0]) ** 2
1435
+ scale = scale.item()
1436
+ sa = 1 - scale * 0.25
1437
+ sb = 1 + scale * 0.15
1438
+ sigA = sigma_mid / (sa ** 2)
1439
+ sigB = sigma_mid / (sb ** 2)
1440
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
1441
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
1442
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2)
1443
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
1444
+ x = x + d_2 * dt_2
1445
+ else:
1446
+ dt = sigmas[i + 1] - sigma_hat
1447
+ # Euler method
1448
+ x = x + d * dt
1449
+ return x
1450
+
1451
+ @torch.no_grad()
1452
+ def sample_euler_smea_multi_cs(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1453
+ extra_args = {} if extra_args is None else extra_args
1454
+ s_in = x.new_ones([x.shape[0]])
1455
+ for i in trange(len(sigmas) - 1, disable=disable):
1456
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1457
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1458
+ sigma_hat = sigmas[i] * (gamma + 1)
1459
+ if gamma > 0:
1460
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1461
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1462
+ d = to_d(x, sigma_hat, denoised)
1463
+ if callback is not None:
1464
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1465
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
1466
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1467
+ dt_1 = sigma_mid - sigma_hat
1468
+ dt_2 = sigmas[i + 1] - sigma_hat
1469
+ x_2 = x + d * dt_1
1470
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
1471
+ sa = 1 - scale * 0.25
1472
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
1473
+ d_2 = to_d(x_2, sigma_mid, denoised_2 * (sa ** 2) * 1.25)
1474
+ x = x + d_2 * dt_2
1475
+ else:
1476
+ dt = sigmas[i + 1] - sigma_hat
1477
+ # Euler method
1478
+ x = x + d * dt
1479
+ return x
1480
+
1481
+ @torch.no_grad()
1482
+ def sample_euler_smea_multi_as(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1483
+ extra_args = {} if extra_args is None else extra_args
1484
+ s_in = x.new_ones([x.shape[0]])
1485
+ for i in trange(len(sigmas) - 1, disable=disable):
1486
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1487
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1488
+ sigma_hat = sigmas[i] * (gamma + 1)
1489
+ if gamma > 0:
1490
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1491
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1492
+ d = to_d(x, sigma_hat, denoised)
1493
+ if callback is not None:
1494
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1495
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
1496
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1497
+ dt_1 = sigma_mid - sigma_hat
1498
+ dt_2 = sigmas[i + 1] - sigma_hat
1499
+ x_2 = x + d * dt_1
1500
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
1501
+ sa = 1 + scale * 0.15
1502
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
1503
+ d_2 = to_d(x_2, sigma_mid, denoised_2 * (sa ** 2) * 0.75)
1504
+ x = x + d_2 * dt_2
1505
+ else:
1506
+ dt = sigmas[i + 1] - sigma_hat
1507
+ # Euler method
1508
+ x = x + d * dt
1509
+ return x
1510
+
1511
+ ## og sampler
1512
+ @torch.no_grad()
1513
+ def sample_euler_dy_og(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1514
+ extra_args = {} if extra_args is None else extra_args
1515
+ s_in = x.new_ones([x.shape[0]])
1516
+ for i in trange(len(sigmas) - 1, disable=disable):
1517
+ # print(i)
1518
+ # i绗竴姝ヤ负0
1519
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1520
+ eps = torch.randn_like(x) * s_noise
1521
+ sigma_hat = sigmas[i] * (gamma + 1)
1522
+ # print(sigma_hat)
1523
+ dt = sigmas[i + 1] - sigma_hat
1524
+ if gamma > 0:
1525
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1526
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1527
+ d = sampling.to_d(x, sigma_hat, denoised)
1528
+ if sigmas[i + 1] > 0:
1529
+ if i // 2 == 1:
1530
+ x = dy_sampling_step(x, model, dt, sigma_hat, **extra_args)
1531
+ if callback is not None:
1532
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1533
+ # Euler method
1534
+ x = x + d * dt
1535
+ return x
1536
+
1537
+ @torch.no_grad()
1538
+ def sample_euler_smea_dy_og(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1539
+ extra_args = {} if extra_args is None else extra_args
1540
+ s_in = x.new_ones([x.shape[0]])
1541
+ for i in trange(len(sigmas) - 1, disable=disable):
1542
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1543
+ eps = torch.randn_like(x) * s_noise
1544
+ sigma_hat = sigmas[i] * (gamma + 1)
1545
+ dt = sigmas[i + 1] - sigma_hat
1546
+ if gamma > 0:
1547
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1548
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1549
+ d = sampling.to_d(x, sigma_hat, denoised)
1550
+ # Euler method
1551
+ x = x + d * dt
1552
+ if sigmas[i + 1] > 0:
1553
+ if i + 1 // 2 == 1:
1554
+ x = dy_sampling_step(x, model, dt, sigma_hat, **extra_args)
1555
+ if i + 1 // 2 == 0:
1556
+ x = smea_sampling_step(x, model, dt, sigma_hat, **extra_args)
1557
+ if callback is not None:
1558
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1559
+ return x
1560
+
1561
+ ## TCD
1562
+
1563
+ def sample_tcd_euler_a(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, gamma=0.3):
1564
+ # TCD sampling using modified Euler Ancestral sampler. by @laksjdjf
1565
+ extra_args = {} if extra_args is None else extra_args
1566
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
1567
+ s_in = x.new_ones([x.shape[0]])
1568
+ for i in trange(len(sigmas) - 1, disable=disable):
1569
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1570
+ if callback is not None:
1571
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1572
+
1573
+ #d = to_d(x, sigmas[i], denoised)
1574
+ sigma_from = sigmas[i]
1575
+ sigma_to = sigmas[i + 1]
1576
+
1577
+ t = model.inner_model.sigma_to_t(sigma_from)
1578
+ down_t = (1 - gamma) * t
1579
+ sigma_down = model.inner_model.t_to_sigma(down_t)
1580
+
1581
+ if sigma_down > sigma_to:
1582
+ sigma_down = sigma_to
1583
+ sigma_up = (sigma_to ** 2 - sigma_down ** 2) ** 0.5
1584
+
1585
+ # same as euler ancestral
1586
+ d = to_d(x, sigma_from, denoised)
1587
+ dt = sigma_down - sigma_from
1588
+ x += d * dt
1589
+
1590
+ if sigma_to > 0 and gamma > 0:
1591
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigma_up
1592
+ return x
1593
+
1594
+ @torch.no_grad()
1595
+ def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, gamma=0.3):
1596
+ # TCD sampling using modified DDPM.
1597
+ extra_args = {} if extra_args is None else extra_args
1598
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
1599
+ s_in = x.new_ones([x.shape[0]])
1600
+
1601
+ for i in trange(len(sigmas) - 1, disable=disable):
1602
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1603
+ if callback is not None:
1604
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1605
+
1606
+ sigma_from, sigma_to = sigmas[i], sigmas[i+1]
1607
+
1608
+ # TCD offset, based on gamma, and conversion between sigma and timestep
1609
+ t = model.inner_model.sigma_to_t(sigma_from)
1610
+ t_s = (1 - gamma) * t
1611
+ sigma_to_s = model.inner_model.t_to_sigma(t_s)
1612
+
1613
+ # if sigma_to_s > sigma_to:
1614
+ # sigma_to_s = sigma_to
1615
+ # if sigma_to_s < 0:
1616
+ # sigma_to_s = torch.tensor(1.0)
1617
+ #print(f"sigma_from: {sigma_from}, sigma_to: {sigma_to}, sigma_to_s: {sigma_to_s}")
1618
+
1619
+
1620
+ # The following is equivalent to the comfy DDPM implementation
1621
+ # x = DDPMSampler_step(x / torch.sqrt(1.0 + sigma_from ** 2.0), sigma_from, sigma_to, (x - denoised) / sigma_from, noise_sampler)
1622
+
1623
+ noise_est = (x - denoised) / sigma_from
1624
+ x /= torch.sqrt(1.0 + sigma_from ** 2.0)
1625
+
1626
+ alpha_cumprod = 1 / ((sigma_from * sigma_from) + 1) # _t
1627
+ alpha_cumprod_prev = 1 / ((sigma_to * sigma_to) + 1) # _t_prev
1628
+ alpha = (alpha_cumprod / alpha_cumprod_prev)
1629
+
1630
+ ## These values should approach 1.0?
1631
+ # print(f"alpha_cumprod: {alpha_cumprod}")
1632
+ # print(f"alpha_cumprod_prev: {alpha_cumprod_prev}")
1633
+ # print(f"alpha: {alpha}")
1634
+
1635
+
1636
+ # alpha_cumprod_down = 1 / ((sigma_to_s * sigma_to_s) + 1) # _s
1637
+ # alpha_d = (alpha_cumprod_prev / alpha_cumprod_down)
1638
+ # alpha2 = (alpha_cumprod / alpha_cumprod_down)
1639
+ # print(f"** alpha_cumprod_down: {alpha_cumprod_down}")
1640
+ # print(f"** alpha_d: {alpha_d}, alpha2: #{alpha2}")
1641
+
1642
+ # epsilon noise prediction from comfy DDPM implementation
1643
+ x = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise_est / (1 - alpha_cumprod).sqrt())
1644
+ # x = (1.0 / alpha_d).sqrt() * (x - (1 - alpha) * noise_est / (1 - alpha_cumprod).sqrt())
1645
+
1646
+ first_step = sigma_to == 0
1647
+ last_step = i == len(sigmas) - 2
1648
+
1649
+ if not first_step:
1650
+ if gamma > 0 and not last_step:
1651
+ noise = noise_sampler(sigma_from, sigma_to)
1652
+
1653
+ # x += ((1 - alpha_d) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise
1654
+ variance = ((1 - alpha_cumprod_prev) / (1 - alpha_cumprod)) * (1 - alpha_cumprod / alpha_cumprod_prev)
1655
+ x += variance.sqrt() * noise # scale noise by std deviation
1656
+
1657
+ # relevant diffusers code from scheduling_tcd.py
1658
+ # prev_sample = (alpha_prod_t_prev / alpha_prod_s).sqrt() * pred_noised_sample + (
1659
+ # 1 - alpha_prod_t_prev / alpha_prod_s
1660
+ # ).sqrt() * noise
1661
+
1662
+ x *= torch.sqrt(1.0 + sigma_to ** 2.0)
1663
+
1664
+ # beta_cumprod_t = 1 - alpha_cumprod
1665
+ # beta_cumprod_s = 1 - alpha_cumprod_down
1666
+
1667
+
1668
+ return x
1669
+
1670
+ # 袙 褋邪屑芯屑 泻芯薪褑械 sd-webui-smea.py
1671
+ from modules.script_callbacks import on_before_ui
1672
+ on_before_ui(init)
sd-webui-smea/sd-webui-smea (13).py ADDED
@@ -0,0 +1,1657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import k_diffusion.sampling
4
+
5
+ from k_diffusion.sampling import to_d, BrownianTreeNoiseSampler
6
+ from tqdm.auto import trange
7
+ from modules import scripts
8
+ from modules import sd_samplers_kdiffusion, sd_samplers_common, sd_samplers
9
+ from modules.sd_samplers_kdiffusion import KDiffusionSampler
10
+
11
+ class _Rescaler:
12
+ def __init__(self, model, x, mode, **extra_args):
13
+ self.model = model
14
+ self.x = x
15
+ self.mode = mode
16
+ self.extra_args = extra_args
17
+ self.init_latent, self.mask, self.nmask = model.init_latent, model.mask, model.nmask
18
+
19
+ def __enter__(self):
20
+ if self.init_latent is not None:
21
+ self.model.init_latent = torch.nn.functional.interpolate(input=self.init_latent, size=self.x.shape[2:4], mode=self.mode)
22
+ if self.mask is not None:
23
+ self.model.mask = torch.nn.functional.interpolate(input=self.mask.unsqueeze(0), size=self.x.shape[2:4], mode=self.mode).squeeze(0)
24
+ if self.nmask is not None:
25
+ self.model.nmask = torch.nn.functional.interpolate(input=self.nmask.unsqueeze(0), size=self.x.shape[2:4], mode=self.mode).squeeze(0)
26
+ return self
27
+
28
+ def __exit__(self, type, value, traceback):
29
+ del self.model.init_latent, self.model.mask, self.model.nmask
30
+ self.model.init_latent, self.model.mask, self.model.nmask = self.init_latent, self.mask, self.nmask
31
+
32
+ class Smea(scripts.Script):
33
+
34
+ def title(self):
35
+ return "Euler Smea Dy sampler"
36
+
37
+ def show(self, is_img2img):
38
+ return scripts.AlwaysVisible
39
+
40
+ def __init__(self):
41
+ init()
42
+ return
43
+
44
+ def init():
45
+ for i in sd_samplers.all_samplers:
46
+ if "Euler Max" in i.name:
47
+ return
48
+
49
+ samplers_smea = [
50
+ ('Euler Max', sample_euler_max, ['k_euler'], {}),
51
+ ('Euler Max1b', sample_euler_max1b, ['k_euler'], {}),
52
+ ('Euler Max1c', sample_euler_max1c, ['k_euler'], {}),
53
+ ('Euler Max1d', sample_euler_max1d, ['k_euler'], {}),
54
+ ('Euler Max2', sample_euler_max2, ['k_euler'], {}),
55
+ ('Euler Max2b', sample_euler_max2b, ['k_euler'], {}),
56
+ ('Euler Max2c', sample_euler_max2c, ['k_euler'], {}),
57
+ ('Euler Max2d', sample_euler_max2d, ['k_euler'], {}),
58
+ ('Euler Max3', sample_euler_max3, ['k_euler'], {}),
59
+ ('Euler Max3b', sample_euler_max3b, ['k_euler'], {}),
60
+ ('Euler Max3c', sample_euler_max3c, ['k_euler'], {}),
61
+ ('Euler Max4', sample_euler_max4, ['k_euler'], {}),
62
+ ('Euler Max4b', sample_euler_max4b, ['k_euler'], {}),
63
+ ('Euler Max4c', sample_euler_max4c, ['k_euler'], {}),
64
+ ('Euler Max4d', sample_euler_max4d, ['k_euler'], {}),
65
+ ('Euler Max4e', sample_euler_max4e, ['k_euler'], {}),
66
+ ('Euler Max4f', sample_euler_max4f, ['k_euler'], {}),
67
+ ('Euler Dy', sample_euler_dy, ['k_euler'], {}),
68
+ ('Euler Smea', sample_euler_smea, ['k_euler'], {}),
69
+ ('Euler Smea Dy', sample_euler_smea_dy, ['k_euler'], {}),
70
+ ('Euler Smea Max', sample_euler_smea_max, ['k_euler'], {}),
71
+ ('Euler Smea Max s', sample_euler_smea_max_s, ['k_euler'], {}),
72
+ ('Euler Smea dyn a', sample_euler_smea_dyn_a, ['k_euler'], {}),
73
+ ('Euler Smea dyn b', sample_euler_smea_dyn_b, ['k_euler'], {}),
74
+ ('Euler Smea dyn c', sample_euler_smea_dyn_c, ['k_euler'], {}),
75
+ ('Euler Smea ma', sample_euler_smea_multi_a, ['k_euler'], {}),
76
+ ('Euler Smea mb', sample_euler_smea_multi_b, ['k_euler'], {}),
77
+ ('Euler Smea mc', sample_euler_smea_multi_c, ['k_euler'], {}),
78
+ ('Euler Smea md', sample_euler_smea_multi_d, ['k_euler'], {}),
79
+ ('Euler Smea mas', sample_euler_smea_multi_as, ['k_euler'], {}),
80
+ ('Euler Smea mbs', sample_euler_smea_multi_bs, ['k_euler'], {}),
81
+ ('Euler Smea mcs', sample_euler_smea_multi_cs, ['k_euler'], {}),
82
+ ('Euler Smea mds', sample_euler_smea_multi_ds, ['k_euler'], {}),
83
+ ('Euler Smea mbs2', sample_euler_smea_multi_bs2, ['k_euler'], {}),
84
+ ('Euler Smea mds2', sample_euler_smea_multi_ds2, ['k_euler'], {}),
85
+ ('Euler Smea mds2 max', sample_euler_smea_multi_ds2_m, ['k_euler'], {}),
86
+ ('Euler Smea mds2 s max', sample_euler_smea_multi_ds2_s_m, ['k_euler'], {}),
87
+ ('Euler Smea mbs2 s', sample_euler_smea_multi_bs2_s, ['k_euler'], {}),
88
+ ('Euler Smea mds2 s', sample_euler_smea_multi_ds2_s, ['k_euler'], {}),
89
+ ('Euler h max', sample_euler_h_m, ['k_euler'], {"brownian_noise": True}),
90
+ ('Euler h max b', sample_euler_h_m_b, ['k_euler'], {"brownian_noise": True}),
91
+ ('Euler h max c', sample_euler_h_m_c, ['k_euler'], {"brownian_noise": True}),
92
+ ('Euler h max d', sample_euler_h_m_d, ['k_euler'], {"brownian_noise": True}),
93
+ ('Euler h max e', sample_euler_h_m_e, ['k_euler'], {"brownian_noise": True}),
94
+ ('Euler h max f', sample_euler_h_m_f, ['k_euler'], {"brownian_noise": True}),
95
+ ('Euler h max g', sample_euler_h_m_g, ['k_euler'], {"brownian_noise": True}),
96
+ ('Euler h max b c', sample_euler_h_m_b_c, ['k_euler'], {"brownian_noise": True}),
97
+ ('Euler h max b c CFG++', sample_euler_h_m_b_c_pp, ['k_euler'], {"brownian_noise": True, "cfgpp": True}),
98
+ ('Euler Dy koishi-star', sample_euler_dy_og, ['k_euler'], {}),
99
+ ('Euler Smea Dy koishi-star', sample_euler_smea_dy_og, ['k_euler'], {}),
100
+ ('TCD Euler a', sample_tcd_euler_a, ['tcd_euler_a'], {}),
101
+ ('TCD', sample_tcd, ['tcd'], {}),
102
+ ]
103
+
104
+ samplers_data_smea = [
105
+ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
106
+ for label, funcname, aliases, options in samplers_smea
107
+ if callable(funcname)
108
+ ]
109
+
110
+ sampler_exparams_smea = {
111
+ sample_euler_max: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
112
+ sample_euler_max1b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
113
+ sample_euler_max1c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
114
+ sample_euler_max1d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
115
+ sample_euler_max2: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
116
+ sample_euler_max2b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
117
+ sample_euler_max2c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
118
+ sample_euler_max2d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
119
+ sample_euler_max3: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
120
+ sample_euler_max3b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
121
+ sample_euler_max3c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
122
+ sample_euler_max4: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
123
+ sample_euler_max4b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
124
+ sample_euler_max4c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
125
+ sample_euler_max4d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
126
+ sample_euler_max4e: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
127
+ sample_euler_max4f: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
128
+ sample_euler_dy: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
129
+ sample_euler_smea: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
130
+ sample_euler_smea_dy: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
131
+ sample_euler_smea_max: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
132
+ sample_euler_smea_max_s: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
133
+ sample_euler_smea_dyn_a: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
134
+ sample_euler_smea_dyn_b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
135
+ sample_euler_smea_dyn_c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
136
+ sample_euler_smea_multi_a: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
137
+ sample_euler_smea_multi_b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
138
+ sample_euler_smea_multi_c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
139
+ sample_euler_smea_multi_d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
140
+ sample_euler_smea_multi_as: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
141
+ sample_euler_smea_multi_bs: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
142
+ sample_euler_smea_multi_cs: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
143
+ sample_euler_smea_multi_ds: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
144
+ sample_euler_smea_multi_bs2: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
145
+ sample_euler_smea_multi_ds2: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
146
+ sample_euler_smea_multi_ds2_m: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
147
+ sample_euler_smea_multi_ds2_s_m: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
148
+ sample_euler_smea_multi_bs2_s: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
149
+ sample_euler_smea_multi_ds2_s: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
150
+ sample_euler_h_m: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
151
+ sample_euler_h_m_b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
152
+ sample_euler_h_m_c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
153
+ sample_euler_h_m_d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
154
+ sample_euler_h_m_e: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
155
+ sample_euler_h_m_f: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
156
+ sample_euler_h_m_g: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
157
+ sample_euler_h_m_b_c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
158
+ sample_euler_h_m_b_c_pp: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
159
+ sample_euler_dy_og: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
160
+ sample_euler_smea_dy_og: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
161
+ }
162
+ sd_samplers_kdiffusion.sampler_extra_params = {**sd_samplers_kdiffusion.sampler_extra_params, **sampler_exparams_smea}
163
+
164
+ samplers_map_smea = {x.name: x for x in samplers_data_smea}
165
+ sd_samplers_kdiffusion.k_diffusion_samplers_map = {**sd_samplers_kdiffusion.k_diffusion_samplers_map, **samplers_map_smea}
166
+
167
+ for i, item in enumerate(sd_samplers.all_samplers):
168
+ if "Euler" in item.name:
169
+ sd_samplers.all_samplers = sd_samplers.all_samplers[:i + 1] + [*samplers_data_smea] + sd_samplers.all_samplers[i + 1:]
170
+ break
171
+ sd_samplers.all_samplers_map = {x.name: x for x in sd_samplers.all_samplers}
172
+ sd_samplers.set_samplers()
173
+
174
+ return
175
+
176
+ def default_noise_sampler(x):
177
+ return lambda sigma, sigma_next: k_diffusion.sampling.torch.randn_like(x)
178
+
179
+ @torch.no_grad()
180
+ def dy_sampling_step(x, model, dt, sigma_hat, **extra_args):
181
+ original_shape = x.shape
182
+ batch_size, channels, m, n = original_shape[0], original_shape[1], original_shape[2] // 2, original_shape[3] // 2
183
+ extra_row = x.shape[2] % 2 == 1
184
+ extra_col = x.shape[3] % 2 == 1
185
+
186
+ if extra_row:
187
+ extra_row_content = x[:, :, -1:, :]
188
+ x = x[:, :, :-1, :]
189
+ if extra_col:
190
+ extra_col_content = x[:, :, :, -1:]
191
+ x = x[:, :, :, :-1]
192
+
193
+ a_list = x.unfold(2, 2, 2).unfold(3, 2, 2).contiguous().view(batch_size, channels, m * n, 2, 2)
194
+ c = a_list[:, :, :, 1, 1].view(batch_size, channels, m, n)
195
+
196
+ with _Rescaler(model, c, 'nearest-exact', **extra_args) as rescaler:
197
+ denoised = model(c, sigma_hat * c.new_ones([c.shape[0]]), **rescaler.extra_args)
198
+ d = to_d(c, sigma_hat, denoised)
199
+ c = c + d * dt
200
+
201
+ d_list = c.view(batch_size, channels, m * n, 1, 1)
202
+ a_list[:, :, :, 1, 1] = d_list[:, :, :, 0, 0]
203
+ x = a_list.view(batch_size, channels, m, n, 2, 2).permute(0, 1, 2, 4, 3, 5).reshape(batch_size, channels, 2 * m, 2 * n)
204
+
205
+ if extra_row or extra_col:
206
+ x_expanded = torch.zeros(original_shape, dtype=x.dtype, device=x.device)
207
+ x_expanded[:, :, :2 * m, :2 * n] = x
208
+ if extra_row:
209
+ x_expanded[:, :, -1:, :2 * n + 1] = extra_row_content
210
+ if extra_col:
211
+ x_expanded[:, :, :2 * m, -1:] = extra_col_content
212
+ if extra_row and extra_col:
213
+ x_expanded[:, :, -1:, -1:] = extra_col_content[:, :, -1:, :]
214
+ x = x_expanded
215
+
216
+ return x
217
+
218
+ @torch.no_grad()
219
+ def smea_sampling_step(x, model, dt, sigma_hat, **extra_args):
220
+ m, n = x.shape[2], x.shape[3]
221
+ x = torch.nn.functional.interpolate(input=x, size=None, scale_factor=(1.25, 1.25), mode='nearest-exact', align_corners=None, recompute_scale_factor=None)
222
+ with _Rescaler(model, x, 'nearest-exact', **extra_args) as rescaler:
223
+ denoised = model(x, sigma_hat * x.new_ones([x.shape[0]]), **rescaler.extra_args)
224
+ d = to_d(x, sigma_hat, denoised)
225
+ x = x + d * dt
226
+ x = torch.nn.functional.interpolate(input=x, size=(m,n), scale_factor=None, mode='nearest-exact', align_corners=None, recompute_scale_factor=None)
227
+ return x
228
+
229
+ @torch.no_grad()
230
+ def smea_sampling_step_denoised(x, model, sigma_hat, scale=1.25, smooth=False, **extra_args):
231
+ m, n = x.shape[2], x.shape[3]
232
+ filter = 'nearest-exact' if not smooth else 'bilinear'
233
+ x = torch.nn.functional.interpolate(input=x, scale_factor=(scale, scale), mode=filter)
234
+ with _Rescaler(model, x, filter, **extra_args) as rescaler:
235
+ denoised = model(x, sigma_hat * x.new_ones([x.shape[0]]), **rescaler.extra_args)
236
+ x = denoised
237
+ x = torch.nn.functional.interpolate(input=x, size=(m,n), mode='nearest-exact')
238
+ return x
239
+
240
+ @torch.no_grad()
241
+ def sample_euler_max(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
242
+ extra_args = {} if extra_args is None else extra_args
243
+ s_in = x.new_ones([x.shape[0]])
244
+ for i in trange(len(sigmas) - 1, disable=disable):
245
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
246
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
247
+ sigma_hat = sigmas[i] * (gamma + 1)
248
+ if gamma > 0:
249
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
250
+ denoised = model(x, sigma_hat * s_in, **extra_args)
251
+ d = to_d(x, sigma_hat, denoised)
252
+ if callback is not None:
253
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
254
+ dt = sigmas[i + 1] - sigma_hat
255
+ # Euler method
256
+ x = x + (math.cos(i + 1)/(i + 1) + 1) * d * dt
257
+ return x
258
+
259
+
260
+ @torch.no_grad()
261
+ def sample_euler_max1b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
262
+ extra_args = {} if extra_args is None else extra_args
263
+ s_in = x.new_ones([x.shape[0]])
264
+ for i in trange(len(sigmas) - 1, disable=disable):
265
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
266
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
267
+ sigma_hat = sigmas[i] * (gamma + 1)
268
+ if gamma > 0:
269
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
270
+ denoised = model(x, sigma_hat * s_in, **extra_args)
271
+ d = to_d(x, sigma_hat, denoised)
272
+ if callback is not None:
273
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
274
+ dt = sigmas[i + 1] - sigma_hat
275
+ # Euler method
276
+ x = x + (math.cos(1.05 * i + 1)/(1.1 * i + 1.5) + 1) * d * dt
277
+ return x
278
+
279
+ @torch.no_grad()
280
+ def sample_euler_max1c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
281
+ extra_args = {} if extra_args is None else extra_args
282
+ s_in = x.new_ones([x.shape[0]])
283
+ for i in trange(len(sigmas) - 1, disable=disable):
284
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
285
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
286
+ sigma_hat = sigmas[i] * (gamma + 1)
287
+ if gamma > 0:
288
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
289
+ denoised = model(x, sigma_hat * s_in, **extra_args)
290
+ d = to_d(x, sigma_hat, denoised)
291
+ if callback is not None:
292
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
293
+ dt = sigmas[i + 1] - sigma_hat
294
+ # Euler method
295
+ x = x + (math.cos(1.05 * i + 1.1)/(1.25 * i + 1.5) + 1) * d * dt
296
+ return x
297
+
298
+ @torch.no_grad()
299
+ def sample_euler_max1d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
300
+ extra_args = {} if extra_args is None else extra_args
301
+ s_in = x.new_ones([x.shape[0]])
302
+ for i in trange(len(sigmas) - 1, disable=disable):
303
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
304
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
305
+ sigma_hat = sigmas[i] * (gamma + 1)
306
+ if gamma > 0:
307
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
308
+ denoised = model(x, sigma_hat * s_in, **extra_args)
309
+ d = to_d(x, sigma_hat, denoised)
310
+ if callback is not None:
311
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
312
+ dt = sigmas[i + 1] - sigma_hat
313
+ # Euler method
314
+ x = x + (math.cos(math.pi * 0.333 * i + 0.9)/(0.5 * i + 1.5) + 1) * d * dt
315
+ return x
316
+
317
+ @torch.no_grad()
318
+ def sample_euler_max2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
319
+ extra_args = {} if extra_args is None else extra_args
320
+ s_in = x.new_ones([x.shape[0]])
321
+ for i in trange(len(sigmas) - 1, disable=disable):
322
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
323
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
324
+ sigma_hat = sigmas[i] * (gamma + 1)
325
+ if gamma > 0:
326
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
327
+ denoised = model(x, sigma_hat * s_in, **extra_args)
328
+ d = to_d(x, sigma_hat, denoised)
329
+ if callback is not None:
330
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
331
+ dt = sigmas[i + 1] - sigma_hat
332
+ # Euler method
333
+ x = x + (math.cos(math.pi * 0.333 * i - 0.1)/(0.5 * i + 1.5) + 1) * d * dt
334
+ return x
335
+
336
+ @torch.no_grad()
337
+ def sample_euler_max2b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
338
+ extra_args = {} if extra_args is None else extra_args
339
+ s_in = x.new_ones([x.shape[0]])
340
+ for i in trange(len(sigmas) - 1, disable=disable):
341
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
342
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
343
+ sigma_hat = sigmas[i] * (gamma + 1)
344
+ if gamma > 0:
345
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
346
+ denoised = model(x, sigma_hat * s_in, **extra_args)
347
+ d = to_d(x, sigma_hat, denoised)
348
+ if callback is not None:
349
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
350
+ dt = sigmas[i + 1] - sigma_hat
351
+ # Euler method
352
+ x = x + (math.cos(math.pi * 0.5 * i - 0.0)/(0.5 * i + 1.5) + 1) * d * dt
353
+ return x
354
+
355
+ @torch.no_grad()
356
+ def sample_euler_max2c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
357
+ extra_args = {} if extra_args is None else extra_args
358
+ s_in = x.new_ones([x.shape[0]])
359
+ for i in trange(len(sigmas) - 1, disable=disable):
360
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
361
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
362
+ sigma_hat = sigmas[i] * (gamma + 1)
363
+ if gamma > 0:
364
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
365
+ denoised = model(x, sigma_hat * s_in, **extra_args)
366
+ d = to_d(x, sigma_hat, denoised)
367
+ if callback is not None:
368
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
369
+ dt = sigmas[i + 1] - sigma_hat
370
+ # Euler method
371
+ x = x + (math.cos(math.pi * 0.5 * i)/(i + 2) + 1) * d * dt
372
+ return x
373
+
374
+ @torch.no_grad()
375
+ def sample_euler_max2d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
376
+ extra_args = {} if extra_args is None else extra_args
377
+ s_in = x.new_ones([x.shape[0]])
378
+ for i in trange(len(sigmas) - 1, disable=disable):
379
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
380
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
381
+ sigma_hat = sigmas[i] * (gamma + 1)
382
+ if gamma > 0:
383
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
384
+ denoised = model(x, sigma_hat * s_in, **extra_args)
385
+ d = to_d(x, sigma_hat, denoised)
386
+ if callback is not None:
387
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
388
+ dt = sigmas[i + 1] - sigma_hat
389
+ # Euler method
390
+ x = x + (math.cos(math.pi * 0.5 * i)/(0.75 * i + 1.75) + 1) * d * dt
391
+ return x
392
+
393
+ @torch.no_grad()
394
+ def sample_euler_max3b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
395
+ extra_args = {} if extra_args is None else extra_args
396
+ s_in = x.new_ones([x.shape[0]])
397
+ for i in trange(len(sigmas) - 1, disable=disable):
398
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
399
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
400
+ sigma_hat = sigmas[i] * (gamma + 1)
401
+ if gamma > 0:
402
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
403
+ denoised = model(x, sigma_hat * s_in, **extra_args)
404
+ d = to_d(x, sigma_hat, denoised)
405
+ if callback is not None:
406
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
407
+ dt = sigmas[i + 1] - sigma_hat
408
+ # Euler method
409
+ x = x + (math.cos(2 * i + 0.5)/(2 * i + 1.5) + 1) * d * dt
410
+ return x
411
+
412
+ @torch.no_grad()
413
+ def sample_euler_max3c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
414
+ extra_args = {} if extra_args is None else extra_args
415
+ s_in = x.new_ones([x.shape[0]])
416
+ for i in trange(len(sigmas) - 1, disable=disable):
417
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
418
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
419
+ sigma_hat = sigmas[i] * (gamma + 1)
420
+ if gamma > 0:
421
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
422
+ denoised = model(x, sigma_hat * s_in, **extra_args)
423
+ d = to_d(x, sigma_hat, denoised)
424
+ if callback is not None:
425
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
426
+ dt = sigmas[i + 1] - sigma_hat
427
+ # Euler method
428
+ x = x + (math.cos(2 * i + 0.5)/(1.5 * i + 2.7) + 1) * d * dt
429
+ return x
430
+
431
+ @torch.no_grad()
432
+ def sample_euler_max3(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
433
+ extra_args = {} if extra_args is None else extra_args
434
+ s_in = x.new_ones([x.shape[0]])
435
+ for i in trange(len(sigmas) - 1, disable=disable):
436
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
437
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
438
+ sigma_hat = sigmas[i] * (gamma + 1)
439
+ if gamma > 0:
440
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
441
+ denoised = model(x, sigma_hat * s_in, **extra_args)
442
+ d = to_d(x, sigma_hat, denoised)
443
+ if callback is not None:
444
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
445
+ dt = sigmas[i + 1] - sigma_hat
446
+ # Euler method
447
+ x = x + (math.cos(2 * i + 1)/(2 * i + 1) + 1) * d * dt
448
+ return x
449
+
450
+ @torch.no_grad()
451
+ def sample_euler_max4b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
452
+ extra_args = {} if extra_args is None else extra_args
453
+ s_in = x.new_ones([x.shape[0]])
454
+ for i in trange(len(sigmas) - 1, disable=disable):
455
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
456
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
457
+ sigma_hat = sigmas[i] * (gamma + 1)
458
+ if gamma > 0:
459
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
460
+ denoised = model(x, sigma_hat * s_in, **extra_args)
461
+ d = to_d(x, sigma_hat, denoised)
462
+ if callback is not None:
463
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
464
+ dt = sigmas[i + 1] - sigma_hat
465
+ # Euler method
466
+ x = x + (math.cos(math.pi * i - 0.1)/(2 * i + 2) + 1) * d * dt
467
+ return x
468
+
469
+ @torch.no_grad()
470
+ def sample_euler_max4c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
471
+ extra_args = {} if extra_args is None else extra_args
472
+ s_in = x.new_ones([x.shape[0]])
473
+ for i in trange(len(sigmas) - 1, disable=disable):
474
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
475
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
476
+ sigma_hat = sigmas[i] * (gamma + 1)
477
+ if gamma > 0:
478
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
479
+ denoised = model(x, sigma_hat * s_in, **extra_args)
480
+ d = to_d(x, sigma_hat, denoised)
481
+ if callback is not None:
482
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
483
+ dt = sigmas[i + 1] - sigma_hat
484
+ # Euler method
485
+ x = x + (math.cos(math.pi * i - 0.1)/(2 * i + 1.5) + 1) * d * dt
486
+ return x
487
+
488
+ @torch.no_grad()
489
+ def sample_euler_max4d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
490
+ extra_args = {} if extra_args is None else extra_args
491
+ s_in = x.new_ones([x.shape[0]])
492
+ for i in trange(len(sigmas) - 1, disable=disable):
493
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
494
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
495
+ sigma_hat = sigmas[i] * (gamma + 1)
496
+ if gamma > 0:
497
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
498
+ denoised = model(x, sigma_hat * s_in, **extra_args)
499
+ d = to_d(x, sigma_hat, denoised)
500
+ if callback is not None:
501
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
502
+ dt = sigmas[i + 1] - sigma_hat
503
+ # Euler method
504
+ x = x + (math.cos(math.pi * i - 0.1)/(i + 1.5) + 1) * d * dt
505
+ return x
506
+
507
+ @torch.no_grad()
508
+ def sample_euler_max4e(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
509
+ extra_args = {} if extra_args is None else extra_args
510
+ s_in = x.new_ones([x.shape[0]])
511
+ for i in trange(len(sigmas) - 1, disable=disable):
512
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
513
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
514
+ sigma_hat = sigmas[i] * (gamma + 1)
515
+ if gamma > 0:
516
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
517
+ denoised = model(x, sigma_hat * s_in, **extra_args)
518
+ d = to_d(x, sigma_hat, denoised)
519
+ if callback is not None:
520
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
521
+ dt = sigmas[i + 1] - sigma_hat
522
+ # Euler method
523
+ x = x + (math.cos(math.pi * i - 0.1)/(i + 1) + 1) * d * dt
524
+ return x
525
+
526
+ @torch.no_grad()
527
+ def sample_euler_max4f(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
528
+ extra_args = {} if extra_args is None else extra_args
529
+ s_in = x.new_ones([x.shape[0]])
530
+ for i in trange(len(sigmas) - 1, disable=disable):
531
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
532
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
533
+ sigma_hat = sigmas[i] * (gamma + 1)
534
+ if gamma > 0:
535
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
536
+ denoised = model(x, sigma_hat * s_in, **extra_args)
537
+ d = to_d(x, sigma_hat, denoised)
538
+ if callback is not None:
539
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
540
+ dt = sigmas[i + 1] - sigma_hat
541
+ # Euler method
542
+ x = x + (math.cos(math.pi * i - 0.1)/(i + 2) + 1) * d * dt
543
+ return x
544
+
545
+
546
+ @torch.no_grad()
547
+ def sample_euler_max4(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
548
+ # 袛芯斜邪胁褜褌械 蟹写械褋褜 褌械谢芯 褎褍薪泻褑懈懈 懈谢懈 褏芯褌褟 斜褘 pass, 褔褌芯斜褘 懈蟹斜械卸邪褌褜 IndentationError
549
+ pass
550
+
551
+ @torch.no_grad()
552
+ def sample_euler_dy(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
553
+ extra_args = {} if extra_args is None else extra_args
554
+ s_in = x.new_ones([x.shape[0]])
555
+ for i in trange(len(sigmas) - 1, disable=disable):
556
+ # print(i)
557
+ # i绗竴姝ヤ负0
558
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
559
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
560
+ sigma_hat = sigmas[i] * (gamma + 1)
561
+ # print(sigma_hat)
562
+ dt = sigmas[i + 1] - sigma_hat
563
+ if gamma > 0:
564
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
565
+ denoised = model(x, sigma_hat * s_in, **extra_args)
566
+ d = to_d(x, sigma_hat, denoised)
567
+ if callback is not None:
568
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
569
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2 and i % 2 == 0:
570
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
571
+ dt_1 = sigma_mid - sigmas[i]
572
+ dt_2 = sigmas[i + 1] - sigmas[i]
573
+ x_2 = x + d * dt_1
574
+ x_temp = dy_sampling_step(x_2, model, dt_2, sigma_mid, **extra_args)
575
+ x = x_temp - d * dt_1
576
+ # Euler method
577
+ x = x + d * dt
578
+ return x
579
+
580
+ @torch.no_grad()
581
+ def sample_euler_smea_dyn_a(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
582
+ extra_args = {} if extra_args is None else extra_args
583
+ s_in = x.new_ones([x.shape[0]])
584
+ for i in trange(len(sigmas) - 1, disable=disable):
585
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
586
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
587
+ sigma_hat = sigmas[i] * (gamma + 1)
588
+ if gamma > 0:
589
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
590
+ denoised = model(x, sigma_hat * s_in, **extra_args)
591
+ d = to_d(x, sigma_hat, denoised)
592
+ if callback is not None:
593
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
594
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2:
595
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
596
+ dt_1 = sigma_mid - sigma_hat
597
+ dt_2 = sigmas[i + 1] - sigma_hat
598
+ x_2 = x + d * dt_1
599
+ #scale = (sigma_mid / sigmas[0]) * 0.25
600
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2 * 0.15
601
+ #scale = scale.item()
602
+ if i % 2 == 0:
603
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale, **extra_args)
604
+ #denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + sigma_mid.item() * 0.01, **extra_args)
605
+ else:
606
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
607
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
608
+ x = x + d_2 * dt_2
609
+ else:
610
+ dt = sigmas[i + 1] - sigma_hat
611
+ # Euler method
612
+ x = x + d * dt
613
+ return x
614
+
615
+ @torch.no_grad()
616
+ def sample_euler_smea_dyn_b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
617
+ extra_args = {} if extra_args is None else extra_args
618
+ s_in = x.new_ones([x.shape[0]])
619
+ for i in trange(len(sigmas) - 1, disable=disable):
620
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
621
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
622
+ sigma_hat = sigmas[i] * (gamma + 1)
623
+ if gamma > 0:
624
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
625
+ denoised = model(x, sigma_hat * s_in, **extra_args)
626
+ d = to_d(x, sigma_hat, denoised)
627
+ if callback is not None:
628
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
629
+ if sigmas[i + 1] > 0 and (i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 3 or i < 3):
630
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
631
+ dt_1 = sigma_mid - sigma_hat
632
+ dt_2 = sigmas[i + 1] - sigma_hat
633
+ x_2 = x + d * dt_1
634
+ #scale = (sigma_mid / sigmas[0]) * 0.25
635
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2 * 0.2
636
+ #scale = scale.item()
637
+ if i % 4 == 0:
638
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale, **extra_args)
639
+ #denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - sigma_mid.item() * 0.01, **extra_args)
640
+ elif i % 4 == 2:
641
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale, **extra_args)
642
+ #denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + sigma_mid.item() * 0.01, **extra_args)
643
+ else:
644
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
645
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
646
+ x = x + d_2 * dt_2
647
+ else:
648
+ dt = sigmas[i + 1] - sigma_hat
649
+ # Euler method
650
+ x = x + d * dt
651
+ return x
652
+
653
+ @torch.no_grad()
654
+ def sample_euler_smea_dyn_c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
655
+ extra_args = {} if extra_args is None else extra_args
656
+ s_in = x.new_ones([x.shape[0]])
657
+ for i in trange(len(sigmas) - 1, disable=disable):
658
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
659
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
660
+ sigma_hat = sigmas[i] * (gamma + 1)
661
+ if gamma > 0:
662
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
663
+ denoised = model(x, sigma_hat * s_in, **extra_args)
664
+ d = to_d(x, sigma_hat, denoised)
665
+ if callback is not None:
666
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
667
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2:
668
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
669
+ dt_1 = sigma_mid - sigma_hat
670
+ dt_2 = sigmas[i + 1] - sigma_hat
671
+ x_2 = x + d * dt_1
672
+ #scale = (sigma_mid / sigmas[0]) * 0.25
673
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2 * 0.25
674
+ #scale = scale.item()
675
+ if i % 2 == 0:
676
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale, **extra_args)
677
+ #denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + sigma_mid.item() * 0.01, **extra_args)
678
+ else:
679
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
680
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
681
+ x = x + d_2 * dt_2
682
+ else:
683
+ dt = sigmas[i + 1] - sigma_hat
684
+ # Euler method
685
+ x = x + d * dt
686
+ return x
687
+
688
+ @torch.no_grad()
689
+ def sample_euler_smea(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
690
+ extra_args = {} if extra_args is None else extra_args
691
+ s_in = x.new_ones([x.shape[0]])
692
+ for i in trange(len(sigmas) - 1, disable=disable):
693
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
694
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
695
+ sigma_hat = sigmas[i] * (gamma + 1)
696
+ if gamma > 0:
697
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
698
+ denoised = model(x, sigma_hat * s_in, **extra_args)
699
+ d = to_d(x, sigma_hat, denoised)
700
+ if callback is not None:
701
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
702
+ dt = sigmas[i + 1] - sigma_hat
703
+ # Euler method
704
+ x = x + d * dt
705
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2 and i % 2 == 0:
706
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
707
+ dt_1 = sigma_mid - sigmas[i]
708
+ dt_2 = sigmas[i + 1] - sigmas[i]
709
+ #print(dt_1, "#", dt_2, "#", dt_3, "#", dt_4)
710
+ x_2 = x + d * dt_1
711
+ x_temp = smea_sampling_step(x, model, dt_2, sigma_mid, **extra_args)
712
+ x = x_temp - d * dt_1
713
+ return x
714
+
715
+ @torch.no_grad()
716
+ def sample_euler_smea_dy(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
717
+ extra_args = {} if extra_args is None else extra_args
718
+ s_in = x.new_ones([x.shape[0]])
719
+ for i in trange(len(sigmas) - 1, disable=disable):
720
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
721
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
722
+ sigma_hat = sigmas[i] * (gamma + 1)
723
+ if gamma > 0:
724
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
725
+ denoised = model(x, sigma_hat * s_in, **extra_args)
726
+ d = to_d(x, sigma_hat, denoised)
727
+ if callback is not None:
728
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
729
+ dt = sigmas[i + 1] - sigma_hat
730
+ # Euler method
731
+ x = x + d * dt
732
+ if sigmas[i + 1] > 0 and (i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2 or i < 3) and i % 3 != 2:
733
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
734
+ dt_1 = sigma_mid - sigmas[i]
735
+ dt_2 = sigmas[i + 1] - sigmas[i]
736
+ #print(dt_1, "#", dt_2, "#", dt_3, "#", dt_4)
737
+ x_2 = x + d * dt_1
738
+ if i % 3 == 1:
739
+ x_temp = dy_sampling_step(x, model, dt_2, sigma_mid, **extra_args)
740
+ elif i % 3 == 0:
741
+ x_temp = smea_sampling_step(x, model, dt_2, sigma_mid, **extra_args)
742
+ x = x_temp - d * dt_1
743
+ return x
744
+
745
+ @torch.no_grad()
746
+ def sample_euler_smea_multi_d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
747
+ extra_args = {} if extra_args is None else extra_args
748
+ s_in = x.new_ones([x.shape[0]])
749
+ for i in trange(len(sigmas) - 1, disable=disable):
750
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
751
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
752
+ sigma_hat = sigmas[i] * (gamma + 1)
753
+ if gamma > 0:
754
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
755
+ denoised = model(x, sigma_hat * s_in, **extra_args)
756
+ d = to_d(x, sigma_hat, denoised)
757
+ if callback is not None:
758
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
759
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 + 2 and i % 2 == 0:
760
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
761
+ dt_1 = sigma_mid - sigma_hat
762
+ dt_2 = sigmas[i + 1] - sigma_hat
763
+ x_2 = x + d * dt_1
764
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
765
+ if i == 0:
766
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale * 0.15, **extra_args)
767
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
768
+ denoised_2 = (denoised_2a + denoised_2c) / 2
769
+ elif i < len(sigmas) * 0.334:
770
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale * 0.25, **extra_args)
771
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale * 0.15, **extra_args)
772
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
773
+ denoised_2 = (denoised_2a + denoised_2b + denoised_2c) / 3
774
+ else:
775
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale * 0.03, True, **extra_args)
776
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
777
+ denoised_2 = (denoised_2b + denoised_2c) / 2
778
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
779
+ x = x + d_2 * dt_2
780
+ else:
781
+ dt = sigmas[i + 1] - sigma_hat
782
+ # Euler method
783
+ x = x + d * dt
784
+ return x
785
+
786
+ @torch.no_grad()
787
+ def sample_euler_smea_multi_b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
788
+ extra_args = {} if extra_args is None else extra_args
789
+ s_in = x.new_ones([x.shape[0]])
790
+ for i in trange(len(sigmas) - 1, disable=disable):
791
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
792
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
793
+ sigma_hat = sigmas[i] * (gamma + 1)
794
+ if gamma > 0:
795
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
796
+ denoised = model(x, sigma_hat * s_in, **extra_args)
797
+ d = to_d(x, sigma_hat, denoised)
798
+ if callback is not None:
799
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
800
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
801
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
802
+ dt_1 = sigma_mid - sigma_hat
803
+ dt_2 = sigmas[i + 1] - sigma_hat
804
+ x_2 = x + d * dt_1
805
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
806
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale * 0.25, **extra_args)
807
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale * 0.15, **extra_args)
808
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
809
+ denoised_2 = (denoised_2a + denoised_2b + denoised_2c) / 3
810
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
811
+ x = x + d_2 * dt_2
812
+ else:
813
+ dt = sigmas[i + 1] - sigma_hat
814
+ # Euler method
815
+ x = x + d * dt
816
+ return x
817
+
818
+ @torch.no_grad()
819
+ def sample_euler_smea_multi_c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
820
+ extra_args = {} if extra_args is None else extra_args
821
+ s_in = x.new_ones([x.shape[0]])
822
+ for i in trange(len(sigmas) - 1, disable=disable):
823
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
824
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
825
+ sigma_hat = sigmas[i] * (gamma + 1)
826
+ if gamma > 0:
827
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
828
+ denoised = model(x, sigma_hat * s_in, **extra_args)
829
+ d = to_d(x, sigma_hat, denoised)
830
+ if callback is not None:
831
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
832
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
833
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
834
+ dt_1 = sigma_mid - sigma_hat
835
+ dt_2 = sigmas[i + 1] - sigma_hat
836
+ x_2 = x + d * dt_1
837
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
838
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale * 0.25, **extra_args)
839
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
840
+ denoised_2 = (denoised_2a + denoised_2c) / 2
841
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
842
+ x = x + d_2 * dt_2
843
+ else:
844
+ dt = sigmas[i + 1] - sigma_hat
845
+ # Euler method
846
+ x = x + d * dt
847
+ return x
848
+
849
+ @torch.no_grad()
850
+ def sample_euler_smea_multi_a(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
851
+ extra_args = {} if extra_args is None else extra_args
852
+ s_in = x.new_ones([x.shape[0]])
853
+ for i in trange(len(sigmas) - 1, disable=disable):
854
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
855
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
856
+ sigma_hat = sigmas[i] * (gamma + 1)
857
+ if gamma > 0:
858
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
859
+ denoised = model(x, sigma_hat * s_in, **extra_args)
860
+ d = to_d(x, sigma_hat, denoised)
861
+ if callback is not None:
862
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
863
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
864
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
865
+ dt_1 = sigma_mid - sigma_hat
866
+ dt_2 = sigmas[i + 1] - sigma_hat
867
+ x_2 = x + d * dt_1
868
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
869
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale * 0.15, **extra_args)
870
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
871
+ denoised_2 = (denoised_2b + denoised_2c) / 2
872
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
873
+ x = x + d_2 * dt_2
874
+ else:
875
+ dt = sigmas[i + 1] - sigma_hat
876
+ # Euler method
877
+ x = x + d * dt
878
+ return x
879
+
880
+
881
+ @torch.no_grad()
882
+ def sample_euler_smea_multi_ds(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
883
+ extra_args = {} if extra_args is None else extra_args
884
+ s_in = x.new_ones([x.shape[0]])
885
+ for i in trange(len(sigmas) - 1, disable=disable):
886
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
887
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
888
+ sigma_hat = sigmas[i] * (gamma + 1)
889
+ if gamma > 0:
890
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
891
+ denoised = model(x, sigma_hat * s_in, **extra_args)
892
+ d = to_d(x, sigma_hat, denoised)
893
+ if callback is not None:
894
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
895
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167 + 1: # and i % 2 == 0:
896
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
897
+ dt_1 = sigma_mid - sigma_hat
898
+ dt_2 = sigmas[i + 1] - sigma_hat
899
+ x_2 = x + d * dt_1
900
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
901
+ if i == 0:
902
+ sa = 1 - scale * 0.15
903
+ sb = 1 + scale * 0.09
904
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
905
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, sb, **extra_args)
906
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.625 + denoised_2b * (sb ** 2) * 0.375) / (0.97**2)
907
+ elif i < len(sigmas) * 0.167:
908
+ sa = 1 - scale * 0.25
909
+ sb = 1 + scale * 0.15
910
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
911
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, sb , **extra_args)
912
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.625 + denoised_2b * (sb ** 2) * 0.375) / (0.95**2)
913
+ else:
914
+ sb = 1 + scale * 0.06
915
+ sc = 1 - scale * 0.1
916
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, sb, True, **extra_args)
917
+ denoised_2c = smea_sampling_step_denoised(x_2, model, sigma_mid, sc, **extra_args)
918
+ denoised_2 = (denoised_2b * (sb ** 2) * 0.375 + denoised_2c * (sc ** 2) * 0.625) / (0.98**2)
919
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
920
+ x = x + d_2 * dt_2
921
+ else:
922
+ dt = sigmas[i + 1] - sigma_hat
923
+ # Euler method
924
+ x = x + d * dt
925
+ return x
926
+
927
+ @torch.no_grad()
928
+ def sample_euler_smea_multi_ds2_s(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
929
+ sample = sample_euler_smea_multi_ds2(model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise, smooth=True)
930
+ return sample
931
+
932
+ @torch.no_grad()
933
+ def sample_euler_smea_multi_ds2_s_m(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
934
+ sample = sample_euler_smea_multi_ds2_m(model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise, smooth=True)
935
+ return sample
936
+
937
+ @torch.no_grad()
938
+ def sample_euler_smea_multi_ds2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., smooth=False):
939
+ extra_args = {} if extra_args is None else extra_args
940
+ s_in = x.new_ones([x.shape[0]])
941
+ for i in trange(len(sigmas) - 1, disable=disable):
942
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
943
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
944
+ sigma_hat = sigmas[i] * (gamma + 1)
945
+ if gamma > 0:
946
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
947
+ denoised = model(x, sigma_hat * s_in, **extra_args)
948
+ d = to_d(x, sigma_hat, denoised)
949
+ if callback is not None:
950
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
951
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167 + 1: # and i % 2 == 0:
952
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
953
+ dt_1 = sigma_mid - sigma_hat
954
+ dt_2 = sigmas[i + 1] - sigma_hat
955
+ x_2 = x + d * dt_1
956
+ scale = (sigmas[i] / sigmas[0]) ** 2
957
+ scale = scale.item()
958
+ if i == 0:
959
+ sa = 1 - scale * 0.15
960
+ sb = 1 + scale * 0.09
961
+ sigA = sigma_mid / (sa ** 2)
962
+ sigB = sigma_mid / (sb ** 2)
963
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
964
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
965
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2) #/ (0.97**2) # 1 - (sa * sb ) / 2 + 1
966
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
967
+ elif i < len(sigmas) * 0.167:
968
+ sa = 1 - scale * 0.25
969
+ sb = 1 + scale * 0.15
970
+ sigA = sigma_mid / (sa ** 2)
971
+ sigB = sigma_mid / (sb ** 2)
972
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
973
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
974
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2) #/ (0.95**2)
975
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
976
+ else:
977
+ sb = 1 + scale * 0.06
978
+ sc = 1 - scale * 0.1
979
+ sigB = sigma_mid / (sb ** 2)
980
+ sigC = sigma_mid / (sc ** 2)
981
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
982
+ denoised_2c = smea_sampling_step_denoised(x_2, model, sigC, sc, smooth, **extra_args)
983
+ denoised_2 = (denoised_2b * (sb ** 2) * 0.5 * sc ** 2 + denoised_2c * (sc ** 2) * 0.5 * sb ** 2) #/ (0.98**2)
984
+ d_2 = to_d(x_2, sigB * 0.5 * sc ** 2 + sigC * 0.5 * sb ** 2, denoised_2)
985
+ x = x + d_2 * dt_2
986
+ else:
987
+ dt = sigmas[i + 1] - sigma_hat
988
+ # Euler method
989
+ x = x + d * dt
990
+ return x
991
+
992
+ @torch.no_grad()
993
+ def sample_euler_smea_multi_ds2_m(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., smooth=False):
994
+ extra_args = {} if extra_args is None else extra_args
995
+ s_in = x.new_ones([x.shape[0]])
996
+ for i in trange(len(sigmas) - 1, disable=disable):
997
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
998
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
999
+ sigma_hat = sigmas[i] * (gamma + 1)
1000
+ if gamma > 0:
1001
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1002
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1003
+ d = to_d(x, sigma_hat, denoised)
1004
+ if callback is not None:
1005
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1006
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167 + 1: # and i % 2 == 0:
1007
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1008
+ dt_1 = sigma_mid - sigma_hat
1009
+ dt_2 = sigmas[i + 1] - sigma_hat
1010
+ x_2 = x + d * dt_1
1011
+ scale = (sigmas[i] / sigmas[0]) ** 2
1012
+ #scale = dt_1 ** 2 * 0.01
1013
+ scale = scale.item()
1014
+ if i == 0:
1015
+ sa = 1 - scale * 0.15 #15
1016
+ sb = 1 + scale * 0.09 #09
1017
+ sigA = sigma_mid / (sa ** 2)
1018
+ sigB = sigma_mid / (sb ** 2)
1019
+ #delta = sa * sb
1020
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
1021
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
1022
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2) #/ (0.97**2) # 1 - (sa * sb ) / 2 + 1
1023
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
1024
+ elif i < len(sigmas) * 0.167:
1025
+ sa = 1 - scale * 0.25 #25
1026
+ sb = 1 + scale * 0.15 #15
1027
+ sigA = sigma_mid / (sa ** 2)
1028
+ sigB = sigma_mid / (sb ** 2)
1029
+ #delta = sa * sb
1030
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
1031
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
1032
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2) #/ (0.95**2)
1033
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
1034
+ else:
1035
+ sb = 1 + scale * 0.06
1036
+ sc = 1 - scale * 0.1
1037
+ sigB = sigma_mid / (sb ** 2)
1038
+ sigC = sigma_mid / (sc ** 2)
1039
+ #delta = sb * sc
1040
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
1041
+ denoised_2c = smea_sampling_step_denoised(x_2, model, sigC, sc, smooth, **extra_args)
1042
+ denoised_2 = (denoised_2b * (sb ** 2) * 0.5 * sc ** 2+ denoised_2c * (sc ** 2) * 0.5 * sb ** 2) #/ (0.98**2)
1043
+ d_2 = to_d(x_2, sigB * 0.5 * sc ** 2 + sigC * 0.5 * sb ** 2, denoised_2)
1044
+ x = x + (math.cos(1.05 * i + 1.1)/(1.25 * i + 1.5) + 1) * d_2 * dt_2
1045
+ else:
1046
+ dt = sigmas[i + 1] - sigma_hat
1047
+ # Euler method
1048
+ x = x + (math.cos(1.05 * i + 1.1)/(1.25 * i + 1.5) + 1) * d * dt
1049
+ return x
1050
+
1051
+ @torch.no_grad()
1052
+ def sample_euler_h_m(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1053
+ extra_args = {} if extra_args is None else extra_args
1054
+ s_in = x.new_ones([x.shape[0]])
1055
+ for i in trange(len(sigmas) - 1, disable=disable):
1056
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1057
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1058
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1059
+ gamma = min((2 ** 0.5 - 1) - wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1060
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler == None else noise_sampler
1061
+ sigma_hat = sigmas[i] * (gamma + 1)
1062
+ if gamma > 0:
1063
+ x = x - eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1064
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1065
+ d = to_d(x, sigma_hat, denoised)
1066
+ dt = sigmas[i + 1] - sigma_hat
1067
+ if callback is not None:
1068
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1069
+ if sigmas[i + 1] > 0:
1070
+ x_2 = x + (gamma + 1) * d * dt
1071
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1072
+ d_prime = d * 0.5 + d_2 * 0.5
1073
+ x = x + d_prime * dt
1074
+ else:
1075
+ # Euler method
1076
+ x = x + (gamma + 1) * d * dt
1077
+ return x
1078
+
1079
+ @torch.no_grad()
1080
+ def sample_euler_h_m_b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1081
+ extra_args = {} if extra_args is None else extra_args
1082
+ s_in = x.new_ones([x.shape[0]])
1083
+ for i in trange(len(sigmas) - 1, disable=disable):
1084
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1085
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1086
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1087
+ gamma = min(wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1088
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1089
+ sigma_hat = sigmas[i] * (gamma + 1)
1090
+ if gamma > 0:
1091
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1092
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1093
+ d = to_d(x, sigma_hat, denoised)
1094
+ dt = sigmas[i + 1] - sigma_hat
1095
+ if callback is not None:
1096
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1097
+ if sigmas[i + 1] > 0:
1098
+ x_2 = x + (gamma + 1) * d * dt
1099
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1100
+ d_prime = d * 0.5 + d_2 * 0.5
1101
+ x = x + d_prime * dt
1102
+ else:
1103
+ # Euler method
1104
+ x = x + (gamma + 1) * d * dt
1105
+ return x
1106
+
1107
+ @torch.no_grad()
1108
+ def sample_euler_h_m_c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1109
+ extra_args = {} if extra_args is None else extra_args
1110
+ s_in = x.new_ones([x.shape[0]])
1111
+ for i in trange(len(sigmas) - 1, disable=disable):
1112
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1113
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1114
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1115
+ gamma = max((2 ** 0.5 - 1) + wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1116
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1117
+ sigma_hat = sigmas[i] * (gamma + 1)
1118
+ if gamma > 0:
1119
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1120
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1121
+ d = to_d(x, sigma_hat, denoised)
1122
+ dt = sigmas[i + 1] - sigma_hat
1123
+ if callback is not None:
1124
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1125
+ if sigmas[i + 1] > 0:
1126
+ x_2 = x + (gamma + 1) * d * dt
1127
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1128
+ d_prime = d * 0.5 + d_2 * 0.5
1129
+ x = x + d_prime * dt
1130
+ else:
1131
+ # Euler method
1132
+ x = x + (gamma + 1) * d * dt
1133
+ return x
1134
+
1135
+ @torch.no_grad()
1136
+ def sample_euler_h_m_d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1137
+ extra_args = {} if extra_args is None else extra_args
1138
+ s_in = x.new_ones([x.shape[0]])
1139
+ for i in trange(len(sigmas) - 1, disable=disable):
1140
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1141
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1142
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1143
+ gamma = min((2 ** 0.5 - 1) - wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1144
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1145
+ sigma_hat = sigmas[i] * (gamma + 1)
1146
+ if gamma > 0:
1147
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1148
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1149
+ d = to_d(x, sigma_hat, denoised)
1150
+ dt = sigmas[i + 1] - sigma_hat
1151
+ if callback is not None:
1152
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1153
+ if sigmas[i + 1] > 0:
1154
+ x_2 = x + (gamma + 1) * d * dt
1155
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1156
+ d_prime = d * 0.5 + d_2 * 0.5
1157
+ x = x + d_prime * dt
1158
+ else:
1159
+ # Euler method
1160
+ x = x + (gamma + 1) * d * dt
1161
+ return x
1162
+
1163
+ @torch.no_grad()
1164
+ def sample_euler_h_m_e(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1165
+ extra_args = {} if extra_args is None else extra_args
1166
+ s_in = x.new_ones([x.shape[0]])
1167
+ for i in trange(len(sigmas) - 1, disable=disable):
1168
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1169
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1170
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1171
+ gamma = max((2 ** 0.5 - 1) + wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1172
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1173
+ sigma_hat = sigmas[i] * (gamma + 1)
1174
+ if gamma > 0:
1175
+ x = x - eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1176
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1177
+ d = to_d(x, sigma_hat, denoised)
1178
+ dt = sigmas[i + 1] - sigma_hat
1179
+ if callback is not None:
1180
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1181
+ if sigmas[i + 1] > 0:
1182
+ x_2 = x + (gamma + 1) * d * dt
1183
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1184
+ d_prime = d * 0.5 + d_2 * 0.5
1185
+ x = x + d_prime * dt
1186
+ else:
1187
+ # Euler method
1188
+ x = x + (gamma + 1) * d * dt
1189
+ return x
1190
+
1191
+ @torch.no_grad()
1192
+ def sample_euler_h_m_f(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1193
+ extra_args = {} if extra_args is None else extra_args
1194
+ s_in = x.new_ones([x.shape[0]])
1195
+ for i in trange(len(sigmas) - 1, disable=disable):
1196
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1197
+ wave_max = math.cos(0)/1.5 + 1
1198
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1199
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1200
+ gamma = min((wave_max - wave) * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1201
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1202
+ sigma_hat = sigmas[i] * (gamma + 1)
1203
+ if gamma > 0:
1204
+ x = x - eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1205
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1206
+ d = to_d(x, sigma_hat, denoised)
1207
+ dt = sigmas[i + 1] - sigma_hat
1208
+ if callback is not None:
1209
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1210
+ if sigmas[i + 1] > 0:
1211
+ x_2 = x + (gamma + 1) * d * dt
1212
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1213
+ d_prime = d * 0.5 + d_2 * 0.5
1214
+ x = x + d_prime * dt
1215
+ else:
1216
+ # Euler method
1217
+ x = x + (gamma + 1) * d * dt
1218
+ return x
1219
+
1220
+ @torch.no_grad()
1221
+ def sample_euler_h_m_g(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1222
+ extra_args = {} if extra_args is None else extra_args
1223
+ s_in = x.new_ones([x.shape[0]])
1224
+ for i in trange(len(sigmas) - 1, disable=disable):
1225
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1226
+ wave_max = math.cos(0)/1.5 + 1
1227
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1228
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1229
+ gamma = min((wave_max - wave) * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1230
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1231
+ sigma_hat = sigmas[i] * (gamma + 1)
1232
+ if gamma > 0:
1233
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1234
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1235
+ d = to_d(x, sigma_hat, denoised)
1236
+ dt = sigmas[i + 1] - sigma_hat
1237
+ if callback is not None:
1238
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1239
+ if sigmas[i + 1] > 0:
1240
+ x_2 = x + (gamma + 1) * d * dt
1241
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1242
+ d_prime = d * 0.5 + d_2 * 0.5
1243
+ x = x + d_prime * dt
1244
+ else:
1245
+ # Euler method
1246
+ x = x + (gamma + 1) * d * dt
1247
+ return x
1248
+
1249
+ @torch.no_grad()
1250
+ def sample_euler_h_m_b_c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1251
+ extra_args = {} if extra_args is None else extra_args
1252
+ s_in = x.new_ones([x.shape[0]])
1253
+ for i in trange(len(sigmas) - 1, disable=disable):
1254
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1255
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1256
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1257
+ gamma = min(wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1258
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1259
+ gammaup = gamma + 1
1260
+ sigma_hat = sigmas[i] * gammaup
1261
+ if gamma > 0:
1262
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1263
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1264
+ last_noise_uncond = model.last_noise_uncond
1265
+ d = to_d(x, sigma_hat, denoised)
1266
+ dt = sigmas[i + 1] - sigma_hat
1267
+ if callback is not None:
1268
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1269
+ if i == 0:
1270
+ x = x + gammaup * d * dt
1271
+ elif i <= len(sigmas) - 4:
1272
+ x_2 = x + gammaup * d * dt
1273
+ d_2 = to_d(x_2, sigmas[i + 1] * gammaup, denoised)
1274
+ x_3 = x_2 + gammaup * d_2 * dt
1275
+ d_3 = to_d(x_3, sigmas[i + 2] * gammaup, denoised)
1276
+ d_prime = d * 0.5 + d_2 * 0.375 + d_3 * 0.125
1277
+ x = x + d_prime * dt
1278
+ elif sigmas[i + 1] > 0:
1279
+ x_2 = x + gammaup * d * dt
1280
+ d_2 = to_d(x_2, sigmas[i + 1] * gammaup, denoised)
1281
+ d_prime = d * 0.5 + d_2 * 0.5
1282
+ x = x + d_prime * dt
1283
+ else:
1284
+ # Euler method
1285
+ x = x + gammaup * d * dt
1286
+ return x
1287
+
1288
+ @torch.no_grad()
1289
+ def sample_euler_h_m_b_c_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1290
+ extra_args = {} if extra_args is None else extra_args
1291
+ s_in = x.new_ones([x.shape[0]])
1292
+ for i in trange(len(sigmas) - 1, disable=disable):
1293
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1294
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1295
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1296
+ gamma = min(wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1297
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1298
+ gammaup = gamma + 1
1299
+ sigma_hat = sigmas[i] * gammaup
1300
+ if gamma > 0:
1301
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1302
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1303
+ last_noise_uncond = model.last_noise_uncond
1304
+ d = to_d(x, sigma_hat, denoised)
1305
+ dt = sigmas[i + 1] - sigma_hat
1306
+ if callback is not None:
1307
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1308
+ if i == 0:
1309
+ x = x + gammaup * d * dt
1310
+ elif i <= len(sigmas) - 4:
1311
+ x_2 = x + gammaup * d * dt
1312
+ d_2 = to_d(x_2, sigmas[i + 1] * gammaup, denoised)
1313
+ x_3 = x_2 + gammaup * d_2 * dt
1314
+ d_3 = to_d(x_3, sigmas[i + 2] * gammaup, last_noise_uncond)
1315
+ d_prime = d * 0.5 + d_2 * 0.375 + d_3 * 0.125
1316
+ x = x + d_prime * dt
1317
+ elif sigmas[i + 1] > 0:
1318
+ x_2 = x + gammaup * d * dt
1319
+ d_2 = to_d(x_2, sigmas[i + 1] * gammaup, denoised)
1320
+ d_prime = d * 0.5 + d_2 * 0.5
1321
+ x = x + d_prime * dt
1322
+ else:
1323
+ # Euler method
1324
+ x = x + gammaup * d * dt
1325
+ return x
1326
+
1327
+ @torch.no_grad()
1328
+ def sample_euler_smea_max(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., smooth=False):
1329
+ extra_args = {} if extra_args is None else extra_args
1330
+ s_in = x.new_ones([x.shape[0]])
1331
+ for i in trange(len(sigmas) - 1, disable=disable):
1332
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1333
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1334
+ sigma_hat = sigmas[i] * (gamma + 1)
1335
+ sa = math.cos(i + 1)/(1.5 * i + 1.75) + 1
1336
+ if gamma > 0:
1337
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1338
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1339
+ d = to_d(x, sigma_hat, denoised)
1340
+ if callback is not None:
1341
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1342
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167 + 1: # and i % 2 == 0:
1343
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1344
+ dt_1 = sigma_mid - sigma_hat
1345
+ dt_2 = sigmas[i + 1] - sigma_hat
1346
+ x_2 = x + d * dt_1
1347
+ sigA = sigma_mid / (sa ** 2)
1348
+ sigB = sigma_mid
1349
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
1350
+ denoised_2b = model(x_2, sigma_mid * s_in, **extra_args)
1351
+ denoised_2 = (denoised_2a * 0.5 * (sa ** 2) + denoised_2b * 0.5 / (sa ** 2))
1352
+ d_2 = to_d(x_2, sigA * 0.5 * (sa ** 2) + sigB * 0.5 / (sa ** 2), denoised_2)
1353
+ x = x + d_2 * dt_2
1354
+ else:
1355
+ dt = sigmas[i + 1] - sigma_hat
1356
+ # Euler method
1357
+ x = x + sa * d * dt
1358
+ return x
1359
+
1360
+
1361
+ @torch.no_grad()
1362
+ def sample_euler_smea_max_s(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1363
+ sample = sample_euler_smea_max(model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise, smooth=True)
1364
+ return sample
1365
+
1366
+ @torch.no_grad()
1367
+ def sample_euler_smea_multi_bs(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1368
+ extra_args = {} if extra_args is None else extra_args
1369
+ s_in = x.new_ones([x.shape[0]])
1370
+ for i in trange(len(sigmas) - 1, disable=disable):
1371
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1372
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1373
+ sigma_hat = sigmas[i] * (gamma + 1)
1374
+ if gamma > 0:
1375
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1376
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1377
+ d = to_d(x, sigma_hat, denoised)
1378
+ if callback is not None:
1379
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1380
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
1381
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1382
+ dt_1 = sigma_mid - sigma_hat
1383
+ dt_2 = sigmas[i + 1] - sigma_hat
1384
+ x_2 = x + d * dt_1
1385
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
1386
+ sa = 1 - scale * 0.25
1387
+ sb = 1 + scale * 0.15
1388
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
1389
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, sb, **extra_args)
1390
+ denoised_2 = denoised_2a * (sa ** 2) * 0.625 + denoised_2b * (sb ** 2) * 0.375 / (0.95**2)
1391
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
1392
+ x = x + d_2 * dt_2
1393
+ else:
1394
+ dt = sigmas[i + 1] - sigma_hat
1395
+ # Euler method
1396
+ x = x + d * dt
1397
+ return x
1398
+
1399
+ @torch.no_grad()
1400
+ def sample_euler_smea_multi_bs2_s(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1401
+ sample = sample_euler_smea_multi_bs2(model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise, smooth=True)
1402
+ return sample
1403
+
1404
+ @torch.no_grad()
1405
+ def sample_euler_smea_multi_bs2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., smooth=False):
1406
+ extra_args = {} if extra_args is None else extra_args
1407
+ s_in = x.new_ones([x.shape[0]])
1408
+ for i in trange(len(sigmas) - 1, disable=disable):
1409
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1410
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1411
+ sigma_hat = sigmas[i] * (gamma + 1)
1412
+ if gamma > 0:
1413
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1414
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1415
+ d = to_d(x, sigma_hat, denoised)
1416
+ if callback is not None:
1417
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1418
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
1419
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1420
+ dt_1 = sigma_mid - sigma_hat
1421
+ dt_2 = sigmas[i + 1] - sigma_hat
1422
+ x_2 = x + d * dt_1
1423
+ scale = (sigmas[i] / sigmas[0]) ** 2
1424
+ scale = scale.item()
1425
+ sa = 1 - scale * 0.25
1426
+ sb = 1 + scale * 0.15
1427
+ sigA = sigma_mid / (sa ** 2)
1428
+ sigB = sigma_mid / (sb ** 2)
1429
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
1430
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
1431
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2)
1432
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
1433
+ x = x + d_2 * dt_2
1434
+ else:
1435
+ dt = sigmas[i + 1] - sigma_hat
1436
+ # Euler method
1437
+ x = x + d * dt
1438
+ return x
1439
+
1440
+ @torch.no_grad()
1441
+ def sample_euler_smea_multi_cs(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1442
+ extra_args = {} if extra_args is None else extra_args
1443
+ s_in = x.new_ones([x.shape[0]])
1444
+ for i in trange(len(sigmas) - 1, disable=disable):
1445
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1446
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1447
+ sigma_hat = sigmas[i] * (gamma + 1)
1448
+ if gamma > 0:
1449
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1450
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1451
+ d = to_d(x, sigma_hat, denoised)
1452
+ if callback is not None:
1453
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1454
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
1455
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1456
+ dt_1 = sigma_mid - sigma_hat
1457
+ dt_2 = sigmas[i + 1] - sigma_hat
1458
+ x_2 = x + d * dt_1
1459
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
1460
+ sa = 1 - scale * 0.25
1461
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
1462
+ d_2 = to_d(x_2, sigma_mid, denoised_2 * (sa ** 2) * 1.25)
1463
+ x = x + d_2 * dt_2
1464
+ else:
1465
+ dt = sigmas[i + 1] - sigma_hat
1466
+ # Euler method
1467
+ x = x + d * dt
1468
+ return x
1469
+
1470
+ @torch.no_grad()
1471
+ def sample_euler_smea_multi_as(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1472
+ extra_args = {} if extra_args is None else extra_args
1473
+ s_in = x.new_ones([x.shape[0]])
1474
+ for i in trange(len(sigmas) - 1, disable=disable):
1475
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1476
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1477
+ sigma_hat = sigmas[i] * (gamma + 1)
1478
+ if gamma > 0:
1479
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1480
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1481
+ d = to_d(x, sigma_hat, denoised)
1482
+ if callback is not None:
1483
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1484
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
1485
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1486
+ dt_1 = sigma_mid - sigma_hat
1487
+ dt_2 = sigmas[i + 1] - sigma_hat
1488
+ x_2 = x + d * dt_1
1489
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
1490
+ sa = 1 + scale * 0.15
1491
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
1492
+ d_2 = to_d(x_2, sigma_mid, denoised_2 * (sa ** 2) * 0.75)
1493
+ x = x + d_2 * dt_2
1494
+ else:
1495
+ dt = sigmas[i + 1] - sigma_hat
1496
+ # Euler method
1497
+ x = x + d * dt
1498
+ return x
1499
+
1500
+ ## og sampler
1501
+ @torch.no_grad()
1502
+ def sample_euler_dy_og(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1503
+ extra_args = {} if extra_args is None else extra_args
1504
+ s_in = x.new_ones([x.shape[0]])
1505
+ for i in trange(len(sigmas) - 1, disable=disable):
1506
+ # print(i)
1507
+ # i绗竴姝ヤ负0
1508
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1509
+ eps = torch.randn_like(x) * s_noise
1510
+ sigma_hat = sigmas[i] * (gamma + 1)
1511
+ # print(sigma_hat)
1512
+ dt = sigmas[i + 1] - sigma_hat
1513
+ if gamma > 0:
1514
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1515
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1516
+ d = sampling.to_d(x, sigma_hat, denoised)
1517
+ if sigmas[i + 1] > 0:
1518
+ if i // 2 == 1:
1519
+ x = dy_sampling_step(x, model, dt, sigma_hat, **extra_args)
1520
+ if callback is not None:
1521
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1522
+ # Euler method
1523
+ x = x + d * dt
1524
+ return x
1525
+
1526
+ @torch.no_grad()
1527
+ def sample_euler_smea_dy_og(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1528
+ extra_args = {} if extra_args is None else extra_args
1529
+ s_in = x.new_ones([x.shape[0]])
1530
+ for i in trange(len(sigmas) - 1, disable=disable):
1531
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1532
+ eps = torch.randn_like(x) * s_noise
1533
+ sigma_hat = sigmas[i] * (gamma + 1)
1534
+ dt = sigmas[i + 1] - sigma_hat
1535
+ if gamma > 0:
1536
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1537
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1538
+ d = sampling.to_d(x, sigma_hat, denoised)
1539
+ # Euler method
1540
+ x = x + d * dt
1541
+ if sigmas[i + 1] > 0:
1542
+ if i + 1 // 2 == 1:
1543
+ x = dy_sampling_step(x, model, dt, sigma_hat, **extra_args)
1544
+ if i + 1 // 2 == 0:
1545
+ x = smea_sampling_step(x, model, dt, sigma_hat, **extra_args)
1546
+ if callback is not None:
1547
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1548
+ return x
1549
+
1550
+ ## TCD
1551
+
1552
+ def sample_tcd_euler_a(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, gamma=0.3):
1553
+ # TCD sampling using modified Euler Ancestral sampler. by @laksjdjf
1554
+ extra_args = {} if extra_args is None else extra_args
1555
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
1556
+ s_in = x.new_ones([x.shape[0]])
1557
+ for i in trange(len(sigmas) - 1, disable=disable):
1558
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1559
+ if callback is not None:
1560
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1561
+
1562
+ #d = to_d(x, sigmas[i], denoised)
1563
+ sigma_from = sigmas[i]
1564
+ sigma_to = sigmas[i + 1]
1565
+
1566
+ t = model.inner_model.sigma_to_t(sigma_from)
1567
+ down_t = (1 - gamma) * t
1568
+ sigma_down = model.inner_model.t_to_sigma(down_t)
1569
+
1570
+ if sigma_down > sigma_to:
1571
+ sigma_down = sigma_to
1572
+ sigma_up = (sigma_to ** 2 - sigma_down ** 2) ** 0.5
1573
+
1574
+ # same as euler ancestral
1575
+ d = to_d(x, sigma_from, denoised)
1576
+ dt = sigma_down - sigma_from
1577
+ x += d * dt
1578
+
1579
+ if sigma_to > 0 and gamma > 0:
1580
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigma_up
1581
+ return x
1582
+
1583
+ @torch.no_grad()
1584
+ def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, gamma=0.3):
1585
+ # TCD sampling using modified DDPM.
1586
+ extra_args = {} if extra_args is None else extra_args
1587
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
1588
+ s_in = x.new_ones([x.shape[0]])
1589
+
1590
+ for i in trange(len(sigmas) - 1, disable=disable):
1591
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1592
+ if callback is not None:
1593
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1594
+
1595
+ sigma_from, sigma_to = sigmas[i], sigmas[i+1]
1596
+
1597
+ # TCD offset, based on gamma, and conversion between sigma and timestep
1598
+ t = model.inner_model.sigma_to_t(sigma_from)
1599
+ t_s = (1 - gamma) * t
1600
+ sigma_to_s = model.inner_model.t_to_sigma(t_s)
1601
+
1602
+ # if sigma_to_s > sigma_to:
1603
+ # sigma_to_s = sigma_to
1604
+ # if sigma_to_s < 0:
1605
+ # sigma_to_s = torch.tensor(1.0)
1606
+ #print(f"sigma_from: {sigma_from}, sigma_to: {sigma_to}, sigma_to_s: {sigma_to_s}")
1607
+
1608
+
1609
+ # The following is equivalent to the comfy DDPM implementation
1610
+ # x = DDPMSampler_step(x / torch.sqrt(1.0 + sigma_from ** 2.0), sigma_from, sigma_to, (x - denoised) / sigma_from, noise_sampler)
1611
+
1612
+ noise_est = (x - denoised) / sigma_from
1613
+ x /= torch.sqrt(1.0 + sigma_from ** 2.0)
1614
+
1615
+ alpha_cumprod = 1 / ((sigma_from * sigma_from) + 1) # _t
1616
+ alpha_cumprod_prev = 1 / ((sigma_to * sigma_to) + 1) # _t_prev
1617
+ alpha = (alpha_cumprod / alpha_cumprod_prev)
1618
+
1619
+ ## These values should approach 1.0?
1620
+ # print(f"alpha_cumprod: {alpha_cumprod}")
1621
+ # print(f"alpha_cumprod_prev: {alpha_cumprod_prev}")
1622
+ # print(f"alpha: {alpha}")
1623
+
1624
+
1625
+ # alpha_cumprod_down = 1 / ((sigma_to_s * sigma_to_s) + 1) # _s
1626
+ # alpha_d = (alpha_cumprod_prev / alpha_cumprod_down)
1627
+ # alpha2 = (alpha_cumprod / alpha_cumprod_down)
1628
+ # print(f"** alpha_cumprod_down: {alpha_cumprod_down}")
1629
+ # print(f"** alpha_d: {alpha_d}, alpha2: #{alpha2}")
1630
+
1631
+ # epsilon noise prediction from comfy DDPM implementation
1632
+ x = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise_est / (1 - alpha_cumprod).sqrt())
1633
+ # x = (1.0 / alpha_d).sqrt() * (x - (1 - alpha) * noise_est / (1 - alpha_cumprod).sqrt())
1634
+
1635
+ first_step = sigma_to == 0
1636
+ last_step = i == len(sigmas) - 2
1637
+
1638
+ if not first_step:
1639
+ if gamma > 0 and not last_step:
1640
+ noise = noise_sampler(sigma_from, sigma_to)
1641
+
1642
+ # x += ((1 - alpha_d) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise
1643
+ variance = ((1 - alpha_cumprod_prev) / (1 - alpha_cumprod)) * (1 - alpha_cumprod / alpha_cumprod_prev)
1644
+ x += variance.sqrt() * noise # scale noise by std deviation
1645
+
1646
+ # relevant diffusers code from scheduling_tcd.py
1647
+ # prev_sample = (alpha_prod_t_prev / alpha_prod_s).sqrt() * pred_noised_sample + (
1648
+ # 1 - alpha_prod_t_prev / alpha_prod_s
1649
+ # ).sqrt() * noise
1650
+
1651
+ x *= torch.sqrt(1.0 + sigma_to ** 2.0)
1652
+
1653
+ # beta_cumprod_t = 1 - alpha_cumprod
1654
+ # beta_cumprod_s = 1 - alpha_cumprod_down
1655
+
1656
+
1657
+ return x
sd-webui-smea/sd-webui-smea-chanhe.py ADDED
@@ -0,0 +1,1657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import k_diffusion.sampling
4
+
5
+ from k_diffusion.sampling import to_d, BrownianTreeNoiseSampler
6
+ from tqdm.auto import trange
7
+ from modules import scripts
8
+ from modules import sd_samplers_kdiffusion, sd_samplers_common, sd_samplers
9
+ from modules.sd_samplers_kdiffusion import KDiffusionSampler
10
+
11
+ class _Rescaler:
12
+ def __init__(self, model, x, mode, **extra_args):
13
+ self.model = model
14
+ self.x = x
15
+ self.mode = mode
16
+ self.extra_args = extra_args
17
+ self.init_latent, self.mask, self.nmask = model.init_latent, model.mask, model.nmask
18
+
19
+ def __enter__(self):
20
+ if self.init_latent is not None:
21
+ self.model.init_latent = torch.nn.functional.interpolate(input=self.init_latent, size=self.x.shape[2:4], mode=self.mode)
22
+ if self.mask is not None:
23
+ self.model.mask = torch.nn.functional.interpolate(input=self.mask.unsqueeze(0), size=self.x.shape[2:4], mode=self.mode).squeeze(0)
24
+ if self.nmask is not None:
25
+ self.model.nmask = torch.nn.functional.interpolate(input=self.nmask.unsqueeze(0), size=self.x.shape[2:4], mode=self.mode).squeeze(0)
26
+ return self
27
+
28
+ def __exit__(self, type, value, traceback):
29
+ del self.model.init_latent, self.model.mask, self.model.nmask
30
+ self.model.init_latent, self.model.mask, self.model.nmask = self.init_latent, self.mask, self.nmask
31
+
32
+ class Smea(scripts.Script):
33
+
34
+ def title(self):
35
+ return "Euler Smea Dy sampler"
36
+
37
+ def show(self, is_img2img):
38
+ return scripts.AlwaysVisible
39
+
40
+ def __init__(self):
41
+ init()
42
+ return
43
+
44
+ def init():
45
+ for i in sd_samplers.all_samplers:
46
+ if "Euler Max" in i.name:
47
+ return
48
+
49
+ samplers_smea = [
50
+ ('Euler Max', sample_euler_max, ['k_euler'], {}),
51
+ ('Euler Max1b', sample_euler_max1b, ['k_euler'], {}),
52
+ ('Euler Max1c', sample_euler_max1c, ['k_euler'], {}),
53
+ ('Euler Max1d', sample_euler_max1d, ['k_euler'], {}),
54
+ ('Euler Max2', sample_euler_max2, ['k_euler'], {}),
55
+ ('Euler Max2b', sample_euler_max2b, ['k_euler'], {}),
56
+ ('Euler Max2c', sample_euler_max2c, ['k_euler'], {}),
57
+ ('Euler Max2d', sample_euler_max2d, ['k_euler'], {}),
58
+ ('Euler Max3', sample_euler_max3, ['k_euler'], {}),
59
+ ('Euler Max3b', sample_euler_max3b, ['k_euler'], {}),
60
+ ('Euler Max3c', sample_euler_max3c, ['k_euler'], {}),
61
+ ('Euler Max4', sample_euler_max4, ['k_euler'], {}),
62
+ ('Euler Max4b', sample_euler_max4b, ['k_euler'], {}),
63
+ ('Euler Max4c', sample_euler_max4c, ['k_euler'], {}),
64
+ ('Euler Max4d', sample_euler_max4d, ['k_euler'], {}),
65
+ ('Euler Max4e', sample_euler_max4e, ['k_euler'], {}),
66
+ ('Euler Max4f', sample_euler_max4f, ['k_euler'], {}),
67
+ ('Euler Dy', sample_euler_dy, ['k_euler'], {}),
68
+ ('Euler Smea', sample_euler_smea, ['k_euler'], {}),
69
+ ('Euler Smea Dy', sample_euler_smea_dy, ['k_euler'], {}),
70
+ ('Euler Smea Max', sample_euler_smea_max, ['k_euler'], {}),
71
+ ('Euler Smea Max s', sample_euler_smea_max_s, ['k_euler'], {}),
72
+ ('Euler Smea dyn a', sample_euler_smea_dyn_a, ['k_euler'], {}),
73
+ ('Euler Smea dyn b', sample_euler_smea_dyn_b, ['k_euler'], {}),
74
+ ('Euler Smea dyn c', sample_euler_smea_dyn_c, ['k_euler'], {}),
75
+ ('Euler Smea ma', sample_euler_smea_multi_a, ['k_euler'], {}),
76
+ ('Euler Smea mb', sample_euler_smea_multi_b, ['k_euler'], {}),
77
+ ('Euler Smea mc', sample_euler_smea_multi_c, ['k_euler'], {}),
78
+ ('Euler Smea md', sample_euler_smea_multi_d, ['k_euler'], {}),
79
+ ('Euler Smea mas', sample_euler_smea_multi_as, ['k_euler'], {}),
80
+ ('Euler Smea mbs', sample_euler_smea_multi_bs, ['k_euler'], {}),
81
+ ('Euler Smea mcs', sample_euler_smea_multi_cs, ['k_euler'], {}),
82
+ ('Euler Smea mds', sample_euler_smea_multi_ds, ['k_euler'], {}),
83
+ ('Euler Smea mbs2', sample_euler_smea_multi_bs2, ['k_euler'], {}),
84
+ ('Euler Smea mds2', sample_euler_smea_multi_ds2, ['k_euler'], {}),
85
+ ('Euler Smea mds2 max', sample_euler_smea_multi_ds2_m, ['k_euler'], {}),
86
+ ('Euler Smea mds2 s max', sample_euler_smea_multi_ds2_s_m, ['k_euler'], {}),
87
+ ('Euler Smea mbs2 s', sample_euler_smea_multi_bs2_s, ['k_euler'], {}),
88
+ ('Euler Smea mds2 s', sample_euler_smea_multi_ds2_s, ['k_euler'], {}),
89
+ ('Euler h max', sample_euler_h_m, ['k_euler'], {"brownian_noise": True}),
90
+ ('Euler h max b', sample_euler_h_m_b, ['k_euler'], {"brownian_noise": True}),
91
+ ('Euler h max c', sample_euler_h_m_c, ['k_euler'], {"brownian_noise": True}),
92
+ ('Euler h max d', sample_euler_h_m_d, ['k_euler'], {"brownian_noise": True}),
93
+ ('Euler h max e', sample_euler_h_m_e, ['k_euler'], {"brownian_noise": True}),
94
+ ('Euler h max f', sample_euler_h_m_f, ['k_euler'], {"brownian_noise": True}),
95
+ ('Euler h max g', sample_euler_h_m_g, ['k_euler'], {"brownian_noise": True}),
96
+ ('Euler h max b c', sample_euler_h_m_b_c, ['k_euler'], {"brownian_noise": True}),
97
+ ('Euler h max b c CFG++', sample_euler_h_m_b_c_pp, ['k_euler'], {"brownian_noise": True, "cfgpp": True}),
98
+ ('Euler Dy koishi-star', sample_euler_dy_og, ['k_euler'], {}),
99
+ ('Euler Smea Dy koishi-star', sample_euler_smea_dy_og, ['k_euler'], {}),
100
+ ('TCD Euler a', sample_tcd_euler_a, ['tcd_euler_a'], {}),
101
+ ('TCD', sample_tcd, ['tcd'], {}),
102
+ ]
103
+
104
+ samplers_data_smea = [
105
+ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
106
+ for label, funcname, aliases, options in samplers_smea
107
+ if callable(funcname)
108
+ ]
109
+
110
+ sampler_exparams_smea = {
111
+ sample_euler_max: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
112
+ sample_euler_max1b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
113
+ sample_euler_max1c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
114
+ sample_euler_max1d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
115
+ sample_euler_max2: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
116
+ sample_euler_max2b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
117
+ sample_euler_max2c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
118
+ sample_euler_max2d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
119
+ sample_euler_max3: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
120
+ sample_euler_max3b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
121
+ sample_euler_max3c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
122
+ sample_euler_max4: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
123
+ sample_euler_max4b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
124
+ sample_euler_max4c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
125
+ sample_euler_max4d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
126
+ sample_euler_max4e: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
127
+ sample_euler_max4f: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
128
+ sample_euler_dy: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
129
+ sample_euler_smea: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
130
+ sample_euler_smea_dy: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
131
+ sample_euler_smea_max: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
132
+ sample_euler_smea_max_s: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
133
+ sample_euler_smea_dyn_a: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
134
+ sample_euler_smea_dyn_b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
135
+ sample_euler_smea_dyn_c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
136
+ sample_euler_smea_multi_a: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
137
+ sample_euler_smea_multi_b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
138
+ sample_euler_smea_multi_c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
139
+ sample_euler_smea_multi_d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
140
+ sample_euler_smea_multi_as: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
141
+ sample_euler_smea_multi_bs: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
142
+ sample_euler_smea_multi_cs: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
143
+ sample_euler_smea_multi_ds: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
144
+ sample_euler_smea_multi_bs2: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
145
+ sample_euler_smea_multi_ds2: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
146
+ sample_euler_smea_multi_ds2_m: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
147
+ sample_euler_smea_multi_ds2_s_m: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
148
+ sample_euler_smea_multi_bs2_s: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
149
+ sample_euler_smea_multi_ds2_s: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
150
+ sample_euler_h_m: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
151
+ sample_euler_h_m_b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
152
+ sample_euler_h_m_c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
153
+ sample_euler_h_m_d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
154
+ sample_euler_h_m_e: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
155
+ sample_euler_h_m_f: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
156
+ sample_euler_h_m_g: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
157
+ sample_euler_h_m_b_c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
158
+ sample_euler_h_m_b_c_pp: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
159
+ sample_euler_dy_og: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
160
+ sample_euler_smea_dy_og: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
161
+ }
162
+ sd_samplers_kdiffusion.sampler_extra_params = {**sd_samplers_kdiffusion.sampler_extra_params, **sampler_exparams_smea}
163
+
164
+ samplers_map_smea = {x.name: x for x in samplers_data_smea}
165
+ sd_samplers_kdiffusion.k_diffusion_samplers_map = {**sd_samplers_kdiffusion.k_diffusion_samplers_map, **samplers_map_smea}
166
+
167
+ for i, item in enumerate(sd_samplers.all_samplers):
168
+ if "Euler" in item.name:
169
+ sd_samplers.all_samplers = sd_samplers.all_samplers[:i + 1] + [*samplers_data_smea] + sd_samplers.all_samplers[i + 1:]
170
+ break
171
+ sd_samplers.all_samplers_map = {x.name: x for x in sd_samplers.all_samplers}
172
+ sd_samplers.set_samplers()
173
+
174
+ return
175
+
176
+ def default_noise_sampler(x):
177
+ return lambda sigma, sigma_next: k_diffusion.sampling.torch.randn_like(x)
178
+
179
+ @torch.no_grad()
180
+ def dy_sampling_step(x, model, dt, sigma_hat, **extra_args):
181
+ original_shape = x.shape
182
+ batch_size, channels, m, n = original_shape[0], original_shape[1], original_shape[2] // 2, original_shape[3] // 2
183
+ extra_row = x.shape[2] % 2 == 1
184
+ extra_col = x.shape[3] % 2 == 1
185
+
186
+ if extra_row:
187
+ extra_row_content = x[:, :, -1:, :]
188
+ x = x[:, :, :-1, :]
189
+ if extra_col:
190
+ extra_col_content = x[:, :, :, -1:]
191
+ x = x[:, :, :, :-1]
192
+
193
+ a_list = x.unfold(2, 2, 2).unfold(3, 2, 2).contiguous().view(batch_size, channels, m * n, 2, 2)
194
+ c = a_list[:, :, :, 1, 1].view(batch_size, channels, m, n)
195
+
196
+ with _Rescaler(model, c, 'nearest-exact', **extra_args) as rescaler:
197
+ denoised = model(c, sigma_hat * c.new_ones([c.shape[0]]), **rescaler.extra_args)
198
+ d = to_d(c, sigma_hat, denoised)
199
+ c = c + d * dt
200
+
201
+ d_list = c.view(batch_size, channels, m * n, 1, 1)
202
+ a_list[:, :, :, 1, 1] = d_list[:, :, :, 0, 0]
203
+ x = a_list.view(batch_size, channels, m, n, 2, 2).permute(0, 1, 2, 4, 3, 5).reshape(batch_size, channels, 2 * m, 2 * n)
204
+
205
+ if extra_row or extra_col:
206
+ x_expanded = torch.zeros(original_shape, dtype=x.dtype, device=x.device)
207
+ x_expanded[:, :, :2 * m, :2 * n] = x
208
+ if extra_row:
209
+ x_expanded[:, :, -1:, :2 * n + 1] = extra_row_content
210
+ if extra_col:
211
+ x_expanded[:, :, :2 * m, -1:] = extra_col_content
212
+ if extra_row and extra_col:
213
+ x_expanded[:, :, -1:, -1:] = extra_col_content[:, :, -1:, :]
214
+ x = x_expanded
215
+
216
+ return x
217
+
218
+ @torch.no_grad()
219
+ def smea_sampling_step(x, model, dt, sigma_hat, **extra_args):
220
+ m, n = x.shape[2], x.shape[3]
221
+ x = torch.nn.functional.interpolate(input=x, size=None, scale_factor=(1.25, 1.25), mode='nearest-exact', align_corners=None, recompute_scale_factor=None)
222
+ with _Rescaler(model, x, 'nearest-exact', **extra_args) as rescaler:
223
+ denoised = model(x, sigma_hat * x.new_ones([x.shape[0]]), **rescaler.extra_args)
224
+ d = to_d(x, sigma_hat, denoised)
225
+ x = x + d * dt
226
+ x = torch.nn.functional.interpolate(input=x, size=(m,n), scale_factor=None, mode='nearest-exact', align_corners=None, recompute_scale_factor=None)
227
+ return x
228
+
229
+ @torch.no_grad()
230
+ def smea_sampling_step_denoised(x, model, sigma_hat, scale=1.25, smooth=False, **extra_args):
231
+ m, n = x.shape[2], x.shape[3]
232
+ filter = 'nearest-exact' if not smooth else 'bilinear'
233
+ x = torch.nn.functional.interpolate(input=x, scale_factor=(scale, scale), mode=filter)
234
+ with _Rescaler(model, x, filter, **extra_args) as rescaler:
235
+ denoised = model(x, sigma_hat * x.new_ones([x.shape[0]]), **rescaler.extra_args)
236
+ x = denoised
237
+ x = torch.nn.functional.interpolate(input=x, size=(m,n), mode='nearest-exact')
238
+ return x
239
+
240
+ @torch.no_grad()
241
+ def sample_euler_max(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
242
+ extra_args = {} if extra_args is None else extra_args
243
+ s_in = x.new_ones([x.shape[0]])
244
+ for i in trange(len(sigmas) - 1, disable=disable):
245
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
246
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
247
+ sigma_hat = sigmas[i] * (gamma + 1)
248
+ if gamma > 0:
249
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
250
+ denoised = model(x, sigma_hat * s_in, **extra_args)
251
+ d = to_d(x, sigma_hat, denoised)
252
+ if callback is not None:
253
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
254
+ dt = sigmas[i + 1] - sigma_hat
255
+ # Euler method
256
+ x = x + (math.cos(i + 1)/(i + 1) + 1) * d * dt
257
+ return x
258
+
259
+
260
+ @torch.no_grad()
261
+ def sample_euler_max1b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
262
+ extra_args = {} if extra_args is None else extra_args
263
+ s_in = x.new_ones([x.shape[0]])
264
+ for i in trange(len(sigmas) - 1, disable=disable):
265
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
266
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
267
+ sigma_hat = sigmas[i] * (gamma + 1)
268
+ if gamma > 0:
269
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
270
+ denoised = model(x, sigma_hat * s_in, **extra_args)
271
+ d = to_d(x, sigma_hat, denoised)
272
+ if callback is not None:
273
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
274
+ dt = sigmas[i + 1] - sigma_hat
275
+ # Euler method
276
+ x = x + (math.cos(1.05 * i + 1)/(1.1 * i + 1.5) + 1) * d * dt
277
+ return x
278
+
279
+ @torch.no_grad()
280
+ def sample_euler_max1c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
281
+ extra_args = {} if extra_args is None else extra_args
282
+ s_in = x.new_ones([x.shape[0]])
283
+ for i in trange(len(sigmas) - 1, disable=disable):
284
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
285
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
286
+ sigma_hat = sigmas[i] * (gamma + 1)
287
+ if gamma > 0:
288
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
289
+ denoised = model(x, sigma_hat * s_in, **extra_args)
290
+ d = to_d(x, sigma_hat, denoised)
291
+ if callback is not None:
292
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
293
+ dt = sigmas[i + 1] - sigma_hat
294
+ # Euler method
295
+ x = x + (math.cos(1.05 * i + 1.1)/(1.25 * i + 1.5) + 1) * d * dt
296
+ return x
297
+
298
+ @torch.no_grad()
299
+ def sample_euler_max1d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
300
+ extra_args = {} if extra_args is None else extra_args
301
+ s_in = x.new_ones([x.shape[0]])
302
+ for i in trange(len(sigmas) - 1, disable=disable):
303
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
304
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
305
+ sigma_hat = sigmas[i] * (gamma + 1)
306
+ if gamma > 0:
307
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
308
+ denoised = model(x, sigma_hat * s_in, **extra_args)
309
+ d = to_d(x, sigma_hat, denoised)
310
+ if callback is not None:
311
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
312
+ dt = sigmas[i + 1] - sigma_hat
313
+ # Euler method
314
+ x = x + (math.cos(math.pi * 0.333 * i + 0.9)/(0.5 * i + 1.5) + 1) * d * dt
315
+ return x
316
+
317
+ @torch.no_grad()
318
+ def sample_euler_max2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
319
+ extra_args = {} if extra_args is None else extra_args
320
+ s_in = x.new_ones([x.shape[0]])
321
+ for i in trange(len(sigmas) - 1, disable=disable):
322
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
323
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
324
+ sigma_hat = sigmas[i] * (gamma + 1)
325
+ if gamma > 0:
326
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
327
+ denoised = model(x, sigma_hat * s_in, **extra_args)
328
+ d = to_d(x, sigma_hat, denoised)
329
+ if callback is not None:
330
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
331
+ dt = sigmas[i + 1] - sigma_hat
332
+ # Euler method
333
+ x = x + (math.cos(math.pi * 0.333 * i - 0.1)/(0.5 * i + 1.5) + 1) * d * dt
334
+ return x
335
+
336
+ @torch.no_grad()
337
+ def sample_euler_max2b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
338
+ extra_args = {} if extra_args is None else extra_args
339
+ s_in = x.new_ones([x.shape[0]])
340
+ for i in trange(len(sigmas) - 1, disable=disable):
341
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
342
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
343
+ sigma_hat = sigmas[i] * (gamma + 1)
344
+ if gamma > 0:
345
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
346
+ denoised = model(x, sigma_hat * s_in, **extra_args)
347
+ d = to_d(x, sigma_hat, denoised)
348
+ if callback is not None:
349
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
350
+ dt = sigmas[i + 1] - sigma_hat
351
+ # Euler method
352
+ x = x + (math.cos(math.pi * 0.5 * i - 0.0)/(0.5 * i + 1.5) + 1) * d * dt
353
+ return x
354
+
355
+ @torch.no_grad()
356
+ def sample_euler_max2c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
357
+ extra_args = {} if extra_args is None else extra_args
358
+ s_in = x.new_ones([x.shape[0]])
359
+ for i in trange(len(sigmas) - 1, disable=disable):
360
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
361
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
362
+ sigma_hat = sigmas[i] * (gamma + 1)
363
+ if gamma > 0:
364
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
365
+ denoised = model(x, sigma_hat * s_in, **extra_args)
366
+ d = to_d(x, sigma_hat, denoised)
367
+ if callback is not None:
368
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
369
+ dt = sigmas[i + 1] - sigma_hat
370
+ # Euler method
371
+ x = x + (math.cos(math.pi * 0.5 * i)/(i + 2) + 1) * d * dt
372
+ return x
373
+
374
+ @torch.no_grad()
375
+ def sample_euler_max2d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
376
+ extra_args = {} if extra_args is None else extra_args
377
+ s_in = x.new_ones([x.shape[0]])
378
+ for i in trange(len(sigmas) - 1, disable=disable):
379
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
380
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
381
+ sigma_hat = sigmas[i] * (gamma + 1)
382
+ if gamma > 0:
383
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
384
+ denoised = model(x, sigma_hat * s_in, **extra_args)
385
+ d = to_d(x, sigma_hat, denoised)
386
+ if callback is not None:
387
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
388
+ dt = sigmas[i + 1] - sigma_hat
389
+ # Euler method
390
+ x = x + (math.cos(math.pi * 0.5 * i)/(0.75 * i + 1.75) + 1) * d * dt
391
+ return x
392
+
393
+ @torch.no_grad()
394
+ def sample_euler_max3b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
395
+ extra_args = {} if extra_args is None else extra_args
396
+ s_in = x.new_ones([x.shape[0]])
397
+ for i in trange(len(sigmas) - 1, disable=disable):
398
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
399
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
400
+ sigma_hat = sigmas[i] * (gamma + 1)
401
+ if gamma > 0:
402
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
403
+ denoised = model(x, sigma_hat * s_in, **extra_args)
404
+ d = to_d(x, sigma_hat, denoised)
405
+ if callback is not None:
406
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
407
+ dt = sigmas[i + 1] - sigma_hat
408
+ # Euler method
409
+ x = x + (math.cos(2 * i + 0.5)/(2 * i + 1.5) + 1) * d * dt
410
+ return x
411
+
412
+ @torch.no_grad()
413
+ def sample_euler_max3c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
414
+ extra_args = {} if extra_args is None else extra_args
415
+ s_in = x.new_ones([x.shape[0]])
416
+ for i in trange(len(sigmas) - 1, disable=disable):
417
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
418
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
419
+ sigma_hat = sigmas[i] * (gamma + 1)
420
+ if gamma > 0:
421
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
422
+ denoised = model(x, sigma_hat * s_in, **extra_args)
423
+ d = to_d(x, sigma_hat, denoised)
424
+ if callback is not None:
425
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
426
+ dt = sigmas[i + 1] - sigma_hat
427
+ # Euler method
428
+ x = x + (math.cos(2 * i + 0.5)/(1.5 * i + 2.7) + 1) * d * dt
429
+ return x
430
+
431
+ @torch.no_grad()
432
+ def sample_euler_max3(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
433
+ extra_args = {} if extra_args is None else extra_args
434
+ s_in = x.new_ones([x.shape[0]])
435
+ for i in trange(len(sigmas) - 1, disable=disable):
436
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
437
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
438
+ sigma_hat = sigmas[i] * (gamma + 1)
439
+ if gamma > 0:
440
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
441
+ denoised = model(x, sigma_hat * s_in, **extra_args)
442
+ d = to_d(x, sigma_hat, denoised)
443
+ if callback is not None:
444
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
445
+ dt = sigmas[i + 1] - sigma_hat
446
+ # Euler method
447
+ x = x + (math.cos(2 * i + 1)/(2 * i + 1) + 1) * d * dt
448
+ return x
449
+
450
+ @torch.no_grad()
451
+ def sample_euler_max4b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
452
+ extra_args = {} if extra_args is None else extra_args
453
+ s_in = x.new_ones([x.shape[0]])
454
+ for i in trange(len(sigmas) - 1, disable=disable):
455
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
456
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
457
+ sigma_hat = sigmas[i] * (gamma + 1)
458
+ if gamma > 0:
459
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
460
+ denoised = model(x, sigma_hat * s_in, **extra_args)
461
+ d = to_d(x, sigma_hat, denoised)
462
+ if callback is not None:
463
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
464
+ dt = sigmas[i + 1] - sigma_hat
465
+ # Euler method
466
+ x = x + (math.cos(math.pi * i - 0.1)/(2 * i + 2) + 1) * d * dt
467
+ return x
468
+
469
+ @torch.no_grad()
470
+ def sample_euler_max4c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
471
+ extra_args = {} if extra_args is None else extra_args
472
+ s_in = x.new_ones([x.shape[0]])
473
+ for i in trange(len(sigmas) - 1, disable=disable):
474
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
475
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
476
+ sigma_hat = sigmas[i] * (gamma + 1)
477
+ if gamma > 0:
478
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
479
+ denoised = model(x, sigma_hat * s_in, **extra_args)
480
+ d = to_d(x, sigma_hat, denoised)
481
+ if callback is not None:
482
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
483
+ dt = sigmas[i + 1] - sigma_hat
484
+ # Euler method
485
+ x = x + (math.cos(math.pi * i - 0.1)/(2 * i + 1.5) + 1) * d * dt
486
+ return x
487
+
488
+ @torch.no_grad()
489
+ def sample_euler_max4d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
490
+ extra_args = {} if extra_args is None else extra_args
491
+ s_in = x.new_ones([x.shape[0]])
492
+ for i in trange(len(sigmas) - 1, disable=disable):
493
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
494
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
495
+ sigma_hat = sigmas[i] * (gamma + 1)
496
+ if gamma > 0:
497
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
498
+ denoised = model(x, sigma_hat * s_in, **extra_args)
499
+ d = to_d(x, sigma_hat, denoised)
500
+ if callback is not None:
501
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
502
+ dt = sigmas[i + 1] - sigma_hat
503
+ # Euler method
504
+ x = x + (math.cos(math.pi * i - 0.1)/(i + 1.5) + 1) * d * dt
505
+ return x
506
+
507
+ @torch.no_grad()
508
+ def sample_euler_max4e(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
509
+ extra_args = {} if extra_args is None else extra_args
510
+ s_in = x.new_ones([x.shape[0]])
511
+ for i in trange(len(sigmas) - 1, disable=disable):
512
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
513
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
514
+ sigma_hat = sigmas[i] * (gamma + 1)
515
+ if gamma > 0:
516
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
517
+ denoised = model(x, sigma_hat * s_in, **extra_args)
518
+ d = to_d(x, sigma_hat, denoised)
519
+ if callback is not None:
520
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
521
+ dt = sigmas[i + 1] - sigma_hat
522
+ # Euler method
523
+ x = x + (math.cos(math.pi * i - 0.1)/(i + 1) + 1) * d * dt
524
+ return x
525
+
526
+ @torch.no_grad()
527
+ def sample_euler_max4f(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
528
+ extra_args = {} if extra_args is None else extra_args
529
+ s_in = x.new_ones([x.shape[0]])
530
+ for i in trange(len(sigmas) - 1, disable=disable):
531
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
532
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
533
+ sigma_hat = sigmas[i] * (gamma + 1)
534
+ if gamma > 0:
535
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
536
+ denoised = model(x, sigma_hat * s_in, **extra_args)
537
+ d = to_d(x, sigma_hat, denoised)
538
+ if callback is not None:
539
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
540
+ dt = sigmas[i + 1] - sigma_hat
541
+ # Euler method
542
+ x = x + (math.cos(math.pi * i - 0.1)/(i + 2) + 1) * d * dt
543
+ return x
544
+
545
+
546
+ @torch.no_grad()
547
+ def sample_euler_max4(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
548
+ # 袛芯斜邪胁褜褌械 蟹写械褋褜 褌械谢芯 褎褍薪泻褑懈懈 懈谢懈 褏芯褌褟 斜褘 pass, 褔褌芯斜褘 懈蟹斜械卸邪褌褜 IndentationError
549
+ pass
550
+
551
+ @torch.no_grad()
552
+ def sample_euler_dy(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
553
+ extra_args = {} if extra_args is None else extra_args
554
+ s_in = x.new_ones([x.shape[0]])
555
+ for i in trange(len(sigmas) - 1, disable=disable):
556
+ # print(i)
557
+ # i绗竴姝ヤ负0
558
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
559
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
560
+ sigma_hat = sigmas[i] * (gamma + 1)
561
+ # print(sigma_hat)
562
+ dt = sigmas[i + 1] - sigma_hat
563
+ if gamma > 0:
564
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
565
+ denoised = model(x, sigma_hat * s_in, **extra_args)
566
+ d = to_d(x, sigma_hat, denoised)
567
+ if callback is not None:
568
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
569
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2 and i % 2 == 0:
570
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
571
+ dt_1 = sigma_mid - sigmas[i]
572
+ dt_2 = sigmas[i + 1] - sigmas[i]
573
+ x_2 = x + d * dt_1
574
+ x_temp = dy_sampling_step(x_2, model, dt_2, sigma_mid, **extra_args)
575
+ x = x_temp - d * dt_1
576
+ # Euler method
577
+ x = x + d * dt
578
+ return x
579
+
580
+ @torch.no_grad()
581
+ def sample_euler_smea_dyn_a(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
582
+ extra_args = {} if extra_args is None else extra_args
583
+ s_in = x.new_ones([x.shape[0]])
584
+ for i in trange(len(sigmas) - 1, disable=disable):
585
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
586
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
587
+ sigma_hat = sigmas[i] * (gamma + 1)
588
+ if gamma > 0:
589
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
590
+ denoised = model(x, sigma_hat * s_in, **extra_args)
591
+ d = to_d(x, sigma_hat, denoised)
592
+ if callback is not None:
593
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
594
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2:
595
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
596
+ dt_1 = sigma_mid - sigma_hat
597
+ dt_2 = sigmas[i + 1] - sigma_hat
598
+ x_2 = x + d * dt_1
599
+ #scale = (sigma_mid / sigmas[0]) * 0.25
600
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2 * 0.15
601
+ #scale = scale.item()
602
+ if i % 2 == 0:
603
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale, **extra_args)
604
+ #denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + sigma_mid.item() * 0.01, **extra_args)
605
+ else:
606
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
607
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
608
+ x = x + d_2 * dt_2
609
+ else:
610
+ dt = sigmas[i + 1] - sigma_hat
611
+ # Euler method
612
+ x = x + d * dt
613
+ return x
614
+
615
+ @torch.no_grad()
616
+ def sample_euler_smea_dyn_b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
617
+ extra_args = {} if extra_args is None else extra_args
618
+ s_in = x.new_ones([x.shape[0]])
619
+ for i in trange(len(sigmas) - 1, disable=disable):
620
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
621
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
622
+ sigma_hat = sigmas[i] * (gamma + 1)
623
+ if gamma > 0:
624
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
625
+ denoised = model(x, sigma_hat * s_in, **extra_args)
626
+ d = to_d(x, sigma_hat, denoised)
627
+ if callback is not None:
628
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
629
+ if sigmas[i + 1] > 0 and (i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 3 or i < 3):
630
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
631
+ dt_1 = sigma_mid - sigma_hat
632
+ dt_2 = sigmas[i + 1] - sigma_hat
633
+ x_2 = x + d * dt_1
634
+ #scale = (sigma_mid / sigmas[0]) * 0.25
635
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2 * 0.2
636
+ #scale = scale.item()
637
+ if i % 4 == 0:
638
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale, **extra_args)
639
+ #denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - sigma_mid.item() * 0.01, **extra_args)
640
+ elif i % 4 == 2:
641
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale, **extra_args)
642
+ #denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + sigma_mid.item() * 0.01, **extra_args)
643
+ else:
644
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
645
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
646
+ x = x + d_2 * dt_2
647
+ else:
648
+ dt = sigmas[i + 1] - sigma_hat
649
+ # Euler method
650
+ x = x + d * dt
651
+ return x
652
+
653
+ @torch.no_grad()
654
+ def sample_euler_smea_dyn_c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
655
+ extra_args = {} if extra_args is None else extra_args
656
+ s_in = x.new_ones([x.shape[0]])
657
+ for i in trange(len(sigmas) - 1, disable=disable):
658
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
659
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
660
+ sigma_hat = sigmas[i] * (gamma + 1)
661
+ if gamma > 0:
662
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
663
+ denoised = model(x, sigma_hat * s_in, **extra_args)
664
+ d = to_d(x, sigma_hat, denoised)
665
+ if callback is not None:
666
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
667
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2:
668
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
669
+ dt_1 = sigma_mid - sigma_hat
670
+ dt_2 = sigmas[i + 1] - sigma_hat
671
+ x_2 = x + d * dt_1
672
+ #scale = (sigma_mid / sigmas[0]) * 0.25
673
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2 * 0.25
674
+ #scale = scale.item()
675
+ if i % 2 == 0:
676
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale, **extra_args)
677
+ #denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + sigma_mid.item() * 0.01, **extra_args)
678
+ else:
679
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
680
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
681
+ x = x + d_2 * dt_2
682
+ else:
683
+ dt = sigmas[i + 1] - sigma_hat
684
+ # Euler method
685
+ x = x + d * dt
686
+ return x
687
+
688
+ @torch.no_grad()
689
+ def sample_euler_smea(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
690
+ extra_args = {} if extra_args is None else extra_args
691
+ s_in = x.new_ones([x.shape[0]])
692
+ for i in trange(len(sigmas) - 1, disable=disable):
693
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
694
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
695
+ sigma_hat = sigmas[i] * (gamma + 1)
696
+ if gamma > 0:
697
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
698
+ denoised = model(x, sigma_hat * s_in, **extra_args)
699
+ d = to_d(x, sigma_hat, denoised)
700
+ if callback is not None:
701
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
702
+ dt = sigmas[i + 1] - sigma_hat
703
+ # Euler method
704
+ x = x + d * dt
705
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2 and i % 2 == 0:
706
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
707
+ dt_1 = sigma_mid - sigmas[i]
708
+ dt_2 = sigmas[i + 1] - sigmas[i]
709
+ #print(dt_1, "#", dt_2, "#", dt_3, "#", dt_4)
710
+ x_2 = x + d * dt_1
711
+ x_temp = smea_sampling_step(x, model, dt_2, sigma_mid, **extra_args)
712
+ x = x_temp - d * dt_1
713
+ return x
714
+
715
+ @torch.no_grad()
716
+ def sample_euler_smea_dy(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
717
+ extra_args = {} if extra_args is None else extra_args
718
+ s_in = x.new_ones([x.shape[0]])
719
+ for i in trange(len(sigmas) - 1, disable=disable):
720
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
721
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
722
+ sigma_hat = sigmas[i] * (gamma + 1)
723
+ if gamma > 0:
724
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
725
+ denoised = model(x, sigma_hat * s_in, **extra_args)
726
+ d = to_d(x, sigma_hat, denoised)
727
+ if callback is not None:
728
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
729
+ dt = sigmas[i + 1] - sigma_hat
730
+ # Euler method
731
+ x = x + d * dt
732
+ if sigmas[i + 1] > 0 and (i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2 or i < 3) and i % 3 != 2:
733
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
734
+ dt_1 = sigma_mid - sigmas[i]
735
+ dt_2 = sigmas[i + 1] - sigmas[i]
736
+ #print(dt_1, "#", dt_2, "#", dt_3, "#", dt_4)
737
+ x_2 = x + d * dt_1
738
+ if i % 3 == 1:
739
+ x_temp = dy_sampling_step(x, model, dt_2, sigma_mid, **extra_args)
740
+ elif i % 3 == 0:
741
+ x_temp = smea_sampling_step(x, model, dt_2, sigma_mid, **extra_args)
742
+ x = x_temp - d * dt_1
743
+ return x
744
+
745
+ @torch.no_grad()
746
+ def sample_euler_smea_multi_d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
747
+ extra_args = {} if extra_args is None else extra_args
748
+ s_in = x.new_ones([x.shape[0]])
749
+ for i in trange(len(sigmas) - 1, disable=disable):
750
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
751
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
752
+ sigma_hat = sigmas[i] * (gamma + 1)
753
+ if gamma > 0:
754
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
755
+ denoised = model(x, sigma_hat * s_in, **extra_args)
756
+ d = to_d(x, sigma_hat, denoised)
757
+ if callback is not None:
758
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
759
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 + 2 and i % 2 == 0:
760
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
761
+ dt_1 = sigma_mid - sigma_hat
762
+ dt_2 = sigmas[i + 1] - sigma_hat
763
+ x_2 = x + d * dt_1
764
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
765
+ if i == 0:
766
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale * 0.15, **extra_args)
767
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
768
+ denoised_2 = (denoised_2a + denoised_2c) / 2
769
+ elif i < len(sigmas) * 0.334:
770
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale * 0.25, **extra_args)
771
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale * 0.15, **extra_args)
772
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
773
+ denoised_2 = (denoised_2a + denoised_2b + denoised_2c) / 3
774
+ else:
775
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale * 0.03, True, **extra_args)
776
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
777
+ denoised_2 = (denoised_2b + denoised_2c) / 2
778
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
779
+ x = x + d_2 * dt_2
780
+ else:
781
+ dt = sigmas[i + 1] - sigma_hat
782
+ # Euler method
783
+ x = x + d * dt
784
+ return x
785
+
786
+ @torch.no_grad()
787
+ def sample_euler_smea_multi_b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
788
+ extra_args = {} if extra_args is None else extra_args
789
+ s_in = x.new_ones([x.shape[0]])
790
+ for i in trange(len(sigmas) - 1, disable=disable):
791
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
792
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
793
+ sigma_hat = sigmas[i] * (gamma + 1)
794
+ if gamma > 0:
795
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
796
+ denoised = model(x, sigma_hat * s_in, **extra_args)
797
+ d = to_d(x, sigma_hat, denoised)
798
+ if callback is not None:
799
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
800
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
801
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
802
+ dt_1 = sigma_mid - sigma_hat
803
+ dt_2 = sigmas[i + 1] - sigma_hat
804
+ x_2 = x + d * dt_1
805
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
806
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale * 0.25, **extra_args)
807
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale * 0.15, **extra_args)
808
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
809
+ denoised_2 = (denoised_2a + denoised_2b + denoised_2c) / 3
810
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
811
+ x = x + d_2 * dt_2
812
+ else:
813
+ dt = sigmas[i + 1] - sigma_hat
814
+ # Euler method
815
+ x = x + d * dt
816
+ return x
817
+
818
+ @torch.no_grad()
819
+ def sample_euler_smea_multi_c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
820
+ extra_args = {} if extra_args is None else extra_args
821
+ s_in = x.new_ones([x.shape[0]])
822
+ for i in trange(len(sigmas) - 1, disable=disable):
823
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
824
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
825
+ sigma_hat = sigmas[i] * (gamma + 1)
826
+ if gamma > 0:
827
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
828
+ denoised = model(x, sigma_hat * s_in, **extra_args)
829
+ d = to_d(x, sigma_hat, denoised)
830
+ if callback is not None:
831
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
832
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
833
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
834
+ dt_1 = sigma_mid - sigma_hat
835
+ dt_2 = sigmas[i + 1] - sigma_hat
836
+ x_2 = x + d * dt_1
837
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
838
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale * 0.25, **extra_args)
839
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
840
+ denoised_2 = (denoised_2a + denoised_2c) / 2
841
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
842
+ x = x + d_2 * dt_2
843
+ else:
844
+ dt = sigmas[i + 1] - sigma_hat
845
+ # Euler method
846
+ x = x + d * dt
847
+ return x
848
+
849
+ @torch.no_grad()
850
+ def sample_euler_smea_multi_a(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
851
+ extra_args = {} if extra_args is None else extra_args
852
+ s_in = x.new_ones([x.shape[0]])
853
+ for i in trange(len(sigmas) - 1, disable=disable):
854
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
855
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
856
+ sigma_hat = sigmas[i] * (gamma + 1)
857
+ if gamma > 0:
858
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
859
+ denoised = model(x, sigma_hat * s_in, **extra_args)
860
+ d = to_d(x, sigma_hat, denoised)
861
+ if callback is not None:
862
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
863
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
864
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
865
+ dt_1 = sigma_mid - sigma_hat
866
+ dt_2 = sigmas[i + 1] - sigma_hat
867
+ x_2 = x + d * dt_1
868
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
869
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale * 0.15, **extra_args)
870
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
871
+ denoised_2 = (denoised_2b + denoised_2c) / 2
872
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
873
+ x = x + d_2 * dt_2
874
+ else:
875
+ dt = sigmas[i + 1] - sigma_hat
876
+ # Euler method
877
+ x = x + d * dt
878
+ return x
879
+
880
+
881
+ @torch.no_grad()
882
+ def sample_euler_smea_multi_ds(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
883
+ extra_args = {} if extra_args is None else extra_args
884
+ s_in = x.new_ones([x.shape[0]])
885
+ for i in trange(len(sigmas) - 1, disable=disable):
886
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
887
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
888
+ sigma_hat = sigmas[i] * (gamma + 1)
889
+ if gamma > 0:
890
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
891
+ denoised = model(x, sigma_hat * s_in, **extra_args)
892
+ d = to_d(x, sigma_hat, denoised)
893
+ if callback is not None:
894
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
895
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167 + 1: # and i % 2 == 0:
896
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
897
+ dt_1 = sigma_mid - sigma_hat
898
+ dt_2 = sigmas[i + 1] - sigma_hat
899
+ x_2 = x + d * dt_1
900
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
901
+ if i == 0:
902
+ sa = 1 - scale * 0.15
903
+ sb = 1 + scale * 0.09
904
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
905
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, sb, **extra_args)
906
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.625 + denoised_2b * (sb ** 2) * 0.375) / (0.97**2)
907
+ elif i < len(sigmas) * 0.167:
908
+ sa = 1 - scale * 0.25
909
+ sb = 1 + scale * 0.15
910
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
911
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, sb , **extra_args)
912
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.625 + denoised_2b * (sb ** 2) * 0.375) / (0.95**2)
913
+ else:
914
+ sb = 1 + scale * 0.06
915
+ sc = 1 - scale * 0.1
916
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, sb, True, **extra_args)
917
+ denoised_2c = smea_sampling_step_denoised(x_2, model, sigma_mid, sc, **extra_args)
918
+ denoised_2 = (denoised_2b * (sb ** 2) * 0.375 + denoised_2c * (sc ** 2) * 0.625) / (0.98**2)
919
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
920
+ x = x + d_2 * dt_2
921
+ else:
922
+ dt = sigmas[i + 1] - sigma_hat
923
+ # Euler method
924
+ x = x + d * dt
925
+ return x
926
+
927
+ @torch.no_grad()
928
+ def sample_euler_smea_multi_ds2_s(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
929
+ sample = sample_euler_smea_multi_ds2(model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise, smooth=True)
930
+ return sample
931
+
932
+ @torch.no_grad()
933
+ def sample_euler_smea_multi_ds2_s_m(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
934
+ sample = sample_euler_smea_multi_ds2_m(model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise, smooth=True)
935
+ return sample
936
+
937
+ @torch.no_grad()
938
+ def sample_euler_smea_multi_ds2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., smooth=False):
939
+ extra_args = {} if extra_args is None else extra_args
940
+ s_in = x.new_ones([x.shape[0]])
941
+ for i in trange(len(sigmas) - 1, disable=disable):
942
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
943
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
944
+ sigma_hat = sigmas[i] * (gamma + 1)
945
+ if gamma > 0:
946
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
947
+ denoised = model(x, sigma_hat * s_in, **extra_args)
948
+ d = to_d(x, sigma_hat, denoised)
949
+ if callback is not None:
950
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
951
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167 + 1: # and i % 2 == 0:
952
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
953
+ dt_1 = sigma_mid - sigma_hat
954
+ dt_2 = sigmas[i + 1] - sigma_hat
955
+ x_2 = x + d * dt_1
956
+ scale = (sigmas[i] / sigmas[0]) ** 2
957
+ scale = scale.item()
958
+ if i == 0:
959
+ sa = 1 - scale * 0.15
960
+ sb = 1 + scale * 0.09
961
+ sigA = sigma_mid / (sa ** 2)
962
+ sigB = sigma_mid / (sb ** 2)
963
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
964
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
965
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2) #/ (0.97**2) # 1 - (sa * sb ) / 2 + 1
966
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
967
+ elif i < len(sigmas) * 0.167:
968
+ sa = 1 - scale * 0.25
969
+ sb = 1 + scale * 0.15
970
+ sigA = sigma_mid / (sa ** 2)
971
+ sigB = sigma_mid / (sb ** 2)
972
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
973
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
974
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2) #/ (0.95**2)
975
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
976
+ else:
977
+ sb = 1 + scale * 0.06
978
+ sc = 1 - scale * 0.1
979
+ sigB = sigma_mid / (sb ** 2)
980
+ sigC = sigma_mid / (sc ** 2)
981
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
982
+ denoised_2c = smea_sampling_step_denoised(x_2, model, sigC, sc, smooth, **extra_args)
983
+ denoised_2 = (denoised_2b * (sb ** 2) * 0.5 * sc ** 2 + denoised_2c * (sc ** 2) * 0.5 * sb ** 2) #/ (0.98**2)
984
+ d_2 = to_d(x_2, sigB * 0.5 * sc ** 2 + sigC * 0.5 * sb ** 2, denoised_2)
985
+ x = x + d_2 * dt_2
986
+ else:
987
+ dt = sigmas[i + 1] - sigma_hat
988
+ # Euler method
989
+ x = x + d * dt
990
+ return x
991
+
992
+ @torch.no_grad()
993
+ def sample_euler_smea_multi_ds2_m(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., smooth=False):
994
+ extra_args = {} if extra_args is None else extra_args
995
+ s_in = x.new_ones([x.shape[0]])
996
+ for i in trange(len(sigmas) - 1, disable=disable):
997
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
998
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
999
+ sigma_hat = sigmas[i] * (gamma + 1)
1000
+ if gamma > 0:
1001
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1002
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1003
+ d = to_d(x, sigma_hat, denoised)
1004
+ if callback is not None:
1005
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1006
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167 + 1: # and i % 2 == 0:
1007
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1008
+ dt_1 = sigma_mid - sigma_hat
1009
+ dt_2 = sigmas[i + 1] - sigma_hat
1010
+ x_2 = x + d * dt_1
1011
+ scale = (sigmas[i] / sigmas[0]) ** 2
1012
+ #scale = dt_1 ** 2 * 0.01
1013
+ scale = scale.item()
1014
+ if i == 0:
1015
+ sa = 1 - scale * 0.15 #15
1016
+ sb = 1 + scale * 0.09 #09
1017
+ sigA = sigma_mid / (sa ** 2)
1018
+ sigB = sigma_mid / (sb ** 2)
1019
+ #delta = sa * sb
1020
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
1021
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
1022
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2) #/ (0.97**2) # 1 - (sa * sb ) / 2 + 1
1023
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
1024
+ elif i < len(sigmas) * 0.167:
1025
+ sa = 1 - scale * 0.25 #25
1026
+ sb = 1 + scale * 0.15 #15
1027
+ sigA = sigma_mid / (sa ** 2)
1028
+ sigB = sigma_mid / (sb ** 2)
1029
+ #delta = sa * sb
1030
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
1031
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
1032
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2) #/ (0.95**2)
1033
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
1034
+ else:
1035
+ sb = 1 + scale * 0.06
1036
+ sc = 1 - scale * 0.1
1037
+ sigB = sigma_mid / (sb ** 2)
1038
+ sigC = sigma_mid / (sc ** 2)
1039
+ #delta = sb * sc
1040
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
1041
+ denoised_2c = smea_sampling_step_denoised(x_2, model, sigC, sc, smooth, **extra_args)
1042
+ denoised_2 = (denoised_2b * (sb ** 2) * 0.5 * sc ** 2+ denoised_2c * (sc ** 2) * 0.5 * sb ** 2) #/ (0.98**2)
1043
+ d_2 = to_d(x_2, sigB * 0.5 * sc ** 2 + sigC * 0.5 * sb ** 2, denoised_2)
1044
+ x = x + (math.cos(1.05 * i + 1.1)/(1.25 * i + 1.5) + 1) * d_2 * dt_2
1045
+ else:
1046
+ dt = sigmas[i + 1] - sigma_hat
1047
+ # Euler method
1048
+ x = x + (math.cos(1.05 * i + 1.1)/(1.25 * i + 1.5) + 1) * d * dt
1049
+ return x
1050
+
1051
+ @torch.no_grad()
1052
+ def sample_euler_h_m(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1053
+ extra_args = {} if extra_args is None else extra_args
1054
+ s_in = x.new_ones([x.shape[0]])
1055
+ for i in trange(len(sigmas) - 1, disable=disable):
1056
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1057
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1058
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1059
+ gamma = min((2 ** 0.5 - 1) - wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1060
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler == None else noise_sampler
1061
+ sigma_hat = sigmas[i] * (gamma + 1)
1062
+ if gamma > 0:
1063
+ x = x - eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1064
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1065
+ d = to_d(x, sigma_hat, denoised)
1066
+ dt = sigmas[i + 1] - sigma_hat
1067
+ if callback is not None:
1068
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1069
+ if sigmas[i + 1] > 0:
1070
+ x_2 = x + d * dt
1071
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1072
+ d_prime = d * 0.5 + d_2 * 0.5
1073
+ x = x + d_prime * dt
1074
+ else:
1075
+ # Euler method
1076
+ x = x + d * dt
1077
+ return x
1078
+
1079
+ @torch.no_grad()
1080
+ def sample_euler_h_m_b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1081
+ extra_args = {} if extra_args is None else extra_args
1082
+ s_in = x.new_ones([x.shape[0]])
1083
+ for i in trange(len(sigmas) - 1, disable=disable):
1084
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1085
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1086
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1087
+ gamma = min(wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1088
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1089
+ sigma_hat = sigmas[i] * (gamma + 1)
1090
+ if gamma > 0:
1091
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1092
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1093
+ d = to_d(x, sigma_hat, denoised)
1094
+ dt = sigmas[i + 1] - sigma_hat
1095
+ if callback is not None:
1096
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1097
+ if sigmas[i + 1] > 0:
1098
+ x_2 = x + d * dt
1099
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1100
+ d_prime = d * 0.5 + d_2 * 0.5
1101
+ x = x + d_prime * dt
1102
+ else:
1103
+ # Euler method
1104
+ x = x + d * dt
1105
+ return x
1106
+
1107
+ @torch.no_grad()
1108
+ def sample_euler_h_m_c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1109
+ extra_args = {} if extra_args is None else extra_args
1110
+ s_in = x.new_ones([x.shape[0]])
1111
+ for i in trange(len(sigmas) - 1, disable=disable):
1112
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1113
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1114
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1115
+ gamma = max((2 ** 0.5 - 1) + wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1116
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1117
+ sigma_hat = sigmas[i] * (gamma + 1)
1118
+ if gamma > 0:
1119
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1120
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1121
+ d = to_d(x, sigma_hat, denoised)
1122
+ dt = sigmas[i + 1] - sigma_hat
1123
+ if callback is not None:
1124
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1125
+ if sigmas[i + 1] > 0:
1126
+ x_2 = x + d * dt
1127
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1128
+ d_prime = d * 0.5 + d_2 * 0.5
1129
+ x = x + d_prime * dt
1130
+ else:
1131
+ # Euler method
1132
+ x = x + d * dt
1133
+ return x
1134
+
1135
+ @torch.no_grad()
1136
+ def sample_euler_h_m_d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1137
+ extra_args = {} if extra_args is None else extra_args
1138
+ s_in = x.new_ones([x.shape[0]])
1139
+ for i in trange(len(sigmas) - 1, disable=disable):
1140
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1141
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1142
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1143
+ gamma = min((2 ** 0.5 - 1) - wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1144
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1145
+ sigma_hat = sigmas[i] * (gamma + 1)
1146
+ if gamma > 0:
1147
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1148
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1149
+ d = to_d(x, sigma_hat, denoised)
1150
+ dt = sigmas[i + 1] - sigma_hat
1151
+ if callback is not None:
1152
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1153
+ if sigmas[i + 1] > 0:
1154
+ x_2 = x + d * dt
1155
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1156
+ d_prime = d * 0.5 + d_2 * 0.5
1157
+ x = x + d_prime * dt
1158
+ else:
1159
+ # Euler method
1160
+ x = x + d * dt
1161
+ return x
1162
+
1163
+ @torch.no_grad()
1164
+ def sample_euler_h_m_e(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1165
+ extra_args = {} if extra_args is None else extra_args
1166
+ s_in = x.new_ones([x.shape[0]])
1167
+ for i in trange(len(sigmas) - 1, disable=disable):
1168
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1169
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1170
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1171
+ gamma = max((2 ** 0.5 - 1) + wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1172
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1173
+ sigma_hat = sigmas[i] * (gamma + 1)
1174
+ if gamma > 0:
1175
+ x = x - eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1176
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1177
+ d = to_d(x, sigma_hat, denoised)
1178
+ dt = sigmas[i + 1] - sigma_hat
1179
+ if callback is not None:
1180
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1181
+ if sigmas[i + 1] > 0:
1182
+ x_2 = x + d * dt
1183
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1184
+ d_prime = d * 0.5 + d_2 * 0.5
1185
+ x = x + d_prime * dt
1186
+ else:
1187
+ # Euler method
1188
+ x = x + d * dt
1189
+ return x
1190
+
1191
+ @torch.no_grad()
1192
+ def sample_euler_h_m_f(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1193
+ extra_args = {} if extra_args is None else extra_args
1194
+ s_in = x.new_ones([x.shape[0]])
1195
+ for i in trange(len(sigmas) - 1, disable=disable):
1196
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1197
+ wave_max = math.cos(0)/1.5 + 1
1198
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1199
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1200
+ gamma = min((wave_max - wave) * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1201
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1202
+ sigma_hat = sigmas[i] * (gamma + 1)
1203
+ if gamma > 0:
1204
+ x = x - eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1205
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1206
+ d = to_d(x, sigma_hat, denoised)
1207
+ dt = sigmas[i + 1] - sigma_hat
1208
+ if callback is not None:
1209
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1210
+ if sigmas[i + 1] > 0:
1211
+ x_2 = x + d * dt
1212
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1213
+ d_prime = d * 0.5 + d_2 * 0.5
1214
+ x = x + d_prime * dt
1215
+ else:
1216
+ # Euler method
1217
+ x = x + d * dt
1218
+ return x
1219
+
1220
+ @torch.no_grad()
1221
+ def sample_euler_h_m_g(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1222
+ extra_args = {} if extra_args is None else extra_args
1223
+ s_in = x.new_ones([x.shape[0]])
1224
+ for i in trange(len(sigmas) - 1, disable=disable):
1225
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1226
+ wave_max = math.cos(0)/1.5 + 1
1227
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1228
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1229
+ gamma = min((wave_max - wave) * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1230
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1231
+ sigma_hat = sigmas[i] * (gamma + 1)
1232
+ if gamma > 0:
1233
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1234
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1235
+ d = to_d(x, sigma_hat, denoised)
1236
+ dt = sigmas[i + 1] - sigma_hat
1237
+ if callback is not None:
1238
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1239
+ if sigmas[i + 1] > 0:
1240
+ x_2 = x + d * dt
1241
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1242
+ d_prime = d * 0.5 + d_2 * 0.5
1243
+ x = x + d_prime * dt
1244
+ else:
1245
+ # Euler method
1246
+ x = x + d * dt
1247
+ return x
1248
+
1249
+ @torch.no_grad()
1250
+ def sample_euler_h_m_b_c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1251
+ extra_args = {} if extra_args is None else extra_args
1252
+ s_in = x.new_ones([x.shape[0]])
1253
+ for i in trange(len(sigmas) - 1, disable=disable):
1254
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1255
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1256
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1257
+ gamma = min(wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1258
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1259
+ gammaup = gamma + 1
1260
+ sigma_hat = sigmas[i] * gammaup
1261
+ if gamma > 0:
1262
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1263
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1264
+ last_noise_uncond = model.last_noise_uncond
1265
+ d = to_d(x, sigma_hat, denoised)
1266
+ dt = sigmas[i + 1] - sigma_hat
1267
+ if callback is not None:
1268
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1269
+ if i == 0:
1270
+ x = x + d * dt
1271
+ elif i <= len(sigmas) - 4:
1272
+ x_2 = x + d * dt
1273
+ d_2 = to_d(x_2, sigmas[i + 1] * gammaup, denoised)
1274
+ x_3 = x_2 + d_2 * dt
1275
+ d_3 = to_d(x_3, sigmas[i + 2] * gammaup, denoised)
1276
+ d_prime = d * 0.5 + d_2 * 0.375 + d_3 * 0.125
1277
+ x = x + d_prime * dt
1278
+ elif sigmas[i + 1] > 0:
1279
+ x_2 = x + d * dt
1280
+ d_2 = to_d(x_2, sigmas[i + 1] * gammaup, denoised)
1281
+ d_prime = d * 0.5 + d_2 * 0.5
1282
+ x = x + d_prime * dt
1283
+ else:
1284
+ # Euler method
1285
+ x = x + d * dt
1286
+ return x
1287
+
1288
+ @torch.no_grad()
1289
+ def sample_euler_h_m_b_c_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1290
+ extra_args = {} if extra_args is None else extra_args
1291
+ s_in = x.new_ones([x.shape[0]])
1292
+ for i in trange(len(sigmas) - 1, disable=disable):
1293
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1294
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1295
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1296
+ gamma = min(wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1297
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1298
+ gammaup = gamma + 1
1299
+ sigma_hat = sigmas[i] * gammaup
1300
+ if gamma > 0:
1301
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1302
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1303
+ last_noise_uncond = model.last_noise_uncond
1304
+ d = to_d(x, sigma_hat, denoised)
1305
+ dt = sigmas[i + 1] - sigma_hat
1306
+ if callback is not None:
1307
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1308
+ if i == 0:
1309
+ x = x + d * dt
1310
+ elif i <= len(sigmas) - 4:
1311
+ x_2 = x + d * dt
1312
+ d_2 = to_d(x_2, sigmas[i + 1] * gammaup, denoised)
1313
+ x_3 = x_2 + d_2 * dt
1314
+ d_3 = to_d(x_3, sigmas[i + 2] * gammaup, last_noise_uncond)
1315
+ d_prime = d * 0.5 + d_2 * 0.375 + d_3 * 0.125
1316
+ x = x + d_prime * dt
1317
+ elif sigmas[i + 1] > 0:
1318
+ x_2 = x + d * dt
1319
+ d_2 = to_d(x_2, sigmas[i + 1] * gammaup, denoised)
1320
+ d_prime = d * 0.5 + d_2 * 0.5
1321
+ x = x + d_prime * dt
1322
+ else:
1323
+ # Euler method
1324
+ x = x + d * dt
1325
+ return x
1326
+
1327
+ @torch.no_grad()
1328
+ def sample_euler_smea_max(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., smooth=False):
1329
+ extra_args = {} if extra_args is None else extra_args
1330
+ s_in = x.new_ones([x.shape[0]])
1331
+ for i in trange(len(sigmas) - 1, disable=disable):
1332
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1333
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1334
+ sigma_hat = sigmas[i] * (gamma + 1)
1335
+ sa = math.cos(i + 1)/(1.5 * i + 1.75) + 1
1336
+ if gamma > 0:
1337
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1338
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1339
+ d = to_d(x, sigma_hat, denoised)
1340
+ if callback is not None:
1341
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1342
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167 + 1: # and i % 2 == 0:
1343
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1344
+ dt_1 = sigma_mid - sigma_hat
1345
+ dt_2 = sigmas[i + 1] - sigma_hat
1346
+ x_2 = x + d * dt_1
1347
+ sigA = sigma_mid / (sa ** 2)
1348
+ sigB = sigma_mid
1349
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
1350
+ denoised_2b = model(x_2, sigma_mid * s_in, **extra_args)
1351
+ denoised_2 = (denoised_2a * 0.5 * (sa ** 2) + denoised_2b * 0.5 / (sa ** 2))
1352
+ d_2 = to_d(x_2, sigA * 0.5 * (sa ** 2) + sigB * 0.5 / (sa ** 2), denoised_2)
1353
+ x = x + d_2 * dt_2
1354
+ else:
1355
+ dt = sigmas[i + 1] - sigma_hat
1356
+ # Euler method
1357
+ x = x + sa * d * dt
1358
+ return x
1359
+
1360
+
1361
+ @torch.no_grad()
1362
+ def sample_euler_smea_max_s(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1363
+ sample = sample_euler_smea_max(model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise, smooth=True)
1364
+ return sample
1365
+
1366
+ @torch.no_grad()
1367
+ def sample_euler_smea_multi_bs(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1368
+ extra_args = {} if extra_args is None else extra_args
1369
+ s_in = x.new_ones([x.shape[0]])
1370
+ for i in trange(len(sigmas) - 1, disable=disable):
1371
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1372
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1373
+ sigma_hat = sigmas[i] * (gamma + 1)
1374
+ if gamma > 0:
1375
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1376
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1377
+ d = to_d(x, sigma_hat, denoised)
1378
+ if callback is not None:
1379
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1380
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
1381
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1382
+ dt_1 = sigma_mid - sigma_hat
1383
+ dt_2 = sigmas[i + 1] - sigma_hat
1384
+ x_2 = x + d * dt_1
1385
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
1386
+ sa = 1 - scale * 0.25
1387
+ sb = 1 + scale * 0.15
1388
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
1389
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, sb, **extra_args)
1390
+ denoised_2 = denoised_2a * (sa ** 2) * 0.625 + denoised_2b * (sb ** 2) * 0.375 / (0.95**2)
1391
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
1392
+ x = x + d_2 * dt_2
1393
+ else:
1394
+ dt = sigmas[i + 1] - sigma_hat
1395
+ # Euler method
1396
+ x = x + d * dt
1397
+ return x
1398
+
1399
+ @torch.no_grad()
1400
+ def sample_euler_smea_multi_bs2_s(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1401
+ sample = sample_euler_smea_multi_bs2(model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise, smooth=True)
1402
+ return sample
1403
+
1404
+ @torch.no_grad()
1405
+ def sample_euler_smea_multi_bs2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., smooth=False):
1406
+ extra_args = {} if extra_args is None else extra_args
1407
+ s_in = x.new_ones([x.shape[0]])
1408
+ for i in trange(len(sigmas) - 1, disable=disable):
1409
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1410
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1411
+ sigma_hat = sigmas[i] * (gamma + 1)
1412
+ if gamma > 0:
1413
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1414
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1415
+ d = to_d(x, sigma_hat, denoised)
1416
+ if callback is not None:
1417
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1418
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
1419
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1420
+ dt_1 = sigma_mid - sigma_hat
1421
+ dt_2 = sigmas[i + 1] - sigma_hat
1422
+ x_2 = x + d * dt_1
1423
+ scale = (sigmas[i] / sigmas[0]) ** 2
1424
+ scale = scale.item()
1425
+ sa = 1 - scale * 0.25
1426
+ sb = 1 + scale * 0.15
1427
+ sigA = sigma_mid / (sa ** 2)
1428
+ sigB = sigma_mid / (sb ** 2)
1429
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
1430
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
1431
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2)
1432
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
1433
+ x = x + d_2 * dt_2
1434
+ else:
1435
+ dt = sigmas[i + 1] - sigma_hat
1436
+ # Euler method
1437
+ x = x + d * dt
1438
+ return x
1439
+
1440
+ @torch.no_grad()
1441
+ def sample_euler_smea_multi_cs(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1442
+ extra_args = {} if extra_args is None else extra_args
1443
+ s_in = x.new_ones([x.shape[0]])
1444
+ for i in trange(len(sigmas) - 1, disable=disable):
1445
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1446
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1447
+ sigma_hat = sigmas[i] * (gamma + 1)
1448
+ if gamma > 0:
1449
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1450
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1451
+ d = to_d(x, sigma_hat, denoised)
1452
+ if callback is not None:
1453
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1454
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
1455
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1456
+ dt_1 = sigma_mid - sigma_hat
1457
+ dt_2 = sigmas[i + 1] - sigma_hat
1458
+ x_2 = x + d * dt_1
1459
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
1460
+ sa = 1 - scale * 0.25
1461
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
1462
+ d_2 = to_d(x_2, sigma_mid, denoised_2 * (sa ** 2) * 1.25)
1463
+ x = x + d_2 * dt_2
1464
+ else:
1465
+ dt = sigmas[i + 1] - sigma_hat
1466
+ # Euler method
1467
+ x = x + d * dt
1468
+ return x
1469
+
1470
+ @torch.no_grad()
1471
+ def sample_euler_smea_multi_as(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1472
+ extra_args = {} if extra_args is None else extra_args
1473
+ s_in = x.new_ones([x.shape[0]])
1474
+ for i in trange(len(sigmas) - 1, disable=disable):
1475
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1476
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1477
+ sigma_hat = sigmas[i] * (gamma + 1)
1478
+ if gamma > 0:
1479
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1480
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1481
+ d = to_d(x, sigma_hat, denoised)
1482
+ if callback is not None:
1483
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1484
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
1485
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1486
+ dt_1 = sigma_mid - sigma_hat
1487
+ dt_2 = sigmas[i + 1] - sigma_hat
1488
+ x_2 = x + d * dt_1
1489
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
1490
+ sa = 1 + scale * 0.15
1491
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
1492
+ d_2 = to_d(x_2, sigma_mid, denoised_2 * (sa ** 2) * 0.75)
1493
+ x = x + d_2 * dt_2
1494
+ else:
1495
+ dt = sigmas[i + 1] - sigma_hat
1496
+ # Euler method
1497
+ x = x + d * dt
1498
+ return x
1499
+
1500
+ ## og sampler
1501
+ @torch.no_grad()
1502
+ def sample_euler_dy_og(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1503
+ extra_args = {} if extra_args is None else extra_args
1504
+ s_in = x.new_ones([x.shape[0]])
1505
+ for i in trange(len(sigmas) - 1, disable=disable):
1506
+ # print(i)
1507
+ # i绗竴姝ヤ负0
1508
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1509
+ eps = torch.randn_like(x) * s_noise
1510
+ sigma_hat = sigmas[i] * (gamma + 1)
1511
+ # print(sigma_hat)
1512
+ dt = sigmas[i + 1] - sigma_hat
1513
+ if gamma > 0:
1514
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1515
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1516
+ d = sampling.to_d(x, sigma_hat, denoised)
1517
+ if sigmas[i + 1] > 0:
1518
+ if i // 2 == 1:
1519
+ x = dy_sampling_step(x, model, dt, sigma_hat, **extra_args)
1520
+ if callback is not None:
1521
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1522
+ # Euler method
1523
+ x = x + d * dt
1524
+ return x
1525
+
1526
+ @torch.no_grad()
1527
+ def sample_euler_smea_dy_og(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1528
+ extra_args = {} if extra_args is None else extra_args
1529
+ s_in = x.new_ones([x.shape[0]])
1530
+ for i in trange(len(sigmas) - 1, disable=disable):
1531
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1532
+ eps = torch.randn_like(x) * s_noise
1533
+ sigma_hat = sigmas[i] * (gamma + 1)
1534
+ dt = sigmas[i + 1] - sigma_hat
1535
+ if gamma > 0:
1536
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1537
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1538
+ d = sampling.to_d(x, sigma_hat, denoised)
1539
+ # Euler method
1540
+ x = x + d * dt
1541
+ if sigmas[i + 1] > 0:
1542
+ if i + 1 // 2 == 1:
1543
+ x = dy_sampling_step(x, model, dt, sigma_hat, **extra_args)
1544
+ if i + 1 // 2 == 0:
1545
+ x = smea_sampling_step(x, model, dt, sigma_hat, **extra_args)
1546
+ if callback is not None:
1547
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1548
+ return x
1549
+
1550
+ ## TCD
1551
+
1552
+ def sample_tcd_euler_a(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, gamma=0.3):
1553
+ # TCD sampling using modified Euler Ancestral sampler. by @laksjdjf
1554
+ extra_args = {} if extra_args is None else extra_args
1555
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
1556
+ s_in = x.new_ones([x.shape[0]])
1557
+ for i in trange(len(sigmas) - 1, disable=disable):
1558
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1559
+ if callback is not None:
1560
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1561
+
1562
+ #d = to_d(x, sigmas[i], denoised)
1563
+ sigma_from = sigmas[i]
1564
+ sigma_to = sigmas[i + 1]
1565
+
1566
+ t = model.inner_model.sigma_to_t(sigma_from)
1567
+ down_t = (1 - gamma) * t
1568
+ sigma_down = model.inner_model.t_to_sigma(down_t)
1569
+
1570
+ if sigma_down > sigma_to:
1571
+ sigma_down = sigma_to
1572
+ sigma_up = (sigma_to ** 2 - sigma_down ** 2) ** 0.5
1573
+
1574
+ # same as euler ancestral
1575
+ d = to_d(x, sigma_from, denoised)
1576
+ dt = sigma_down - sigma_from
1577
+ x += d * dt
1578
+
1579
+ if sigma_to > 0 and gamma > 0:
1580
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigma_up
1581
+ return x
1582
+
1583
+ @torch.no_grad()
1584
+ def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, gamma=0.3):
1585
+ # TCD sampling using modified DDPM.
1586
+ extra_args = {} if extra_args is None else extra_args
1587
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
1588
+ s_in = x.new_ones([x.shape[0]])
1589
+
1590
+ for i in trange(len(sigmas) - 1, disable=disable):
1591
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1592
+ if callback is not None:
1593
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1594
+
1595
+ sigma_from, sigma_to = sigmas[i], sigmas[i+1]
1596
+
1597
+ # TCD offset, based on gamma, and conversion between sigma and timestep
1598
+ t = model.inner_model.sigma_to_t(sigma_from)
1599
+ t_s = (1 - gamma) * t
1600
+ sigma_to_s = model.inner_model.t_to_sigma(t_s)
1601
+
1602
+ # if sigma_to_s > sigma_to:
1603
+ # sigma_to_s = sigma_to
1604
+ # if sigma_to_s < 0:
1605
+ # sigma_to_s = torch.tensor(1.0)
1606
+ #print(f"sigma_from: {sigma_from}, sigma_to: {sigma_to}, sigma_to_s: {sigma_to_s}")
1607
+
1608
+
1609
+ # The following is equivalent to the comfy DDPM implementation
1610
+ # x = DDPMSampler_step(x / torch.sqrt(1.0 + sigma_from ** 2.0), sigma_from, sigma_to, (x - denoised) / sigma_from, noise_sampler)
1611
+
1612
+ noise_est = (x - denoised) / sigma_from
1613
+ x /= torch.sqrt(1.0 + sigma_from ** 2.0)
1614
+
1615
+ alpha_cumprod = 1 / ((sigma_from * sigma_from) + 1) # _t
1616
+ alpha_cumprod_prev = 1 / ((sigma_to * sigma_to) + 1) # _t_prev
1617
+ alpha = (alpha_cumprod / alpha_cumprod_prev)
1618
+
1619
+ ## These values should approach 1.0?
1620
+ # print(f"alpha_cumprod: {alpha_cumprod}")
1621
+ # print(f"alpha_cumprod_prev: {alpha_cumprod_prev}")
1622
+ # print(f"alpha: {alpha}")
1623
+
1624
+
1625
+ # alpha_cumprod_down = 1 / ((sigma_to_s * sigma_to_s) + 1) # _s
1626
+ # alpha_d = (alpha_cumprod_prev / alpha_cumprod_down)
1627
+ # alpha2 = (alpha_cumprod / alpha_cumprod_down)
1628
+ # print(f"** alpha_cumprod_down: {alpha_cumprod_down}")
1629
+ # print(f"** alpha_d: {alpha_d}, alpha2: #{alpha2}")
1630
+
1631
+ # epsilon noise prediction from comfy DDPM implementation
1632
+ x = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise_est / (1 - alpha_cumprod).sqrt())
1633
+ # x = (1.0 / alpha_d).sqrt() * (x - (1 - alpha) * noise_est / (1 - alpha_cumprod).sqrt())
1634
+
1635
+ first_step = sigma_to == 0
1636
+ last_step = i == len(sigmas) - 2
1637
+
1638
+ if not first_step:
1639
+ if gamma > 0 and not last_step:
1640
+ noise = noise_sampler(sigma_from, sigma_to)
1641
+
1642
+ # x += ((1 - alpha_d) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise
1643
+ variance = ((1 - alpha_cumprod_prev) / (1 - alpha_cumprod)) * (1 - alpha_cumprod / alpha_cumprod_prev)
1644
+ x += variance.sqrt() * noise # scale noise by std deviation
1645
+
1646
+ # relevant diffusers code from scheduling_tcd.py
1647
+ # prev_sample = (alpha_prod_t_prev / alpha_prod_s).sqrt() * pred_noised_sample + (
1648
+ # 1 - alpha_prod_t_prev / alpha_prod_s
1649
+ # ).sqrt() * noise
1650
+
1651
+ x *= torch.sqrt(1.0 + sigma_to ** 2.0)
1652
+
1653
+ # beta_cumprod_t = 1 - alpha_cumprod
1654
+ # beta_cumprod_s = 1 - alpha_cumprod_down
1655
+
1656
+
1657
+ return x
sd-webui-smea/sd_webui_smea.py ADDED
@@ -0,0 +1,1657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import k_diffusion.sampling
4
+
5
+ from k_diffusion.sampling import to_d, BrownianTreeNoiseSampler
6
+ from tqdm.auto import trange
7
+ from modules import scripts
8
+ from modules import sd_samplers_kdiffusion, sd_samplers_common, sd_samplers
9
+ from modules.sd_samplers_kdiffusion import KDiffusionSampler
10
+
11
+ class _Rescaler:
12
+ def __init__(self, model, x, mode, **extra_args):
13
+ self.model = model
14
+ self.x = x
15
+ self.mode = mode
16
+ self.extra_args = extra_args
17
+ self.init_latent, self.mask, self.nmask = model.init_latent, model.mask, model.nmask
18
+
19
+ def __enter__(self):
20
+ if self.init_latent is not None:
21
+ self.model.init_latent = torch.nn.functional.interpolate(input=self.init_latent, size=self.x.shape[2:4], mode=self.mode)
22
+ if self.mask is not None:
23
+ self.model.mask = torch.nn.functional.interpolate(input=self.mask.unsqueeze(0), size=self.x.shape[2:4], mode=self.mode).squeeze(0)
24
+ if self.nmask is not None:
25
+ self.model.nmask = torch.nn.functional.interpolate(input=self.nmask.unsqueeze(0), size=self.x.shape[2:4], mode=self.mode).squeeze(0)
26
+ return self
27
+
28
+ def __exit__(self, type, value, traceback):
29
+ del self.model.init_latent, self.model.mask, self.model.nmask
30
+ self.model.init_latent, self.model.mask, self.model.nmask = self.init_latent, self.mask, self.nmask
31
+
32
+ class Smea(scripts.Script):
33
+
34
+ def title(self):
35
+ return "Euler Smea Dy sampler"
36
+
37
+ def show(self, is_img2img):
38
+ return scripts.AlwaysVisible
39
+
40
+ def __init__(self):
41
+ init()
42
+ return
43
+
44
+ def init():
45
+ for i in sd_samplers.all_samplers:
46
+ if "Euler Max" in i.name:
47
+ return
48
+
49
+ samplers_smea = [
50
+ ('Euler Max', sample_euler_max, ['k_euler'], {}),
51
+ ('Euler Max1b', sample_euler_max1b, ['k_euler'], {}),
52
+ ('Euler Max1c', sample_euler_max1c, ['k_euler'], {}),
53
+ ('Euler Max1d', sample_euler_max1d, ['k_euler'], {}),
54
+ ('Euler Max2', sample_euler_max2, ['k_euler'], {}),
55
+ ('Euler Max2b', sample_euler_max2b, ['k_euler'], {}),
56
+ ('Euler Max2c', sample_euler_max2c, ['k_euler'], {}),
57
+ ('Euler Max2d', sample_euler_max2d, ['k_euler'], {}),
58
+ ('Euler Max3', sample_euler_max3, ['k_euler'], {}),
59
+ ('Euler Max3b', sample_euler_max3b, ['k_euler'], {}),
60
+ ('Euler Max3c', sample_euler_max3c, ['k_euler'], {}),
61
+ ('Euler Max4', sample_euler_max4, ['k_euler'], {}),
62
+ ('Euler Max4b', sample_euler_max4b, ['k_euler'], {}),
63
+ ('Euler Max4c', sample_euler_max4c, ['k_euler'], {}),
64
+ ('Euler Max4d', sample_euler_max4d, ['k_euler'], {}),
65
+ ('Euler Max4e', sample_euler_max4e, ['k_euler'], {}),
66
+ ('Euler Max4f', sample_euler_max4f, ['k_euler'], {}),
67
+ ('Euler Dy', sample_euler_dy, ['k_euler'], {}),
68
+ ('Euler Smea', sample_euler_smea, ['k_euler'], {}),
69
+ ('Euler Smea Dy', sample_euler_smea_dy, ['k_euler'], {}),
70
+ ('Euler Smea Max', sample_euler_smea_max, ['k_euler'], {}),
71
+ ('Euler Smea Max s', sample_euler_smea_max_s, ['k_euler'], {}),
72
+ ('Euler Smea dyn a', sample_euler_smea_dyn_a, ['k_euler'], {}),
73
+ ('Euler Smea dyn b', sample_euler_smea_dyn_b, ['k_euler'], {}),
74
+ ('Euler Smea dyn c', sample_euler_smea_dyn_c, ['k_euler'], {}),
75
+ ('Euler Smea ma', sample_euler_smea_multi_a, ['k_euler'], {}),
76
+ ('Euler Smea mb', sample_euler_smea_multi_b, ['k_euler'], {}),
77
+ ('Euler Smea mc', sample_euler_smea_multi_c, ['k_euler'], {}),
78
+ ('Euler Smea md', sample_euler_smea_multi_d, ['k_euler'], {}),
79
+ ('Euler Smea mas', sample_euler_smea_multi_as, ['k_euler'], {}),
80
+ ('Euler Smea mbs', sample_euler_smea_multi_bs, ['k_euler'], {}),
81
+ ('Euler Smea mcs', sample_euler_smea_multi_cs, ['k_euler'], {}),
82
+ ('Euler Smea mds', sample_euler_smea_multi_ds, ['k_euler'], {}),
83
+ ('Euler Smea mbs2', sample_euler_smea_multi_bs2, ['k_euler'], {}),
84
+ ('Euler Smea mds2', sample_euler_smea_multi_ds2, ['k_euler'], {}),
85
+ ('Euler Smea mds2 max', sample_euler_smea_multi_ds2_m, ['k_euler'], {}),
86
+ ('Euler Smea mds2 s max', sample_euler_smea_multi_ds2_s_m, ['k_euler'], {}),
87
+ ('Euler Smea mbs2 s', sample_euler_smea_multi_bs2_s, ['k_euler'], {}),
88
+ ('Euler Smea mds2 s', sample_euler_smea_multi_ds2_s, ['k_euler'], {}),
89
+ ('Euler h max', sample_euler_h_m, ['k_euler'], {"brownian_noise": True}),
90
+ ('Euler h max b', sample_euler_h_m_b, ['k_euler'], {"brownian_noise": True}),
91
+ ('Euler h max c', sample_euler_h_m_c, ['k_euler'], {"brownian_noise": True}),
92
+ ('Euler h max d', sample_euler_h_m_d, ['k_euler'], {"brownian_noise": True}),
93
+ ('Euler h max e', sample_euler_h_m_e, ['k_euler'], {"brownian_noise": True}),
94
+ ('Euler h max f', sample_euler_h_m_f, ['k_euler'], {"brownian_noise": True}),
95
+ ('Euler h max g', sample_euler_h_m_g, ['k_euler'], {"brownian_noise": True}),
96
+ ('Euler h max b c', sample_euler_h_m_b_c, ['k_euler'], {"brownian_noise": True}),
97
+ ('Euler h max b c CFG++', sample_euler_h_m_b_c_pp, ['k_euler'], {"brownian_noise": True, "cfgpp": True}),
98
+ ('Euler Dy koishi-star', sample_euler_dy_og, ['k_euler'], {}),
99
+ ('Euler Smea Dy koishi-star', sample_euler_smea_dy_og, ['k_euler'], {}),
100
+ ('TCD Euler a', sample_tcd_euler_a, ['tcd_euler_a'], {}),
101
+ ('TCD', sample_tcd, ['tcd'], {}),
102
+ ]
103
+
104
+ samplers_data_smea = [
105
+ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
106
+ for label, funcname, aliases, options in samplers_smea
107
+ if callable(funcname)
108
+ ]
109
+
110
+ sampler_exparams_smea = {
111
+ sample_euler_max: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
112
+ sample_euler_max1b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
113
+ sample_euler_max1c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
114
+ sample_euler_max1d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
115
+ sample_euler_max2: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
116
+ sample_euler_max2b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
117
+ sample_euler_max2c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
118
+ sample_euler_max2d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
119
+ sample_euler_max3: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
120
+ sample_euler_max3b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
121
+ sample_euler_max3c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
122
+ sample_euler_max4: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
123
+ sample_euler_max4b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
124
+ sample_euler_max4c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
125
+ sample_euler_max4d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
126
+ sample_euler_max4e: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
127
+ sample_euler_max4f: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
128
+ sample_euler_dy: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
129
+ sample_euler_smea: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
130
+ sample_euler_smea_dy: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
131
+ sample_euler_smea_max: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
132
+ sample_euler_smea_max_s: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
133
+ sample_euler_smea_dyn_a: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
134
+ sample_euler_smea_dyn_b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
135
+ sample_euler_smea_dyn_c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
136
+ sample_euler_smea_multi_a: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
137
+ sample_euler_smea_multi_b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
138
+ sample_euler_smea_multi_c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
139
+ sample_euler_smea_multi_d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
140
+ sample_euler_smea_multi_as: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
141
+ sample_euler_smea_multi_bs: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
142
+ sample_euler_smea_multi_cs: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
143
+ sample_euler_smea_multi_ds: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
144
+ sample_euler_smea_multi_bs2: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
145
+ sample_euler_smea_multi_ds2: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
146
+ sample_euler_smea_multi_ds2_m: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
147
+ sample_euler_smea_multi_ds2_s_m: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
148
+ sample_euler_smea_multi_bs2_s: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
149
+ sample_euler_smea_multi_ds2_s: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
150
+ sample_euler_h_m: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
151
+ sample_euler_h_m_b: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
152
+ sample_euler_h_m_c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
153
+ sample_euler_h_m_d: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
154
+ sample_euler_h_m_e: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
155
+ sample_euler_h_m_f: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
156
+ sample_euler_h_m_g: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
157
+ sample_euler_h_m_b_c: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
158
+ sample_euler_h_m_b_c_pp: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
159
+ sample_euler_dy_og: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
160
+ sample_euler_smea_dy_og: ['s_churn', 's_tmin', 's_tmax', 's_noise'],
161
+ }
162
+ sd_samplers_kdiffusion.sampler_extra_params = {**sd_samplers_kdiffusion.sampler_extra_params, **sampler_exparams_smea}
163
+
164
+ samplers_map_smea = {x.name: x for x in samplers_data_smea}
165
+ sd_samplers_kdiffusion.k_diffusion_samplers_map = {**sd_samplers_kdiffusion.k_diffusion_samplers_map, **samplers_map_smea}
166
+
167
+ for i, item in enumerate(sd_samplers.all_samplers):
168
+ if "Euler" in item.name:
169
+ sd_samplers.all_samplers = sd_samplers.all_samplers[:i + 1] + [*samplers_data_smea] + sd_samplers.all_samplers[i + 1:]
170
+ break
171
+ sd_samplers.all_samplers_map = {x.name: x for x in sd_samplers.all_samplers}
172
+ sd_samplers.set_samplers()
173
+
174
+ return
175
+
176
+ def default_noise_sampler(x):
177
+ return lambda sigma, sigma_next: k_diffusion.sampling.torch.randn_like(x)
178
+
179
+ @torch.no_grad()
180
+ def dy_sampling_step(x, model, dt, sigma_hat, **extra_args):
181
+ original_shape = x.shape
182
+ batch_size, channels, m, n = original_shape[0], original_shape[1], original_shape[2] // 2, original_shape[3] // 2
183
+ extra_row = x.shape[2] % 2 == 1
184
+ extra_col = x.shape[3] % 2 == 1
185
+
186
+ if extra_row:
187
+ extra_row_content = x[:, :, -1:, :]
188
+ x = x[:, :, :-1, :]
189
+ if extra_col:
190
+ extra_col_content = x[:, :, :, -1:]
191
+ x = x[:, :, :, :-1]
192
+
193
+ a_list = x.unfold(2, 2, 2).unfold(3, 2, 2).contiguous().view(batch_size, channels, m * n, 2, 2)
194
+ c = a_list[:, :, :, 1, 1].view(batch_size, channels, m, n)
195
+
196
+ with _Rescaler(model, c, 'nearest-exact', **extra_args) as rescaler:
197
+ denoised = model(c, sigma_hat * c.new_ones([c.shape[0]]), **rescaler.extra_args)
198
+ d = to_d(c, sigma_hat, denoised)
199
+ c = c + d * dt
200
+
201
+ d_list = c.view(batch_size, channels, m * n, 1, 1)
202
+ a_list[:, :, :, 1, 1] = d_list[:, :, :, 0, 0]
203
+ x = a_list.view(batch_size, channels, m, n, 2, 2).permute(0, 1, 2, 4, 3, 5).reshape(batch_size, channels, 2 * m, 2 * n)
204
+
205
+ if extra_row or extra_col:
206
+ x_expanded = torch.zeros(original_shape, dtype=x.dtype, device=x.device)
207
+ x_expanded[:, :, :2 * m, :2 * n] = x
208
+ if extra_row:
209
+ x_expanded[:, :, -1:, :2 * n + 1] = extra_row_content
210
+ if extra_col:
211
+ x_expanded[:, :, :2 * m, -1:] = extra_col_content
212
+ if extra_row and extra_col:
213
+ x_expanded[:, :, -1:, -1:] = extra_col_content[:, :, -1:, :]
214
+ x = x_expanded
215
+
216
+ return x
217
+
218
+ @torch.no_grad()
219
+ def smea_sampling_step(x, model, dt, sigma_hat, **extra_args):
220
+ m, n = x.shape[2], x.shape[3]
221
+ x = torch.nn.functional.interpolate(input=x, size=None, scale_factor=(1.25, 1.25), mode='nearest-exact', align_corners=None, recompute_scale_factor=None)
222
+ with _Rescaler(model, x, 'nearest-exact', **extra_args) as rescaler:
223
+ denoised = model(x, sigma_hat * x.new_ones([x.shape[0]]), **rescaler.extra_args)
224
+ d = to_d(x, sigma_hat, denoised)
225
+ x = x + d * dt
226
+ x = torch.nn.functional.interpolate(input=x, size=(m,n), scale_factor=None, mode='nearest-exact', align_corners=None, recompute_scale_factor=None)
227
+ return x
228
+
229
+ @torch.no_grad()
230
+ def smea_sampling_step_denoised(x, model, sigma_hat, scale=1.25, smooth=False, **extra_args):
231
+ m, n = x.shape[2], x.shape[3]
232
+ filter = 'nearest-exact' if not smooth else 'bilinear'
233
+ x = torch.nn.functional.interpolate(input=x, scale_factor=(scale, scale), mode=filter)
234
+ with _Rescaler(model, x, filter, **extra_args) as rescaler:
235
+ denoised = model(x, sigma_hat * x.new_ones([x.shape[0]]), **rescaler.extra_args)
236
+ x = denoised
237
+ x = torch.nn.functional.interpolate(input=x, size=(m,n), mode='nearest-exact')
238
+ return x
239
+
240
+ @torch.no_grad()
241
+ def sample_euler_max(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
242
+ extra_args = {} if extra_args is None else extra_args
243
+ s_in = x.new_ones([x.shape[0]])
244
+ for i in trange(len(sigmas) - 1, disable=disable):
245
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
246
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
247
+ sigma_hat = sigmas[i] * (gamma + 1)
248
+ if gamma > 0:
249
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
250
+ denoised = model(x, sigma_hat * s_in, **extra_args)
251
+ d = to_d(x, sigma_hat, denoised)
252
+ if callback is not None:
253
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
254
+ dt = sigmas[i + 1] - sigma_hat
255
+ # Euler method
256
+ x = x + (math.cos(i + 1)/(i + 1) + 1) * d * dt
257
+ return x
258
+
259
+
260
+ @torch.no_grad()
261
+ def sample_euler_max1b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
262
+ extra_args = {} if extra_args is None else extra_args
263
+ s_in = x.new_ones([x.shape[0]])
264
+ for i in trange(len(sigmas) - 1, disable=disable):
265
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
266
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
267
+ sigma_hat = sigmas[i] * (gamma + 1)
268
+ if gamma > 0:
269
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
270
+ denoised = model(x, sigma_hat * s_in, **extra_args)
271
+ d = to_d(x, sigma_hat, denoised)
272
+ if callback is not None:
273
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
274
+ dt = sigmas[i + 1] - sigma_hat
275
+ # Euler method
276
+ x = x + (math.cos(1.05 * i + 1)/(1.1 * i + 1.5) + 1) * d * dt
277
+ return x
278
+
279
+ @torch.no_grad()
280
+ def sample_euler_max1c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
281
+ extra_args = {} if extra_args is None else extra_args
282
+ s_in = x.new_ones([x.shape[0]])
283
+ for i in trange(len(sigmas) - 1, disable=disable):
284
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
285
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
286
+ sigma_hat = sigmas[i] * (gamma + 1)
287
+ if gamma > 0:
288
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
289
+ denoised = model(x, sigma_hat * s_in, **extra_args)
290
+ d = to_d(x, sigma_hat, denoised)
291
+ if callback is not None:
292
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
293
+ dt = sigmas[i + 1] - sigma_hat
294
+ # Euler method
295
+ x = x + (math.cos(1.05 * i + 1.1)/(1.25 * i + 1.5) + 1) * d * dt
296
+ return x
297
+
298
+ @torch.no_grad()
299
+ def sample_euler_max1d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
300
+ extra_args = {} if extra_args is None else extra_args
301
+ s_in = x.new_ones([x.shape[0]])
302
+ for i in trange(len(sigmas) - 1, disable=disable):
303
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
304
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
305
+ sigma_hat = sigmas[i] * (gamma + 1)
306
+ if gamma > 0:
307
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
308
+ denoised = model(x, sigma_hat * s_in, **extra_args)
309
+ d = to_d(x, sigma_hat, denoised)
310
+ if callback is not None:
311
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
312
+ dt = sigmas[i + 1] - sigma_hat
313
+ # Euler method
314
+ x = x + (math.cos(math.pi * 0.333 * i + 0.9)/(0.5 * i + 1.5) + 1) * d * dt
315
+ return x
316
+
317
+ @torch.no_grad()
318
+ def sample_euler_max2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
319
+ extra_args = {} if extra_args is None else extra_args
320
+ s_in = x.new_ones([x.shape[0]])
321
+ for i in trange(len(sigmas) - 1, disable=disable):
322
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
323
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
324
+ sigma_hat = sigmas[i] * (gamma + 1)
325
+ if gamma > 0:
326
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
327
+ denoised = model(x, sigma_hat * s_in, **extra_args)
328
+ d = to_d(x, sigma_hat, denoised)
329
+ if callback is not None:
330
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
331
+ dt = sigmas[i + 1] - sigma_hat
332
+ # Euler method
333
+ x = x + (math.cos(math.pi * 0.333 * i - 0.1)/(0.5 * i + 1.5) + 1) * d * dt
334
+ return x
335
+
336
+ @torch.no_grad()
337
+ def sample_euler_max2b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
338
+ extra_args = {} if extra_args is None else extra_args
339
+ s_in = x.new_ones([x.shape[0]])
340
+ for i in trange(len(sigmas) - 1, disable=disable):
341
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
342
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
343
+ sigma_hat = sigmas[i] * (gamma + 1)
344
+ if gamma > 0:
345
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
346
+ denoised = model(x, sigma_hat * s_in, **extra_args)
347
+ d = to_d(x, sigma_hat, denoised)
348
+ if callback is not None:
349
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
350
+ dt = sigmas[i + 1] - sigma_hat
351
+ # Euler method
352
+ x = x + (math.cos(math.pi * 0.5 * i - 0.0)/(0.5 * i + 1.5) + 1) * d * dt
353
+ return x
354
+
355
+ @torch.no_grad()
356
+ def sample_euler_max2c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
357
+ extra_args = {} if extra_args is None else extra_args
358
+ s_in = x.new_ones([x.shape[0]])
359
+ for i in trange(len(sigmas) - 1, disable=disable):
360
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
361
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
362
+ sigma_hat = sigmas[i] * (gamma + 1)
363
+ if gamma > 0:
364
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
365
+ denoised = model(x, sigma_hat * s_in, **extra_args)
366
+ d = to_d(x, sigma_hat, denoised)
367
+ if callback is not None:
368
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
369
+ dt = sigmas[i + 1] - sigma_hat
370
+ # Euler method
371
+ x = x + (math.cos(math.pi * 0.5 * i)/(i + 2) + 1) * d * dt
372
+ return x
373
+
374
+ @torch.no_grad()
375
+ def sample_euler_max2d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
376
+ extra_args = {} if extra_args is None else extra_args
377
+ s_in = x.new_ones([x.shape[0]])
378
+ for i in trange(len(sigmas) - 1, disable=disable):
379
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
380
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
381
+ sigma_hat = sigmas[i] * (gamma + 1)
382
+ if gamma > 0:
383
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
384
+ denoised = model(x, sigma_hat * s_in, **extra_args)
385
+ d = to_d(x, sigma_hat, denoised)
386
+ if callback is not None:
387
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
388
+ dt = sigmas[i + 1] - sigma_hat
389
+ # Euler method
390
+ x = x + (math.cos(math.pi * 0.5 * i)/(0.75 * i + 1.75) + 1) * d * dt
391
+ return x
392
+
393
+ @torch.no_grad()
394
+ def sample_euler_max3b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
395
+ extra_args = {} if extra_args is None else extra_args
396
+ s_in = x.new_ones([x.shape[0]])
397
+ for i in trange(len(sigmas) - 1, disable=disable):
398
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
399
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
400
+ sigma_hat = sigmas[i] * (gamma + 1)
401
+ if gamma > 0:
402
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
403
+ denoised = model(x, sigma_hat * s_in, **extra_args)
404
+ d = to_d(x, sigma_hat, denoised)
405
+ if callback is not None:
406
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
407
+ dt = sigmas[i + 1] - sigma_hat
408
+ # Euler method
409
+ x = x + (math.cos(2 * i + 0.5)/(2 * i + 1.5) + 1) * d * dt
410
+ return x
411
+
412
+ @torch.no_grad()
413
+ def sample_euler_max3c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
414
+ extra_args = {} if extra_args is None else extra_args
415
+ s_in = x.new_ones([x.shape[0]])
416
+ for i in trange(len(sigmas) - 1, disable=disable):
417
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
418
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
419
+ sigma_hat = sigmas[i] * (gamma + 1)
420
+ if gamma > 0:
421
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
422
+ denoised = model(x, sigma_hat * s_in, **extra_args)
423
+ d = to_d(x, sigma_hat, denoised)
424
+ if callback is not None:
425
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
426
+ dt = sigmas[i + 1] - sigma_hat
427
+ # Euler method
428
+ x = x + (math.cos(2 * i + 0.5)/(1.5 * i + 2.7) + 1) * d * dt
429
+ return x
430
+
431
+ @torch.no_grad()
432
+ def sample_euler_max3(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
433
+ extra_args = {} if extra_args is None else extra_args
434
+ s_in = x.new_ones([x.shape[0]])
435
+ for i in trange(len(sigmas) - 1, disable=disable):
436
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
437
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
438
+ sigma_hat = sigmas[i] * (gamma + 1)
439
+ if gamma > 0:
440
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
441
+ denoised = model(x, sigma_hat * s_in, **extra_args)
442
+ d = to_d(x, sigma_hat, denoised)
443
+ if callback is not None:
444
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
445
+ dt = sigmas[i + 1] - sigma_hat
446
+ # Euler method
447
+ x = x + (math.cos(2 * i + 1)/(2 * i + 1) + 1) * d * dt
448
+ return x
449
+
450
+ @torch.no_grad()
451
+ def sample_euler_max4b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
452
+ extra_args = {} if extra_args is None else extra_args
453
+ s_in = x.new_ones([x.shape[0]])
454
+ for i in trange(len(sigmas) - 1, disable=disable):
455
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
456
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
457
+ sigma_hat = sigmas[i] * (gamma + 1)
458
+ if gamma > 0:
459
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
460
+ denoised = model(x, sigma_hat * s_in, **extra_args)
461
+ d = to_d(x, sigma_hat, denoised)
462
+ if callback is not None:
463
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
464
+ dt = sigmas[i + 1] - sigma_hat
465
+ # Euler method
466
+ x = x + (math.cos(math.pi * i - 0.1)/(2 * i + 2) + 1) * d * dt
467
+ return x
468
+
469
+ @torch.no_grad()
470
+ def sample_euler_max4c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
471
+ extra_args = {} if extra_args is None else extra_args
472
+ s_in = x.new_ones([x.shape[0]])
473
+ for i in trange(len(sigmas) - 1, disable=disable):
474
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
475
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
476
+ sigma_hat = sigmas[i] * (gamma + 1)
477
+ if gamma > 0:
478
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
479
+ denoised = model(x, sigma_hat * s_in, **extra_args)
480
+ d = to_d(x, sigma_hat, denoised)
481
+ if callback is not None:
482
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
483
+ dt = sigmas[i + 1] - sigma_hat
484
+ # Euler method
485
+ x = x + (math.cos(math.pi * i - 0.1)/(2 * i + 1.5) + 1) * d * dt
486
+ return x
487
+
488
+ @torch.no_grad()
489
+ def sample_euler_max4d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
490
+ extra_args = {} if extra_args is None else extra_args
491
+ s_in = x.new_ones([x.shape[0]])
492
+ for i in trange(len(sigmas) - 1, disable=disable):
493
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
494
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
495
+ sigma_hat = sigmas[i] * (gamma + 1)
496
+ if gamma > 0:
497
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
498
+ denoised = model(x, sigma_hat * s_in, **extra_args)
499
+ d = to_d(x, sigma_hat, denoised)
500
+ if callback is not None:
501
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
502
+ dt = sigmas[i + 1] - sigma_hat
503
+ # Euler method
504
+ x = x + (math.cos(math.pi * i - 0.1)/(i + 1.5) + 1) * d * dt
505
+ return x
506
+
507
+ @torch.no_grad()
508
+ def sample_euler_max4e(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
509
+ extra_args = {} if extra_args is None else extra_args
510
+ s_in = x.new_ones([x.shape[0]])
511
+ for i in trange(len(sigmas) - 1, disable=disable):
512
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
513
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
514
+ sigma_hat = sigmas[i] * (gamma + 1)
515
+ if gamma > 0:
516
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
517
+ denoised = model(x, sigma_hat * s_in, **extra_args)
518
+ d = to_d(x, sigma_hat, denoised)
519
+ if callback is not None:
520
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
521
+ dt = sigmas[i + 1] - sigma_hat
522
+ # Euler method
523
+ x = x + (math.cos(math.pi * i - 0.1)/(i + 1) + 1) * d * dt
524
+ return x
525
+
526
+ @torch.no_grad()
527
+ def sample_euler_max4f(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
528
+ extra_args = {} if extra_args is None else extra_args
529
+ s_in = x.new_ones([x.shape[0]])
530
+ for i in trange(len(sigmas) - 1, disable=disable):
531
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
532
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
533
+ sigma_hat = sigmas[i] * (gamma + 1)
534
+ if gamma > 0:
535
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
536
+ denoised = model(x, sigma_hat * s_in, **extra_args)
537
+ d = to_d(x, sigma_hat, denoised)
538
+ if callback is not None:
539
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
540
+ dt = sigmas[i + 1] - sigma_hat
541
+ # Euler method
542
+ x = x + (math.cos(math.pi * i - 0.1)/(i + 2) + 1) * d * dt
543
+ return x
544
+
545
+
546
+ @torch.no_grad()
547
+ def sample_euler_max4(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
548
+ # 袛芯斜邪胁褜褌械 蟹写械褋褜 褌械谢芯 褎褍薪泻褑懈懈 懈谢懈 褏芯褌褟 斜褘 pass, 褔褌芯斜褘 懈蟹斜械卸邪褌褜 IndentationError
549
+ pass
550
+
551
+ @torch.no_grad()
552
+ def sample_euler_dy(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
553
+ extra_args = {} if extra_args is None else extra_args
554
+ s_in = x.new_ones([x.shape[0]])
555
+ for i in trange(len(sigmas) - 1, disable=disable):
556
+ # print(i)
557
+ # i绗竴姝ヤ负0
558
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
559
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
560
+ sigma_hat = sigmas[i] * (gamma + 1)
561
+ # print(sigma_hat)
562
+ dt = sigmas[i + 1] - sigma_hat
563
+ if gamma > 0:
564
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
565
+ denoised = model(x, sigma_hat * s_in, **extra_args)
566
+ d = to_d(x, sigma_hat, denoised)
567
+ if callback is not None:
568
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
569
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2 and i % 2 == 0:
570
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
571
+ dt_1 = sigma_mid - sigmas[i]
572
+ dt_2 = sigmas[i + 1] - sigmas[i]
573
+ x_2 = x + d * dt_1
574
+ x_temp = dy_sampling_step(x_2, model, dt_2, sigma_mid, **extra_args)
575
+ x = x_temp - d * dt_1
576
+ # Euler method
577
+ x = x + d * dt
578
+ return x
579
+
580
+ @torch.no_grad()
581
+ def sample_euler_smea_dyn_a(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
582
+ extra_args = {} if extra_args is None else extra_args
583
+ s_in = x.new_ones([x.shape[0]])
584
+ for i in trange(len(sigmas) - 1, disable=disable):
585
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
586
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
587
+ sigma_hat = sigmas[i] * (gamma + 1)
588
+ if gamma > 0:
589
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
590
+ denoised = model(x, sigma_hat * s_in, **extra_args)
591
+ d = to_d(x, sigma_hat, denoised)
592
+ if callback is not None:
593
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
594
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2:
595
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
596
+ dt_1 = sigma_mid - sigma_hat
597
+ dt_2 = sigmas[i + 1] - sigma_hat
598
+ x_2 = x + d * dt_1
599
+ #scale = (sigma_mid / sigmas[0]) * 0.25
600
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2 * 0.15
601
+ #scale = scale.item()
602
+ if i % 2 == 0:
603
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale, **extra_args)
604
+ #denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + sigma_mid.item() * 0.01, **extra_args)
605
+ else:
606
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
607
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
608
+ x = x + d_2 * dt_2
609
+ else:
610
+ dt = sigmas[i + 1] - sigma_hat
611
+ # Euler method
612
+ x = x + d * dt
613
+ return x
614
+
615
+ @torch.no_grad()
616
+ def sample_euler_smea_dyn_b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
617
+ extra_args = {} if extra_args is None else extra_args
618
+ s_in = x.new_ones([x.shape[0]])
619
+ for i in trange(len(sigmas) - 1, disable=disable):
620
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
621
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
622
+ sigma_hat = sigmas[i] * (gamma + 1)
623
+ if gamma > 0:
624
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
625
+ denoised = model(x, sigma_hat * s_in, **extra_args)
626
+ d = to_d(x, sigma_hat, denoised)
627
+ if callback is not None:
628
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
629
+ if sigmas[i + 1] > 0 and (i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 3 or i < 3):
630
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
631
+ dt_1 = sigma_mid - sigma_hat
632
+ dt_2 = sigmas[i + 1] - sigma_hat
633
+ x_2 = x + d * dt_1
634
+ #scale = (sigma_mid / sigmas[0]) * 0.25
635
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2 * 0.2
636
+ #scale = scale.item()
637
+ if i % 4 == 0:
638
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale, **extra_args)
639
+ #denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - sigma_mid.item() * 0.01, **extra_args)
640
+ elif i % 4 == 2:
641
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale, **extra_args)
642
+ #denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + sigma_mid.item() * 0.01, **extra_args)
643
+ else:
644
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
645
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
646
+ x = x + d_2 * dt_2
647
+ else:
648
+ dt = sigmas[i + 1] - sigma_hat
649
+ # Euler method
650
+ x = x + d * dt
651
+ return x
652
+
653
+ @torch.no_grad()
654
+ def sample_euler_smea_dyn_c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
655
+ extra_args = {} if extra_args is None else extra_args
656
+ s_in = x.new_ones([x.shape[0]])
657
+ for i in trange(len(sigmas) - 1, disable=disable):
658
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
659
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
660
+ sigma_hat = sigmas[i] * (gamma + 1)
661
+ if gamma > 0:
662
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
663
+ denoised = model(x, sigma_hat * s_in, **extra_args)
664
+ d = to_d(x, sigma_hat, denoised)
665
+ if callback is not None:
666
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
667
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2:
668
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
669
+ dt_1 = sigma_mid - sigma_hat
670
+ dt_2 = sigmas[i + 1] - sigma_hat
671
+ x_2 = x + d * dt_1
672
+ #scale = (sigma_mid / sigmas[0]) * 0.25
673
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2 * 0.25
674
+ #scale = scale.item()
675
+ if i % 2 == 0:
676
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale, **extra_args)
677
+ #denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + sigma_mid.item() * 0.01, **extra_args)
678
+ else:
679
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
680
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
681
+ x = x + d_2 * dt_2
682
+ else:
683
+ dt = sigmas[i + 1] - sigma_hat
684
+ # Euler method
685
+ x = x + d * dt
686
+ return x
687
+
688
+ @torch.no_grad()
689
+ def sample_euler_smea(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
690
+ extra_args = {} if extra_args is None else extra_args
691
+ s_in = x.new_ones([x.shape[0]])
692
+ for i in trange(len(sigmas) - 1, disable=disable):
693
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
694
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
695
+ sigma_hat = sigmas[i] * (gamma + 1)
696
+ if gamma > 0:
697
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
698
+ denoised = model(x, sigma_hat * s_in, **extra_args)
699
+ d = to_d(x, sigma_hat, denoised)
700
+ if callback is not None:
701
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
702
+ dt = sigmas[i + 1] - sigma_hat
703
+ # Euler method
704
+ x = x + d * dt
705
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2 and i % 2 == 0:
706
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
707
+ dt_1 = sigma_mid - sigmas[i]
708
+ dt_2 = sigmas[i + 1] - sigmas[i]
709
+ #print(dt_1, "#", dt_2, "#", dt_3, "#", dt_4)
710
+ x_2 = x + d * dt_1
711
+ x_temp = smea_sampling_step(x, model, dt_2, sigma_mid, **extra_args)
712
+ x = x_temp - d * dt_1
713
+ return x
714
+
715
+ @torch.no_grad()
716
+ def sample_euler_smea_dy(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
717
+ extra_args = {} if extra_args is None else extra_args
718
+ s_in = x.new_ones([x.shape[0]])
719
+ for i in trange(len(sigmas) - 1, disable=disable):
720
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
721
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
722
+ sigma_hat = sigmas[i] * (gamma + 1)
723
+ if gamma > 0:
724
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
725
+ denoised = model(x, sigma_hat * s_in, **extra_args)
726
+ d = to_d(x, sigma_hat, denoised)
727
+ if callback is not None:
728
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
729
+ dt = sigmas[i + 1] - sigma_hat
730
+ # Euler method
731
+ x = x + d * dt
732
+ if sigmas[i + 1] > 0 and (i < len(sigmas) * 0.334 - len(sigmas) * 0.334 % 2 or i < 3) and i % 3 != 2:
733
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
734
+ dt_1 = sigma_mid - sigmas[i]
735
+ dt_2 = sigmas[i + 1] - sigmas[i]
736
+ #print(dt_1, "#", dt_2, "#", dt_3, "#", dt_4)
737
+ x_2 = x + d * dt_1
738
+ if i % 3 == 1:
739
+ x_temp = dy_sampling_step(x, model, dt_2, sigma_mid, **extra_args)
740
+ elif i % 3 == 0:
741
+ x_temp = smea_sampling_step(x, model, dt_2, sigma_mid, **extra_args)
742
+ x = x_temp - d * dt_1
743
+ return x
744
+
745
+ @torch.no_grad()
746
+ def sample_euler_smea_multi_d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
747
+ extra_args = {} if extra_args is None else extra_args
748
+ s_in = x.new_ones([x.shape[0]])
749
+ for i in trange(len(sigmas) - 1, disable=disable):
750
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
751
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
752
+ sigma_hat = sigmas[i] * (gamma + 1)
753
+ if gamma > 0:
754
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
755
+ denoised = model(x, sigma_hat * s_in, **extra_args)
756
+ d = to_d(x, sigma_hat, denoised)
757
+ if callback is not None:
758
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
759
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.334 + 2 and i % 2 == 0:
760
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
761
+ dt_1 = sigma_mid - sigma_hat
762
+ dt_2 = sigmas[i + 1] - sigma_hat
763
+ x_2 = x + d * dt_1
764
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
765
+ if i == 0:
766
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale * 0.15, **extra_args)
767
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
768
+ denoised_2 = (denoised_2a + denoised_2c) / 2
769
+ elif i < len(sigmas) * 0.334:
770
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale * 0.25, **extra_args)
771
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale * 0.15, **extra_args)
772
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
773
+ denoised_2 = (denoised_2a + denoised_2b + denoised_2c) / 3
774
+ else:
775
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale * 0.03, True, **extra_args)
776
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
777
+ denoised_2 = (denoised_2b + denoised_2c) / 2
778
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
779
+ x = x + d_2 * dt_2
780
+ else:
781
+ dt = sigmas[i + 1] - sigma_hat
782
+ # Euler method
783
+ x = x + d * dt
784
+ return x
785
+
786
+ @torch.no_grad()
787
+ def sample_euler_smea_multi_b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
788
+ extra_args = {} if extra_args is None else extra_args
789
+ s_in = x.new_ones([x.shape[0]])
790
+ for i in trange(len(sigmas) - 1, disable=disable):
791
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
792
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
793
+ sigma_hat = sigmas[i] * (gamma + 1)
794
+ if gamma > 0:
795
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
796
+ denoised = model(x, sigma_hat * s_in, **extra_args)
797
+ d = to_d(x, sigma_hat, denoised)
798
+ if callback is not None:
799
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
800
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
801
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
802
+ dt_1 = sigma_mid - sigma_hat
803
+ dt_2 = sigmas[i + 1] - sigma_hat
804
+ x_2 = x + d * dt_1
805
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
806
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale * 0.25, **extra_args)
807
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale * 0.15, **extra_args)
808
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
809
+ denoised_2 = (denoised_2a + denoised_2b + denoised_2c) / 3
810
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
811
+ x = x + d_2 * dt_2
812
+ else:
813
+ dt = sigmas[i + 1] - sigma_hat
814
+ # Euler method
815
+ x = x + d * dt
816
+ return x
817
+
818
+ @torch.no_grad()
819
+ def sample_euler_smea_multi_c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
820
+ extra_args = {} if extra_args is None else extra_args
821
+ s_in = x.new_ones([x.shape[0]])
822
+ for i in trange(len(sigmas) - 1, disable=disable):
823
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
824
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
825
+ sigma_hat = sigmas[i] * (gamma + 1)
826
+ if gamma > 0:
827
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
828
+ denoised = model(x, sigma_hat * s_in, **extra_args)
829
+ d = to_d(x, sigma_hat, denoised)
830
+ if callback is not None:
831
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
832
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
833
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
834
+ dt_1 = sigma_mid - sigma_hat
835
+ dt_2 = sigmas[i + 1] - sigma_hat
836
+ x_2 = x + d * dt_1
837
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
838
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 - scale * 0.25, **extra_args)
839
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
840
+ denoised_2 = (denoised_2a + denoised_2c) / 2
841
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
842
+ x = x + d_2 * dt_2
843
+ else:
844
+ dt = sigmas[i + 1] - sigma_hat
845
+ # Euler method
846
+ x = x + d * dt
847
+ return x
848
+
849
+ @torch.no_grad()
850
+ def sample_euler_smea_multi_a(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
851
+ extra_args = {} if extra_args is None else extra_args
852
+ s_in = x.new_ones([x.shape[0]])
853
+ for i in trange(len(sigmas) - 1, disable=disable):
854
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
855
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
856
+ sigma_hat = sigmas[i] * (gamma + 1)
857
+ if gamma > 0:
858
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
859
+ denoised = model(x, sigma_hat * s_in, **extra_args)
860
+ d = to_d(x, sigma_hat, denoised)
861
+ if callback is not None:
862
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
863
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
864
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
865
+ dt_1 = sigma_mid - sigma_hat
866
+ dt_2 = sigmas[i + 1] - sigma_hat
867
+ x_2 = x + d * dt_1
868
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
869
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, 1 + scale * 0.15, **extra_args)
870
+ denoised_2c = model(x_2, sigma_mid * s_in, **extra_args)
871
+ denoised_2 = (denoised_2b + denoised_2c) / 2
872
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
873
+ x = x + d_2 * dt_2
874
+ else:
875
+ dt = sigmas[i + 1] - sigma_hat
876
+ # Euler method
877
+ x = x + d * dt
878
+ return x
879
+
880
+
881
+ @torch.no_grad()
882
+ def sample_euler_smea_multi_ds(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
883
+ extra_args = {} if extra_args is None else extra_args
884
+ s_in = x.new_ones([x.shape[0]])
885
+ for i in trange(len(sigmas) - 1, disable=disable):
886
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
887
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
888
+ sigma_hat = sigmas[i] * (gamma + 1)
889
+ if gamma > 0:
890
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
891
+ denoised = model(x, sigma_hat * s_in, **extra_args)
892
+ d = to_d(x, sigma_hat, denoised)
893
+ if callback is not None:
894
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
895
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167 + 1: # and i % 2 == 0:
896
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
897
+ dt_1 = sigma_mid - sigma_hat
898
+ dt_2 = sigmas[i + 1] - sigma_hat
899
+ x_2 = x + d * dt_1
900
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
901
+ if i == 0:
902
+ sa = 1 - scale * 0.15
903
+ sb = 1 + scale * 0.09
904
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
905
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, sb, **extra_args)
906
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.625 + denoised_2b * (sb ** 2) * 0.375) / (0.97**2)
907
+ elif i < len(sigmas) * 0.167:
908
+ sa = 1 - scale * 0.25
909
+ sb = 1 + scale * 0.15
910
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
911
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, sb , **extra_args)
912
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.625 + denoised_2b * (sb ** 2) * 0.375) / (0.95**2)
913
+ else:
914
+ sb = 1 + scale * 0.06
915
+ sc = 1 - scale * 0.1
916
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, sb, True, **extra_args)
917
+ denoised_2c = smea_sampling_step_denoised(x_2, model, sigma_mid, sc, **extra_args)
918
+ denoised_2 = (denoised_2b * (sb ** 2) * 0.375 + denoised_2c * (sc ** 2) * 0.625) / (0.98**2)
919
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
920
+ x = x + d_2 * dt_2
921
+ else:
922
+ dt = sigmas[i + 1] - sigma_hat
923
+ # Euler method
924
+ x = x + d * dt
925
+ return x
926
+
927
+ @torch.no_grad()
928
+ def sample_euler_smea_multi_ds2_s(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
929
+ sample = sample_euler_smea_multi_ds2(model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise, smooth=True)
930
+ return sample
931
+
932
+ @torch.no_grad()
933
+ def sample_euler_smea_multi_ds2_s_m(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
934
+ sample = sample_euler_smea_multi_ds2_m(model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise, smooth=True)
935
+ return sample
936
+
937
+ @torch.no_grad()
938
+ def sample_euler_smea_multi_ds2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., smooth=False):
939
+ extra_args = {} if extra_args is None else extra_args
940
+ s_in = x.new_ones([x.shape[0]])
941
+ for i in trange(len(sigmas) - 1, disable=disable):
942
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
943
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
944
+ sigma_hat = sigmas[i] * (gamma + 1)
945
+ if gamma > 0:
946
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
947
+ denoised = model(x, sigma_hat * s_in, **extra_args)
948
+ d = to_d(x, sigma_hat, denoised)
949
+ if callback is not None:
950
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
951
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167 + 1: # and i % 2 == 0:
952
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
953
+ dt_1 = sigma_mid - sigma_hat
954
+ dt_2 = sigmas[i + 1] - sigma_hat
955
+ x_2 = x + d * dt_1
956
+ scale = (sigmas[i] / sigmas[0]) ** 2
957
+ scale = scale.item()
958
+ if i == 0:
959
+ sa = 1 - scale * 0.15
960
+ sb = 1 + scale * 0.09
961
+ sigA = sigma_mid / (sa ** 2)
962
+ sigB = sigma_mid / (sb ** 2)
963
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
964
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
965
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2) #/ (0.97**2) # 1 - (sa * sb ) / 2 + 1
966
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
967
+ elif i < len(sigmas) * 0.167:
968
+ sa = 1 - scale * 0.25
969
+ sb = 1 + scale * 0.15
970
+ sigA = sigma_mid / (sa ** 2)
971
+ sigB = sigma_mid / (sb ** 2)
972
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
973
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
974
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2) #/ (0.95**2)
975
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
976
+ else:
977
+ sb = 1 + scale * 0.06
978
+ sc = 1 - scale * 0.1
979
+ sigB = sigma_mid / (sb ** 2)
980
+ sigC = sigma_mid / (sc ** 2)
981
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
982
+ denoised_2c = smea_sampling_step_denoised(x_2, model, sigC, sc, smooth, **extra_args)
983
+ denoised_2 = (denoised_2b * (sb ** 2) * 0.5 * sc ** 2 + denoised_2c * (sc ** 2) * 0.5 * sb ** 2) #/ (0.98**2)
984
+ d_2 = to_d(x_2, sigB * 0.5 * sc ** 2 + sigC * 0.5 * sb ** 2, denoised_2)
985
+ x = x + d_2 * dt_2
986
+ else:
987
+ dt = sigmas[i + 1] - sigma_hat
988
+ # Euler method
989
+ x = x + d * dt
990
+ return x
991
+
992
+ @torch.no_grad()
993
+ def sample_euler_smea_multi_ds2_m(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., smooth=False):
994
+ extra_args = {} if extra_args is None else extra_args
995
+ s_in = x.new_ones([x.shape[0]])
996
+ for i in trange(len(sigmas) - 1, disable=disable):
997
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
998
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
999
+ sigma_hat = sigmas[i] * (gamma + 1)
1000
+ if gamma > 0:
1001
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1002
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1003
+ d = to_d(x, sigma_hat, denoised)
1004
+ if callback is not None:
1005
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1006
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167 + 1: # and i % 2 == 0:
1007
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1008
+ dt_1 = sigma_mid - sigma_hat
1009
+ dt_2 = sigmas[i + 1] - sigma_hat
1010
+ x_2 = x + d * dt_1
1011
+ scale = (sigmas[i] / sigmas[0]) ** 2
1012
+ #scale = dt_1 ** 2 * 0.01
1013
+ scale = scale.item()
1014
+ if i == 0:
1015
+ sa = 1 - scale * 0.15 #15
1016
+ sb = 1 + scale * 0.09 #09
1017
+ sigA = sigma_mid / (sa ** 2)
1018
+ sigB = sigma_mid / (sb ** 2)
1019
+ #delta = sa * sb
1020
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
1021
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
1022
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2) #/ (0.97**2) # 1 - (sa * sb ) / 2 + 1
1023
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
1024
+ elif i < len(sigmas) * 0.167:
1025
+ sa = 1 - scale * 0.25 #25
1026
+ sb = 1 + scale * 0.15 #15
1027
+ sigA = sigma_mid / (sa ** 2)
1028
+ sigB = sigma_mid / (sb ** 2)
1029
+ #delta = sa * sb
1030
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
1031
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
1032
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2) #/ (0.95**2)
1033
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
1034
+ else:
1035
+ sb = 1 + scale * 0.06
1036
+ sc = 1 - scale * 0.1
1037
+ sigB = sigma_mid / (sb ** 2)
1038
+ sigC = sigma_mid / (sc ** 2)
1039
+ #delta = sb * sc
1040
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
1041
+ denoised_2c = smea_sampling_step_denoised(x_2, model, sigC, sc, smooth, **extra_args)
1042
+ denoised_2 = (denoised_2b * (sb ** 2) * 0.5 * sc ** 2+ denoised_2c * (sc ** 2) * 0.5 * sb ** 2) #/ (0.98**2)
1043
+ d_2 = to_d(x_2, sigB * 0.5 * sc ** 2 + sigC * 0.5 * sb ** 2, denoised_2)
1044
+ x = x + (math.cos(1.05 * i + 1.1)/(1.25 * i + 1.5) + 1) * d_2 * dt_2
1045
+ else:
1046
+ dt = sigmas[i + 1] - sigma_hat
1047
+ # Euler method
1048
+ x = x + (math.cos(1.05 * i + 1.1)/(1.25 * i + 1.5) + 1) * d * dt
1049
+ return x
1050
+
1051
+ @torch.no_grad()
1052
+ def sample_euler_h_m(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1053
+ extra_args = {} if extra_args is None else extra_args
1054
+ s_in = x.new_ones([x.shape[0]])
1055
+ for i in trange(len(sigmas) - 1, disable=disable):
1056
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1057
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1058
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1059
+ gamma = min((2 ** 0.5 - 1) - wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1060
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler == None else noise_sampler
1061
+ sigma_hat = sigmas[i] * (gamma + 1)
1062
+ if gamma > 0:
1063
+ x = x - eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1064
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1065
+ d = to_d(x, sigma_hat, denoised)
1066
+ dt = sigmas[i + 1] - sigma_hat
1067
+ if callback is not None:
1068
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1069
+ if sigmas[i + 1] > 0:
1070
+ x_2 = x + d * dt
1071
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1072
+ d_prime = d * 0.5 + d_2 * 0.5
1073
+ x = x + d_prime * dt
1074
+ else:
1075
+ # Euler method
1076
+ x = x + d * dt
1077
+ return x
1078
+
1079
+ @torch.no_grad()
1080
+ def sample_euler_h_m_b(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1081
+ extra_args = {} if extra_args is None else extra_args
1082
+ s_in = x.new_ones([x.shape[0]])
1083
+ for i in trange(len(sigmas) - 1, disable=disable):
1084
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1085
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1086
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1087
+ gamma = min(wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1088
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1089
+ sigma_hat = sigmas[i] * (gamma + 1)
1090
+ if gamma > 0:
1091
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1092
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1093
+ d = to_d(x, sigma_hat, denoised)
1094
+ dt = sigmas[i + 1] - sigma_hat
1095
+ if callback is not None:
1096
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1097
+ if sigmas[i + 1] > 0:
1098
+ x_2 = x + d * dt
1099
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1100
+ d_prime = d * 0.5 + d_2 * 0.5
1101
+ x = x + d_prime * dt
1102
+ else:
1103
+ # Euler method
1104
+ x = x + d * dt
1105
+ return x
1106
+
1107
+ @torch.no_grad()
1108
+ def sample_euler_h_m_c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1109
+ extra_args = {} if extra_args is None else extra_args
1110
+ s_in = x.new_ones([x.shape[0]])
1111
+ for i in trange(len(sigmas) - 1, disable=disable):
1112
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1113
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1114
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1115
+ gamma = max((2 ** 0.5 - 1) + wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1116
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1117
+ sigma_hat = sigmas[i] * (gamma + 1)
1118
+ if gamma > 0:
1119
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1120
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1121
+ d = to_d(x, sigma_hat, denoised)
1122
+ dt = sigmas[i + 1] - sigma_hat
1123
+ if callback is not None:
1124
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1125
+ if sigmas[i + 1] > 0:
1126
+ x_2 = x + d * dt
1127
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1128
+ d_prime = d * 0.5 + d_2 * 0.5
1129
+ x = x + d_prime * dt
1130
+ else:
1131
+ # Euler method
1132
+ x = x + d * dt
1133
+ return x
1134
+
1135
+ @torch.no_grad()
1136
+ def sample_euler_h_m_d(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1137
+ extra_args = {} if extra_args is None else extra_args
1138
+ s_in = x.new_ones([x.shape[0]])
1139
+ for i in trange(len(sigmas) - 1, disable=disable):
1140
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1141
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1142
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1143
+ gamma = min((2 ** 0.5 - 1) - wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1144
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1145
+ sigma_hat = sigmas[i] * (gamma + 1)
1146
+ if gamma > 0:
1147
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1148
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1149
+ d = to_d(x, sigma_hat, denoised)
1150
+ dt = sigmas[i + 1] - sigma_hat
1151
+ if callback is not None:
1152
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1153
+ if sigmas[i + 1] > 0:
1154
+ x_2 = x + d * dt
1155
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1156
+ d_prime = d * 0.5 + d_2 * 0.5
1157
+ x = x + d_prime * dt
1158
+ else:
1159
+ # Euler method
1160
+ x = x + d * dt
1161
+ return x
1162
+
1163
+ @torch.no_grad()
1164
+ def sample_euler_h_m_e(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1165
+ extra_args = {} if extra_args is None else extra_args
1166
+ s_in = x.new_ones([x.shape[0]])
1167
+ for i in trange(len(sigmas) - 1, disable=disable):
1168
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1169
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1170
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1171
+ gamma = max((2 ** 0.5 - 1) + wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1172
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1173
+ sigma_hat = sigmas[i] * (gamma + 1)
1174
+ if gamma > 0:
1175
+ x = x - eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1176
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1177
+ d = to_d(x, sigma_hat, denoised)
1178
+ dt = sigmas[i + 1] - sigma_hat
1179
+ if callback is not None:
1180
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1181
+ if sigmas[i + 1] > 0:
1182
+ x_2 = x + d * dt
1183
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1184
+ d_prime = d * 0.5 + d_2 * 0.5
1185
+ x = x + d_prime * dt
1186
+ else:
1187
+ # Euler method
1188
+ x = x + d * dt
1189
+ return x
1190
+
1191
+ @torch.no_grad()
1192
+ def sample_euler_h_m_f(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1193
+ extra_args = {} if extra_args is None else extra_args
1194
+ s_in = x.new_ones([x.shape[0]])
1195
+ for i in trange(len(sigmas) - 1, disable=disable):
1196
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1197
+ wave_max = math.cos(0)/1.5 + 1
1198
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1199
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1200
+ gamma = min((wave_max - wave) * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1201
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1202
+ sigma_hat = sigmas[i] * (gamma + 1)
1203
+ if gamma > 0:
1204
+ x = x - eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1205
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1206
+ d = to_d(x, sigma_hat, denoised)
1207
+ dt = sigmas[i + 1] - sigma_hat
1208
+ if callback is not None:
1209
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1210
+ if sigmas[i + 1] > 0:
1211
+ x_2 = x + d * dt
1212
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1213
+ d_prime = d * 0.5 + d_2 * 0.5
1214
+ x = x + d_prime * dt
1215
+ else:
1216
+ # Euler method
1217
+ x = x + d * dt
1218
+ return x
1219
+
1220
+ @torch.no_grad()
1221
+ def sample_euler_h_m_g(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1222
+ extra_args = {} if extra_args is None else extra_args
1223
+ s_in = x.new_ones([x.shape[0]])
1224
+ for i in trange(len(sigmas) - 1, disable=disable):
1225
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1226
+ wave_max = math.cos(0)/1.5 + 1
1227
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1228
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1229
+ gamma = min((wave_max - wave) * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1230
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1231
+ sigma_hat = sigmas[i] * (gamma + 1)
1232
+ if gamma > 0:
1233
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1234
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1235
+ d = to_d(x, sigma_hat, denoised)
1236
+ dt = sigmas[i + 1] - sigma_hat
1237
+ if callback is not None:
1238
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1239
+ if sigmas[i + 1] > 0:
1240
+ x_2 = x + d * dt
1241
+ d_2 = to_d(x_2, sigmas[i + 1] * (gamma + 1), denoised)
1242
+ d_prime = d * 0.5 + d_2 * 0.5
1243
+ x = x + d_prime * dt
1244
+ else:
1245
+ # Euler method
1246
+ x = x + d * dt
1247
+ return x
1248
+
1249
+ @torch.no_grad()
1250
+ def sample_euler_h_m_b_c(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1251
+ extra_args = {} if extra_args is None else extra_args
1252
+ s_in = x.new_ones([x.shape[0]])
1253
+ for i in trange(len(sigmas) - 1, disable=disable):
1254
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1255
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1256
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1257
+ gamma = min(wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1258
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1259
+ gammaup = gamma + 1
1260
+ sigma_hat = sigmas[i] * gammaup
1261
+ if gamma > 0:
1262
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1263
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1264
+ last_noise_uncond = model.last_noise_uncond
1265
+ d = to_d(x, sigma_hat, denoised)
1266
+ dt = sigmas[i + 1] - sigma_hat
1267
+ if callback is not None:
1268
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1269
+ if i == 0:
1270
+ x = x + d * dt
1271
+ elif i <= len(sigmas) - 4:
1272
+ x_2 = x + d * dt
1273
+ d_2 = to_d(x_2, sigmas[i + 1] * gammaup, denoised)
1274
+ x_3 = x_2 + d_2 * dt
1275
+ d_3 = to_d(x_3, sigmas[i + 2] * gammaup, denoised)
1276
+ d_prime = d * 0.5 + d_2 * 0.375 + d_3 * 0.125
1277
+ x = x + d_prime * dt
1278
+ elif sigmas[i + 1] > 0:
1279
+ x_2 = x + d * dt
1280
+ d_2 = to_d(x_2, sigmas[i + 1] * gammaup, denoised)
1281
+ d_prime = d * 0.5 + d_2 * 0.5
1282
+ x = x + d_prime * dt
1283
+ else:
1284
+ # Euler method
1285
+ x = x + d * dt
1286
+ return x
1287
+
1288
+ @torch.no_grad()
1289
+ def sample_euler_h_m_b_c_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1290
+ extra_args = {} if extra_args is None else extra_args
1291
+ s_in = x.new_ones([x.shape[0]])
1292
+ for i in trange(len(sigmas) - 1, disable=disable):
1293
+ wave = math.cos(math.pi * 0.5 * i)/(0.5 * i + 1.5) + 1
1294
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
1295
+ s_tmin, s_tmax = sigma_min if s_tmin == 0. else s_tmin, sigma_max if s_tmax == float('inf') else s_tmax
1296
+ gamma = min(wave * ((2 ** 0.5 - 1) + s_churn) / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1297
+ eps = k_diffusion.sampling.BrownianTreeNoiseSampler(x, s_tmin, s_tmax, 0) if noise_sampler is None else noise_sampler
1298
+ gammaup = gamma + 1
1299
+ sigma_hat = sigmas[i] * gammaup
1300
+ if gamma > 0:
1301
+ x = x + eps(sigmas[i], sigmas[i + 1]) * s_noise * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1302
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1303
+ last_noise_uncond = model.last_noise_uncond
1304
+ d = to_d(x, sigma_hat, denoised)
1305
+ dt = sigmas[i + 1] - sigma_hat
1306
+ if callback is not None:
1307
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1308
+ if i == 0:
1309
+ x = x + d * dt
1310
+ elif i <= len(sigmas) - 4:
1311
+ x_2 = x + d * dt
1312
+ d_2 = to_d(x_2, sigmas[i + 1] * gammaup, denoised)
1313
+ x_3 = x_2 + d_2 * dt
1314
+ d_3 = to_d(x_3, sigmas[i + 2] * gammaup, last_noise_uncond)
1315
+ d_prime = d * 0.5 + d_2 * 0.375 + d_3 * 0.125
1316
+ x = x + d_prime * dt
1317
+ elif sigmas[i + 1] > 0:
1318
+ x_2 = x + d * dt
1319
+ d_2 = to_d(x_2, sigmas[i + 1] * gammaup, denoised)
1320
+ d_prime = d * 0.5 + d_2 * 0.5
1321
+ x = x + d_prime * dt
1322
+ else:
1323
+ # Euler method
1324
+ x = x + d * dt
1325
+ return x
1326
+
1327
+ @torch.no_grad()
1328
+ def sample_euler_smea_max(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., smooth=False):
1329
+ extra_args = {} if extra_args is None else extra_args
1330
+ s_in = x.new_ones([x.shape[0]])
1331
+ for i in trange(len(sigmas) - 1, disable=disable):
1332
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1333
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1334
+ sigma_hat = sigmas[i] * (gamma + 1)
1335
+ sa = math.cos(i + 1)/(1.5 * i + 1.75) + 1
1336
+ if gamma > 0:
1337
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1338
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1339
+ d = to_d(x, sigma_hat, denoised)
1340
+ if callback is not None:
1341
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1342
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167 + 1: # and i % 2 == 0:
1343
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1344
+ dt_1 = sigma_mid - sigma_hat
1345
+ dt_2 = sigmas[i + 1] - sigma_hat
1346
+ x_2 = x + d * dt_1
1347
+ sigA = sigma_mid / (sa ** 2)
1348
+ sigB = sigma_mid
1349
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
1350
+ denoised_2b = model(x_2, sigma_mid * s_in, **extra_args)
1351
+ denoised_2 = (denoised_2a * 0.5 * (sa ** 2) + denoised_2b * 0.5 / (sa ** 2))
1352
+ d_2 = to_d(x_2, sigA * 0.5 * (sa ** 2) + sigB * 0.5 / (sa ** 2), denoised_2)
1353
+ x = x + d_2 * dt_2
1354
+ else:
1355
+ dt = sigmas[i + 1] - sigma_hat
1356
+ # Euler method
1357
+ x = x + sa * d * dt
1358
+ return x
1359
+
1360
+
1361
+ @torch.no_grad()
1362
+ def sample_euler_smea_max_s(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1363
+ sample = sample_euler_smea_max(model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise, smooth=True)
1364
+ return sample
1365
+
1366
+ @torch.no_grad()
1367
+ def sample_euler_smea_multi_bs(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1368
+ extra_args = {} if extra_args is None else extra_args
1369
+ s_in = x.new_ones([x.shape[0]])
1370
+ for i in trange(len(sigmas) - 1, disable=disable):
1371
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1372
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1373
+ sigma_hat = sigmas[i] * (gamma + 1)
1374
+ if gamma > 0:
1375
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1376
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1377
+ d = to_d(x, sigma_hat, denoised)
1378
+ if callback is not None:
1379
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1380
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
1381
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1382
+ dt_1 = sigma_mid - sigma_hat
1383
+ dt_2 = sigmas[i + 1] - sigma_hat
1384
+ x_2 = x + d * dt_1
1385
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
1386
+ sa = 1 - scale * 0.25
1387
+ sb = 1 + scale * 0.15
1388
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
1389
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigma_mid, sb, **extra_args)
1390
+ denoised_2 = denoised_2a * (sa ** 2) * 0.625 + denoised_2b * (sb ** 2) * 0.375 / (0.95**2)
1391
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
1392
+ x = x + d_2 * dt_2
1393
+ else:
1394
+ dt = sigmas[i + 1] - sigma_hat
1395
+ # Euler method
1396
+ x = x + d * dt
1397
+ return x
1398
+
1399
+ @torch.no_grad()
1400
+ def sample_euler_smea_multi_bs2_s(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1401
+ sample = sample_euler_smea_multi_bs2(model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise, smooth=True)
1402
+ return sample
1403
+
1404
+ @torch.no_grad()
1405
+ def sample_euler_smea_multi_bs2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., smooth=False):
1406
+ extra_args = {} if extra_args is None else extra_args
1407
+ s_in = x.new_ones([x.shape[0]])
1408
+ for i in trange(len(sigmas) - 1, disable=disable):
1409
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1410
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1411
+ sigma_hat = sigmas[i] * (gamma + 1)
1412
+ if gamma > 0:
1413
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1414
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1415
+ d = to_d(x, sigma_hat, denoised)
1416
+ if callback is not None:
1417
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1418
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
1419
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1420
+ dt_1 = sigma_mid - sigma_hat
1421
+ dt_2 = sigmas[i + 1] - sigma_hat
1422
+ x_2 = x + d * dt_1
1423
+ scale = (sigmas[i] / sigmas[0]) ** 2
1424
+ scale = scale.item()
1425
+ sa = 1 - scale * 0.25
1426
+ sb = 1 + scale * 0.15
1427
+ sigA = sigma_mid / (sa ** 2)
1428
+ sigB = sigma_mid / (sb ** 2)
1429
+ denoised_2a = smea_sampling_step_denoised(x_2, model, sigA, sa, smooth, **extra_args)
1430
+ denoised_2b = smea_sampling_step_denoised(x_2, model, sigB, sb, smooth, **extra_args)
1431
+ denoised_2 = (denoised_2a * (sa ** 2) * 0.5 * sb ** 2 + denoised_2b * (sb ** 2) * 0.5 * sa ** 2)
1432
+ d_2 = to_d(x_2, sigA * 0.5 * sb ** 2 + sigB * 0.5 * sa ** 2, denoised_2)
1433
+ x = x + d_2 * dt_2
1434
+ else:
1435
+ dt = sigmas[i + 1] - sigma_hat
1436
+ # Euler method
1437
+ x = x + d * dt
1438
+ return x
1439
+
1440
+ @torch.no_grad()
1441
+ def sample_euler_smea_multi_cs(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1442
+ extra_args = {} if extra_args is None else extra_args
1443
+ s_in = x.new_ones([x.shape[0]])
1444
+ for i in trange(len(sigmas) - 1, disable=disable):
1445
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1446
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1447
+ sigma_hat = sigmas[i] * (gamma + 1)
1448
+ if gamma > 0:
1449
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1450
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1451
+ d = to_d(x, sigma_hat, denoised)
1452
+ if callback is not None:
1453
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1454
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
1455
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1456
+ dt_1 = sigma_mid - sigma_hat
1457
+ dt_2 = sigmas[i + 1] - sigma_hat
1458
+ x_2 = x + d * dt_1
1459
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
1460
+ sa = 1 - scale * 0.25
1461
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
1462
+ d_2 = to_d(x_2, sigma_mid, denoised_2 * (sa ** 2) * 1.25)
1463
+ x = x + d_2 * dt_2
1464
+ else:
1465
+ dt = sigmas[i + 1] - sigma_hat
1466
+ # Euler method
1467
+ x = x + d * dt
1468
+ return x
1469
+
1470
+ @torch.no_grad()
1471
+ def sample_euler_smea_multi_as(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1472
+ extra_args = {} if extra_args is None else extra_args
1473
+ s_in = x.new_ones([x.shape[0]])
1474
+ for i in trange(len(sigmas) - 1, disable=disable):
1475
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1476
+ eps = k_diffusion.sampling.torch.randn_like(x) * s_noise
1477
+ sigma_hat = sigmas[i] * (gamma + 1)
1478
+ if gamma > 0:
1479
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1480
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1481
+ d = to_d(x, sigma_hat, denoised)
1482
+ if callback is not None:
1483
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1484
+ if sigmas[i + 1] > 0 and i < len(sigmas) * 0.167:
1485
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
1486
+ dt_1 = sigma_mid - sigma_hat
1487
+ dt_2 = sigmas[i + 1] - sigma_hat
1488
+ x_2 = x + d * dt_1
1489
+ scale = ((len(sigmas) - i) / len(sigmas)) ** 2
1490
+ sa = 1 + scale * 0.15
1491
+ denoised_2 = smea_sampling_step_denoised(x_2, model, sigma_mid, sa, **extra_args)
1492
+ d_2 = to_d(x_2, sigma_mid, denoised_2 * (sa ** 2) * 0.75)
1493
+ x = x + d_2 * dt_2
1494
+ else:
1495
+ dt = sigmas[i + 1] - sigma_hat
1496
+ # Euler method
1497
+ x = x + d * dt
1498
+ return x
1499
+
1500
+ ## og sampler
1501
+ @torch.no_grad()
1502
+ def sample_euler_dy_og(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1503
+ extra_args = {} if extra_args is None else extra_args
1504
+ s_in = x.new_ones([x.shape[0]])
1505
+ for i in trange(len(sigmas) - 1, disable=disable):
1506
+ # print(i)
1507
+ # i绗竴姝ヤ负0
1508
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1509
+ eps = torch.randn_like(x) * s_noise
1510
+ sigma_hat = sigmas[i] * (gamma + 1)
1511
+ # print(sigma_hat)
1512
+ dt = sigmas[i + 1] - sigma_hat
1513
+ if gamma > 0:
1514
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1515
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1516
+ d = sampling.to_d(x, sigma_hat, denoised)
1517
+ if sigmas[i + 1] > 0:
1518
+ if i // 2 == 1:
1519
+ x = dy_sampling_step(x, model, dt, sigma_hat, **extra_args)
1520
+ if callback is not None:
1521
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1522
+ # Euler method
1523
+ x = x + d * dt
1524
+ return x
1525
+
1526
+ @torch.no_grad()
1527
+ def sample_euler_smea_dy_og(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
1528
+ extra_args = {} if extra_args is None else extra_args
1529
+ s_in = x.new_ones([x.shape[0]])
1530
+ for i in trange(len(sigmas) - 1, disable=disable):
1531
+ gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1532
+ eps = torch.randn_like(x) * s_noise
1533
+ sigma_hat = sigmas[i] * (gamma + 1)
1534
+ dt = sigmas[i + 1] - sigma_hat
1535
+ if gamma > 0:
1536
+ x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1537
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1538
+ d = sampling.to_d(x, sigma_hat, denoised)
1539
+ # Euler method
1540
+ x = x + d * dt
1541
+ if sigmas[i + 1] > 0:
1542
+ if i + 1 // 2 == 1:
1543
+ x = dy_sampling_step(x, model, dt, sigma_hat, **extra_args)
1544
+ if i + 1 // 2 == 0:
1545
+ x = smea_sampling_step(x, model, dt, sigma_hat, **extra_args)
1546
+ if callback is not None:
1547
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1548
+ return x
1549
+
1550
+ ## TCD
1551
+
1552
+ def sample_tcd_euler_a(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, gamma=0.3):
1553
+ # TCD sampling using modified Euler Ancestral sampler. by @laksjdjf
1554
+ extra_args = {} if extra_args is None else extra_args
1555
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
1556
+ s_in = x.new_ones([x.shape[0]])
1557
+ for i in trange(len(sigmas) - 1, disable=disable):
1558
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1559
+ if callback is not None:
1560
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1561
+
1562
+ #d = to_d(x, sigmas[i], denoised)
1563
+ sigma_from = sigmas[i]
1564
+ sigma_to = sigmas[i + 1]
1565
+
1566
+ t = model.inner_model.sigma_to_t(sigma_from)
1567
+ down_t = (1 - gamma) * t
1568
+ sigma_down = model.inner_model.t_to_sigma(down_t)
1569
+
1570
+ if sigma_down > sigma_to:
1571
+ sigma_down = sigma_to
1572
+ sigma_up = (sigma_to ** 2 - sigma_down ** 2) ** 0.5
1573
+
1574
+ # same as euler ancestral
1575
+ d = to_d(x, sigma_from, denoised)
1576
+ dt = sigma_down - sigma_from
1577
+ x += d * dt
1578
+
1579
+ if sigma_to > 0 and gamma > 0:
1580
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigma_up
1581
+ return x
1582
+
1583
+ @torch.no_grad()
1584
+ def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, gamma=0.3):
1585
+ # TCD sampling using modified DDPM.
1586
+ extra_args = {} if extra_args is None else extra_args
1587
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
1588
+ s_in = x.new_ones([x.shape[0]])
1589
+
1590
+ for i in trange(len(sigmas) - 1, disable=disable):
1591
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1592
+ if callback is not None:
1593
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1594
+
1595
+ sigma_from, sigma_to = sigmas[i], sigmas[i+1]
1596
+
1597
+ # TCD offset, based on gamma, and conversion between sigma and timestep
1598
+ t = model.inner_model.sigma_to_t(sigma_from)
1599
+ t_s = (1 - gamma) * t
1600
+ sigma_to_s = model.inner_model.t_to_sigma(t_s)
1601
+
1602
+ # if sigma_to_s > sigma_to:
1603
+ # sigma_to_s = sigma_to
1604
+ # if sigma_to_s < 0:
1605
+ # sigma_to_s = torch.tensor(1.0)
1606
+ #print(f"sigma_from: {sigma_from}, sigma_to: {sigma_to}, sigma_to_s: {sigma_to_s}")
1607
+
1608
+
1609
+ # The following is equivalent to the comfy DDPM implementation
1610
+ # x = DDPMSampler_step(x / torch.sqrt(1.0 + sigma_from ** 2.0), sigma_from, sigma_to, (x - denoised) / sigma_from, noise_sampler)
1611
+
1612
+ noise_est = (x - denoised) / sigma_from
1613
+ x /= torch.sqrt(1.0 + sigma_from ** 2.0)
1614
+
1615
+ alpha_cumprod = 1 / ((sigma_from * sigma_from) + 1) # _t
1616
+ alpha_cumprod_prev = 1 / ((sigma_to * sigma_to) + 1) # _t_prev
1617
+ alpha = (alpha_cumprod / alpha_cumprod_prev)
1618
+
1619
+ ## These values should approach 1.0?
1620
+ # print(f"alpha_cumprod: {alpha_cumprod}")
1621
+ # print(f"alpha_cumprod_prev: {alpha_cumprod_prev}")
1622
+ # print(f"alpha: {alpha}")
1623
+
1624
+
1625
+ # alpha_cumprod_down = 1 / ((sigma_to_s * sigma_to_s) + 1) # _s
1626
+ # alpha_d = (alpha_cumprod_prev / alpha_cumprod_down)
1627
+ # alpha2 = (alpha_cumprod / alpha_cumprod_down)
1628
+ # print(f"** alpha_cumprod_down: {alpha_cumprod_down}")
1629
+ # print(f"** alpha_d: {alpha_d}, alpha2: #{alpha2}")
1630
+
1631
+ # epsilon noise prediction from comfy DDPM implementation
1632
+ x = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise_est / (1 - alpha_cumprod).sqrt())
1633
+ # x = (1.0 / alpha_d).sqrt() * (x - (1 - alpha) * noise_est / (1 - alpha_cumprod).sqrt())
1634
+
1635
+ first_step = sigma_to == 0
1636
+ last_step = i == len(sigmas) - 2
1637
+
1638
+ if not first_step:
1639
+ if gamma > 0 and not last_step:
1640
+ noise = noise_sampler(sigma_from, sigma_to)
1641
+
1642
+ # x += ((1 - alpha_d) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise
1643
+ variance = ((1 - alpha_cumprod_prev) / (1 - alpha_cumprod)) * (1 - alpha_cumprod / alpha_cumprod_prev)
1644
+ x += variance.sqrt() * noise # scale noise by std deviation
1645
+
1646
+ # relevant diffusers code from scheduling_tcd.py
1647
+ # prev_sample = (alpha_prod_t_prev / alpha_prod_s).sqrt() * pred_noised_sample + (
1648
+ # 1 - alpha_prod_t_prev / alpha_prod_s
1649
+ # ).sqrt() * noise
1650
+
1651
+ x *= torch.sqrt(1.0 + sigma_to ** 2.0)
1652
+
1653
+ # beta_cumprod_t = 1 - alpha_cumprod
1654
+ # beta_cumprod_s = 1 - alpha_cumprod_down
1655
+
1656
+
1657
+ return x