dikdimon's picture
Upload extensions using SD-Hub extension
c336648 verified
import gradio as gr
from modules import scripts
import modules.shared as shared
import torch, math
import torchvision.transforms.functional as TF
#effect seems better when aplied to denoised result after CFG, rather than to cond/uncond before CFG
class driftrForge(scripts.Script):
def __init__(self):
self.method1 = "None"
self.method2 = "None"
def title(self):
return "Latent Drift Correction"
def show(self, is_img2img):
# make this extension visible in both txt2img and img2img tab.
return scripts.AlwaysVisible
def ui(self, *args, **kwargs):
with gr.Accordion(open=False, label=self.title()):
with gr.Row():
method1 = gr.Dropdown(["None", "custom", "mean", "median", "mean/median average", "centered mean", "average of extremes", "average of quantiles"], value="None", type="value", label='Correction method (per channel)')
method2 = gr.Dropdown(["None", "mean", "median", "mean/median average", "center to quantile", "local average"], value="None", type="value", label='Correction method (overall)')
with gr.Row():
strengthC = gr.Slider(minimum=-1.0, maximum=1.0, step=0.01, value=1.0, label='strength (per channel)')
strengthO = gr.Slider(minimum=-1.0, maximum=1.0, step=0.01, value=0.8, label='strength (overall)')
with gr.Row(equalHeight=True):
custom = gr.Textbox(value='0.5 * (M + m)', max_lines=1, label='custom function', visible=True)
topK = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, value=0.5, label='quantiles', visible=False, scale=0)
blur = gr.Slider(minimum=0, maximum=128, step=1, value=0, label='blur radius (x8)', visible=False, scale=0)
sigmaWeight = gr.Dropdown(["Hard", "Soft", "None"], value="Hard", type="value", label='Limit effect by sigma', scale=0)
with gr.Row():
stepS = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.0, label='Start step')
stepE = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label='End step')
with gr.Row():
softClampS = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label='Soft clamp start step')
softClampE = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label='Soft clamp end step')
def show_topK(m1, m2):
if m1 == "centered mean" or m1 == "average of extremes" or m1 == "average of quantiles":
return gr.update(visible=True), gr.update(visible=False)
elif m2 == "center to quantile":
return gr.update(visible=True), gr.update(visible=False)
elif m2 == "local average":
return gr.update(visible=False), gr.update(visible=True)
else:
return gr.update(visible=False), gr.update(visible=False)
method1.change(
fn=show_topK,
inputs=[method1, method2],
outputs=[topK, blur],
show_progress=False
)
method2.change(
fn=show_topK,
inputs=[method1, method2],
outputs=[topK, blur],
show_progress=False
)
self.infotext_fields = [
(method1, "ldc_method1"),
(method2, "ldc_method2"),
(topK, "ldc_topK"),
(blur, "ldc_blur"),
(strengthC, "ldc_strengthC"),
(strengthO, "ldc_strengthO"),
(stepS, "ldc_stepS"),
(stepE, "ldc_stepE"),
(sigmaWeight, "ldc_sigW"),
(softClampS, "ldc_softClampS"),
(softClampE, "ldc_softClampE"),
(custom, "ldc_custom"),
]
return method1, method2, topK, blur, strengthC, strengthO, stepS, stepE, sigmaWeight, softClampS, softClampE, custom
def patch(self, model):
model_sampling = model.model.model_sampling
sigmin = model_sampling.sigma(model_sampling.timestep(model_sampling.sigma_min))
sigmax = model_sampling.sigma(model_sampling.timestep(model_sampling.sigma_max))
## https://huggingface.co/blog/TimothyAlexisVass/explaining-the-sdxl-latent-space
def soft_clamp_tensor(input_tensor, threshold=3.5, boundary=4):
if max(abs(input_tensor.max()), abs(input_tensor.min())) < 4:
return input_tensor
channel_dim = 1
max_vals = input_tensor.max(channel_dim, keepdim=True)[0]
max_replace = ((input_tensor - threshold) / (max_vals - threshold)) * (boundary - threshold) + threshold
over_mask = (input_tensor > threshold)
min_vals = input_tensor.min(channel_dim, keepdim=True)[0]
min_replace = ((input_tensor + threshold) / (min_vals + threshold)) * (-boundary + threshold) - threshold
under_mask = (input_tensor < -threshold)
return torch.where(over_mask, max_replace, torch.where(under_mask, min_replace, input_tensor))
def center_latent_mean_values(latent, multiplier):
thisStep = shared.state.sampling_step
lastStep = shared.state.sampling_steps
channelMultiplier = multiplier * self.strengthC
fullMultiplier = multiplier * self.strengthO
if thisStep >= self.stepS * lastStep and thisStep <= self.stepE * lastStep:
for b in range(len(latent)):
for c in range(4):
custom = None
channel = latent[b][c]
if self.method1 == "mean":
custom = "M"
#averageMid = torch.mean(channel)
#latent[b][c] -= averageMid * channelMultiplier
elif self.method1 == "median":
custom = "m"
#averageMid = torch.quantile(channel, 0.5)
#latent[b][c] -= averageMid * channelMultiplier
elif self.method1 == "mean/median average":
custom = "0.5 * (M+m)"
#averageMid = 0.5 * (torch.mean(channel) + torch.quantile(channel, 0.5))
#latent[b][c] -= averageMid * channelMultiplier
elif self.method1 == "centered mean":
custom="rM(self.topK, 1.0-self.topK)"
## valuesHi = torch.topk(channel, int(len(channel)*self.topK), largest=True).values
## valuesLo = torch.topk(channel, int(len(channel)*self.topK), largest=False).values
## averageMid = torch.mean(channel).item() * len(channel)
## averageMid -= torch.mean(valuesHi).item() * len(channel)*self.topK
## averageMid -= torch.mean(valuesLo).item() * len(channel)*self.topK
## averageMid /= len(channel)*(1.0 - 2*self.topK)
## latent[b][c] -= averageMid * channelMultiplier
elif self.method1 == "average of extremes":
custom="0.5 * (inner_rL(self.topK) + inner_rH(self.topK))"
## valuesHi = torch.topk(channel, int(len(channel)*self.topK), largest=True).values
## valuesLo = torch.topk(channel, int(len(channel)*self.topK), largest=False).values
## averageMid = 0.5 * (torch.mean(valuesHi).item() + torch.mean(valuesLo).item())
## latent[b][c] -= averageMid * channelMultiplier
elif self.method1 == "average of quantiles":
custom="0.5 * (q(self.topK) + q(1.0-self.topK))"
## averageMid = 0.5 * (torch.quantile(channel, self.topK) + torch.quantile(channel, 1.0 - self.topK))
## latent[b][c] -= averageMid * channelMultiplier
elif self.method1 == "custom":
custom = self.custom
if custom != None:
M = torch.mean(channel)
m = torch.quantile(channel, 0.5)
def q(quant):
return torch.quantile(channel, quant)
def qa(quant):
return torch.quantile(abs(channel), quant)
def inner_rL(lo): # mean of values from lowest to input(proportional)
valuesLo = torch.topk(channel, int(len(channel)*lo), largest=False).values
return torch.mean(valuesLo).item()
def inner_rH(hi): # mean of values from input(proportional) to highest
valuesHi = torch.topk(channel, int(len(channel)*hi), largest=True).values
return torch.mean(valuesHi).item()
def rM(rangelo, rangehi): # mean of range
if rangelo == rangehi:
return M
else:
averageHi = inner_rH(1.0-rangehi)
averageLo = inner_rL(rangelo)
average = torch.mean(channel).item() * len(channel)
average -= averageLo * len(channel) * rangelo
average -= averageHi * len(channel) * (1.0-rangehi)
average /= len(channel)*(rangehi - rangelo)
return average
averageMid = eval(custom)
latent[b][c] -= averageMid * channelMultiplier
if self.method2 == "mean":
latent[b] -= latent[b].mean() * fullMultiplier
elif self.method2 == "median":
latent[b] -= latent[b].median() * fullMultiplier
elif self.method2 == "mean/median average":
mm = latent[b].mean() + latent[b].median()
latent[b] -= 0.5 * fullMultiplier * mm
elif self.method2 == "center to quantile":
quantile = torch.quantile(latent[b].flatten(), self.topK) # 0.5 is same as median
latent[b] -= quantile * fullMultiplier
elif self.method2 == "local average" and fullMultiplier != 0.0 and self.blur != 0:
minDim = min(latent.size(2), latent.size(3))
if minDim % 2 == 0: # blur kernel size must be odd
minDim -= 1
blurSize = min (minDim, 1+self.blur+self.blur)
blurred = TF.gaussian_blur(latent[b], blurSize)
torch.lerp(latent[b], blurred, fullMultiplier, out=latent[b])
del blurred
if thisStep >= self.softClampS * lastStep and thisStep <= self.softClampE * lastStep:
for b in range(len(latent)):
latent[b] = soft_clamp_tensor (latent[b])
return latent
def map_sigma(sigma, sigmax, sigmin):
return (sigma - sigmin) / (sigmax - sigmin)
def center_mean_latent_post_cfg(args):
denoised = args["denoised"]
sigma = args["sigma"][0]
if self.sigmaWeight == "None": # range 1 - always full correction
mult = 1
else:
mult = map_sigma(sigma, sigmax, sigmin) # range 0.0 to 1.0
if self.sigmaWeight == "Soft": # range 0.5 to 1.0
mult += 1.0
mult /= 2.0
denoised = center_latent_mean_values(denoised, mult)
return denoised
m = model.clone()
m.set_model_sampler_post_cfg_function(center_mean_latent_post_cfg)
return (m, )
def process(self, params, *script_args, **kwargs):
method1, method2, topK, blur, strengthC, strengthO, stepS, stepE, sigmaWeight, softClampS, softClampE, custom = script_args
if method1 == "None" and method2 == "None":
return
self.method1 = method1
self.method2 = method2
self.topK = topK
self.blur = blur
self.strengthC = strengthC
self.strengthO = strengthO
self.stepS = stepS
self.stepE = stepE
self.sigmaWeight = sigmaWeight
self.softClampS = softClampS
self.softClampE = softClampE
self.custom = custom
# Below codes will add some logs to the texts below the image outputs on UI.
# The extra_generation_params does not influence results.
params.extra_generation_params.update(dict(
ldc_method1 = method1,
ldc_method2 = method2,
ldc_strengthC = strengthC,
ldc_strengthO = strengthO,
ldc_stepS = stepS,
ldc_stepE = stepE,
ldc_sigW = sigmaWeight,
ldc_softClampS = softClampS,
ldc_softClampE = softClampE,
))
if method1 == "custom":
params.extra_generation_params.update(dict(ldc_custom = custom, ))
if method1 == "centered mean" or method1 == "average of extremes" or method1 == "average of quantiles" or method2 == "center to quantile":
params.extra_generation_params.update(dict(ldc_topK = topK, ))
if method2 == "local average":
params.extra_generation_params.update(dict(ldc_blur = blur, ))
return
def process_before_every_sampling(self, params, *script_args, **kwargs):
method1, method2 = script_args[0], script_args[1]
if method1 != "None" or method2 != "None":
unet = params.sd_model.forge_objects.unet
unet = driftrForge.patch(self, unet)[0]
params.sd_model.forge_objects.unet = unet
return