File size: 14,558 Bytes
c336648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import gradio as gr

from modules import scripts
import modules.shared as shared
import torch, math
import torchvision.transforms.functional as TF

#effect seems better when aplied to denoised result after CFG, rather than to cond/uncond before CFG

class driftrForge(scripts.Script):
    def __init__(self):
        self.method1 = "None"
        self.method2 = "None"

    def title(self):
        return "Latent Drift Correction"

    def show(self, is_img2img):
        # make this extension visible in both txt2img and img2img tab.
        return scripts.AlwaysVisible

    def ui(self, *args, **kwargs):
        with gr.Accordion(open=False, label=self.title()):
            with gr.Row():
                method1 = gr.Dropdown(["None", "custom", "mean", "median", "mean/median average", "centered mean", "average of extremes", "average of quantiles"], value="None", type="value", label='Correction method (per channel)')
                method2 = gr.Dropdown(["None", "mean", "median", "mean/median average", "center to quantile", "local average"], value="None", type="value", label='Correction method (overall)')
            with gr.Row():
                strengthC = gr.Slider(minimum=-1.0, maximum=1.0, step=0.01, value=1.0, label='strength (per channel)')
                strengthO = gr.Slider(minimum=-1.0, maximum=1.0, step=0.01, value=0.8, label='strength (overall)')
            with gr.Row(equalHeight=True):
                custom = gr.Textbox(value='0.5 * (M + m)', max_lines=1, label='custom function', visible=True)
                topK = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, value=0.5, label='quantiles', visible=False, scale=0)
                blur = gr.Slider(minimum=0, maximum=128, step=1, value=0, label='blur radius (x8)', visible=False, scale=0)
                sigmaWeight = gr.Dropdown(["Hard", "Soft", "None"], value="Hard", type="value", label='Limit effect by sigma', scale=0)
            with gr.Row():
                stepS = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.0, label='Start step')
                stepE = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label='End step')
            with gr.Row():
                softClampS = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label='Soft clamp start step')
                softClampE = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label='Soft clamp end step')

            def show_topK(m1, m2):
                if m1 == "centered mean" or m1 == "average of extremes" or m1 == "average of quantiles":
                    return gr.update(visible=True), gr.update(visible=False)
                elif m2 == "center to quantile":
                    return gr.update(visible=True), gr.update(visible=False)
                elif m2 == "local average":
                    return gr.update(visible=False), gr.update(visible=True)
                else:
                    return gr.update(visible=False), gr.update(visible=False)

            method1.change(
                fn=show_topK,
                inputs=[method1, method2],
                outputs=[topK, blur],
                show_progress=False
            )
            method2.change(
                fn=show_topK,
                inputs=[method1, method2],
                outputs=[topK, blur],
                show_progress=False
            )

        self.infotext_fields = [
            (method1,       "ldc_method1"),
            (method2,       "ldc_method2"),
            (topK,          "ldc_topK"),
            (blur,          "ldc_blur"),
            (strengthC,     "ldc_strengthC"),
            (strengthO,     "ldc_strengthO"),
            (stepS,         "ldc_stepS"),
            (stepE,         "ldc_stepE"),
            (sigmaWeight,   "ldc_sigW"),
            (softClampS,    "ldc_softClampS"),
            (softClampE,    "ldc_softClampE"),
            (custom,        "ldc_custom"),
        ]

        return method1, method2, topK, blur, strengthC, strengthO, stepS, stepE, sigmaWeight, softClampS, softClampE, custom


    def patch(self, model):
        model_sampling = model.model.model_sampling
        sigmin = model_sampling.sigma(model_sampling.timestep(model_sampling.sigma_min))
        sigmax = model_sampling.sigma(model_sampling.timestep(model_sampling.sigma_max))


