ouclxy commited on
Commit
d5a352b
·
verified ·
1 Parent(s): 8924b08

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +35 -0
gradio_app.py CHANGED
@@ -11,6 +11,7 @@ import spaces
11
  import torch
12
  import cv2
13
  import numpy as np
 
14
 
15
  from huggingface_hub import snapshot_download
16
 
@@ -297,16 +298,26 @@ def _load_models_cpu_once():
297
  from omegaconf import OmegaConf
298
 
299
  # Config
 
 
300
  G_INFER_CONFIG = OmegaConf.load('./configs/inference/inference_v2.yaml')
 
301
 
302
  # Tokenizer / encoders / vae (CPU)
 
303
  G_TOKENIZER = AutoTokenizer.from_pretrained(G_ARGS.pretrained_model_name_or_path, subfolder="tokenizer",
304
  revision=G_ARGS.revision)
 
 
305
  G_IMAGE_ENCODER = CLIPVisionModelWithProjection.from_pretrained(G_ARGS.image_encoder, revision=G_ARGS.revision)
 
 
306
  G_VAE = AutoencoderKL.from_pretrained(G_ARGS.pretrained_model_name_or_path, subfolder="vae",
307
  revision=G_ARGS.revision)
 
308
 
309
  # UNet2D with 8-channel conv_in (CPU)
 
310
  G_UNET2 = UNet2DConditionModel.from_pretrained(
311
  G_ARGS.pretrained_model_name_or_path, subfolder="unet", revision=G_ARGS.revision, torch_dtype=torch.float32
312
  )
@@ -318,13 +329,17 @@ def _load_models_cpu_once():
318
  conv_in_8.weight[:, :4, :, :].copy_(G_UNET2.conv_in.weight)
319
  conv_in_8.bias.copy_(G_UNET2.conv_in.bias)
320
  G_UNET2.conv_in = conv_in_8
 
321
 
322
  # ControlNet (CPU)
 
323
  G_CONTROLNET = ControlNetModel.from_unet(G_UNET2)
324
  state_dict2 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model.bin"), map_location="cpu")
325
  G_CONTROLNET.load_state_dict(state_dict2, strict=False)
 
326
 
327
  # UNet3D (CPU)
 
328
  prefix = "motion_module"
329
  ckpt_num = "4140000"
330
  save_path = os.path.join(G_ARGS.model_path, f"{prefix}-{ckpt_num}.pth")
@@ -334,13 +349,17 @@ def _load_models_cpu_once():
334
  subfolder="unet",
335
  unet_additional_kwargs=G_INFER_CONFIG.unet_additional_kwargs,
336
  )
 
337
 
338
  # CC projection (CPU)
 
339
  G_CC_PROJ = CCProjection()
340
  state_dict3 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model_1.bin"), map_location="cpu")
341
  G_CC_PROJ.load_state_dict(state_dict3, strict=False)
 
342
 
343
  # Hair encoder (CPU)
 
344
  from ref_encoder.reference_unet import ref_unet
345
  G_HAIR_ENCODER = ref_unet.from_pretrained(
346
  G_ARGS.pretrained_model_name_or_path, subfolder="unet", revision=G_ARGS.revision, low_cpu_mem_usage=False,
@@ -348,6 +367,8 @@ def _load_models_cpu_once():
348
  )
349
  state_dict4 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model_2.bin"), map_location="cpu")
350
  G_HAIR_ENCODER.load_state_dict(state_dict4, strict=False)
 
 
351
 
352
 
353
  try:
@@ -381,10 +402,12 @@ def _ensure_models_loaded():
381
  with open("imgs/background.png", "rb") as f:
382
  _b64_bg = base64.b64encode(f.read()).decode()
383
 
 
384
  @spaces.GPU(duration=300)
385
  def inference(id_image, hair_image):
386
  # ZeroGPU: 强制使用 'cuda' 设备(ZeroGPU 下 torch.cuda.is_available 可能为 False)。
387
  device = torch.device("cuda")
 
388
 
389
  # 确保全局模型已加载
390
  _ensure_models_loaded()
@@ -412,8 +435,10 @@ def inference(id_image, hair_image):
412
  hair_image.save(hair_path)
413
 
414
  # Align
 
415
  aligned_id = _maybe_align_image(id_path, output_size=1024, prefer_cuda=True)
416
  aligned_hair = _maybe_align_image(hair_path, output_size=1024, prefer_cuda=True)
 
417
 
418
  aligned_id_path = "gradio_outputs/aligned_id.png"
419
  aligned_hair_path = "gradio_outputs/aligned_hair.png"
@@ -421,9 +446,11 @@ def inference(id_image, hair_image):
421
  cv2.imwrite(aligned_hair_path, cv2.cvtColor(aligned_hair, cv2.COLOR_RGB2BGR))
422
 
423
  # Balding
 
424
  bald_id_path = "gradio_outputs/bald_id.png"
