|
|
import torch |
|
|
import tqdm |
|
|
import k_diffusion.sampling |
|
|
from modules import sd_samplers_common, sd_samplers_kdiffusion, sd_samplers |
|
|
from tqdm.auto import trange, tqdm |
|
|
from k_diffusion import utils |
|
|
from k_diffusion.sampling import to_d, default_noise_sampler, get_ancestral_step |
|
|
import math |
|
|
from importlib import import_module |
|
|
|
|
|
sampling = import_module("k_diffusion.sampling") |
|
|
NAME = 'Euler_A_Test' |
|
|
ALIAS = 'euler_a_test' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def sample_euler_ancestral_test(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): |
|
|
"""Ancestral sampling with Euler method steps.""" |
|
|
extra_args = {} if extra_args is None else extra_args |
|
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler |
|
|
s_in = x.new_ones([x.shape[0]]) |
|
|
for i in trange(len(sigmas) - 1, disable=disable): |
|
|
denoised = model(x, sigmas[i] * s_in, **extra_args) |
|
|
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) |
|
|
if callback is not None: |
|
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) |
|
|
d = to_d(x, sigmas[i], denoised) |
|
|
|
|
|
dt = sigma_down - sigmas[i] |
|
|
x = x + d * dt |
|
|
if sigmas[i + 1] > 0: |
|
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not NAME in [x.name for x in sd_samplers.all_samplers]: |
|
|
euler_max_samplers = [(NAME, sample_euler_ancestral_test, [ALIAS], {})] |
|
|
samplers_data_euler_max_samplers = [ |
|
|
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: sd_samplers_kdiffusion.KDiffusionSampler(funcname, model), aliases, options) |
|
|
for label, funcname, aliases, options in euler_max_samplers |
|
|
if callable(funcname) or hasattr(k_diffusion.sampling, funcname) |
|
|
] |
|
|
sd_samplers.all_samplers += samplers_data_euler_max_samplers |
|
|
sd_samplers.all_samplers_map = {x.name: x for x in sd_samplers.all_samplers} |
|
|
sd_samplers.set_samplers() |
|
|
|