Update core/models/ddim/ddim_vd.py
Browse files- core/models/ddim/ddim_vd.py +16 -4
core/models/ddim/ddim_vd.py
CHANGED
|
@@ -6,6 +6,7 @@ import torch
|
|
| 6 |
import numpy as np
|
| 7 |
from tqdm import tqdm
|
| 8 |
from functools import partial
|
|
|
|
| 9 |
|
| 10 |
from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
| 11 |
|
|
@@ -27,7 +28,8 @@ class DDIMSampler_VD(DDIMSampler):
|
|
| 27 |
mix_weight=None,
|
| 28 |
noise_dropout=0.,
|
| 29 |
verbose=True,
|
| 30 |
-
log_every_t=100,
|
|
|
|
| 31 |
|
| 32 |
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
|
| 33 |
print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
|
|
@@ -42,7 +44,8 @@ class DDIMSampler_VD(DDIMSampler):
|
|
| 42 |
noise_dropout=noise_dropout,
|
| 43 |
temperature=temperature,
|
| 44 |
log_every_t=log_every_t,
|
| 45 |
-
mix_weight=mix_weight,
|
|
|
|
| 46 |
return samples, intermediates
|
| 47 |
|
| 48 |
@torch.no_grad()
|
|
@@ -58,7 +61,8 @@ class DDIMSampler_VD(DDIMSampler):
|
|
| 58 |
noise_dropout=0.,
|
| 59 |
temperature=1.,
|
| 60 |
mix_weight=None,
|
| 61 |
-
log_every_t=100,
|
|
|
|
| 62 |
|
| 63 |
device = self.model.device
|
| 64 |
dtype = condition[0][0].dtype
|
|
@@ -86,7 +90,12 @@ class DDIMSampler_VD(DDIMSampler):
|
|
| 86 |
|
| 87 |
pred_xt = xt
|
| 88 |
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
|
|
|
|
|
|
|
|
|
| 89 |
for i, step in enumerate(iterator):
|
|
|
|
|
|
|
| 90 |
index = total_steps - i - 1
|
| 91 |
ts = torch.full((bs,), step, device=device, dtype=torch.long)
|
| 92 |
|
|
@@ -107,6 +116,9 @@ class DDIMSampler_VD(DDIMSampler):
|
|
| 107 |
intermediates['pred_xt'].append(pred_xt)
|
| 108 |
intermediates['pred_x0'].append(pred_x0)
|
| 109 |
|
|
|
|
|
|
|
|
|
|
| 110 |
return pred_xt, intermediates
|
| 111 |
|
| 112 |
@torch.no_grad()
|
|
@@ -172,4 +184,4 @@ class DDIMSampler_VD(DDIMSampler):
|
|
| 172 |
x_prev_i = a_prev.sqrt() * pred_x0_i + dir_xt + noise
|
| 173 |
x_prev.append(x_prev_i)
|
| 174 |
pred_x0.append(pred_x0_i)
|
| 175 |
-
return x_prev, pred_x0
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
from tqdm import tqdm
|
| 8 |
from functools import partial
|
| 9 |
+
import streamlit as st
|
| 10 |
|
| 11 |
from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
| 12 |
|
|
|
|
| 28 |
mix_weight=None,
|
| 29 |
noise_dropout=0.,
|
| 30 |
verbose=True,
|
| 31 |
+
log_every_t=100,
|
| 32 |
+
progress_bar=False, ):
|
| 33 |
|
| 34 |
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
|
| 35 |
print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
|
|
|
|
| 44 |
noise_dropout=noise_dropout,
|
| 45 |
temperature=temperature,
|
| 46 |
log_every_t=log_every_t,
|
| 47 |
+
mix_weight=mix_weight,
|
| 48 |
+
progress_bar=progress_bar, )
|
| 49 |
return samples, intermediates
|
| 50 |
|
| 51 |
@torch.no_grad()
|
|
|
|
| 61 |
noise_dropout=0.,
|
| 62 |
temperature=1.,
|
| 63 |
mix_weight=None,
|
| 64 |
+
log_every_t=100,
|
| 65 |
+
progress_bar=False,):
|
| 66 |
|
| 67 |
device = self.model.device
|
| 68 |
dtype = condition[0][0].dtype
|
|
|
|
| 90 |
|
| 91 |
pred_xt = xt
|
| 92 |
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
| 93 |
+
if progress_bar is not None:
|
| 94 |
+
progress_bar.progress(0)
|
| 95 |
+
progress_bar.text("Generating samples...")
|
| 96 |
for i, step in enumerate(iterator):
|
| 97 |
+
if progress_bar is not None:
|
| 98 |
+
progress_bar.progress(i/total_steps)
|
| 99 |
index = total_steps - i - 1
|
| 100 |
ts = torch.full((bs,), step, device=device, dtype=torch.long)
|
| 101 |
|
|
|
|
| 116 |
intermediates['pred_xt'].append(pred_xt)
|
| 117 |
intermediates['pred_x0'].append(pred_x0)
|
| 118 |
|
| 119 |
+
if progress_bar is not None:
|
| 120 |
+
progress_bar.success("Sampling complete.")
|
| 121 |
+
|
| 122 |
return pred_xt, intermediates
|
| 123 |
|
| 124 |
@torch.no_grad()
|
|
|
|
| 184 |
x_prev_i = a_prev.sqrt() * pred_x0_i + dir_xt + noise
|
| 185 |
x_prev.append(x_prev_i)
|
| 186 |
pred_x0.append(pred_x0_i)
|
| 187 |
+
return x_prev, pred_x0
|