##  https://huggingface.co/blog/TimothyAlexisVass/explaining-the-sdxl-latent-space
        def soft_clamp_tensor(input_tensor, threshold=3.5, boundary=4):
            if max(abs(input_tensor.max()), abs(input_tensor.min())) < 4:
                return input_tensor
            channel_dim = 1

            max_vals = input_tensor.max(channel_dim, keepdim=True)[0]
            max_replace = ((input_tensor - threshold) / (max_vals - threshold)) * (boundary - threshold) + threshold
            over_mask = (input_tensor > threshold)

            min_vals = input_tensor.min(channel_dim, keepdim=True)[0]
            min_replace = ((input_tensor + threshold) / (min_vals + threshold)) * (-boundary + threshold) - threshold
            under_mask = (input_tensor < -threshold)

            return torch.where(over_mask, max_replace, torch.where(under_mask, min_replace, input_tensor))

        def center_latent_mean_values(latent, multiplier):
            thisStep = shared.state.sampling_step
            lastStep = shared.state.sampling_steps

            channelMultiplier = multiplier * self.strengthC
            fullMultiplier = multiplier * self.strengthO
          
            if thisStep >= self.stepS * lastStep and thisStep <= self.stepE * lastStep:
                for b in range(len(latent)):
                    for c in range(4):
                        custom = None
                        channel = latent[b][c]

                        if self.method1 == "mean":
                            custom = "M"
                            #averageMid = torch.mean(channel)
                            #latent[b][c] -= averageMid  * channelMultiplier

                        elif self.method1 == "median":
                            custom = "m"
                            #averageMid = torch.quantile(channel, 0.5)
                            #latent[b][c] -= averageMid  * channelMultiplier

                        elif self.method1 == "mean/median average":
                            custom = "0.5 * (M+m)"
                            #averageMid = 0.5 * (torch.mean(channel) + torch.quantile(channel, 0.5))
                            #latent[b][c] -= averageMid  * channelMultiplier
                            

                        elif self.method1 == "centered mean":
                            custom="rM(self.topK, 1.0-self.topK)"
##                            valuesHi = torch.topk(channel, int(len(channel)*self.topK), largest=True).values
##                            valuesLo = torch.topk(channel, int(len(channel)*self.topK), largest=False).values
##                            averageMid = torch.mean(channel).item() * len(channel)
##                            averageMid -= torch.mean(valuesHi).item() * len(channel)*self.topK
##                            averageMid -= torch.mean(valuesLo).item() * len(channel)*self.topK
##                            averageMid /= len(channel)*(1.0 - 2*self.topK)
##                            latent[b][c] -= averageMid  * channelMultiplier

                        elif self.method1 == "average of extremes":
                            custom="0.5 * (inner_rL(self.topK) + inner_rH(self.topK))"
##                            valuesHi = torch.topk(channel, int(len(channel)*self.topK), largest=True).values
##                            valuesLo = torch.topk(channel, int(len(channel)*self.topK), largest=False).values
##                            averageMid = 0.5 * (torch.mean(valuesHi).item() + torch.mean(valuesLo).item())
##                            latent[b][c] -= averageMid  * channelMultiplier

                        elif self.method1 == "average of quantiles":
                            custom="0.5 * (q(self.topK) + q(1.0-self.topK))"
