File size: 7,573 Bytes
db57927 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import os
import re
import torch
from safetensors import safe_open
from safetensors.torch import save_file
import hashlib
from io import BytesIO
import safetensors.torch
from typing import Callable, Union, Optional
re_digits = re.compile(r"\d+")
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
re_compiled = {}
suffix_conversion = {
"attentions": {},
"resnets": {
"conv1": "in_layers_2",
"conv2": "out_layers_3",
"time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection",
}
}
def convert_diffusers_name_to_compvis(key, is_sd2):
def match(match_list, regex_text):
regex = re_compiled.get(regex_text)
if regex is None:
regex = re.compile(regex_text)
re_compiled[regex_text] = regex
r = re.match(regex, key)
if not r:
return False
match_list.clear()
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
return True
m = []
if match(m, r"lora_unet_conv_in(.*)"):
return f'diffusion_model_input_blocks_0_0{m[0]}'
if match(m, r"lora_unet_conv_out(.*)"):
return f'diffusion_model_out_2{m[0]}'
if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
if is_sd2:
if 'mlp_fc1' in m[1]:
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
elif 'mlp_fc2' in m[1]:
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
else:
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
if 'mlp_fc1' in m[1]:
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
elif 'mlp_fc2' in m[1]:
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
else:
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
return key
def safetensors_hashes(tensors, metadata):
"""Precalculate the model hashes needed by sd-webui-additional-networks to
save time on indexing the model later."""
# Because writing user metadata to the file can change the result of
# sd_models.model_hash(), only retain the training metadata for purposes of
# calculating the hash, as they are meant to be immutable
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
bytes = safetensors.torch.save(tensors, metadata)
b = BytesIO(bytes)
model_hash = addnet_hash_safetensors(b)
legacy_hash = addnet_hash_legacy(b)
return model_hash, legacy_hash
def addnet_hash_legacy(b):
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
m = hashlib.sha256()
b.seek(0x100000)
m.update(b.read(0x10000))
return m.hexdigest()[0:8]
def addnet_hash_safetensors(b):
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
b.seek(0)
header = b.read(8)
n = int.from_bytes(header, "little")
offset = n + 8
b.seek(offset)
for chunk in iter(lambda: b.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
def lbw_lora(input_, output, ratios):
print("Apply LBW")
assert isinstance(input_, str)
assert isinstance(output, str)
assert isinstance(ratios, str)
assert os.path.exists(input_), f"{input_} is not exists"
assert os.path.exists(output) == False, f"{output} aleady exists"
LOAD_PATH = input_
SAVE_PATH = output
RATIOS = [float(x) for x in ratios.split(",")]
LAYERS = len(RATIOS)
assert LAYERS in [17, 26]
BLOCKID17 = [
"BASE", "IN01", "IN02", "IN04", "IN05", "IN07", "IN08", "M00",
"OUT03", "OUT04", "OUT05", "OUT06", "OUT07", "OUT08", "OUT09", "OUT10", "OUT11"]
BLOCKID26 = [
"BASE", "IN00", "IN01", "IN02", "IN03", "IN04", "IN05", "IN06", "IN07", "IN08", "IN09", "IN10", "IN11", "M00",
"OUT00", "OUT01", "OUT02", "OUT03", "OUT04", "OUT05", "OUT06", "OUT07", "OUT08", "OUT09", "OUT10", "OUT11"]
if LAYERS == 17:
RATIO_OF_ = dict(zip(BLOCKID17, RATIOS))
if LAYERS == 26:
RATIO_OF_ = dict(zip(BLOCKID26, RATIOS))
print(RATIO_OF_)
PATTERNS = [
r"^transformer_text_model_(encoder)_layers_(\d+)_.*",
r"^diffusion_model_(in)put_blocks_(\d+)_.*",
r"^diffusion_model_(middle)_block_(\d+)_.*",
r"^diffusion_model_(out)put_blocks_(\d+)_.*"]
def replacement(match):
g1 = str(match.group(1)) # encoder, in, middle, out
g2 = int(match.group(2)) # number
assert g1 in ["encoder", "in", "middle", "out"]
assert isinstance(g2, int)
if g1 == "encoder":
return "BASE"
if g1 == "middle":
return "M00"
return f"{str.upper(g1)}{g2:02}"
def compvis_name_to_blockid(compvis_name):
strings = compvis_name
for pattern in PATTERNS:
strings = re.sub(pattern, replacement, strings)
if strings != compvis_name:
break
assert strings != compvis_name
blockid = strings
if LAYERS == 17:
assert blockid in BLOCKID26, f"Incorrect layer {blockid}"
assert blockid in BLOCKID17, f"{blockid} is not included in 17 layers. May be 26 layers?"
if LAYERS == 26:
assert blockid in BLOCKID26, f"Incorrect layer {blockid}"
return blockid
with safe_open(LOAD_PATH, framework="pt", device="cpu") as f:
tensors = {}
for key in f.keys():
tensors[key] = f.get_tensor(key) # key = diffusers_name
compvis_name = convert_diffusers_name_to_compvis(key, is_sd2=False)
blockid = compvis_name_to_blockid(compvis_name)
if compvis_name.endswith("lora_up.weight"):
tensors[key] *= RATIO_OF_[blockid]
print(f"({blockid}) {compvis_name} "
f"updated with factor {RATIO_OF_[blockid]}")
save_file(tensors, SAVE_PATH)
print("Done")
|