Spaces:
Runtime error
Runtime error
Update gradio_app.py
Browse files- gradio_app.py +5 -4
gradio_app.py
CHANGED
|
@@ -282,8 +282,9 @@ def inference(id_image, hair_image):
|
|
| 282 |
# Balding
|
| 283 |
bald_id_path = "gradio_outputs/bald_id.png"
|
| 284 |
cv2.imwrite(bald_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
|
|
|
|
| 285 |
bald_head(bald_id_path, bald_id_path)
|
| 286 |
-
|
| 287 |
# Resolve trained model dir
|
| 288 |
trained_model_dir = os.path.abspath("trained_model") if os.path.isdir("trained_model") else None
|
| 289 |
if trained_model_dir is None and os.path.isdir("pretrain"):
|
|
@@ -365,14 +366,14 @@ def inference(id_image, hair_image):
|
|
| 365 |
).to(device)
|
| 366 |
state_dict4 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location="cpu")
|
| 367 |
Hair_Encoder.load_state_dict(state_dict4, strict=False)
|
| 368 |
-
|
| 369 |
# Run inference
|
| 370 |
log_validation(
|
| 371 |
vae, tokenizer, image_encoder, denoising_unet,
|
| 372 |
args, device, logger,
|
| 373 |
cc_projection, controlnet, Hair_Encoder
|
| 374 |
)
|
| 375 |
-
|
| 376 |
output_video = os.path.join(args.output_dir, "validation", "generated_video_0.mp4")
|
| 377 |
|
| 378 |
# Extract frames for slider preview
|
|
@@ -565,6 +566,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"),
|
|
| 565 |
|
| 566 |
|
| 567 |
if __name__ == "__main__":
|
| 568 |
-
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
|
| 569 |
|
| 570 |
|
|
|
|
| 282 |
# Balding
|
| 283 |
bald_id_path = "gradio_outputs/bald_id.png"
|
| 284 |
cv2.imwrite(bald_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
|
| 285 |
+
print("start bald", flush=True)
|
| 286 |
bald_head(bald_id_path, bald_id_path)
|
| 287 |
+
print("done bald", flush=True)
|
| 288 |
# Resolve trained model dir
|
| 289 |
trained_model_dir = os.path.abspath("trained_model") if os.path.isdir("trained_model") else None
|
| 290 |
if trained_model_dir is None and os.path.isdir("pretrain"):
|
|
|
|
| 366 |
).to(device)
|
| 367 |
state_dict4 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location="cpu")
|
| 368 |
Hair_Encoder.load_state_dict(state_dict4, strict=False)
|
| 369 |
+
print("start sd", flush=True)
|
| 370 |
# Run inference
|
| 371 |
log_validation(
|
| 372 |
vae, tokenizer, image_encoder, denoising_unet,
|
| 373 |
args, device, logger,
|
| 374 |
cc_projection, controlnet, Hair_Encoder
|
| 375 |
)
|
| 376 |
+
print("done sd", flush=True)
|
| 377 |
output_video = os.path.join(args.output_dir, "validation", "generated_video_0.mp4")
|
| 378 |
|
| 379 |
# Extract frames for slider preview
|
|
|
|
| 566 |
|
| 567 |
|
| 568 |
if __name__ == "__main__":
|
| 569 |
+
demo.queue().launch(server_name="0.0.0.0", server_port=7860,show_error=True)
|
| 570 |
|
| 571 |
|