| | |
| |
|
| | import torch |
| | import torch as th |
| | import torch.fft as fft |
| | import torch.nn.functional as F |
| | import math |
| |
|
| | def normalize(latent, target_min=None, target_max=None): |
| | """ |
| | Normalize a tensor `latent` between `target_min` and `target_max`. |
| | |
| | Args: |
| | latent (torch.Tensor): The input tensor to be normalized. |
| | target_min (float, optional): The minimum value after normalization. |
| | - When `None` min will be tensor min range value. |
| | target_max (float, optional): The maximum value after normalization. |
| | - When `None` max will be tensor max range value. |
| | |
| | Returns: |
| | torch.Tensor: The normalized tensor |
| | """ |
| | min_val = latent.min() |
| | max_val = latent.max() |
| | |
| | if target_min is None: |
| | target_min = min_val |
| | if target_max is None: |
| | target_max = max_val |
| | |
| | normalized = (latent - min_val) / (max_val - min_val) |
| | scaled = normalized * (target_max - target_min) + target_min |
| | return scaled |
| |
|
| | def hslerp(a, b, t): |
| | """ |
| | Perform Hybrid Spherical Linear Interpolation (HSLERP) between two tensors. |
| | |
| | This function combines two input tensors `a` and `b` using HSLERP, which is a specialized |
| | interpolation method for smooth transitions between orientations or colors. |
| | |
| | Args: |
| | a (tensor): The first input tensor. |
| | b (tensor): The second input tensor. |
| | t (float): The blending factor, a value between 0 and 1 that controls the interpolation. |
| | |
| | Returns: |
| | tensor: The result of HSLERP interpolation between `a` and `b`. |
| | |
| | Note: |
| | HSLERP provides smooth transitions between orientations or colors, particularly useful |
| | in applications like image processing and 3D graphics. |
| | """ |
| | if a.shape != b.shape: |
| | raise ValueError("Input tensors a and b must have the same shape.") |
| |
|
| | num_channels = a.size(1) |
| | |
| | interpolation_tensor = torch.zeros(1, num_channels, 1, 1, device=a.device, dtype=a.dtype) |
| | interpolation_tensor[0, 0, 0, 0] = 1.0 |
| |
|
| | result = (1 - t) * a + t * b |
| |
|
| | if t < 0.5: |
| | result += (torch.norm(b - a, dim=1, keepdim=True) / 6) * interpolation_tensor |
| | else: |
| | result -= (torch.norm(b - a, dim=1, keepdim=True) / 6) * interpolation_tensor |
| |
|
| | return result |
| |
|
| | blending_modes = { |
| | |
| | |
| | |
| | |
| |
|
| | |
| | 'bislerp': lambda a, b, t: normalize((1 - t) * a + t * b), |
| | |
| | 'colorize': lambda a, b, t: a + (b - a) * t, |
| | |
| | 'cosine interp': lambda a, b, t: (a + b - (a - b) * torch.cos(t * torch.tensor(math.pi))) / 2, |
| | |
| | 'cuberp': lambda a, b, t: a + (b - a) * (3 * t ** 2 - 2 * t ** 3), |
| | |
| | |
| | 'hslerp': hslerp, |
| | |
| | 'inject': lambda a, b, t: a + b * t, |
| | |
| | 'lerp': lambda a, b, t: (1 - t) * a + t * b, |
| | |
| | 'linear dodge': lambda a, b, t: normalize(a + b * t), |
| | } |
| |
|
| | mscales = { |
| | "Default": None, |
| | "Bandpass": [ |
| | (5, 0.0), |
| | (15, 1.0), |
| | (25, 0.0), |
| | ], |
| | "Low-Pass": [ |
| | (10, 1.0), |
| | ], |
| | "High-Pass": [ |
| | (10, 0.0), |
| | ], |
| | "Pass-Through": [ |
| | (10, 1.0), |
| | ], |
| | "Gaussian-Blur": [ |
| | (10, 0.5), |
| | ], |
| | "Edge-Enhancement": [ |
| | (10, 2.0), |
| | ], |
| | "Sharpen": [ |
| | (10, 1.5), |
| | ], |
| | "Multi-Bandpass": [ |
| | [(5, 0.0), (15, 1.0), (25, 0.0)], |
| | ], |
| | "Multi-Low-Pass": [ |
| | [(5, 1.0), (10, 0.5), (15, 0.2)], |
| | ], |
| | "Multi-High-Pass": [ |
| | [(5, 0.0), (10, 0.5), (15, 0.8)], |
| | ], |
| | "Multi-Pass-Through": [ |
| | [(5, 1.0), (10, 1.0), (15, 1.0)], |
| | ], |
| | "Multi-Gaussian-Blur": [ |
| | [(5, 0.5), (10, 0.8), (15, 0.2)], |
| | ], |
| | "Multi-Edge-Enhancement": [ |
| | [(5, 1.2), (10, 1.5), (15, 2.0)], |
| | ], |
| | "Multi-Sharpen": [ |
| | [(5, 1.5), (10, 2.0), (15, 2.5)], |
| | ], |
| | } |
| |
|
| | |
| | |
| | def __temp__forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): |
| | """ |
| | Apply the model to an input batch. |
| | :param x: an [N x C x ...] Tensor of inputs. |
| | :param timesteps: a 1-D batch of timesteps. |
| | :param context: conditioning plugged in via crossattn |
| | :param y: an [N] Tensor of labels, if class-conditional. |
| | :return: an [N x C x ...] Tensor of outputs. |
| | """ |
| | transformer_options["original_shape"] = list(x.shape) |
| | transformer_options["current_index"] = 0 |
| | transformer_patches = transformer_options.get("patches", {}) |
| |
|
| | assert (y is not None) == ( |
| | self.num_classes is not None |
| | ), "must specify y if and only if the model is class-conditional" |
| | hs = [] |
| | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) |
| | emb = self.time_embed(t_emb) |
| |
|
| | if self.num_classes is not None: |
| | assert y.shape[0] == x.shape[0] |
| | emb = emb + self.label_emb(y) |
| |
|
| | h = x.type(self.dtype) |
| | for id, module in enumerate(self.input_blocks): |
| | transformer_options["block"] = ("input", id) |
| | h = forward_timestep_embed(module, h, emb, context, transformer_options) |
| | if control is not None and 'input' in control and len(control['input']) > 0: |
| | ctrl = control['input'].pop() |
| | if ctrl is not None: |
| | h += ctrl |
| | hs.append(h) |
| | |
| | hsp = hs |
| | if "input_block_patch" in transformer_patches: |
| | patch = transformer_patches["input_block_patch"] |
| | for p in patch: |
| | h, hsp = p(h, hsp, transformer_options) |
| | del hsp |
| | |
| | transformer_options["block"] = ("middle", 0) |
| | h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) |
| | if control is not None and 'middle' in control and len(control['middle']) > 0: |
| | ctrl = control['middle'].pop() |
| | if ctrl is not None: |
| | h += ctrl |
| | |
| | hsp = [h] |
| | if "middle_block_patch" in transformer_patches: |
| | patch = transformer_patches["middle_block_patch"] |
| | for p in patch: |
| | h, hsp = p(h, hsp, transformer_options) |
| | del hsp |
| |
|
| | for id, module in enumerate(self.output_blocks): |
| | transformer_options["block"] = ("output", id) |
| | hsp = hs.pop() |
| | if control is not None and 'output' in control and len(control['output']) > 0: |
| | ctrl = control['output'].pop() |
| | if ctrl is not None: |
| | hsp += ctrl |
| |
|
| | if "output_block_patch" in transformer_patches: |
| | patch = transformer_patches["output_block_patch"] |
| | for p in patch: |
| | h, hsp = p(h, hsp, transformer_options) |
| |
|
| | h = th.cat([h, hsp], dim=1) |
| | del hsp |
| | if len(hs) > 0: |
| | output_shape = hs[-1].shape |
| | else: |
| | output_shape = None |
| | h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape) |
| | h = h.type(x.dtype) |
| | if self.predict_codebook_ids: |
| | return self.id_predictor(h) |
| | else: |
| | return self.out(h) |
| |
|
| | print("Patching UNetModel.forward") |
| | import comfy.ldm.modules.diffusionmodules.openaimodel |
| | from comfy.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed |
| | from comfy.ldm.modules.diffusionmodules.util import timestep_embedding |
| | comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = __temp__forward |
| | if comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward is __temp__forward: |
| | print("UNetModel.forward has been successfully patched.") |
| | else: |
| | print("UNetModel.forward patching failed.") |
| |
|
| | def Fourier_filter(x, threshold, scale, scales=None, strength=1.0): |
| | |
| | if isinstance(x, list): |
| | x = x[0] |
| | if isinstance(x, torch.Tensor): |
| | x_freq = fft.fftn(x.float(), dim=(-2, -1)) |
| | x_freq = fft.fftshift(x_freq, dim=(-2, -1)) |
| |
|
| | B, C, H, W = x_freq.shape |
| | mask = torch.ones((B, C, H, W), device=x.device) |
| |
|
| | crow, ccol = H // 2, W // 2 |
| | mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale |
| |
|
| | if scales is not None: |
| | if isinstance(scales[0], tuple): |
| | |
| | for scale_params in scales: |
| | if len(scale_params) == 2: |
| | scale_threshold, scale_value = scale_params |
| | scaled_scale_value = scale_value * strength |
| | scale_mask = torch.ones((B, C, H, W), device=x.device) |
| | scale_mask[..., crow - scale_threshold:crow + scale_threshold, ccol - scale_threshold:ccol + scale_threshold] = scaled_scale_value |
| | mask = mask + (scale_mask - mask) * strength |
| | else: |
| | |
| | for scale_params in scales: |
| | if isinstance(scale_params, list): |
| | for scale_tuple in scale_params: |
| | if len(scale_tuple) == 2: |
| | scale_threshold, scale_value = scale_tuple |
| | scaled_scale_value = scale_value * strength |
| | scale_mask = torch.ones((B, C, H, W), device=x.device) |
| | scale_mask[..., crow - scale_threshold:crow + scale_threshold, ccol - scale_threshold:ccol + scale_threshold] = scaled_scale_value |
| | mask = mask + (scale_mask - mask) * strength |
| |
|
| | x_freq = x_freq * mask |
| |
|
| | |
| | x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) |
| | x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real |
| |
|
| | return x_filtered.to(x.dtype) |
| | |
| | return x |
| |
|
| | class WAS_FreeU: |
| | @classmethod |
| | def INPUT_TYPES(s): |
| | return {"required": { |
| | "model": ("MODEL",), |
| | "target_block": (["output_block", "middle_block", "input_block", "all"],), |
| | "multiscale_mode": (list(mscales.keys()),), |
| | "multiscale_strength": ("FLOAT", {"default": 1.0, "max": 1.0, "min": 0, "step": 0.001}), |
| | "slice_b1": ("INT", {"default": 640, "min": 64, "max": 1280, "step": 1}), |
| | "slice_b2": ("INT", {"default": 320, "min": 64, "max": 640, "step": 1}), |
| | "b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.001}), |
| | "b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.001}), |
| | "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.001}), |
| | "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.001}), |
| | }, |
| | "optional": { |
| | "b1_mode": (list(blending_modes.keys()),), |
| | "b1_blend": ("FLOAT", {"default": 1.0, "max": 100, "min": 0, "step": 0.001}), |
| | "b2_mode": (list(blending_modes.keys()),), |
| | "b2_blend": ("FLOAT", {"default": 1.0, "max": 100, "min": 0, "step": 0.001}), |
| | "threshold": ("INT", {"default": 1.0, "max": 10, "min": 1, "step": 1}), |
| | "use_override_scales": (["false", "true"],), |
| | "override_scales": ("STRING", {"default": '''# OVERRIDE SCALES |
| | |
| | # Sharpen |
| | # 10, 1.5''', "multiline": True}), |
| | } |
| | } |
| | |
| | RETURN_TYPES = ("MODEL",) |
| | FUNCTION = "patch" |
| |
|
| | CATEGORY = "_for_testing" |
| |
|
| | def patch(self, model, target_block, multiscale_mode, multiscale_strength, slice_b1, slice_b2, b1, b2, s1, s2, b1_mode="add", b1_blend=1.0, b2_mode="add", b2_blend=1.0, threshold=1.0, use_override_scales="false", override_scales=""): |
| | |
| | min_slice = 64 |
| | max_slice_b1 = 1280 |
| | max_slice_b2 = 640 |
| | slice_b1 = max(min(max_slice_b1, slice_b1), min_slice) |
| | slice_b2 = max(min(min(slice_b1, max_slice_b2), slice_b2), min_slice) |
| | |
| | scales_list = [] |
| | if use_override_scales == "true": |
| | if override_scales.strip() != "": |
| | scales_str = override_scales.strip().splitlines() |
| | for line in scales_str: |
| | if not line.strip().startswith('#') and not line.strip().startswith('!') and not line.strip().startswith('//'): |
| | scale_values = line.split(',') |
| | if len(scale_values) == 2: |
| | scales_list.append((int(scale_values[0]), float(scale_values[1]))) |
| | |
| | if use_override_scales == "true" and not scales_list: |
| | print("No valid override scales found. Using default scale.") |
| | scales_list = None |
| |
|
| | scales = mscales[multiscale_mode] if use_override_scales == "false" else scales_list |
| | |
| | print(f"FreeU Plate Portions: {slice_b1} over {slice_b2}") |
| | print(f"FreeU Multi-Scales: {scales}") |
| | |
| | def block_patch(h, hsp, transformer_options): |
| | if h.shape[1] == 1280: |
| | h_t = h[:,:slice_b1] |
| | h_r = h_t * b1 |
| | h[:,:slice_b1] = blending_modes[b1_mode](h_t, h_r, b1_blend) |
| | hsp = Fourier_filter(hsp, threshold=threshold, scale=s1, scales=scales, strength=multiscale_strength) |
| | if h.shape[1] == 640: |
| | h_t = h[:,:slice_b2] |
| | h_r = h_t * b2 |
| | h[:,:slice_b2] = blending_modes[b2_mode](h_t, h_r, b2_blend) |
| | hsp = Fourier_filter(hsp, threshold=threshold, scale=s2, scales=scales, strength=multiscale_strength) |
| | return h, hsp |
| |
|
| | print(f"Patching {target_block}") |
| | |
| | m = model.clone() |
| | if target_block == "all" or target_block == "output_block": |
| | m.set_model_output_block_patch(block_patch) |
| | if target_block == "all" or target_block == "input_block": |
| | m.set_model_patch(block_patch, "input_block_patch") |
| | if target_block == "all" or target_block == "middle_block": |
| | m.set_model_patch(block_patch, "middle_block_patch") |
| | return (m, ) |
| |
|
| |
|
| | NODE_CLASS_MAPPINGS = { |
| | "FreeU (Advanced)": WAS_FreeU, |
| | } |
| |
|
| | NODE_DISPLAY_NAME_MAPPINGS = { |
| | "FreeU (Advanced)": "FreeU (Advanced Plus)", |
| | } |
| |
|