|
|
import math |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange, repeat |
|
|
import gradio as gr |
|
|
|
|
|
from inspect import isfunction |
|
|
from torch import nn, einsum |
|
|
|
|
|
from modules.processing import StableDiffusionProcessing |
|
|
import modules.scripts as scripts |
|
|
from modules import shared |
|
|
from modules.script_callbacks import on_cfg_denoiser, CFGDenoiserParams, CFGDenoisedParams, on_cfg_denoised, AfterCFGCallbackParams, on_cfg_after_cfg |
|
|
|
|
|
import os |
|
|
from scripts import xyz_grid_support_sag |
|
|
|
|
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") |
|
|
|
|
|
def exists(val): |
|
|
return val is not None |
|
|
|
|
|
def default(val, d): |
|
|
if exists(val): |
|
|
return val |
|
|
return d() if isfunction(d) else d |
|
|
|
|
|
def adaptive_gaussian_blur_2d(img, sigma, kernel_size=None): |
|
|
if kernel_size is None: |
|
|
kernel_size = max(5, int(sigma * 4 + 1)) |
|
|
kernel_size = kernel_size + 1 if kernel_size % 2 == 0 else kernel_size |
|
|
|
|
|
ksize_half = (kernel_size - 1) * 0.5 |
|
|
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) |
|
|
pdf = torch.exp(-0.5 * (x / sigma).pow(2)) |
|
|
x_kernel = pdf / pdf.sum() |
|
|
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) |
|
|
|
|
|
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) |
|
|
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) |
|
|
|
|
|
padding = kernel_size // 2 |
|
|
img = F.pad(img, (padding, padding, padding, padding), mode="reflect") |
|
|
img = F.conv2d(img, kernel2d, groups=img.shape[-3]) |
|
|
|
|
|
return img |
|
|
|
|
|
class LoggedSelfAttention(nn.Module): |
|
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): |
|
|
super().__init__() |
|
|
inner_dim = dim_head * heads |
|
|
context_dim = default(context_dim, query_dim) |
|
|
|
|
|
self.scale = dim_head ** -0.5 |
|
|
self.heads = heads |
|
|
|
|
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) |
|
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) |
|
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) |
|
|
|
|
|
self.to_out = nn.Sequential( |
|
|
nn.Linear(inner_dim, query_dim), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
self.attn_probs = None |
|
|
|
|
|
def forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): |
|
|
if additional_tokens is not None: |
|
|
n_tokens_to_mask = additional_tokens.shape[1] |
|
|
x = torch.cat([additional_tokens, x], dim=1) |
|
|
|
|
|
if n_times_crossframe_attn_in_self: |
|
|
assert x.shape[0] % n_times_crossframe_attn_in_self == 0 |
|
|
k = repeat( |
|
|
k[::n_times_crossframe_attn_in_self], |
|
|
"b ... -> (b n) ...", |
|
|
n=n_times_crossframe_attn_in_self, |
|
|
) |
|
|
v = repeat( |
|
|
v[::n_times_crossframe_attn_in_self], |
|
|
"b ... -> (b n) ...", |
|
|
n=n_times_crossframe_attn_in_self, |
|
|
) |
|
|
|
|
|
h = self.heads |
|
|
|
|
|
q = self.to_q(x) |
|
|
context = default(context, x) |
|
|
k = self.to_k(context) |
|
|
v = self.to_v(context) |
|
|
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) |
|
|
|
|
|
if _ATTN_PRECISION == "fp32": |
|
|
with torch.autocast(enabled=False, device_type='cuda'): |
|
|
q, k = q.float(), k.float() |
|
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale |
|
|
else: |
|
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale |
|
|
|
|
|
del q, k |
|
|
|
|
|
if exists(mask): |
|
|
mask = rearrange(mask, 'b ... -> b (...)') |
|
|
max_neg_value = -torch.finfo(sim.dtype).max |
|
|
mask = repeat(mask, 'b j -> (b h) () j', h=h) |
|
|
sim.masked_fill_(~mask, max_neg_value) |
|
|
|
|
|
sim = sim.softmax(dim=-1) |
|
|
|
|
|
self.attn_probs = sim |
|
|
|
|
|
out = einsum('b i j, b j d -> b i d', sim, v) |
|
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) |
|
|
|
|
|
if additional_tokens is not None: |
|
|
out = out[:, n_tokens_to_mask:] |
|
|
return self.to_out(out) |
|
|
|
|
|
def xattn_forward_log(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): |
|
|
if additional_tokens is not None: |
|
|
n_tokens_to_mask = additional_tokens.shape[1] |
|
|
x = torch.cat([additional_tokens, x], dim=1) |
|
|
|
|
|
if n_times_crossframe_attn_in_self: |
|
|
assert x.shape[0] % n_times_crossframe_attn_in_self == 0 |
|
|
k = repeat( |
|
|
k[::n_times_crossframe_attn_in_self], |
|
|
"b ... -> (b n) ...", |
|
|
n=n_times_crossframe_attn_in_self, |
|
|
) |
|
|
v = repeat( |
|
|
v[::n_times_crossframe_attn_in_self], |
|
|
"b ... -> (b n) ...", |
|
|
n=n_times_crossframe_attn_in_self, |
|
|
) |
|
|
|
|
|
h = self.heads |
|
|
|
|
|
q = self.to_q(x) |
|
|
context = default(context, x) |
|
|
k = self.to_k(context) |
|
|
v = self.to_v(context) |
|
|
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) |
|
|
|
|
|
if _ATTN_PRECISION == "fp32": |
|
|
with torch.autocast(enabled=False, device_type='cuda'): |
|
|
q, k = q.float(), k.float() |
|
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale |
|
|
else: |
|
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale |
|
|
|
|
|
del q, k |
|
|
|
|
|
if exists(mask): |
|
|
mask = rearrange(mask, 'b ... -> b (...)') |
|
|
max_neg_value = -torch.finfo(sim.dtype).max |
|
|
mask = repeat(mask, 'b j -> (b h) () j', h=h) |
|
|
sim.masked_fill_(~mask, max_neg_value) |
|
|
|
|
|
sim = sim.softmax(dim=-1) |
|
|
|
|
|
self.attn_probs = sim |
|
|
global current_selfattn_map |
|
|
current_selfattn_map = sim |
|
|
|
|
|
sim = sim.to(dtype=v.dtype) |
|
|
out = einsum('b i j, b j d -> b i d', sim, v) |
|
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) |
|
|
out = self.to_out(out) |
|
|
|
|
|
if additional_tokens is not None: |
|
|
out = out[:, n_tokens_to_mask:] |
|
|
global current_outsize |
|
|
current_outsize = out.shape[-2:] |
|
|
return out |
|
|
|
|
|
|
|
|
global current_degraded_pred_compensation, current_degraded_pred |
|
|
current_degraded_pred = None |
|
|
|
|
|
def get_attention_module_for_block(block, layer_name): |
|
|
fallback_order = ["0", "1"] |
|
|
|
|
|
for layer_name in fallback_order: |
|
|
try: |
|
|
|
|
|
if hasattr(block, "_modules") and layer_name in block._modules: |
|
|
layer = block._modules[layer_name] |
|
|
else: |
|
|
|
|
|
for name, module in block.named_modules(): |
|
|
if name.endswith(layer_name) and hasattr(module, "_modules"): |
|
|
layer = module |
|
|
break |
|
|
else: |
|
|
raise ValueError(f"Layer {layer_name} not found in block of type {type(block).__name__}") |
|
|
|
|
|
|
|
|
if hasattr(layer, "transformer_blocks"): |
|
|
return layer.transformer_blocks._modules['0'].attn1 |
|
|
elif hasattr(layer, "attn1"): |
|
|
return layer.attn1 |
|
|
|
|
|
|
|
|
if isinstance(layer, torch.nn.modules.container.Sequential): |
|
|
for sublayer_name, sublayer in layer._modules.items(): |
|
|
try: |
|
|
return get_attention_module_for_block(sublayer, "0") |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
|
|
|
for name, module in layer.named_modules(): |
|
|
if any(attn_keyword in name.lower() for attn_keyword in ["attn", "attention", "selfattn"]): |
|
|
if hasattr(module, "transformer_blocks"): |
|
|
return module.transformer_blocks._modules['0'].attn1 |
|
|
elif hasattr(module, "attn1"): |
|
|
return module.attn1 |
|
|
|
|
|
except (AttributeError, KeyError, ValueError) as e: |
|
|
logger.warning(f"Error accessing attention layer '{layer_name}': {e}. Trying next fallback.") |
|
|
continue |
|
|
|
|
|
|
|
|
raise ValueError(f"No valid attention layer found within block {type(block).__name__}.") |
|
|
|
|
|
class Script(scripts.Script): |
|
|
def __init__(self): |
|
|
self.custom_resolution = 512 |
|
|
|
|
|
def title(self): |
|
|
return "Self Attention Guidance" |
|
|
|
|
|
def show(self, is_img2img): |
|
|
return scripts.AlwaysVisible |
|
|
|
|
|
def ui(self, is_img2img): |
|
|
with gr.Accordion('Self Attention Guidance', open=False): |
|
|
with gr.Row(): |
|
|
enabled = gr.Checkbox(value=False, label="Enable Self Attention Guidance") |
|
|
method = gr.Checkbox(value=False, label="Use bilinear interpolation") |
|
|
attn = gr.Dropdown(label="Attention target", choices=["middle", "block5", "block8", "dynamic"], value="middle") |
|
|
with gr.Group(): |
|
|
scale = gr.Slider(label='Guidance Scale', minimum=-2.0, maximum=10.0, step=0.01, value=0.75) |
|
|
mask_threshold = gr.Slider(label='Mask Threshold', minimum=0.0, maximum=2.0, step=0.01, value=1.0) |
|
|
blur_sigma = gr.Slider(label='Gaussian Blur Sigma', minimum=0.0, maximum=10.0, step=0.01, value=1.0) |
|
|
custom_resolution = gr.Slider(label='Base Reference Resolution', minimum=256, maximum=2048, step=64, value=512, info="Default base resolution for models: SD 1.5= 512, SD 2.1= 768, SDXL= 1024") |
|
|
enabled.change(fn=None, inputs=[enabled], show_progress=False) |
|
|
|
|
|
self.infotext_fields = ( |
|
|
(enabled, lambda d: gr.Checkbox.update(value="SAG Guidance Enabled" in d)), |
|
|
(scale, "SAG Guidance Scale"), |
|
|
(mask_threshold, "SAG Mask Threshold"), |
|
|
(blur_sigma, "SAG Blur Sigma"), |
|
|
(method, lambda d: gr.Checkbox.update(value="SAG bilinear interpolation" in d)), |
|
|
(attn, "SAG Attention Target"), |
|
|
(custom_resolution, "SAG Custom Resolution")) |
|
|
return [enabled, scale, mask_threshold, blur_sigma, method, attn, custom_resolution] |
|
|
|
|
|
def reset_attention_target(self): |
|
|
global sag_attn_target |
|
|
sag_attn_target = self.original_attn_target |
|
|
|
|
|
def process(self, p: StableDiffusionProcessing, *args): |
|
|
enabled, scale, mask_threshold, blur_sigma, method, attn, custom_resolution = args |
|
|
global sag_enabled, sag_mask_threshold, sag_blur_sigma, sag_method_bilinear, sag_attn_target, current_sag_guidance_scale |
|
|
|
|
|
if enabled: |
|
|
sag_enabled = True |
|
|
sag_mask_threshold = mask_threshold |
|
|
sag_blur_sigma = blur_sigma |
|
|
sag_method_bilinear = method |
|
|
self.original_attn_target = attn |
|
|
sag_attn_target = attn |
|
|
current_sag_guidance_scale = scale |
|
|
self.custom_resolution = custom_resolution |
|
|
|
|
|
if attn != "dynamic": |
|
|
org_attn_module = self.get_attention_module(attn) |
|
|
global saved_original_selfattn_forward |
|
|
saved_original_selfattn_forward = org_attn_module.forward |
|
|
org_attn_module.forward = xattn_forward_log.__get__(org_attn_module, org_attn_module.__class__) |
|
|
|
|
|
p.extra_generation_params.update({ |
|
|
"SAG Guidance Enabled": enabled, |
|
|
"SAG Guidance Scale": scale, |
|
|
"SAG Mask Threshold": mask_threshold, |
|
|
"SAG Blur Sigma": blur_sigma, |
|
|
"SAG bilinear interpolation": method, |
|
|
"SAG Attention Target": attn, |
|
|
"SAG Base Model": base_model, |
|
|
"SAG Custom Resolution": custom_resolution |
|
|
}) |
|
|
else: |
|
|
sag_enabled = False |
|
|
|
|
|
if not hasattr(self, 'callbacks_added'): |
|
|
on_cfg_denoiser(self.denoiser_callback) |
|
|
on_cfg_denoised(self.denoised_callback) |
|
|
on_cfg_after_cfg(self.cfg_after_cfg_callback) |
|
|
self.callbacks_added = True |
|
|
|
|
|
|
|
|
self.reset_attention_target() |
|
|
return |
|
|
|
|
|
def denoiser_callback(self, params: CFGDenoiserParams): |
|
|
if not sag_enabled: |
|
|
return |
|
|
|
|
|
global current_xin, current_batch_size, current_max_sigma, current_sag_block_index, current_unet_kwargs, sag_attn_target, current_sigma |
|
|
|
|
|
current_batch_size = params.text_uncond.shape[0] |
|
|
current_xin = params.x[-current_batch_size:] |
|
|
current_uncond_emb = params.text_uncond |
|
|
current_sigma = params.sigma |
|
|
current_image_cond_in = params.image_cond |
|
|
|
|
|
if params.sampling_step == 0: |
|
|
current_max_sigma = current_sigma[-current_batch_size:][0] |
|
|
current_sag_block_index = -1 |
|
|
|
|
|
current_unet_kwargs = { |
|
|
"sigma": current_sigma[-current_batch_size:], |
|
|
"image_cond": current_image_cond_in[-current_batch_size:], |
|
|
"text_uncond": current_uncond_emb, |
|
|
} |
|
|
|
|
|
|
|
|
global current_degraded_pred |
|
|
if current_degraded_pred is None: |
|
|
current_degraded_pred = torch.zeros_like(params.x) |
|
|
|
|
|
|
|
|
global saved_original_selfattn_forward |
|
|
if sag_attn_target == "dynamic": |
|
|
if current_sag_block_index == -1: |
|
|
org_attn_module = get_attention_module_for_block(shared.sd_model.model.diffusion_model.middle_block, '1') |
|
|
saved_original_selfattn_forward = org_attn_module.forward |
|
|
org_attn_module.forward = xattn_forward_log.__get__(org_attn_module, org_attn_module.__class__) |
|
|
current_sag_block_index = 0 |
|
|
elif torch.any(current_unet_kwargs['sigma'] < current_max_sigma / 6.25): |
|
|
if current_sag_block_index == 1: |
|
|
attn_module = get_attention_module_for_block(shared.sd_model.model.diffusion_model.output_blocks[5], '0') |
|
|
attn_module.forward = saved_original_selfattn_forward |
|
|
|
|
|
|
|
|
try: |
|
|
if shared.sd_model.is_sd1: |
|
|
org_attn_module = shared.sd_model.model.diffusion_model.output_blocks[8]._modules['1'].transformer_blocks._modules['0'].attn1 |
|
|
|
|
|
if shared.sd_model.is_sdxl: |
|
|
if hasattr(org_attn_module, 'resnets'): |
|
|
org_attn_module = org_attn_module.resnets[1].spatial_transformer.transformer_blocks._modules['0'].attn1 |
|
|
except AttributeError: |
|
|
logger.warning("Attention layer not found in block8. Switching attention target to 'middle' block.") |
|
|
sag_attn_target = "middle" |
|
|
org_attn_module = get_attention_module_for_block(shared.sd_model.model.diffusion_model.middle_block, '1') |
|
|
|
|
|
saved_original_selfattn_forward = org_attn_module.forward |
|
|
org_attn_module.forward = xattn_forward_log.__get__(org_attn_module, org_attn_module.__class__) |
|
|
current_sag_block_index = 2 |
|
|
|
|
|
|
|
|
elif torch.any(current_unet_kwargs['sigma'] < current_max_sigma / 2.5): |
|
|
if current_sag_block_index == 0: |
|
|
attn_module = get_attention_module_for_block(shared.sd_model.model.diffusion_model.middle_block, '1') |
|
|
attn_module.forward = saved_original_selfattn_forward |
|
|
|
|
|
|
|
|
try: |
|
|
org_attn_module = shared.sd_model.model.diffusion_model.output_blocks[5]._modules['1'].transformer_blocks._modules['0'].attn1 |
|
|
except AttributeError: |
|
|
logger.warning("Attention layer not found in block5. Switching attention target to 'middle' block.") |
|
|
sag_attn_target = "middle" |
|
|
org_attn_module = get_attention_module_for_block(shared.sd_model.model.diffusion_model.middle_block, '1') |
|
|
|
|
|
saved_original_selfattn_forward = org_attn_module.forward |
|
|
org_attn_module.forward = xattn_forward_log.__get__(org_attn_module, org_attn_module.__class__) |
|
|
current_sag_block_index = 1 |
|
|
|
|
|
if current_degraded_pred is None: |
|
|
current_degraded_pred = torch.zeros_like(params.x) |
|
|
|
|
|
def denoised_callback(self, params: CFGDenoisedParams): |
|
|
global current_degraded_pred_compensation, current_degraded_pred |
|
|
|
|
|
if not sag_enabled: |
|
|
return |
|
|
|
|
|
uncond_output = params.x[-current_batch_size:] |
|
|
original_latents = uncond_output |
|
|
global current_uncond_pred |
|
|
current_uncond_pred = uncond_output |
|
|
|
|
|
attn_map = current_selfattn_map[-current_batch_size*8:] |
|
|
bh, hw1, hw2 = attn_map.shape |
|
|
b, latent_channel, latent_h, latent_w = original_latents.shape |
|
|
h = 8 |
|
|
|
|
|
|
|
|
is_sdxl = shared.sd_model.is_sdxl if hasattr(shared.sd_model, 'is_sdxl') else False |
|
|
is_cross_attention_control = shared.sd_model.model.conditioning_key == "crossattn-adm" |
|
|
|
|
|
|
|
|
block_scale = { |
|
|
"dynamic": 2 ** current_sag_block_index, |
|
|
"block5": 2, |
|
|
"block8": 4, |
|
|
"middle": 1 |
|
|
}.get(sag_attn_target, 1) |
|
|
|
|
|
attn_map = attn_map.reshape(b, h, hw1, hw2) |
|
|
|
|
|
|
|
|
middle_layer_latent_size = [ |
|
|
math.ceil(latent_h / (h * block_scale)), |
|
|
math.ceil(latent_w / (h * block_scale)) |
|
|
] |
|
|
if middle_layer_latent_size[0] * middle_layer_latent_size[1] < hw1: |
|
|
middle_layer_latent_size = [ |
|
|
math.ceil(latent_h / ((h/2) * block_scale)), |
|
|
math.ceil(latent_w / ((h/2) * block_scale)) |
|
|
] |
|
|
|
|
|
|
|
|
reference_resolution = self.custom_resolution |
|
|
|
|
|
|
|
|
scale_factor = math.sqrt((latent_h * latent_w) / (reference_resolution / 8) ** 2) |
|
|
adaptive_threshold = sag_mask_threshold * scale_factor |
|
|
|
|
|
|
|
|
attn_mask = (attn_map.mean(1).sum(1) > adaptive_threshold).float() |
|
|
attn_mask = F.interpolate(attn_mask.unsqueeze(1).unsqueeze(1), (latent_h, latent_w), mode=="nearest-exact" if not sag_method_bilinear else "bilinear").squeeze(1) |
|
|
|
|
|
|
|
|
adaptive_sigma = sag_blur_sigma * scale_factor |
|
|
degraded_latents = adaptive_gaussian_blur_2d(original_latents, sigma=adaptive_sigma) * attn_mask.unsqueeze(1).expand_as(original_latents) + original_latents * (1 - attn_mask.unsqueeze(1).expand_as(original_latents)) |
|
|
|
|
|
renoised_degraded_latent = degraded_latents - (uncond_output - current_xin) |
|
|
|
|
|
|
|
|
if is_sdxl: |
|
|
if is_cross_attention_control: |
|
|
|
|
|
degraded_pred = params.inner_model(renoised_degraded_latent, current_unet_kwargs['sigma'], crossattn_latent=current_unet_kwargs['image_cond'], text_embedding=current_unet_kwargs['text_uncond']) |
|
|
else: |
|
|
|
|
|
cond = {**current_unet_kwargs['text_uncond'], "c_concat": [current_unet_kwargs['image_cond']]} |
|
|
degraded_pred = params.inner_model(renoised_degraded_latent, current_unet_kwargs['sigma'], cond=cond) |
|
|
else: |
|
|
|
|
|
cond = {"c_crossattn": [current_unet_kwargs['text_uncond']], "c_concat": [current_unet_kwargs['image_cond']]} |
|
|
degraded_pred = params.inner_model(renoised_degraded_latent, current_unet_kwargs['sigma'], cond=cond) |
|
|
|
|
|
current_degraded_pred_compensation = uncond_output - degraded_latents |
|
|
current_degraded_pred = degraded_pred |
|
|
|
|
|
logger.info(f"Attention map shape: {attn_map.shape}") |
|
|
logger.info(f"Original latents shape: {original_latents.shape}") |
|
|
logger.info(f"Middle layer latent size: {middle_layer_latent_size}") |
|
|
|
|
|
|
|
|
total_elements = attn_mask.numel() |
|
|
target_elements_per_batch = total_elements // b |
|
|
|
|
|
def find_closest_factors(num, target_h, target_w): |
|
|
h = target_h |
|
|
w = target_w |
|
|
while h * w != num: |
|
|
if h * w < num: |
|
|
w += 1 |
|
|
else: |
|
|
h -= 1 |
|
|
return h, w |
|
|
|
|
|
|
|
|
new_height, new_width = find_closest_factors(target_elements_per_batch, middle_layer_latent_size[0], middle_layer_latent_size[1]) |
|
|
|
|
|
attn_mask = ( |
|
|
attn_mask.reshape(b, new_height, new_width) |
|
|
.unsqueeze(1) |
|
|
.repeat(1, latent_channel, 1, 1) |
|
|
.type(attn_map.dtype) |
|
|
) |
|
|
|
|
|
def cfg_after_cfg_callback(self, params: AfterCFGCallbackParams): |
|
|
if not sag_enabled: |
|
|
return |
|
|
|
|
|
if current_degraded_pred is not None: |
|
|
logger.info(f"params.x shape: {params.x.shape}") |
|
|
logger.info(f"current_uncond_pred shape: {current_uncond_pred.shape}") |
|
|
logger.info(f"current_degraded_pred shape: {current_degraded_pred.shape}") |
|
|
logger.info(f"current_degraded_pred_compensation shape: {current_degraded_pred_compensation.shape}") |
|
|
|
|
|
|
|
|
if params.x.size() != current_uncond_pred.size() or params.x.size() != current_degraded_pred.size() or params.x.size() != current_degraded_pred_compensation.size(): |
|
|
|
|
|
filter_method = "nearest-exact" if not sag_method_bilinear else "bilinear" |
|
|
current_uncond_pred_resized = F.interpolate(current_uncond_pred, size=params.x.shape[2:], mode=filter_method) |
|
|
current_degraded_pred_resized = F.interpolate(current_degraded_pred, size=params.x.shape[2:], mode=filter_method) |
|
|
current_degraded_pred_compensation_resized = F.interpolate(current_degraded_pred_compensation, size=params.x.shape[2:], mode=filter_method) |
|
|
|
|
|
params.x = params.x + (current_uncond_pred_resized - (current_degraded_pred_resized + current_degraded_pred_compensation_resized)) * float(current_sag_guidance_scale) |
|
|
else: |
|
|
params.x = params.x + (current_uncond_pred - (current_degraded_pred + current_degraded_pred_compensation)) * float(current_sag_guidance_scale) |
|
|
|
|
|
params.output_altered = True |
|
|
|
|
|
def postprocess(self, p, processed, *args): |
|
|
enabled, scale, sag_mask_threshold, blur_sigma, method, attn, custom_resolution = args |
|
|
if enabled and hasattr(self, "saved_original_selfattn_forward"): |
|
|
attn_module = self.get_attention_module(attn) |
|
|
if attn_module is not None: |
|
|
attn_module.forward = self.saved_original_selfattn_forward |
|
|
return |
|
|
|
|
|
def get_attention_module(self, attn): |
|
|
try: |
|
|
if attn == "middle": |
|
|
return shared.sd_model.model.diffusion_model.middle_block._modules['1'].transformer_blocks._modules['0'].attn1 |
|
|
elif attn == "block5": |
|
|
return shared.sd_model.model.diffusion_model.output_blocks[5]._modules['1'].transformer_blocks._modules['0'].attn1 |
|
|
elif attn == "block8": |
|
|
if shared.sd_model.is_sdxl: |
|
|
if hasattr(shared.sd_model.model.diffusion_model.output_blocks[8], 'resnets'): |
|
|
return shared.sd_model.model.diffusion_model.output_blocks[8].resnets[1].spatial_transformer.transformer_blocks._modules['0'].attn1 |
|
|
else: |
|
|
|
|
|
return shared.sd_model.model.diffusion_model.output_blocks[8].transformer_blocks._modules['0'].attn1 |
|
|
else: |
|
|
|
|
|
return get_attention_module_for_block(shared.sd_model.model.diffusion_model.output_blocks[8], '0') |
|
|
elif attn == "dynamic": |
|
|
if current_sag_block_index == 0: |
|
|
return shared.sd_model.model.diffusion_model.middle_block._modules['1'].transformer_blocks._modules['0'].attn1 |
|
|
elif current_sag_block_index == 1: |
|
|
return shared.sd_model.model.diffusion_model.output_blocks[5]._modules['1'].transformer_blocks._modules['0'].attn1 |
|
|
elif current_sag_block_index == 2: |
|
|
if shared.sd_model.is_sdxl: |
|
|
if hasattr(shared.sd_model.model.diffusion_model.output_blocks[8], 'resnets'): |
|
|
return shared.sd_model.model.diffusion_model.output_blocks[8].resnets[1].spatial_transformer.transformer_blocks._modules['0'].attn1 |
|
|
else: |
|
|
return shared.sd_model.model.diffusion_model.output_blocks[8].transformer_blocks._modules['0'].attn1 |
|
|
else: |
|
|
return get_attention_module_for_block(shared.sd_model.model.diffusion_model.output_blocks[8], '0') |
|
|
except (KeyError, AttributeError): |
|
|
|
|
|
logger.warning(f"Attention target {attn} not found. Falling back to 'middle'.") |
|
|
return shared.sd_model.model.diffusion_model.middle_block._modules['1'].transformer_blocks._modules['0'].attn1 |
|
|
|
|
|
|
|
|
xyz_grid_support_sag.initialize(Script) |
|
|
|