SwarmComfyCommon / SwarmExtractLora.py
Goodis's picture
Upload 55 files
ca2a3d8 verified
import comfy.model_management
import safetensors.torch
import torch, os, comfy, json
# ATTRIBUTION: This code is a mix of code from kohya-ss, comfy, and Swarm. It would be annoying to disentangle but it's all FOSS and relatively short so it's fine.
CLAMP_QUANTILE = 0.99
def extract_lora(diff, rank):
conv2d = (len(diff.shape) == 4)
kernel_size = None if not conv2d else diff.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = diff.size()[0:2]
rank = min(rank, in_dim, out_dim)
if conv2d:
if conv2d_3x3:
diff = diff.flatten(start_dim=1)
else:
diff = diff.squeeze()
U, S, Vh = torch.linalg.svd(diff.float())
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
return (U, Vh)
def do_lora_handle(base_data, other_data, rank, prefix, require, do_bias, callback):
out_data = {}
device = comfy.model_management.get_torch_device()
for key in base_data.keys():
callback()
if key not in other_data:
continue
base_tensor = base_data[key].float()
other_tensor = other_data[key].float()
if key.startswith("clip_g"):
key = "1." + key[len("clip_g."):]
elif key.startswith("clip_l"):
key = "0." + key[len("clip_l."):]
if require:
if not key.startswith(require):
print(f"Ignore unmatched key {key} (doesn't match {require})")
continue
key = key[len(require):]
if base_tensor.shape != other_tensor.shape:
continue
diff = other_tensor.to(device) - base_tensor.to(device)
other_tensor = other_tensor.cpu()
base_tensor = base_tensor.cpu()
max_diff = float(diff.abs().max())
if max_diff < 1e-5:
print(f"discard unaltered key {key} ({max_diff})")
continue
if key.endswith(".weight"):
fixed_key = key[:-len(".weight")].replace('.', '_')
name = f"lora_{prefix}_{fixed_key}"
if len(base_tensor.shape) >= 2:
print(f"extract key {name} ({max_diff})")
out = extract_lora(diff, rank)
out_data[f"{name}.lora_up.weight"] = out[0].contiguous().half().cpu()
out_data[f"{name}.lora_down.weight"] = out[1].contiguous().half().cpu()
else:
print(f"ignore valid raw pass-through key {name} ({max_diff})")
#out_data[name] = other_tensor.contiguous().half().cpu()
elif key.endswith(".bias") and do_bias:
fixed_key = key[:-len(".bias")].replace('.', '_')
name = f"lora_{prefix}_{fixed_key}"
print(f"extract bias key {name} ({max_diff})")
out_data[f"{name}.diff_b"] = diff.contiguous().half().cpu()
return out_data
class SwarmExtractLora:
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"base_model": ("MODEL", ),
"base_model_clip": ("CLIP", ),
"other_model": ("MODEL", ),
"other_model_clip": ("CLIP", ),
"rank": ("INT", {"default": 16, "min": 1, "max": 320}),
"save_rawpath": ("STRING", {"multiline": False}),
"save_filename": ("STRING", {"multiline": False}),
"save_clip": ("BOOLEAN", {"default": True}),
"metadata": ("STRING", {"multiline": True}),
}
}
CATEGORY = "SwarmUI/models"
RETURN_TYPES = ()
FUNCTION = "extract_lora"
OUTPUT_NODE = True
DESCRIPTION = "Internal node, do not use directly - extracts a LoRA from the difference between two models. This is used by SwarmUI Utilities tab."
def extract_lora(self, base_model, base_model_clip, other_model, other_model_clip, rank, save_rawpath, save_filename, save_clip, metadata):
base_data = base_model.model_state_dict()
other_data = other_model.model_state_dict()
key_count = len(base_data.keys())
if save_clip:
key_count += len(base_model_clip.get_sd().keys())
pbar = comfy.utils.ProgressBar(key_count)
class Helper:
steps = 0
def callback(self):
self.steps += 1
pbar.update_absolute(self.steps, key_count, None)
helper = Helper()
out_data = do_lora_handle(base_data, other_data, rank, "unet", "diffusion_model.", True, lambda: helper.callback())
if save_clip:
# TODO: CLIP keys get wonky, this probably doesn't work? Model-arch-dependent.
out_clip = do_lora_handle(base_model_clip.get_sd(), other_model_clip.get_sd(), rank, "te_text_model_encoder_layers", "0.transformer.text_model.encoder.layers.", False, lambda: helper.callback())
out_clip = do_lora_handle(base_model_clip.get_sd(), other_model_clip.get_sd(), rank, "te2_text_model_encoder_layers", "1.transformer.text_model.encoder.layers.", False, lambda: helper.callback())
out_data.update(out_clip)
# Can't easily autodetect all the correct modelspec info, but at least supply some basics
out_metadata = {
"modelspec.title": f"(Extracted LoRA) {save_filename}",
"modelspec.description": f"LoRA extracted in SwarmUI"
}
if metadata:
out_metadata.update(json.loads(metadata))
path = f"{save_rawpath}{save_filename}.safetensors"
print(f"saving to path {path}")
safetensors.torch.save_file(out_data, path, metadata=out_metadata)
return ()
NODE_CLASS_MAPPINGS = {
"SwarmExtractLora": SwarmExtractLora,
}