ouclxy commited on
Commit
df5a87b
·
verified ·
1 Parent(s): 2cc7422

Update test_stablehairv2.py

Browse files
Files changed (1) hide show
  1. test_stablehairv2.py +24 -37
test_stablehairv2.py CHANGED
@@ -9,7 +9,7 @@ import cv2
9
  import torch
10
  from PIL import Image
11
  from transformers import AutoTokenizer, CLIPVisionModelWithProjection
12
- from diffusers import AutoencoderKL, UniPCMultistepScheduler, UNet2DConditionModel
13
  from src.models.unet_3d import UNet3DConditionModel
14
  from ref_encoder.reference_unet import CCProjection
15
  from ref_encoder.latent_controlnet import ControlNetModel
@@ -61,11 +61,10 @@ def _maybe_align_image(image_path: str, output_size: int, prefer_cuda: bool = Tr
61
  raise
62
  return cv2.resize(img, (output_size, output_size))
63
 
64
-
65
  def log_validation(
66
- vae, tokenizer, image_encoder, denoising_unet,
67
- args, device, logger, cc_projection,
68
- controlnet, hair_encoder, feature_extractor=None
69
  ):
70
  """
71
  Run inference on validation pairs and save generated videos.
@@ -94,21 +93,14 @@ def log_validation(
94
 
95
  print(output_dir)
96
 
97
- # Speed/length overrides via env/args
98
- import os as _os
99
- steps = int(_os.getenv('SH_STEPS', getattr(args, 'num_inference_steps', 30)))
100
- gscale = float(_os.getenv('SH_GUIDANCE', getattr(args, 'guidance_scale', 1.5)))
101
- vlen = int(_os.getenv('SH_VIDEO_LENGTH', getattr(args, 'video_length', 21)))
102
- # 统一时序长度:上下文帧数始终等于视频帧数(不再读取 SH_CONTEXT_FRAMES)
103
- cframes = int(_os.getenv('SH_CFRAMES', getattr(args, 'cframes', 12)))
104
- print("[cfg]推理步数:",steps)
105
- print("[cfg]guidance_scale:",gscale)
106
- print("[cfg]视频帧数:",vlen)
107
- print("[cfg]cframes:",cframes)
108
- # Generate camera trajectory with exactly vlen frames
109
- angles = np.linspace(0, 2 * np.pi, vlen, endpoint=False)
110
- X = 0.4 * np.sin(angles)
111
- Y = -0.05 + 0.3 * np.cos(angles)
112
  x_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(1).to(device)
113
  y_tensor = torch.tensor(Y, dtype=torch.float32).unsqueeze(1).to(device)
114
 
@@ -132,8 +124,8 @@ def log_validation(
132
  # ���¼���ͺͷͼ�� (RGB)
133
  id_image = cv2.cvtColor(cv2.imread(temp_bald_path), cv2.COLOR_BGR2RGB)
134
  id_image = cv2.resize(id_image, (512, 512))
135
-
136
- id_list = [id_image for _ in range(cframes)]
137
  if align_enabled:
138
  hair_image = _maybe_align_image(args.validation_hairs[0], output_size=align_size, prefer_cuda=prefer_cuda)
139
  prompt_img = _maybe_align_image(args.validation_ids[0], output_size=align_size, prefer_cuda=prefer_cuda)
@@ -144,17 +136,16 @@ def log_validation(
144
  prompt_img = cv2.resize(prompt_img, (512, 512))
145
  hair_image = cv2.resize(hair_image, (512, 512))
146
  prompt_img = cv2.resize(prompt_img, (512, 512))
147
-
148
  prompt_img = [prompt_img]
149
 
150
  # Perform inference and save videos
151
-
152
  for idx in range(args.num_validation_images):
153
  result = pipeline(
154
  prompt="",
155
  negative_prompt="",
156
- num_inference_steps=steps,
157
- guidance_scale=gscale,
158
  width=512,
159
  height=512,
160
  controlnet_condition=id_list,
@@ -166,8 +157,8 @@ def log_validation(
166
  poses=None,
167
  x=x_tensor,
168
  y=y_tensor,
169
- video_length=vlen,
170
- context_frames=cframes,
171
  )
172
  video = torch.cat([result.videos, result.videos], dim=0)
173
  video_path = os.path.join(output_dir, f"generated_video_{idx}.mp4")
@@ -269,15 +260,13 @@ def main():
269
  infer_config = OmegaConf.load('./configs/inference/inference_v2.yaml')
270
 
271
  unet2 = UNet2DConditionModel.from_pretrained(
272
- args.pretrained_model_name_or_path, subfolder="unet", use_safetensors=True, revision=args.revision,
273
- torch_dtype=torch.float16
274
  ).to(device)
275
- conv_in_8 = torch.nn.Conv2d(8, unet2.conv_in.out_channels, kernel_size=unet2.conv_in.kernel_size,
276
- padding=unet2.conv_in.padding)
277
  conv_in_8.requires_grad_(False)
278
  unet2.conv_in.requires_grad_(False)
279
  torch.nn.init.zeros_(conv_in_8.weight)
280
- conv_in_8.weight[:, :4, :, :].copy_(unet2.conv_in.weight)
281
  conv_in_8.bias.copy_(unet2.conv_in.bias)
282
  unet2.conv_in = conv_in_8
283
 
@@ -308,12 +297,11 @@ def main():
308
 
309
  from ref_encoder.reference_unet import ref_unet
310
  Hair_Encoder = ref_unet.from_pretrained(
311
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False,
312
- device_map=None, ignore_mismatched_sizes=True
313
  ).to(device)
314
 
315
  state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location=torch.device('cpu'))
316
- # state_dict2 = torch.load(os.path.join('/home/jichao.zhang/code/3dhair/train_sv3d/checkpoint-30000/', "pytorch_model.bin"))
317
  Hair_Encoder.load_state_dict(state_dict2, strict=False)
318
 
319
  # Run validation inference
@@ -323,6 +311,5 @@ def main():
323
  cc_projection, controlnet, Hair_Encoder
324
  )
325
 
326
-
327
  if __name__ == "__main__":
328
  main()
 
9
  import torch
10
  from PIL import Image
11
  from transformers import AutoTokenizer, CLIPVisionModelWithProjection
12
+ from diffusers import AutoencoderKL, UniPCMultistepScheduler,UNet2DConditionModel
13
  from src.models.unet_3d import UNet3DConditionModel
14
  from ref_encoder.reference_unet import CCProjection
15
  from ref_encoder.latent_controlnet import ControlNetModel
 
61
  raise
62
  return cv2.resize(img, (output_size, output_size))
63
 
 
64
  def log_validation(
65
+ vae, tokenizer, image_encoder, denoising_unet,
66
+ args, device, logger, cc_projection,
67
+ controlnet, hair_encoder, feature_extractor=None
68
  ):
69
  """
70
  Run inference on validation pairs and save generated videos.
 
93
 
94
  print(output_dir)
95
 
96
+ # Generate camera trajectory
97
+ x_coords = [0.4 * np.sin(2 * np.pi * i / 120) for i in range(60)]
98
+ y_coords = [-0.05 + 0.3 * np.cos(2 * np.pi * i / 120) for i in range(60)]
99
+ X = [x_coords[0]]
100
+ Y = [y_coords[0]]
101
+ for i in range(20):
102
+ X.append(x_coords[i * 3 + 2])
103
+ Y.append(y_coords[i * 3 + 2])
 
 
 
 
 
 
 
104
  x_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(1).to(device)
105
  y_tensor = torch.tensor(Y, dtype=torch.float32).unsqueeze(1).to(device)
106
 
 
124
  # ���¼���ͺͷͼ�� (RGB)
125
  id_image = cv2.cvtColor(cv2.imread(temp_bald_path), cv2.COLOR_BGR2RGB)
126
  id_image = cv2.resize(id_image, (512, 512))
127
+
128
+ id_list = [id_image for _ in range(12)]
129
  if align_enabled:
130
  hair_image = _maybe_align_image(args.validation_hairs[0], output_size=align_size, prefer_cuda=prefer_cuda)
131
  prompt_img = _maybe_align_image(args.validation_ids[0], output_size=align_size, prefer_cuda=prefer_cuda)
 
136
  prompt_img = cv2.resize(prompt_img, (512, 512))
137
  hair_image = cv2.resize(hair_image, (512, 512))
138
  prompt_img = cv2.resize(prompt_img, (512, 512))
139
+
140
  prompt_img = [prompt_img]
141
 
142
  # Perform inference and save videos
 
143
  for idx in range(args.num_validation_images):
144
  result = pipeline(
145
  prompt="",
146
  negative_prompt="",
147
+ num_inference_steps=30,
148
+ guidance_scale=1.5,
149
  width=512,
150
  height=512,
151
  controlnet_condition=id_list,
 
157
  poses=None,
158
  x=x_tensor,
159
  y=y_tensor,
160
+ video_length=21,
161
+ context_frames=12,
162
  )
163
  video = torch.cat([result.videos, result.videos], dim=0)
164
  video_path = os.path.join(output_dir, f"generated_video_{idx}.mp4")
 
260
  infer_config = OmegaConf.load('./configs/inference/inference_v2.yaml')
261
 
262
  unet2 = UNet2DConditionModel.from_pretrained(
263
+ args.pretrained_model_name_or_path, subfolder="unet", use_safetensors=True, revision=args.revision, torch_dtype=torch.float16
 
264
  ).to(device)
265
+ conv_in_8 = torch.nn.Conv2d(8, unet2.conv_in.out_channels, kernel_size=unet2.conv_in.kernel_size, padding=unet2.conv_in.padding)
 
266
  conv_in_8.requires_grad_(False)
267
  unet2.conv_in.requires_grad_(False)
268
  torch.nn.init.zeros_(conv_in_8.weight)
269
+ conv_in_8.weight[:,:4,:,:].copy_(unet2.conv_in.weight)
270
  conv_in_8.bias.copy_(unet2.conv_in.bias)
271
  unet2.conv_in = conv_in_8
272
 
 
297
 
298
  from ref_encoder.reference_unet import ref_unet
299
  Hair_Encoder = ref_unet.from_pretrained(
300
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False, device_map=None, ignore_mismatched_sizes=True
 
301
  ).to(device)
302
 
303
  state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location=torch.device('cpu'))
304
+ #state_dict2 = torch.load(os.path.join('/home/jichao.zhang/code/3dhair/train_sv3d/checkpoint-30000/', "pytorch_model.bin"))
305
  Hair_Encoder.load_state_dict(state_dict2, strict=False)
306
 
307
  # Run validation inference
 
311
  cc_projection, controlnet, Hair_Encoder
312
  )
313
 
 
314
  if __name__ == "__main__":
315
  main()