ouclxy commited on
Commit
34b4555
·
verified ·
1 Parent(s): f662ad1

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +5 -4
gradio_app.py CHANGED
@@ -282,8 +282,9 @@ def inference(id_image, hair_image):
282
  # Balding
283
  bald_id_path = "gradio_outputs/bald_id.png"
284
  cv2.imwrite(bald_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
 
285
  bald_head(bald_id_path, bald_id_path)
286
-
287
  # Resolve trained model dir
288
  trained_model_dir = os.path.abspath("trained_model") if os.path.isdir("trained_model") else None
289
  if trained_model_dir is None and os.path.isdir("pretrain"):
@@ -365,14 +366,14 @@ def inference(id_image, hair_image):
365
  ).to(device)
366
  state_dict4 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location="cpu")
367
  Hair_Encoder.load_state_dict(state_dict4, strict=False)
368
-
369
  # Run inference
370
  log_validation(
371
  vae, tokenizer, image_encoder, denoising_unet,
372
  args, device, logger,
373
  cc_projection, controlnet, Hair_Encoder
374
  )
375
-
376
  output_video = os.path.join(args.output_dir, "validation", "generated_video_0.mp4")
377
 
378
  # Extract frames for slider preview
@@ -565,6 +566,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"),
565
 
566
 
567
  if __name__ == "__main__":
568
- demo.queue().launch(server_name="0.0.0.0", server_port=7860)
569
 
570
 
 
282
  # Balding
283
  bald_id_path = "gradio_outputs/bald_id.png"
284
  cv2.imwrite(bald_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
285
+ print("start bald", flush=True)
286
  bald_head(bald_id_path, bald_id_path)
287
+ print("done bald", flush=True)
288
  # Resolve trained model dir
289
  trained_model_dir = os.path.abspath("trained_model") if os.path.isdir("trained_model") else None
290
  if trained_model_dir is None and os.path.isdir("pretrain"):
 
366
  ).to(device)
367
  state_dict4 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location="cpu")
368
  Hair_Encoder.load_state_dict(state_dict4, strict=False)
369
+ print("start sd", flush=True)
370
  # Run inference
371
  log_validation(
372
  vae, tokenizer, image_encoder, denoising_unet,
373
  args, device, logger,
374
  cc_projection, controlnet, Hair_Encoder
375
  )
376
+ print("done sd", flush=True)
377
  output_video = os.path.join(args.output_dir, "validation", "generated_video_0.mp4")
378
 
379
  # Extract frames for slider preview
 
566
 
567
 
568
  if __name__ == "__main__":
569
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860,show_error=True)
570
 
571