|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import k_diffusion.sampling |
|
|
import torch |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def sample_dpmpp_2m_alt(model, x, sigmas, extra_args=None, callback=None, disable=None): |
|
|
"""DPM-Solver++(2M) alternative sampler |
|
|
Source: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457 |
|
|
""" |
|
|
extra_args = {} if extra_args is None else extra_args |
|
|
s_in = x.new_ones([x.shape[0]]) |
|
|
sigma_fn = lambda t: t.neg().exp() |
|
|
t_fn = lambda sigma: sigma.log().neg() |
|
|
old_denoised = None |
|
|
|
|
|
for i in k_diffusion.sampling.trange(len(sigmas) - 1, disable=disable): |
|
|
denoised = model(x, sigmas[i] * s_in, **extra_args) |
|
|
if callback is not None: |
|
|
callback( |
|
|
{ |
|
|
"x": x, |
|
|
"i": i, |
|
|
"sigma": sigmas[i], |
|
|
"sigma_hat": sigmas[i], |
|
|
"denoised": denoised, |
|
|
} |
|
|
) |
|
|
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) |
|
|
h = t_next - t |
|
|
if old_denoised is None or sigmas[i + 1] == 0: |
|
|
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised |
|
|
else: |
|
|
h_last = t - t_fn(sigmas[i - 1]) |
|
|
r = h_last / h |
|
|
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised |
|
|
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d |
|
|
sigma_progress = i / len(sigmas) |
|
|
adjustment_factor = 1 + (0.15 * (sigma_progress * sigma_progress)) |
|
|
old_denoised = denoised * adjustment_factor |
|
|
return x |
|
|
|
|
|
|
|
|
def add_sample_dpmpp_2m_alt_webui() -> None: |
|
|
"""Adds DPM-Solver++(2M) alternative sampler to the list of available samplers.""" |
|
|
try: |
|
|
from modules import ( |
|
|
sd_samplers, |
|
|
sd_samplers_common, |
|
|
sd_samplers_kdiffusion, |
|
|
) |
|
|
except ImportError: |
|
|
return |
|
|
|
|
|
samplers_dpmpp_2m_alt = [ |
|
|
( |
|
|
"DPM++ 2M alt", |
|
|
sample_dpmpp_2m_alt, |
|
|
["k_dpmpp_2m_alt"], |
|
|
{"scheduler": "karras"}, |
|
|
) |
|
|
] |
|
|
samplers_data_dpmpp_2m_alt = [ |
|
|
sd_samplers_common.SamplerData( |
|
|
label, |
|
|
lambda model, funcname=funcname: sd_samplers_kdiffusion.KDiffusionSampler( |
|
|
funcname, model |
|
|
), |
|
|
aliases, |
|
|
options, |
|
|
) |
|
|
for label, funcname, aliases, options in samplers_dpmpp_2m_alt |
|
|
] |
|
|
|
|
|
sd_samplers.all_samplers.extend(samplers_data_dpmpp_2m_alt) |
|
|
for x in samplers_data_dpmpp_2m_alt: |
|
|
sd_samplers.all_samplers_map[x.name] = x |
|
|
|
|
|
sd_samplers.set_samplers() |
|
|
|
|
|
|
|
|
add_sample_dpmpp_2m_alt_webui() |
|
|
|