Spaces:
Runtime error
Runtime error
Update modules/model.py
Browse files- modules/model.py +22 -0
modules/model.py
CHANGED
|
@@ -39,6 +39,20 @@ exists = lambda val: val is not None
|
|
| 39 |
default = lambda val, d: val if exists(val) else d
|
| 40 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
def get_attention_scores(attn, query, key, attention_mask=None):
|
| 44 |
|
|
@@ -528,6 +542,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|
| 528 |
noise_pred = noise_pred_uncond + guidance_scale * (
|
| 529 |
noise_pred_text - noise_pred_uncond
|
| 530 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
return noise_pred
|
| 532 |
|
| 533 |
sampler_args = self.get_sampler_extra_args_i2i(sigma_sched, sampler)
|
|
@@ -696,6 +714,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|
| 696 |
noise_pred = noise_pred_uncond + guidance_scale * (
|
| 697 |
noise_pred_text - noise_pred_uncond
|
| 698 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 699 |
return noise_pred
|
| 700 |
|
| 701 |
extra_args = self.get_sampler_extra_args_t2i(
|
|
|
|
| 39 |
default = lambda val, d: val if exists(val) else d
|
| 40 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 41 |
|
| 42 |
+
# from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
| 43 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
| 44 |
+
"""
|
| 45 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
| 46 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
| 47 |
+
"""
|
| 48 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
| 49 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
| 50 |
+
# rescale the results from guidance (fixes overexposure)
|
| 51 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
| 52 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
| 53 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
| 54 |
+
return noise_cfg
|
| 55 |
+
|
| 56 |
|
| 57 |
def get_attention_scores(attn, query, key, attention_mask=None):
|
| 58 |
|
|
|
|
| 542 |
noise_pred = noise_pred_uncond + guidance_scale * (
|
| 543 |
noise_pred_text - noise_pred_uncond
|
| 544 |
)
|
| 545 |
+
|
| 546 |
+
if guidance_rescale > 0.0:
|
| 547 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
| 548 |
+
|
| 549 |
return noise_pred
|
| 550 |
|
| 551 |
sampler_args = self.get_sampler_extra_args_i2i(sigma_sched, sampler)
|
|
|
|
| 714 |
noise_pred = noise_pred_uncond + guidance_scale * (
|
| 715 |
noise_pred_text - noise_pred_uncond
|
| 716 |
)
|
| 717 |
+
|
| 718 |
+
if guidance_rescale > 0.0:
|
| 719 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
| 720 |
+
|
| 721 |
return noise_pred
|
| 722 |
|
| 723 |
extra_args = self.get_sampler_extra_args_t2i(
|