Spaces:
Runtime error
Runtime error
Update test_stablehairv2.py
Browse files- test_stablehairv2.py +24 -37
test_stablehairv2.py
CHANGED
|
@@ -9,7 +9,7 @@ import cv2
|
|
| 9 |
import torch
|
| 10 |
from PIL import Image
|
| 11 |
from transformers import AutoTokenizer, CLIPVisionModelWithProjection
|
| 12 |
-
from diffusers import AutoencoderKL, UniPCMultistepScheduler,
|
| 13 |
from src.models.unet_3d import UNet3DConditionModel
|
| 14 |
from ref_encoder.reference_unet import CCProjection
|
| 15 |
from ref_encoder.latent_controlnet import ControlNetModel
|
|
@@ -61,11 +61,10 @@ def _maybe_align_image(image_path: str, output_size: int, prefer_cuda: bool = Tr
|
|
| 61 |
raise
|
| 62 |
return cv2.resize(img, (output_size, output_size))
|
| 63 |
|
| 64 |
-
|
| 65 |
def log_validation(
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
):
|
| 70 |
"""
|
| 71 |
Run inference on validation pairs and save generated videos.
|
|
@@ -94,21 +93,14 @@ def log_validation(
|
|
| 94 |
|
| 95 |
print(output_dir)
|
| 96 |
|
| 97 |
-
#
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
print("[cfg]guidance_scale:",gscale)
|
| 106 |
-
print("[cfg]视频帧数:",vlen)
|
| 107 |
-
print("[cfg]cframes:",cframes)
|
| 108 |
-
# Generate camera trajectory with exactly vlen frames
|
| 109 |
-
angles = np.linspace(0, 2 * np.pi, vlen, endpoint=False)
|
| 110 |
-
X = 0.4 * np.sin(angles)
|
| 111 |
-
Y = -0.05 + 0.3 * np.cos(angles)
|
| 112 |
x_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(1).to(device)
|
| 113 |
y_tensor = torch.tensor(Y, dtype=torch.float32).unsqueeze(1).to(device)
|
| 114 |
|
|
@@ -132,8 +124,8 @@ def log_validation(
|
|
| 132 |
# ���¼���ͺͷͼ�� (RGB)
|
| 133 |
id_image = cv2.cvtColor(cv2.imread(temp_bald_path), cv2.COLOR_BGR2RGB)
|
| 134 |
id_image = cv2.resize(id_image, (512, 512))
|
| 135 |
-
|
| 136 |
-
id_list = [id_image for _ in range(
|
| 137 |
if align_enabled:
|
| 138 |
hair_image = _maybe_align_image(args.validation_hairs[0], output_size=align_size, prefer_cuda=prefer_cuda)
|
| 139 |
prompt_img = _maybe_align_image(args.validation_ids[0], output_size=align_size, prefer_cuda=prefer_cuda)
|
|
@@ -144,17 +136,16 @@ def log_validation(
|
|
| 144 |
prompt_img = cv2.resize(prompt_img, (512, 512))
|
| 145 |
hair_image = cv2.resize(hair_image, (512, 512))
|
| 146 |
prompt_img = cv2.resize(prompt_img, (512, 512))
|
| 147 |
-
|
| 148 |
prompt_img = [prompt_img]
|
| 149 |
|
| 150 |
# Perform inference and save videos
|
| 151 |
-
|
| 152 |
for idx in range(args.num_validation_images):
|
| 153 |
result = pipeline(
|
| 154 |
prompt="",
|
| 155 |
negative_prompt="",
|
| 156 |
-
num_inference_steps=
|
| 157 |
-
guidance_scale=
|
| 158 |
width=512,
|
| 159 |
height=512,
|
| 160 |
controlnet_condition=id_list,
|
|
@@ -166,8 +157,8 @@ def log_validation(
|
|
| 166 |
poses=None,
|
| 167 |
x=x_tensor,
|
| 168 |
y=y_tensor,
|
| 169 |
-
video_length=
|
| 170 |
-
context_frames=
|
| 171 |
)
|
| 172 |
video = torch.cat([result.videos, result.videos], dim=0)
|
| 173 |
video_path = os.path.join(output_dir, f"generated_video_{idx}.mp4")
|
|
@@ -269,15 +260,13 @@ def main():
|
|
| 269 |
infer_config = OmegaConf.load('./configs/inference/inference_v2.yaml')
|
| 270 |
|
| 271 |
unet2 = UNet2DConditionModel.from_pretrained(
|
| 272 |
-
args.pretrained_model_name_or_path, subfolder="unet", use_safetensors=True, revision=args.revision,
|
| 273 |
-
torch_dtype=torch.float16
|
| 274 |
).to(device)
|
| 275 |
-
conv_in_8 = torch.nn.Conv2d(8, unet2.conv_in.out_channels, kernel_size=unet2.conv_in.kernel_size,
|
| 276 |
-
padding=unet2.conv_in.padding)
|
| 277 |
conv_in_8.requires_grad_(False)
|
| 278 |
unet2.conv_in.requires_grad_(False)
|
| 279 |
torch.nn.init.zeros_(conv_in_8.weight)
|
| 280 |
-
conv_in_8.weight[
|
| 281 |
conv_in_8.bias.copy_(unet2.conv_in.bias)
|
| 282 |
unet2.conv_in = conv_in_8
|
| 283 |
|
|
@@ -308,12 +297,11 @@ def main():
|
|
| 308 |
|
| 309 |
from ref_encoder.reference_unet import ref_unet
|
| 310 |
Hair_Encoder = ref_unet.from_pretrained(
|
| 311 |
-
|
| 312 |
-
device_map=None, ignore_mismatched_sizes=True
|
| 313 |
).to(device)
|
| 314 |
|
| 315 |
state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location=torch.device('cpu'))
|
| 316 |
-
#
|
| 317 |
Hair_Encoder.load_state_dict(state_dict2, strict=False)
|
| 318 |
|
| 319 |
# Run validation inference
|
|
@@ -323,6 +311,5 @@ def main():
|
|
| 323 |
cc_projection, controlnet, Hair_Encoder
|
| 324 |
)
|
| 325 |
|
| 326 |
-
|
| 327 |
if __name__ == "__main__":
|
| 328 |
main()
|
|
|
|
| 9 |
import torch
|
| 10 |
from PIL import Image
|
| 11 |
from transformers import AutoTokenizer, CLIPVisionModelWithProjection
|
| 12 |
+
from diffusers import AutoencoderKL, UniPCMultistepScheduler,UNet2DConditionModel
|
| 13 |
from src.models.unet_3d import UNet3DConditionModel
|
| 14 |
from ref_encoder.reference_unet import CCProjection
|
| 15 |
from ref_encoder.latent_controlnet import ControlNetModel
|
|
|
|
| 61 |
raise
|
| 62 |
return cv2.resize(img, (output_size, output_size))
|
| 63 |
|
|
|
|
| 64 |
def log_validation(
|
| 65 |
+
vae, tokenizer, image_encoder, denoising_unet,
|
| 66 |
+
args, device, logger, cc_projection,
|
| 67 |
+
controlnet, hair_encoder, feature_extractor=None
|
| 68 |
):
|
| 69 |
"""
|
| 70 |
Run inference on validation pairs and save generated videos.
|
|
|
|
| 93 |
|
| 94 |
print(output_dir)
|
| 95 |
|
| 96 |
+
# Generate camera trajectory
|
| 97 |
+
x_coords = [0.4 * np.sin(2 * np.pi * i / 120) for i in range(60)]
|
| 98 |
+
y_coords = [-0.05 + 0.3 * np.cos(2 * np.pi * i / 120) for i in range(60)]
|
| 99 |
+
X = [x_coords[0]]
|
| 100 |
+
Y = [y_coords[0]]
|
| 101 |
+
for i in range(20):
|
| 102 |
+
X.append(x_coords[i * 3 + 2])
|
| 103 |
+
Y.append(y_coords[i * 3 + 2])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
x_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(1).to(device)
|
| 105 |
y_tensor = torch.tensor(Y, dtype=torch.float32).unsqueeze(1).to(device)
|
| 106 |
|
|
|
|
| 124 |
# ���¼���ͺͷͼ�� (RGB)
|
| 125 |
id_image = cv2.cvtColor(cv2.imread(temp_bald_path), cv2.COLOR_BGR2RGB)
|
| 126 |
id_image = cv2.resize(id_image, (512, 512))
|
| 127 |
+
|
| 128 |
+
id_list = [id_image for _ in range(12)]
|
| 129 |
if align_enabled:
|
| 130 |
hair_image = _maybe_align_image(args.validation_hairs[0], output_size=align_size, prefer_cuda=prefer_cuda)
|
| 131 |
prompt_img = _maybe_align_image(args.validation_ids[0], output_size=align_size, prefer_cuda=prefer_cuda)
|
|
|
|
| 136 |
prompt_img = cv2.resize(prompt_img, (512, 512))
|
| 137 |
hair_image = cv2.resize(hair_image, (512, 512))
|
| 138 |
prompt_img = cv2.resize(prompt_img, (512, 512))
|
| 139 |
+
|
| 140 |
prompt_img = [prompt_img]
|
| 141 |
|
| 142 |
# Perform inference and save videos
|
|
|
|
| 143 |
for idx in range(args.num_validation_images):
|
| 144 |
result = pipeline(
|
| 145 |
prompt="",
|
| 146 |
negative_prompt="",
|
| 147 |
+
num_inference_steps=30,
|
| 148 |
+
guidance_scale=1.5,
|
| 149 |
width=512,
|
| 150 |
height=512,
|
| 151 |
controlnet_condition=id_list,
|
|
|
|
| 157 |
poses=None,
|
| 158 |
x=x_tensor,
|
| 159 |
y=y_tensor,
|
| 160 |
+
video_length=21,
|
| 161 |
+
context_frames=12,
|
| 162 |
)
|
| 163 |
video = torch.cat([result.videos, result.videos], dim=0)
|
| 164 |
video_path = os.path.join(output_dir, f"generated_video_{idx}.mp4")
|
|
|
|
| 260 |
infer_config = OmegaConf.load('./configs/inference/inference_v2.yaml')
|
| 261 |
|
| 262 |
unet2 = UNet2DConditionModel.from_pretrained(
|
| 263 |
+
args.pretrained_model_name_or_path, subfolder="unet", use_safetensors=True, revision=args.revision, torch_dtype=torch.float16
|
|
|
|
| 264 |
).to(device)
|
| 265 |
+
conv_in_8 = torch.nn.Conv2d(8, unet2.conv_in.out_channels, kernel_size=unet2.conv_in.kernel_size, padding=unet2.conv_in.padding)
|
|
|
|
| 266 |
conv_in_8.requires_grad_(False)
|
| 267 |
unet2.conv_in.requires_grad_(False)
|
| 268 |
torch.nn.init.zeros_(conv_in_8.weight)
|
| 269 |
+
conv_in_8.weight[:,:4,:,:].copy_(unet2.conv_in.weight)
|
| 270 |
conv_in_8.bias.copy_(unet2.conv_in.bias)
|
| 271 |
unet2.conv_in = conv_in_8
|
| 272 |
|
|
|
|
| 297 |
|
| 298 |
from ref_encoder.reference_unet import ref_unet
|
| 299 |
Hair_Encoder = ref_unet.from_pretrained(
|
| 300 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False, device_map=None, ignore_mismatched_sizes=True
|
|
|
|
| 301 |
).to(device)
|
| 302 |
|
| 303 |
state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location=torch.device('cpu'))
|
| 304 |
+
#state_dict2 = torch.load(os.path.join('/home/jichao.zhang/code/3dhair/train_sv3d/checkpoint-30000/', "pytorch_model.bin"))
|
| 305 |
Hair_Encoder.load_state_dict(state_dict2, strict=False)
|
| 306 |
|
| 307 |
# Run validation inference
|
|
|
|
| 311 |
cc_projection, controlnet, Hair_Encoder
|
| 312 |
)
|
| 313 |
|
|
|
|
| 314 |
if __name__ == "__main__":
|
| 315 |
main()
|