425
  cv2.imwrite(bald_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
426
  bald_head(bald_id_path, bald_id_path)
 
427
 
428
  # Resolve trained model dir
429
  trained_model_dir = os.path.abspath("trained_model") if os.path.isdir("trained_model") else None
@@ -459,6 +486,7 @@ def inference(id_image, hair_image):
459
  logger = logging.getLogger(__name__)
460
 
461
  # 将已加载的全局模型迁移到 GPU
 
462
  tokenizer = G_TOKENIZER
463
  image_encoder = G_IMAGE_ENCODER.to(device)
464
  vae = G_VAE.to(device, dtype=torch.float32)
@@ -467,17 +495,21 @@ def inference(id_image, hair_image):
467
  denoising_unet = G_DENOISING_UNET.to(device)
468
  cc_projection = G_CC_PROJ.to(device)
469
  Hair_Encoder = G_HAIR_ENCODER.to(device)
 
470
 
471
  # Run inference
 
472
  log_validation(
473
  vae, tokenizer, image_encoder, denoising_unet,
474
  args, device, logger,
475
  cc_projection, controlnet, Hair_Encoder
476
  )
 
477
 
478
  output_video = os.path.join(args.output_dir, "validation", "generated_video_0.mp4")
479
 
480
  # Extract frames for slider preview
 
481
  frames_dir = os.path.join(args.output_dir, "frames", uuid.uuid4().hex)
482
  os.makedirs(frames_dir, exist_ok=True)
483
  cap = cv2.VideoCapture(output_video)
@@ -492,6 +524,9 @@ def inference(id_image, hair_image):
492
  frames_list.append(fp)
493
  idx += 1
494
  cap.release()
 
 
 
495
 
496
  max_frames = len(frames_list) if frames_list else 1
497
  first_frame = frames_list[0] if frames_list else None
 
11
  import torch
12
  import cv2
13
  import numpy as np
14
+ import time
15
 
16
  from huggingface_hub import snapshot_download
17
 
 
298
  from omegaconf import OmegaConf
299
 
300
  # Config
301
+ t0 = time.perf_counter()
302
+ t = time.perf_counter()
303
  G_INFER_CONFIG = OmegaConf.load('./configs/inference/inference_v2.yaml')
304
+ print(f"[timing:init] load infer config: {time.perf_counter()-t:.2f}s", flush=True)
305
 
306
  # Tokenizer / encoders / vae (CPU)
307
+ t = time.perf_counter()
308
  G_TOKENIZER = AutoTokenizer.from_pretrained(G_ARGS.pretrained_model_name_or_path, subfolder="tokenizer",
309
  revision=G_ARGS.revision)
310
+ print(f"[timing:init] tokenizer: {time.perf_counter()-t:.2f}s", flush=True)
311
+ t = time.perf_counter()
312
  G_IMAGE_ENCODER = CLIPVisionModelWithProjection.from_pretrained(G_ARGS.image_encoder, revision=G_ARGS.revision)
313
+ print(f"[timing:init] image_encoder: {time.perf_counter()-t:.2f}s", flush=True)
314
+ t = time.perf_counter()
315
  G_VAE = AutoencoderKL.from_pretrained(G_ARGS.pretrained_model_name_or_path, subfolder="vae",
316
  revision=G_ARGS.revision)
317
+ print(f"[timing:init] vae: {time.perf_counter()-t:.2f}s", flush=True)
318
 
319
  # UNet2D with 8-channel conv_in (CPU)
320
+ t = time.perf_counter()
321
  G_UNET2 = UNet2DConditionModel.from_pretrained(
322
  G_ARGS.pretrained_model_name_or_path, subfolder="unet", revision=G_ARGS.revision, torch_dtype=torch.float32
323
  )
 
329
  conv_in_8.weight[:, :4, :, :].copy_(G_UNET2.conv_in.weight)
330
  conv_in_8.bias.copy_(G_UNET2.conv_in.bias)
331
  G_UNET2.conv_in = conv_in_8
332
+ print(f"[timing:init] unet2 + conv_in adapt: {time.perf_counter()-t:.2f}s", flush=True)
333
 
334
  # ControlNet (CPU)
335
+ t = time.perf_counter()
336
  G_CONTROLNET = ControlNetModel.from_unet(G_UNET2)
337
  state_dict2 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model.bin"), map_location="cpu")
338
  G_CONTROLNET.load_state_dict(state_dict2, strict=False)
339
+ print(f"[timing:init] controlnet load_state: {time.perf_counter()-t:.2f}s", flush=True)
340
 
341
  # UNet3D (CPU)
342
+ t = time.perf_counter()
343
  prefix = "motion_module"
344
  ckpt_num = "4140000"
345
  save_path = os.path.join(G_ARGS.model_path, f"{prefix}-{ckpt_num}.pth")
 
349
  subfolder="unet",
350
  unet_additional_kwargs=G_INFER_CONFIG.unet_additional_kwargs,
351
  )