##                            averageMid = 0.5 * (torch.quantile(channel, self.topK) + torch.quantile(channel, 1.0 - self.topK))
##                            latent[b][c] -= averageMid  * channelMultiplier

                        elif self.method1 == "custom":
                            custom = self.custom

                        if custom != None:
                            M = torch.mean(channel)
                            m = torch.quantile(channel, 0.5)
                            def q(quant):
                                return torch.quantile(channel, quant)
                            def qa(quant):
                                return torch.quantile(abs(channel), quant)
                            def inner_rL(lo):   #   mean of values from lowest to input(proportional)
                                valuesLo = torch.topk(channel, int(len(channel)*lo), largest=False).values
                                return torch.mean(valuesLo).item()
                            def inner_rH(hi):   #   mean of values from input(proportional) to highest
                                valuesHi = torch.topk(channel, int(len(channel)*hi), largest=True).values
                                return torch.mean(valuesHi).item()
                                
                            def rM(rangelo, rangehi):       #   mean of range
                                if rangelo == rangehi:
                                    return M
                                else:
                                    averageHi = inner_rH(1.0-rangehi)
                                    averageLo = inner_rL(rangelo)

                                    average = torch.mean(channel).item() * len(channel)
                                    average -= averageLo * len(channel) * rangelo
                                    average -= averageHi * len(channel) * (1.0-rangehi)
                                    average /= len(channel)*(rangehi - rangelo)
                                    return average

                            averageMid = eval(custom)
                            latent[b][c] -= averageMid  * channelMultiplier
                            
                    if self.method2 == "mean":
                        latent[b] -= latent[b].mean() * fullMultiplier
                    elif self.method2 == "median":
                        latent[b] -= latent[b].median() * fullMultiplier
                    elif self.method2 == "mean/median average":
                        mm = latent[b].mean() + latent[b].median()
                        latent[b] -= 0.5 * fullMultiplier * mm
                    elif self.method2 == "center to quantile":
                       quantile = torch.quantile(latent[b].flatten(), self.topK) #   0.5 is same as median
                       latent[b] -= quantile * fullMultiplier
                    elif self.method2 == "local average" and fullMultiplier != 0.0 and self.blur != 0:
                        minDim = min(latent.size(2), latent.size(3))
                        if minDim % 2 == 0:    #     blur kernel size must be odd
                            minDim -= 1
                        blurSize = min (minDim, 1+self.blur+self.blur)
                        
                        blurred = TF.gaussian_blur(latent[b], blurSize)
                        torch.lerp(latent[b], blurred, fullMultiplier, out=latent[b])
                        del blurred


            if thisStep >= self.softClampS * lastStep and thisStep <= self.softClampE * lastStep:
                for b in range(len(latent)):
                    latent[b] = soft_clamp_tensor (latent[b])


            return latent


        def map_sigma(sigma, sigmax, sigmin):
            return (sigma - sigmin) / (sigmax - sigmin)

        def center_mean_latent_post_cfg(args):
            denoised = args["denoised"]
            sigma    = args["sigma"][0]

            if self.sigmaWeight == "None":                  #   range 1 - always full correction
                mult = 1
            else:
                mult = map_sigma(sigma, sigmax, sigmin)     #   range 0.0 to 1.0
                if self.sigmaWeight == "Soft":              #   range 0.5 to 1.0
                    mult += 1.0
                    mult /= 2.0

            denoised = center_latent_mean_values(denoised, mult)        
            return denoised

        m = model.clone()
        m.set_model_sampler_post_cfg_function(center_mean_latent_post_cfg)

        return (m, )


    def process(self, params, *script_args, **kwargs):
        method1, method2, topK, blur, strengthC, strengthO, stepS, stepE, sigmaWeight, softClampS, softClampE, custom = script_args

        if method1 == "None" and method2 == "None":
            return

        self.method1 = method1
        self.method2 = method2
        self.topK = topK
        self.blur = blur
        self.strengthC = strengthC
        self.strengthO = strengthO
        self.stepS = stepS
        self.stepE = stepE
        self.sigmaWeight = sigmaWeight
        self.softClampS = softClampS
        self.softClampE = softClampE
        self.custom = custom
        
        # Below codes will add some logs to the texts below the image outputs on UI.
        # The extra_generation_params does not influence results.
        params.extra_generation_params.update(dict(
            ldc_method1 = method1,
            ldc_method2 = method2,
            ldc_strengthC = strengthC,
            ldc_strengthO = strengthO,
            ldc_stepS = stepS,
            ldc_stepE = stepE,
            ldc_sigW = sigmaWeight,
            ldc_softClampS = softClampS,
            ldc_softClampE = softClampE,
        ))
        if method1 == "custom":
            params.extra_generation_params.update(dict(ldc_custom = custom, ))
        if method1 == "centered mean" or method1 == "average of extremes" or method1 == "average of quantiles" or method2 == "center to quantile":
            params.extra_generation_params.update(dict(ldc_topK = topK, ))
        if method2 == "local average":
            params.extra_generation_params.update(dict(ldc_blur = blur, ))

        return


    def process_before_every_sampling(self, params, *script_args, **kwargs):
        method1, method2 = script_args[0], script_args[1]
        if method1 != "None" or method2 != "None":
            unet = params.sd_model.forge_objects.unet
            unet = driftrForge.patch(self, unet)[0]
            params.sd_model.forge_objects.unet = unet

        return