File size: 4,195 Bytes
0b31b45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import torch
import re
from diffusers import WanPipeline
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
@torch.no_grad()
def encode_video(pipe: WanPipeline, video_frames):
video_tensor = pipe.video_processor.preprocess_video(video_frames).to(
dtype=pipe.dtype, device=pipe.device
)
posterior = pipe.vae.encode(video_tensor, return_dict=False)[0]
z = posterior.mode()
latents_mean = (
torch.tensor(pipe.vae.config.latents_mean)
.view(1, pipe.vae.config.z_dim, 1, 1, 1)
.to(z.device, z.dtype)
)
latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(
1, pipe.vae.config.z_dim, 1, 1, 1
).to(z.device, z.dtype)
latents = (z - latents_mean) * latents_std
return latents
@torch.no_grad()
def decode_latents(pipe: WanPipeline, latents):
latents = latents.to(pipe.vae.dtype)
latents_mean = (
torch.tensor(pipe.vae.config.latents_mean)
.view(1, pipe.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(
1, pipe.vae.config.z_dim, 1, 1, 1
).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
video = pipe.vae.decode(latents, return_dict=False)[0]
video = pipe.video_processor.postprocess_video(video, output_type="np")
return video
def name_convert(n: str):
# blocks.* attention
m = re.match(
r"blocks\.(\d+)\.(self_attn|cross_attn)\.(q|k|v|o|norm_k|norm_q)\.(weight|bias)",
n,
)
if m:
b, kind, comp, suf = m.groups()
attn = "attn1" if kind == "self_attn" else "attn2"
if comp in ("q", "k", "v"):
return f"blocks.{b}.{attn}.to_{comp}.{suf}"
if comp == "o":
return f"blocks.{b}.{attn}.to_out.0.{suf}"
return f"blocks.{b}.{attn}.{comp}.{suf}"
# blocks.* ffn
m = re.match(r"blocks\.(\d+)\.ffn\.(0|2)\.(weight|bias)", n)
if m:
b, idx, suf = m.groups()
if idx == "0":
return f"blocks.{b}.ffn.net.0.proj.{suf}"
return f"blocks.{b}.ffn.net.2.{suf}"
# blocks.* norm3/modulation
m = re.match(r"blocks\.(\d+)\.norm3\.(weight|bias)", n)
if m:
b, suf = m.groups()
return f"blocks.{b}.norm2.{suf}"
m = re.match(r"blocks\.(\d+)\.modulation$", n)
if m:
b = m.group(1)
return f"blocks.{b}.scale_shift_table"
# patch_embedding
if n.startswith("patch_embedding."):
return n
# text / time embedding
m = re.match(r"text_embedding\.(0|2)\.(weight|bias)", n)
if m:
idx, suf = m.groups()
lin = "linear_1" if idx == "0" else "linear_2"
return f"condition_embedder.text_embedder.{lin}.{suf}"
m = re.match(r"time_embedding\.(0|2)\.(weight|bias)", n)
if m:
idx, suf = m.groups()
lin = "linear_1" if idx == "0" else "linear_2"
return f"condition_embedder.time_embedder.{lin}.{suf}"
m = re.match(r"time_projection\.1\.(weight|bias)", n)
if m:
suf = m.group(1)
return f"condition_embedder.time_proj.{suf}"
# head
if n == "head.head.weight":
return "proj_out.weight"
if n == "head.head.bias":
return "proj_out.bias"
if n == "head.modulation":
return "scale_shift_table"
return n
def load_vibt_weight(
transformer, repo_name="Yuanshi/Bridge", weight_path=None, local_path=None
):
assert (
weight_path or local_path
) is not None, "Either weight_path or local_path must be provided."
tensors = load_file(local_path or hf_hub_download(repo_name, weight_path))
new_tensors = {}
for key, value in tensors.items():
key = name_convert(key)
new_tensors[key] = value
for name, param in transformer.named_parameters():
device, dtype = param.device, param.dtype
if name in new_tensors:
assert (
param.shape == new_tensors[name].shape
), f"{name}: {param.shape} != {new_tensors[name].shape}"
param.data = new_tensors[name].to(device=device, dtype=dtype)
|