Spaces:
Runtime error
Runtime error
Update gradio_app.py
Browse files- gradio_app.py +35 -0
gradio_app.py
CHANGED
|
@@ -11,6 +11,7 @@ import spaces
|
|
| 11 |
import torch
|
| 12 |
import cv2
|
| 13 |
import numpy as np
|
|
|
|
| 14 |
|
| 15 |
from huggingface_hub import snapshot_download
|
| 16 |
|
|
@@ -297,16 +298,26 @@ def _load_models_cpu_once():
|
|
| 297 |
from omegaconf import OmegaConf
|
| 298 |
|
| 299 |
# Config
|
|
|
|
|
|
|
| 300 |
G_INFER_CONFIG = OmegaConf.load('./configs/inference/inference_v2.yaml')
|
|
|
|
| 301 |
|
| 302 |
# Tokenizer / encoders / vae (CPU)
|
|
|
|
| 303 |
G_TOKENIZER = AutoTokenizer.from_pretrained(G_ARGS.pretrained_model_name_or_path, subfolder="tokenizer",
|
| 304 |
revision=G_ARGS.revision)
|
|
|
|
|
|
|
| 305 |
G_IMAGE_ENCODER = CLIPVisionModelWithProjection.from_pretrained(G_ARGS.image_encoder, revision=G_ARGS.revision)
|
|
|
|
|
|
|
| 306 |
G_VAE = AutoencoderKL.from_pretrained(G_ARGS.pretrained_model_name_or_path, subfolder="vae",
|
| 307 |
revision=G_ARGS.revision)
|
|
|
|
| 308 |
|
| 309 |
# UNet2D with 8-channel conv_in (CPU)
|
|
|
|
| 310 |
G_UNET2 = UNet2DConditionModel.from_pretrained(
|
| 311 |
G_ARGS.pretrained_model_name_or_path, subfolder="unet", revision=G_ARGS.revision, torch_dtype=torch.float32
|
| 312 |
)
|
|
@@ -318,13 +329,17 @@ def _load_models_cpu_once():
|
|
| 318 |
conv_in_8.weight[:, :4, :, :].copy_(G_UNET2.conv_in.weight)
|
| 319 |
conv_in_8.bias.copy_(G_UNET2.conv_in.bias)
|
| 320 |
G_UNET2.conv_in = conv_in_8
|
|
|
|
| 321 |
|
| 322 |
# ControlNet (CPU)
|
|
|
|
| 323 |
G_CONTROLNET = ControlNetModel.from_unet(G_UNET2)
|
| 324 |
state_dict2 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model.bin"), map_location="cpu")
|
| 325 |
G_CONTROLNET.load_state_dict(state_dict2, strict=False)
|
|
|
|
| 326 |
|
| 327 |
# UNet3D (CPU)
|
|
|
|
| 328 |
prefix = "motion_module"
|
| 329 |
ckpt_num = "4140000"
|
| 330 |
save_path = os.path.join(G_ARGS.model_path, f"{prefix}-{ckpt_num}.pth")
|
|
@@ -334,13 +349,17 @@ def _load_models_cpu_once():
|
|
| 334 |
subfolder="unet",
|
| 335 |
unet_additional_kwargs=G_INFER_CONFIG.unet_additional_kwargs,
|
| 336 |
)
|
|
|
|
| 337 |
|
| 338 |
# CC projection (CPU)
|
|
|
|
| 339 |
G_CC_PROJ = CCProjection()
|
| 340 |
state_dict3 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model_1.bin"), map_location="cpu")
|
| 341 |
G_CC_PROJ.load_state_dict(state_dict3, strict=False)
|
|
|
|
| 342 |
|
| 343 |
# Hair encoder (CPU)
|
|
|
|
| 344 |
from ref_encoder.reference_unet import ref_unet
|
| 345 |
G_HAIR_ENCODER = ref_unet.from_pretrained(
|
| 346 |
G_ARGS.pretrained_model_name_or_path, subfolder="unet", revision=G_ARGS.revision, low_cpu_mem_usage=False,
|
|
@@ -348,6 +367,8 @@ def _load_models_cpu_once():
|
|
| 348 |
)
|
| 349 |
state_dict4 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model_2.bin"), map_location="cpu")
|
| 350 |
G_HAIR_ENCODER.load_state_dict(state_dict4, strict=False)
|
|
|
|
|
|
|
| 351 |
|
| 352 |
|
| 353 |
try:
|
|
@@ -381,10 +402,12 @@ def _ensure_models_loaded():
|
|
| 381 |
with open("imgs/background.png", "rb") as f:
|
| 382 |
_b64_bg = base64.b64encode(f.read()).decode()
|
| 383 |
|
|
|
|
| 384 |
@spaces.GPU(duration=300)
|
| 385 |
def inference(id_image, hair_image):
|
| 386 |
# ZeroGPU: 强制使用 'cuda' 设备(ZeroGPU 下 torch.cuda.is_available 可能为 False)。
|
| 387 |
device = torch.device("cuda")
|
|
|
|
| 388 |
|
| 389 |
# 确保全局模型已加载
|
| 390 |
_ensure_models_loaded()
|
|
@@ -412,8 +435,10 @@ def inference(id_image, hair_image):
|
|
| 412 |
hair_image.save(hair_path)
|
| 413 |
|
| 414 |
# Align
|
|
|
|
| 415 |
aligned_id = _maybe_align_image(id_path, output_size=1024, prefer_cuda=True)
|
| 416 |
aligned_hair = _maybe_align_image(hair_path, output_size=1024, prefer_cuda=True)
|
|
|
|
| 417 |
|
| 418 |
aligned_id_path = "gradio_outputs/aligned_id.png"
|
| 419 |
aligned_hair_path = "gradio_outputs/aligned_hair.png"
|
|
@@ -421,9 +446,11 @@ def inference(id_image, hair_image):
|
|
| 421 |
cv2.imwrite(aligned_hair_path, cv2.cvtColor(aligned_hair, cv2.COLOR_RGB2BGR))
|
| 422 |
|
| 423 |
# Balding
|
|
|
|
| 424 |
bald_id_path = "gradio_outputs/bald_id.png"
|
| 425 |
cv2.imwrite(bald_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
|
| 426 |
bald_head(bald_id_path, bald_id_path)
|
|
|
|
| 427 |
|
| 428 |
# Resolve trained model dir
|
| 429 |
trained_model_dir = os.path.abspath("trained_model") if os.path.isdir("trained_model") else None
|
|
@@ -459,6 +486,7 @@ def inference(id_image, hair_image):
|
|
| 459 |
logger = logging.getLogger(__name__)
|
| 460 |
|
| 461 |
# 将已加载的全局模型迁移到 GPU
|
|
|
|
| 462 |
tokenizer = G_TOKENIZER
|
| 463 |
image_encoder = G_IMAGE_ENCODER.to(device)
|
| 464 |
vae = G_VAE.to(device, dtype=torch.float32)
|
|
@@ -467,17 +495,21 @@ def inference(id_image, hair_image):
|
|
| 467 |
denoising_unet = G_DENOISING_UNET.to(device)
|
| 468 |
cc_projection = G_CC_PROJ.to(device)
|
| 469 |
Hair_Encoder = G_HAIR_ENCODER.to(device)
|
|
|
|
| 470 |
|
| 471 |
# Run inference
|
|
|
|
| 472 |
log_validation(
|
| 473 |
vae, tokenizer, image_encoder, denoising_unet,
|
| 474 |
args, device, logger,
|
| 475 |
cc_projection, controlnet, Hair_Encoder
|
| 476 |
)
|
|
|
|
| 477 |
|
| 478 |
output_video = os.path.join(args.output_dir, "validation", "generated_video_0.mp4")
|
| 479 |
|
| 480 |
# Extract frames for slider preview
|
|
|
|
| 481 |
frames_dir = os.path.join(args.output_dir, "frames", uuid.uuid4().hex)
|
| 482 |
os.makedirs(frames_dir, exist_ok=True)
|
| 483 |
cap = cv2.VideoCapture(output_video)
|
|
@@ -492,6 +524,9 @@ def inference(id_image, hair_image):
|
|
| 492 |
frames_list.append(fp)
|
| 493 |
idx += 1
|
| 494 |
cap.release()
|
|
|
|
|
|
|
|
|
|
| 495 |
|
| 496 |
max_frames = len(frames_list) if frames_list else 1
|
| 497 |
first_frame = frames_list[0] if frames_list else None
|
|
|
|
| 11 |
import torch
|
| 12 |
import cv2
|
| 13 |
import numpy as np
|
| 14 |
+
import time
|
| 15 |
|
| 16 |
from huggingface_hub import snapshot_download
|
| 17 |
|
|
|
|
| 298 |
from omegaconf import OmegaConf
|
| 299 |
|
| 300 |
# Config
|
| 301 |
+
t0 = time.perf_counter()
|
| 302 |
+
t = time.perf_counter()
|
| 303 |
G_INFER_CONFIG = OmegaConf.load('./configs/inference/inference_v2.yaml')
|
| 304 |
+
print(f"[timing:init] load infer config: {time.perf_counter()-t:.2f}s", flush=True)
|
| 305 |
|
| 306 |
# Tokenizer / encoders / vae (CPU)
|
| 307 |
+
t = time.perf_counter()
|
| 308 |
G_TOKENIZER = AutoTokenizer.from_pretrained(G_ARGS.pretrained_model_name_or_path, subfolder="tokenizer",
|
| 309 |
revision=G_ARGS.revision)
|
| 310 |
+
print(f"[timing:init] tokenizer: {time.perf_counter()-t:.2f}s", flush=True)
|
| 311 |
+
t = time.perf_counter()
|
| 312 |
G_IMAGE_ENCODER = CLIPVisionModelWithProjection.from_pretrained(G_ARGS.image_encoder, revision=G_ARGS.revision)
|
| 313 |
+
print(f"[timing:init] image_encoder: {time.perf_counter()-t:.2f}s", flush=True)
|
| 314 |
+
t = time.perf_counter()
|
| 315 |
G_VAE = AutoencoderKL.from_pretrained(G_ARGS.pretrained_model_name_or_path, subfolder="vae",
|
| 316 |
revision=G_ARGS.revision)
|
| 317 |
+
print(f"[timing:init] vae: {time.perf_counter()-t:.2f}s", flush=True)
|
| 318 |
|
| 319 |
# UNet2D with 8-channel conv_in (CPU)
|
| 320 |
+
t = time.perf_counter()
|
| 321 |
G_UNET2 = UNet2DConditionModel.from_pretrained(
|
| 322 |
G_ARGS.pretrained_model_name_or_path, subfolder="unet", revision=G_ARGS.revision, torch_dtype=torch.float32
|
| 323 |
)
|
|
|
|
| 329 |
conv_in_8.weight[:, :4, :, :].copy_(G_UNET2.conv_in.weight)
|
| 330 |
conv_in_8.bias.copy_(G_UNET2.conv_in.bias)
|
| 331 |
G_UNET2.conv_in = conv_in_8
|
| 332 |
+
print(f"[timing:init] unet2 + conv_in adapt: {time.perf_counter()-t:.2f}s", flush=True)
|
| 333 |
|
| 334 |
# ControlNet (CPU)
|
| 335 |
+
t = time.perf_counter()
|
| 336 |
G_CONTROLNET = ControlNetModel.from_unet(G_UNET2)
|
| 337 |
state_dict2 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model.bin"), map_location="cpu")
|
| 338 |
G_CONTROLNET.load_state_dict(state_dict2, strict=False)
|
| 339 |
+
print(f"[timing:init] controlnet load_state: {time.perf_counter()-t:.2f}s", flush=True)
|
| 340 |
|
| 341 |
# UNet3D (CPU)
|
| 342 |
+
t = time.perf_counter()
|
| 343 |
prefix = "motion_module"
|
| 344 |
ckpt_num = "4140000"
|
| 345 |
save_path = os.path.join(G_ARGS.model_path, f"{prefix}-{ckpt_num}.pth")
|
|
|
|
| 349 |
subfolder="unet",
|
| 350 |
unet_additional_kwargs=G_INFER_CONFIG.unet_additional_kwargs,
|
| 351 |
)
|
| 352 |
+
print(f"[timing:init] unet3d from_pretrained_2d: {time.perf_counter()-t:.2f}s", flush=True)
|
| 353 |
|
| 354 |
# CC projection (CPU)
|
| 355 |
+
t = time.perf_counter()
|
| 356 |
G_CC_PROJ = CCProjection()
|
| 357 |
state_dict3 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model_1.bin"), map_location="cpu")
|
| 358 |
G_CC_PROJ.load_state_dict(state_dict3, strict=False)
|
| 359 |
+
print(f"[timing:init] cc_projection load_state: {time.perf_counter()-t:.2f}s", flush=True)
|
| 360 |
|
| 361 |
# Hair encoder (CPU)
|
| 362 |
+
t = time.perf_counter()
|
| 363 |
from ref_encoder.reference_unet import ref_unet
|
| 364 |
G_HAIR_ENCODER = ref_unet.from_pretrained(
|
| 365 |
G_ARGS.pretrained_model_name_or_path, subfolder="unet", revision=G_ARGS.revision, low_cpu_mem_usage=False,
|
|
|
|
| 367 |
)
|
| 368 |
state_dict4 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model_2.bin"), map_location="cpu")
|
| 369 |
G_HAIR_ENCODER.load_state_dict(state_dict4, strict=False)
|
| 370 |
+
print(f"[timing:init] hair_encoder load_state: {time.perf_counter()-t:.2f}s", flush=True)
|
| 371 |
+
print(f"[timing:init] total preload: {time.perf_counter()-t0:.2f}s", flush=True)
|
| 372 |
|
| 373 |
|
| 374 |
try:
|
|
|
|
| 402 |
with open("imgs/background.png", "rb") as f:
|
| 403 |
_b64_bg = base64.b64encode(f.read()).decode()
|
| 404 |
|
| 405 |
+
|
| 406 |
@spaces.GPU(duration=300)
|
| 407 |
def inference(id_image, hair_image):
|
| 408 |
# ZeroGPU: 强制使用 'cuda' 设备(ZeroGPU 下 torch.cuda.is_available 可能为 False)。
|
| 409 |
device = torch.device("cuda")
|
| 410 |
+
t_total = time.perf_counter()
|
| 411 |
|
| 412 |
# 确保全局模型已加载
|
| 413 |
_ensure_models_loaded()
|
|
|
|
| 435 |
hair_image.save(hair_path)
|
| 436 |
|
| 437 |
# Align
|
| 438 |
+
t = time.perf_counter()
|
| 439 |
aligned_id = _maybe_align_image(id_path, output_size=1024, prefer_cuda=True)
|
| 440 |
aligned_hair = _maybe_align_image(hair_path, output_size=1024, prefer_cuda=True)
|
| 441 |
+
print(f"[timing] align total: {time.perf_counter()-t:.2f}s", flush=True)
|
| 442 |
|
| 443 |
aligned_id_path = "gradio_outputs/aligned_id.png"
|
| 444 |
aligned_hair_path = "gradio_outputs/aligned_hair.png"
|
|
|
|
| 446 |
cv2.imwrite(aligned_hair_path, cv2.cvtColor(aligned_hair, cv2.COLOR_RGB2BGR))
|
| 447 |
|
| 448 |
# Balding
|
| 449 |
+
t = time.perf_counter()
|
| 450 |
bald_id_path = "gradio_outputs/bald_id.png"
|
| 451 |
cv2.imwrite(bald_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
|
| 452 |
bald_head(bald_id_path, bald_id_path)
|
| 453 |
+
print(f"[timing] bald_head: {time.perf_counter()-t:.2f}s", flush=True)
|
| 454 |
|
| 455 |
# Resolve trained model dir
|
| 456 |
trained_model_dir = os.path.abspath("trained_model") if os.path.isdir("trained_model") else None
|
|
|
|
| 486 |
logger = logging.getLogger(__name__)
|
| 487 |
|
| 488 |
# 将已加载的全局模型迁移到 GPU
|
| 489 |
+
t = time.perf_counter()
|
| 490 |
tokenizer = G_TOKENIZER
|
| 491 |
image_encoder = G_IMAGE_ENCODER.to(device)
|
| 492 |
vae = G_VAE.to(device, dtype=torch.float32)
|
|
|
|
| 495 |
denoising_unet = G_DENOISING_UNET.to(device)
|
| 496 |
cc_projection = G_CC_PROJ.to(device)
|
| 497 |
Hair_Encoder = G_HAIR_ENCODER.to(device)
|
| 498 |
+
print(f"[timing] move models to cuda: {time.perf_counter()-t:.2f}s", flush=True)
|
| 499 |
|
| 500 |
# Run inference
|
| 501 |
+
t = time.perf_counter()
|
| 502 |
log_validation(
|
| 503 |
vae, tokenizer, image_encoder, denoising_unet,
|
| 504 |
args, device, logger,
|
| 505 |
cc_projection, controlnet, Hair_Encoder
|
| 506 |
)
|
| 507 |
+
print(f"[timing] sd pipeline (log_validation): {time.perf_counter()-t:.2f}s", flush=True)
|
| 508 |
|
| 509 |
output_video = os.path.join(args.output_dir, "validation", "generated_video_0.mp4")
|
| 510 |
|
| 511 |
# Extract frames for slider preview
|
| 512 |
+
t = time.perf_counter()
|
| 513 |
frames_dir = os.path.join(args.output_dir, "frames", uuid.uuid4().hex)
|
| 514 |
os.makedirs(frames_dir, exist_ok=True)
|
| 515 |
cap = cv2.VideoCapture(output_video)
|
|
|
|
| 524 |
frames_list.append(fp)
|
| 525 |
idx += 1
|
| 526 |
cap.release()
|
| 527 |
+
print(f"[timing] extract frames: {time.perf_counter()-t:.2f}s", flush=True)
|
| 528 |
+
|
| 529 |
+
print(f"[timing] total inference: {time.perf_counter()-t_total:.2f}s", flush=True)
|
| 530 |
|
| 531 |
max_frames = len(frames_list) if frames_list else 1
|
| 532 |
first_frame = frames_list[0] if frames_list else None
|