dikdimon's picture
Upload models using SD-Hub extension
6e02054 verified
import math
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
import gradio as gr
from inspect import isfunction
from torch import nn, einsum
from modules.processing import StableDiffusionProcessing
import modules.scripts as scripts
from modules import shared
from modules.script_callbacks import on_cfg_denoiser, CFGDenoiserParams, CFGDenoisedParams, on_cfg_denoised, AfterCFGCallbackParams, on_cfg_after_cfg
import os
from scripts import xyz_grid_support_sag
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def adaptive_gaussian_blur_2d(img, sigma, kernel_size=None):
if kernel_size is None:
kernel_size = max(5, int(sigma * 4 + 1))
kernel_size = kernel_size + 1 if kernel_size % 2 == 0 else kernel_size
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
x_kernel = pdf / pdf.sum()
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
padding = kernel_size // 2
img = F.pad(img, (padding, padding, padding, padding), mode="reflect")
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img
class LoggedSelfAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
self.attn_probs = None
def forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
if additional_tokens is not None:
n_tokens_to_mask = additional_tokens.shape[1]
x = torch.cat([additional_tokens, x], dim=1)
if n_times_crossframe_attn_in_self:
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
k = repeat(
k[::n_times_crossframe_attn_in_self],
"b ... -> (b n) ...",
n=n_times_crossframe_attn_in_self,
)
v = repeat(
v[::n_times_crossframe_attn_in_self],
"b ... -> (b n) ...",
n=n_times_crossframe_attn_in_self,
)
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
if _ATTN_PRECISION == "fp32":
with torch.autocast(enabled=False, device_type='cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
sim = sim.softmax(dim=-1)
self.attn_probs = sim
out = einsum('b i j, b j d -> b i d', sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
if additional_tokens is not None:
out = out[:, n_tokens_to_mask:]
return self.to_out(out)
def xattn_forward_log(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
if additional_tokens is not None:
n_tokens_to_mask = additional_tokens.shape[1]
x = torch.cat([additional_tokens, x], dim=1)
if n_times_crossframe_attn_in_self:
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
k = repeat(
k[::n_times_crossframe_attn_in_self],
"b ... -> (b n) ...",
n=n_times_crossframe_attn_in_self,
)
v = repeat(
v[::n_times_crossframe_attn_in_self],
"b ... -> (b n) ...",
n=n_times_crossframe_attn_in_self,
)
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
if _ATTN_PRECISION == "fp32":
with torch.autocast(enabled=False, device_type='cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
sim = sim.softmax(dim=-1)
self.attn_probs = sim
global current_selfattn_map
current_selfattn_map = sim
sim = sim.to(dtype=v.dtype)
out = einsum('b i j, b j d -> b i d', sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
out = self.to_out(out)
if additional_tokens is not None:
out = out[:, n_tokens_to_mask:]
global current_outsize
current_outsize = out.shape[-2:]
return out
# Global variable declarations
global current_degraded_pred_compensation, current_degraded_pred
current_degraded_pred = None # Initialize as None to indicate it hasn't been set yet
def get_attention_module_for_block(block, layer_name):
fallback_order = ["0", "1"] # Fallback layer names within the block
for layer_name in fallback_order:
try:
# First, try to get the layer directly
if hasattr(block, "_modules") and layer_name in block._modules:
layer = block._modules[layer_name]
else:
# If not found directly, search deeper within the block
for name, module in block.named_modules():
if name.endswith(layer_name) and hasattr(module, "_modules"):
layer = module
break
else:
raise ValueError(f"Layer {layer_name} not found in block of type {type(block).__name__}")
# Check if the layer itself is the attention module
if hasattr(layer, "transformer_blocks"):
return layer.transformer_blocks._modules['0'].attn1
elif hasattr(layer, "attn1"):
return layer.attn1
# Handle nested Sequential modules
if isinstance(layer, torch.nn.modules.container.Sequential):
for sublayer_name, sublayer in layer._modules.items():
try:
return get_attention_module_for_block(sublayer, "0") # Recursively check for the attention module
except ValueError:
pass # If not found in this sublayer, continue to the next
# Generic attention module search (based on module names)
for name, module in layer.named_modules():
if any(attn_keyword in name.lower() for attn_keyword in ["attn", "attention", "selfattn"]):
if hasattr(module, "transformer_blocks"):
return module.transformer_blocks._modules['0'].attn1
elif hasattr(module, "attn1"):
return module.attn1
except (AttributeError, KeyError, ValueError) as e:
logger.warning(f"Error accessing attention layer '{layer_name}': {e}. Trying next fallback.")
continue # Try the next layer in the fallback order
# If all fallbacks fail, raise an error
raise ValueError(f"No valid attention layer found within block {type(block).__name__}.")
class Script(scripts.Script):
def __init__(self):
self.custom_resolution = 512
def title(self):
return "Self Attention Guidance"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with gr.Accordion('Self Attention Guidance', open=False):
with gr.Row():
enabled = gr.Checkbox(value=False, label="Enable Self Attention Guidance")
method = gr.Checkbox(value=False, label="Use bilinear interpolation")
attn = gr.Dropdown(label="Attention target", choices=["middle", "block5", "block8", "dynamic"], value="middle")
with gr.Group():
scale = gr.Slider(label='Guidance Scale', minimum=-2.0, maximum=10.0, step=0.01, value=0.75)
mask_threshold = gr.Slider(label='Mask Threshold', minimum=0.0, maximum=2.0, step=0.01, value=1.0)
blur_sigma = gr.Slider(label='Gaussian Blur Sigma', minimum=0.0, maximum=10.0, step=0.01, value=1.0)
custom_resolution = gr.Slider(label='Base Reference Resolution', minimum=256, maximum=2048, step=64, value=512, info="Default base resolution for models: SD 1.5= 512, SD 2.1= 768, SDXL= 1024")
enabled.change(fn=None, inputs=[enabled], show_progress=False)
self.infotext_fields = (
(enabled, lambda d: gr.Checkbox.update(value="SAG Guidance Enabled" in d)),
(scale, "SAG Guidance Scale"),
(mask_threshold, "SAG Mask Threshold"),
(blur_sigma, "SAG Blur Sigma"),
(method, lambda d: gr.Checkbox.update(value="SAG bilinear interpolation" in d)),
(attn, "SAG Attention Target"),
(custom_resolution, "SAG Custom Resolution"))
return [enabled, scale, mask_threshold, blur_sigma, method, attn, custom_resolution]
def reset_attention_target(self):
global sag_attn_target
sag_attn_target = self.original_attn_target
def process(self, p: StableDiffusionProcessing, *args):
enabled, scale, mask_threshold, blur_sigma, method, attn, custom_resolution = args
global sag_enabled, sag_mask_threshold, sag_blur_sigma, sag_method_bilinear, sag_attn_target, current_sag_guidance_scale
if enabled:
sag_enabled = True
sag_mask_threshold = mask_threshold
sag_blur_sigma = blur_sigma
sag_method_bilinear = method
self.original_attn_target = attn # Save the original attention target
sag_attn_target = attn
current_sag_guidance_scale = scale
self.custom_resolution = custom_resolution
if attn != "dynamic":
org_attn_module = self.get_attention_module(attn)
global saved_original_selfattn_forward
saved_original_selfattn_forward = org_attn_module.forward
org_attn_module.forward = xattn_forward_log.__get__(org_attn_module, org_attn_module.__class__)
p.extra_generation_params.update({
"SAG Guidance Enabled": enabled,
"SAG Guidance Scale": scale,
"SAG Mask Threshold": mask_threshold,
"SAG Blur Sigma": blur_sigma,
"SAG bilinear interpolation": method,
"SAG Attention Target": attn,
"SAG Base Model": base_model,
"SAG Custom Resolution": custom_resolution
})
else:
sag_enabled = False
if not hasattr(self, 'callbacks_added'):
on_cfg_denoiser(self.denoiser_callback)
on_cfg_denoised(self.denoised_callback)
on_cfg_after_cfg(self.cfg_after_cfg_callback)
self.callbacks_added = True
# Reset attention target for each image
self.reset_attention_target()
return
def denoiser_callback(self, params: CFGDenoiserParams):
if not sag_enabled:
return
global current_xin, current_batch_size, current_max_sigma, current_sag_block_index, current_unet_kwargs, sag_attn_target, current_sigma
current_batch_size = params.text_uncond.shape[0]
current_xin = params.x[-current_batch_size:]
current_uncond_emb = params.text_uncond
current_sigma = params.sigma
current_image_cond_in = params.image_cond
if params.sampling_step == 0:
current_max_sigma = current_sigma[-current_batch_size:][0]
current_sag_block_index = -1
current_unet_kwargs = {
"sigma": current_sigma[-current_batch_size:],
"image_cond": current_image_cond_in[-current_batch_size:],
"text_uncond": current_uncond_emb,
}
# Initialize current_degraded_pred when params is available
global current_degraded_pred
if current_degraded_pred is None: # Check if it's the first call
current_degraded_pred = torch.zeros_like(params.x)
#6.25 and 2.5 are decided by testing, there might be better scale number
global saved_original_selfattn_forward
if sag_attn_target == "dynamic":
if current_sag_block_index == -1:
org_attn_module = get_attention_module_for_block(shared.sd_model.model.diffusion_model.middle_block, '1')
saved_original_selfattn_forward = org_attn_module.forward
org_attn_module.forward = xattn_forward_log.__get__(org_attn_module, org_attn_module.__class__)
current_sag_block_index = 0
elif torch.any(current_unet_kwargs['sigma'] < current_max_sigma / 6.25):
if current_sag_block_index == 1:
attn_module = get_attention_module_for_block(shared.sd_model.model.diffusion_model.output_blocks[5], '0')
attn_module.forward = saved_original_selfattn_forward
# Fallback logic for block8
try:
if shared.sd_model.is_sd1:
org_attn_module = shared.sd_model.model.diffusion_model.output_blocks[8]._modules['1'].transformer_blocks._modules['0'].attn1
# Handle potential variations in SDXL architecture
if shared.sd_model.is_sdxl:
if hasattr(org_attn_module, 'resnets'):
org_attn_module = org_attn_module.resnets[1].spatial_transformer.transformer_blocks._modules['0'].attn1
except AttributeError:
logger.warning("Attention layer not found in block8. Switching attention target to 'middle' block.")
sag_attn_target = "middle" # Change to middle block
org_attn_module = get_attention_module_for_block(shared.sd_model.model.diffusion_model.middle_block, '1')
saved_original_selfattn_forward = org_attn_module.forward
org_attn_module.forward = xattn_forward_log.__get__(org_attn_module, org_attn_module.__class__)
current_sag_block_index = 2
# Handle the absence of '1' for the output_blocks[5] module
elif torch.any(current_unet_kwargs['sigma'] < current_max_sigma / 2.5):
if current_sag_block_index == 0:
attn_module = get_attention_module_for_block(shared.sd_model.model.diffusion_model.middle_block, '1')
attn_module.forward = saved_original_selfattn_forward
# Fallback logic for block5
try:
org_attn_module = shared.sd_model.model.diffusion_model.output_blocks[5]._modules['1'].transformer_blocks._modules['0'].attn1
except AttributeError:
logger.warning("Attention layer not found in block5. Switching attention target to 'middle' block.")
sag_attn_target = "middle" # Change to middle block
org_attn_module = get_attention_module_for_block(shared.sd_model.model.diffusion_model.middle_block, '1')
saved_original_selfattn_forward = org_attn_module.forward
org_attn_module.forward = xattn_forward_log.__get__(org_attn_module, org_attn_module.__class__)
current_sag_block_index = 1
if current_degraded_pred is None: # Check if it's the first call
current_degraded_pred = torch.zeros_like(params.x)
def denoised_callback(self, params: CFGDenoisedParams):
global current_degraded_pred_compensation, current_degraded_pred
if not sag_enabled:
return
uncond_output = params.x[-current_batch_size:]
original_latents = uncond_output
global current_uncond_pred
current_uncond_pred = uncond_output
attn_map = current_selfattn_map[-current_batch_size*8:]
bh, hw1, hw2 = attn_map.shape
b, latent_channel, latent_h, latent_w = original_latents.shape
h = 8
# Detect model type (SD 1.x/2.x or SDXL)
is_sdxl = shared.sd_model.is_sdxl if hasattr(shared.sd_model, 'is_sdxl') else False
is_cross_attention_control = shared.sd_model.model.conditioning_key == "crossattn-adm" # This is for SDXL
# Dynamic block scale calculation
block_scale = {
"dynamic": 2 ** current_sag_block_index,
"block5": 2,
"block8": 4,
"middle": 1
}.get(sag_attn_target, 1) # Default to 1 if sag_attn_target is invalid
attn_map = attn_map.reshape(b, h, hw1, hw2)
# Dynamic middle layer size calculation
middle_layer_latent_size = [
math.ceil(latent_h / (h * block_scale)),
math.ceil(latent_w / (h * block_scale))
]
if middle_layer_latent_size[0] * middle_layer_latent_size[1] < hw1:
middle_layer_latent_size = [
math.ceil(latent_h / ((h/2) * block_scale)),
math.ceil(latent_w / ((h/2) * block_scale))
]
# Get reference resolution
reference_resolution = self.custom_resolution
# Calculate scale factor and adaptive mask threshold
scale_factor = math.sqrt((latent_h * latent_w) / (reference_resolution / 8) ** 2)
adaptive_threshold = sag_mask_threshold * scale_factor
# Calculate attention mask and ensure correct dimensions
attn_mask = (attn_map.mean(1).sum(1) > adaptive_threshold).float()
attn_mask = F.interpolate(attn_mask.unsqueeze(1).unsqueeze(1), (latent_h, latent_w), mode=="nearest-exact" if not sag_method_bilinear else "bilinear").squeeze(1)
# Adaptive blur sigma and Gaussian blur
adaptive_sigma = sag_blur_sigma * scale_factor
degraded_latents = adaptive_gaussian_blur_2d(original_latents, sigma=adaptive_sigma) * attn_mask.unsqueeze(1).expand_as(original_latents) + original_latents * (1 - attn_mask.unsqueeze(1).expand_as(original_latents))
renoised_degraded_latent = degraded_latents - (uncond_output - current_xin)
# Handle the cond parameter for inner_model differently for SDXL
if is_sdxl:
if is_cross_attention_control:
# Use crossattn_latent and text_embedding if it is using cross attention control
degraded_pred = params.inner_model(renoised_degraded_latent, current_unet_kwargs['sigma'], crossattn_latent=current_unet_kwargs['image_cond'], text_embedding=current_unet_kwargs['text_uncond'])
else:
# If not using cross attention control, then use the normal cond dict format.
cond = {**current_unet_kwargs['text_uncond'], "c_concat": [current_unet_kwargs['image_cond']]}
degraded_pred = params.inner_model(renoised_degraded_latent, current_unet_kwargs['sigma'], cond=cond)
else:
# For SD1.5 and SD2.1
cond = {"c_crossattn": [current_unet_kwargs['text_uncond']], "c_concat": [current_unet_kwargs['image_cond']]}
degraded_pred = params.inner_model(renoised_degraded_latent, current_unet_kwargs['sigma'], cond=cond)
current_degraded_pred_compensation = uncond_output - degraded_latents
current_degraded_pred = degraded_pred
logger.info(f"Attention map shape: {attn_map.shape}")
logger.info(f"Original latents shape: {original_latents.shape}")
logger.info(f"Middle layer latent size: {middle_layer_latent_size}")
# Calculate the correct size for reshaping
total_elements = attn_mask.numel()
target_elements_per_batch = total_elements // b
def find_closest_factors(num, target_h, target_w):
h = target_h
w = target_w
while h * w != num:
if h * w < num:
w += 1
else:
h -= 1
return h, w
# Find the factors closest to middle_layer_latent_size
new_height, new_width = find_closest_factors(target_elements_per_batch, middle_layer_latent_size[0], middle_layer_latent_size[1])
attn_mask = (
attn_mask.reshape(b, new_height, new_width)
.unsqueeze(1)
.repeat(1, latent_channel, 1, 1)
.type(attn_map.dtype)
)
def cfg_after_cfg_callback(self, params: AfterCFGCallbackParams):
if not sag_enabled:
return
if current_degraded_pred is not None: # Check if current_degraded_pred is defined
logger.info(f"params.x shape: {params.x.shape}")
logger.info(f"current_uncond_pred shape: {current_uncond_pred.shape}")
logger.info(f"current_degraded_pred shape: {current_degraded_pred.shape}")
logger.info(f"current_degraded_pred_compensation shape: {current_degraded_pred_compensation.shape}")
# Ensure tensors have matching sizes
if params.x.size() != current_uncond_pred.size() or params.x.size() != current_degraded_pred.size() or params.x.size() != current_degraded_pred_compensation.size():
# Resize tensors to match params.x
filter_method = "nearest-exact" if not sag_method_bilinear else "bilinear"
current_uncond_pred_resized = F.interpolate(current_uncond_pred, size=params.x.shape[2:], mode=filter_method)
current_degraded_pred_resized = F.interpolate(current_degraded_pred, size=params.x.shape[2:], mode=filter_method)
current_degraded_pred_compensation_resized = F.interpolate(current_degraded_pred_compensation, size=params.x.shape[2:], mode=filter_method)
params.x = params.x + (current_uncond_pred_resized - (current_degraded_pred_resized + current_degraded_pred_compensation_resized)) * float(current_sag_guidance_scale)
else:
params.x = params.x + (current_uncond_pred - (current_degraded_pred + current_degraded_pred_compensation)) * float(current_sag_guidance_scale)
params.output_altered = True
def postprocess(self, p, processed, *args):
enabled, scale, sag_mask_threshold, blur_sigma, method, attn, custom_resolution = args
if enabled and hasattr(self, "saved_original_selfattn_forward"): # Check if SAG was enabled and the forward method was saved
attn_module = self.get_attention_module(attn) # Get the attention module
if attn_module is not None: # Check if we successfully got the attention module
attn_module.forward = self.saved_original_selfattn_forward # Restore the original forward method
return
def get_attention_module(self, attn):
try:
if attn == "middle":
return shared.sd_model.model.diffusion_model.middle_block._modules['1'].transformer_blocks._modules['0'].attn1
elif attn == "block5":
return shared.sd_model.model.diffusion_model.output_blocks[5]._modules['1'].transformer_blocks._modules['0'].attn1
elif attn == "block8":
if shared.sd_model.is_sdxl:
if hasattr(shared.sd_model.model.diffusion_model.output_blocks[8], 'resnets'):
return shared.sd_model.model.diffusion_model.output_blocks[8].resnets[1].spatial_transformer.transformer_blocks._modules['0'].attn1
else:
# If spatial_transformer not present, use standard SDXL attention block location
return shared.sd_model.model.diffusion_model.output_blocks[8].transformer_blocks._modules['0'].attn1
else:
# Non-SDXL logic
return get_attention_module_for_block(shared.sd_model.model.diffusion_model.output_blocks[8], '0')
elif attn == "dynamic":
if current_sag_block_index == 0:
return shared.sd_model.model.diffusion_model.middle_block._modules['1'].transformer_blocks._modules['0'].attn1
elif current_sag_block_index == 1:
return shared.sd_model.model.diffusion_model.output_blocks[5]._modules['1'].transformer_blocks._modules['0'].attn1
elif current_sag_block_index == 2:
if shared.sd_model.is_sdxl:
if hasattr(shared.sd_model.model.diffusion_model.output_blocks[8], 'resnets'):
return shared.sd_model.model.diffusion_model.output_blocks[8].resnets[1].spatial_transformer.transformer_blocks._modules['0'].attn1
else:
return shared.sd_model.model.diffusion_model.output_blocks[8].transformer_blocks._modules['0'].attn1
else:
return get_attention_module_for_block(shared.sd_model.model.diffusion_model.output_blocks[8], '0')
except (KeyError, AttributeError):
# If the specific block can't be accessed, gracefully fall back to the middle block
logger.warning(f"Attention target {attn} not found. Falling back to 'middle'.")
return shared.sd_model.model.diffusion_model.middle_block._modules['1'].transformer_blocks._modules['0'].attn1
# Initialize the script (if needed)
xyz_grid_support_sag.initialize(Script)