352
+ print(f"[timing:init] unet3d from_pretrained_2d: {time.perf_counter()-t:.2f}s", flush=True)
353
 
354
  # CC projection (CPU)
355
+ t = time.perf_counter()
356
  G_CC_PROJ = CCProjection()
357
  state_dict3 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model_1.bin"), map_location="cpu")
358
  G_CC_PROJ.load_state_dict(state_dict3, strict=False)
359
+ print(f"[timing:init] cc_projection load_state: {time.perf_counter()-t:.2f}s", flush=True)
360
 
361
  # Hair encoder (CPU)
362
+ t = time.perf_counter()
363
  from ref_encoder.reference_unet import ref_unet
364
  G_HAIR_ENCODER = ref_unet.from_pretrained(
365
  G_ARGS.pretrained_model_name_or_path, subfolder="unet", revision=G_ARGS.revision, low_cpu_mem_usage=False,
 
367
  )
368
  state_dict4 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model_2.bin"), map_location="cpu")
369
  G_HAIR_ENCODER.load_state_dict(state_dict4, strict=False)
370
+ print(f"[timing:init] hair_encoder load_state: {time.perf_counter()-t:.2f}s", flush=True)
371
+ print(f"[timing:init] total preload: {time.perf_counter()-t0:.2f}s", flush=True)
372
 
373
 
374
  try:
 
402
  with open("imgs/background.png", "rb") as f:
403
  _b64_bg = base64.b64encode(f.read()).decode()
404
 
405
+
406
  @spaces.GPU(duration=300)
407
  def inference(id_image, hair_image):
408
  # ZeroGPU: 强制使用 'cuda' 设备(ZeroGPU 下 torch.cuda.is_available 可能为 False)。
409
  device = torch.device("cuda")
410
+ t_total = time.perf_counter()
411
 
412
  # 确保全局模型已加载
413
  _ensure_models_loaded()
 
435
  hair_image.save(hair_path)
436
 
437
  # Align
438
+ t = time.perf_counter()
439
  aligned_id = _maybe_align_image(id_path, output_size=1024, prefer_cuda=True)
440
  aligned_hair = _maybe_align_image(hair_path, output_size=1024, prefer_cuda=True)
441
+ print(f"[timing] align total: {time.perf_counter()-t:.2f}s", flush=True)
442
 
443
  aligned_id_path = "gradio_outputs/aligned_id.png"
444
  aligned_hair_path = "gradio_outputs/aligned_hair.png"
 
446
  cv2.imwrite(aligned_hair_path, cv2.cvtColor(aligned_hair, cv2.COLOR_RGB2BGR))
447
 
448
  # Balding
449
+ t = time.perf_counter()
450
  bald_id_path = "gradio_outputs/bald_id.png"
451
  cv2.imwrite(bald_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
452
  bald_head(bald_id_path, bald_id_path)
453
+ print(f"[timing] bald_head: {time.perf_counter()-t:.2f}s", flush=True)
454
 
455
  # Resolve trained model dir
456
  trained_model_dir = os.path.abspath("trained_model") if os.path.isdir("trained_model") else None
 
486
  logger = logging.getLogger(__name__)
487
 
488
  # 将已加载的全局模型迁移到 GPU
489
+ t = time.perf_counter()
490
  tokenizer = G_TOKENIZER
491
  image_encoder = G_IMAGE_ENCODER.to(device)
492
  vae = G_VAE.to(device, dtype=torch.float32)
 
495
  denoising_unet = G_DENOISING_UNET.to(device)
496
  cc_projection = G_CC_PROJ.to(device)
497
  Hair_Encoder = G_HAIR_ENCODER.to(device)
498
+ print(f"[timing] move models to cuda: {time.perf_counter()-t:.2f}s", flush=True)
499
 
500
  # Run inference
501
+ t = time.perf_counter()
502
  log_validation(
503
  vae, tokenizer, image_encoder, denoising_unet,
504
  args, device, logger,
505
  cc_projection, controlnet, Hair_Encoder
506
  )
507
+ print(f"[timing] sd pipeline (log_validation): {time.perf_counter()-t:.2f}s", flush=True)
508
 
509
  output_video = os.path.join(args.output_dir, "validation", "generated_video_0.mp4")
510
 
511
  # Extract frames for slider preview
512
+ t = time.perf_counter()
513
  frames_dir = os.path.join(args.output_dir, "frames", uuid.uuid4().hex)
514
  os.makedirs(frames_dir, exist_ok=True)
515
  cap = cv2.VideoCapture(output_video)
 
524
  frames_list.append(fp)
525
  idx += 1
526
  cap.release()
527
+ print(f"[timing] extract frames: {time.perf_counter()-t:.2f}s", flush=True)
528
+
529
+ print(f"[timing] total inference: {time.perf_counter()-t_total:.2f}s", flush=True)
530
 
531
  max_frames = len(frames_list) if frames_list else 1
532
  first_frame = frames_list[0] if frames_list else None