| | import torch |
| | import os |
| | from .resampler import Resampler |
| |
|
| | import contextlib |
| | import comfy.model_management |
| | from comfy.ldm.modules.attention import optimized_attention |
| | from comfy.clip_vision import clip_preprocess |
| |
|
| | CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) |
| |
|
| | |
| | SD_V12_CHANNELS = [320] * 4 + [640] * 4 + [1280] * 4 + [1280] * 6 + [640] * 6 + [320] * 6 + [1280] * 2 |
| | SD_XL_CHANNELS = [640] * 8 + [1280] * 40 + [1280] * 60 + [640] * 12 + [1280] * 20 |
| |
|
| | def get_file_list(path): |
| | return [f for f in os.listdir(path) if f.endswith('.bin') or f.endswith('.safetensors')] |
| |
|
| | def set_model_patch_replace(model, patch_kwargs, key): |
| | to = model.model_options["transformer_options"] |
| | if "patches_replace" not in to: |
| | to["patches_replace"] = {} |
| | if "attn2" not in to["patches_replace"]: |
| | to["patches_replace"]["attn2"] = {} |
| | if key not in to["patches_replace"]["attn2"]: |
| | patch = CrossAttentionPatch(**patch_kwargs) |
| | to["patches_replace"]["attn2"][key] = patch |
| | else: |
| | to["patches_replace"]["attn2"][key].set_new_condition(**patch_kwargs) |
| |
|
| | def load_ipadapter(ckpt_path): |
| | model = comfy.utils.load_torch_file(ckpt_path, safe_load=True) |
| |
|
| | if ckpt_path.lower().endswith(".safetensors"): |
| | st_model = {"image_proj": {}, "ip_adapter": {}} |
| | for key in model.keys(): |
| | if key.startswith("image_proj."): |
| | st_model["image_proj"][key.replace("image_proj.", "")] = model[key] |
| | elif key.startswith("ip_adapter."): |
| | st_model["ip_adapter"][key.replace("ip_adapter.", "")] = model[key] |
| | |
| | model = {"image_proj": st_model["image_proj"], "ip_adapter": {}} |
| | sorted_keys = sorted(st_model["ip_adapter"].keys(), key=lambda x: int(x.split(".")[0])) |
| | for key in sorted_keys: |
| | model["ip_adapter"][key] = st_model["ip_adapter"][key] |
| | st_model = None |
| |
|
| | if not "ip_adapter" in model.keys() or not model["ip_adapter"]: |
| | raise Exception("invalid IPAdapter model {}".format(ckpt_path)) |
| | |
| | return model |
| |
|
| |
|
| | class ImageProjModel(torch.nn.Module): |
| | """Projection Model""" |
| | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): |
| | super().__init__() |
| | |
| | self.cross_attention_dim = cross_attention_dim |
| | self.clip_extra_context_tokens = clip_extra_context_tokens |
| | self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) |
| | self.norm = torch.nn.LayerNorm(cross_attention_dim) |
| | |
| | def forward(self, image_embeds): |
| | embeds = image_embeds |
| | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) |
| | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) |
| | return clip_extra_context_tokens |
| | |
| | |
| | class To_KV(torch.nn.Module): |
| | def __init__(self, cross_attention_dim): |
| | super().__init__() |
| |
|
| | channels = SD_XL_CHANNELS if cross_attention_dim == 2048 else SD_V12_CHANNELS |
| | self.to_kvs = torch.nn.ModuleList([torch.nn.Linear(cross_attention_dim, channel, bias=False) for channel in channels]) |
| | |
| | def load_state_dict(self, state_dict): |
| | |
| | for i, key in enumerate(state_dict.keys()): |
| | self.to_kvs[i].weight.data = state_dict[key] |
| | |
| | class IPAdapterModel(torch.nn.Module): |
| | def __init__(self, state_dict, plus, cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4, sdxl_plus=False): |
| | super().__init__() |
| | self.plus = plus |
| | if self.plus: |
| | self.image_proj_model = Resampler( |
| | dim=1280 if sdxl_plus else cross_attention_dim, |
| | depth=4, |
| | dim_head=64, |
| | heads=20 if sdxl_plus else 12, |
| | num_queries=clip_extra_context_tokens, |
| | embedding_dim=clip_embeddings_dim, |
| | output_dim=cross_attention_dim, |
| | ff_mult=4 |
| | ) |
| | else: |
| | self.image_proj_model = ImageProjModel( |
| | cross_attention_dim=cross_attention_dim, |
| | clip_embeddings_dim=clip_embeddings_dim, |
| | clip_extra_context_tokens=clip_extra_context_tokens |
| | ) |
| | |
| | self.image_proj_model.load_state_dict(state_dict["image_proj"]) |
| | self.ip_layers = To_KV(cross_attention_dim) |
| | self.ip_layers.load_state_dict(state_dict["ip_adapter"]) |
| | |
| | @torch.inference_mode() |
| | def get_image_embeds(self, cond, uncond): |
| | image_prompt_embeds = self.image_proj_model(cond) |
| | uncond_image_prompt_embeds = self.image_proj_model(uncond) |
| | return image_prompt_embeds, uncond_image_prompt_embeds |
| | |
| |
|
| | class IPAdapter: |
| | @classmethod |
| | def INPUT_TYPES(s): |
| | return { |
| | "required": { |
| | "model": ("MODEL", ), |
| | "image": ("IMAGE", ), |
| | "clip_vision": ("CLIP_VISION", ), |
| | "weight": ("FLOAT", { |
| | "default": 1, |
| | "min": -1, |
| | "max": 3, |
| | "step": 0.05 |
| | }), |
| | "model_name": (get_file_list(os.path.join(CURRENT_DIR,"models")), ), |
| | "dtype": (["fp16", "fp32"], ), |
| | }, |
| | "optional": { |
| | "mask": ("MASK",), |
| | } |
| | } |
| | |
| | RETURN_TYPES = ("MODEL", "CLIP_VISION_OUTPUT") |
| | FUNCTION = "adapter" |
| | CATEGORY = "loaders" |
| |
|
| | def adapter(self, model, image, clip_vision, weight, model_name, dtype, mask=None): |
| | device = comfy.model_management.get_torch_device() |
| | self.dtype = torch.float32 if dtype == "fp32" or device.type == "mps" else torch.float16 |
| | self.weight = weight |
| |
|
| | ip_state_dict = load_ipadapter(os.path.join(CURRENT_DIR, os.path.join(CURRENT_DIR, "models", model_name))) |
| | self.plus = "latents" in ip_state_dict["image_proj"] |
| |
|
| | |
| | self.cross_attention_dim = ip_state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1] |
| |
|
| | self.sdxl = self.cross_attention_dim == 2048 |
| | self.sdxl_plus = self.sdxl and self.plus |
| |
|
| | |
| | if self.plus: |
| | self.clip_extra_context_tokens = ip_state_dict["image_proj"]["latents"].shape[1] |
| | else: |
| | self.clip_extra_context_tokens = ip_state_dict["image_proj"]["proj.weight"].shape[0] // self.cross_attention_dim |
| |
|
| | cond, uncond, outputs = self.clip_vision_encode(clip_vision, image, self.plus) |
| | self.clip_embeddings_dim = cond.shape[-1] |
| | |
| | self.ipadapter = IPAdapterModel( |
| | ip_state_dict, |
| | plus = self.plus, |
| | cross_attention_dim = self.cross_attention_dim, |
| | clip_embeddings_dim = self.clip_embeddings_dim, |
| | clip_extra_context_tokens = self.clip_extra_context_tokens, |
| | sdxl_plus = self.sdxl_plus |
| | ) |
| |
|
| | self.ipadapter.to(device, dtype=self.dtype) |
| |
|
| | self.image_emb, self.uncond_image_emb = self.ipadapter.get_image_embeds(cond.to(device, dtype=self.dtype), uncond.to(device, dtype=self.dtype)) |
| | self.image_emb = self.image_emb.to(device, dtype=self.dtype) |
| | self.uncond_image_emb = self.uncond_image_emb.to(device, dtype=self.dtype) |
| | |
| | self.cond_uncond_image_emb = None |
| | |
| | new_model = model.clone() |
| |
|
| | if mask is not None: |
| | mask = mask.squeeze().to(device) |
| |
|
| | ''' |
| | patch_name of sdv1-2: ("input" or "output" or "middle", block_id) |
| | patch_name of sdxl: ("input" or "output" or "middle", block_id, transformer_index) |
| | ''' |
| | patch_kwargs = { |
| | "number": 0, |
| | "weight": self.weight, |
| | "ipadapter": self.ipadapter, |
| | "dtype": self.dtype, |
| | "cond": self.image_emb, |
| | "uncond": self.uncond_image_emb, |
| | "mask": mask |
| | } |
| |
|
| | if not self.sdxl: |
| | for id in [1,2,4,5,7,8]: |
| | set_model_patch_replace(new_model, patch_kwargs, ("input", id)) |
| | patch_kwargs["number"] += 1 |
| | for id in [3,4,5,6,7,8,9,10,11]: |
| | set_model_patch_replace(new_model, patch_kwargs, ("output", id)) |
| | patch_kwargs["number"] += 1 |
| | set_model_patch_replace(new_model, patch_kwargs, ("middle", 0)) |
| | else: |
| | for id in [4,5,7,8]: |
| | block_indices = range(2) if id in [4, 5] else range(10) |
| | for index in block_indices: |
| | set_model_patch_replace(new_model, patch_kwargs, ("input", id, index)) |
| | patch_kwargs["number"] += 1 |
| | for id in range(6): |
| | block_indices = range(2) if id in [3, 4, 5] else range(10) |
| | for index in block_indices: |
| | set_model_patch_replace(new_model, patch_kwargs, ("output", id, index)) |
| | patch_kwargs["number"] += 1 |
| | for index in range(10): |
| | set_model_patch_replace(new_model, patch_kwargs, ("middle", 0, index)) |
| | patch_kwargs["number"] += 1 |
| |
|
| | return (new_model, outputs) |
| | |
| | def clip_vision_encode(self, clip_vision, image, plus=False): |
| |
|
| | inputs = clip_preprocess(image) |
| | comfy.model_management.load_model_gpu(clip_vision.patcher) |
| | pixel_values = inputs.to(clip_vision.load_device) |
| |
|
| | if clip_vision.dtype != torch.float32: |
| | precision_scope = torch.autocast |
| | else: |
| | precision_scope = lambda a, b: contextlib.nullcontext(a) |
| |
|
| | with precision_scope(comfy.model_management.get_autocast_device(clip_vision.load_device), torch.float32): |
| | outputs = clip_vision.model(pixel_values=pixel_values, output_hidden_states=True) |
| |
|
| | if plus: |
| | cond = outputs.hidden_states[-2] |
| | with precision_scope(comfy.model_management.get_autocast_device(clip_vision.load_device), torch.float32): |
| | uncond = clip_vision.model(torch.zeros_like(pixel_values), output_hidden_states=True).hidden_states[-2] |
| | else: |
| | cond = outputs.image_embeds |
| | uncond = torch.zeros_like(cond) |
| | for k in outputs: |
| | t = outputs[k] |
| | if k == "hidden_states": |
| | outputs[k] = None |
| | elif t is not None: |
| | outputs[k] = t.cpu() |
| | return cond, uncond, outputs |
| |
|
| |
|
| | class CrossAttentionPatch: |
| | |
| | def __init__(self, weight, ipadapter, dtype, number, cond, uncond, mask=None): |
| | self.weights = [weight] |
| | self.ipadapters = [ipadapter] |
| | self.conds = [cond] |
| | self.unconds = [uncond] |
| | self.dtype = dtype |
| | self.number = number |
| | self.masks = [mask] |
| | |
| | def set_new_condition(self, weight, ipadapter, cond, uncond, dtype, number, mask=None): |
| | self.weights.append(weight) |
| | self.ipadapters.append(ipadapter) |
| | self.conds.append(cond) |
| | self.unconds.append(uncond) |
| | self.masks.append(mask) |
| | self.dtype = dtype |
| |
|
| | def __call__(self, n, context_attn2, value_attn2, extra_options): |
| | org_dtype = n.dtype |
| | cond_or_uncond = extra_options["cond_or_uncond"] |
| | original_shape = (extra_options["original_shape"][2], extra_options["original_shape"][3]) |
| | with torch.autocast("cuda", dtype=self.dtype): |
| | q = n |
| | k = context_attn2 |
| | v = value_attn2 |
| | b, _, _ = q.shape |
| | batch_prompt = b // len(cond_or_uncond) |
| | out = optimized_attention(q, k, v, extra_options["n_heads"]) |
| |
|
| | for weight, cond, uncond, ipadapter, mask in zip(self.weights, self.conds, self.unconds, self.ipadapters, self.masks): |
| | k_cond = ipadapter.ip_layers.to_kvs[self.number*2](cond).repeat(batch_prompt, 1, 1) |
| | k_uncond = ipadapter.ip_layers.to_kvs[self.number*2](uncond).repeat(batch_prompt, 1, 1) |
| | v_cond = ipadapter.ip_layers.to_kvs[self.number*2+1](cond).repeat(batch_prompt, 1, 1) |
| | v_uncond = ipadapter.ip_layers.to_kvs[self.number*2+1](uncond).repeat(batch_prompt, 1, 1) |
| |
|
| | ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0) |
| | ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0) |
| |
|
| | |
| | ip_k = ip_k.to(dtype=q.dtype) |
| | ip_v = ip_v.to(dtype=q.dtype) |
| | |
| | ip_out = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) |
| | |
| | if mask is not None: |
| | |
| | if original_shape[0] * original_shape[1] == q.shape[1]: |
| | down_sample_rate = 1 |
| | elif (original_shape[0] // 2) * (original_shape[1] // 2) == q.shape[1]: |
| | down_sample_rate = 2 |
| | elif (original_shape[0] // 4) * (original_shape[1] // 4) == q.shape[1]: |
| | down_sample_rate = 4 |
| | else: |
| | down_sample_rate = 8 |
| | mask_downsample = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(original_shape[0] // down_sample_rate, original_shape[1] // down_sample_rate), mode="nearest").squeeze(0) |
| | mask_downsample = mask_downsample.view(1, -1, 1).repeat(out.shape[0], 1, out.shape[2]) |
| | ip_out = ip_out * mask_downsample |
| |
|
| | out = out + ip_out * weight |
| |
|
| | return out.to(dtype=org_dtype) |
| |
|
| |
|