Spaces:
Runtime error
Runtime error
| import os | |
| os.environ.setdefault("GRADIO_TEMP_DIR", "/data2/lzliu/tmp/gradio") | |
| os.environ.setdefault("TMPDIR", "/data2/lzliu/tmp") | |
| os.makedirs("/data2/lzliu/tmp/gradio", exist_ok=True) | |
| os.makedirs("/data2/lzliu/tmp", exist_ok=True) | |
| # 其余保持不变 | |
| import logging | |
| import gradio as gr | |
| import torch | |
| import os | |
| import uuid | |
| from test_stablehairv2 import log_validation | |
| from test_stablehairv2 import UNet3DConditionModel, ControlNetModel, CCProjection | |
| from test_stablehairv2 import AutoTokenizer, CLIPVisionModelWithProjection, AutoencoderKL, UNet2DConditionModel | |
| from omegaconf import OmegaConf | |
| import numpy as np | |
| import cv2 | |
| from test_stablehairv2 import _maybe_align_image | |
| from HairMapper.hair_mapper_run import bald_head | |
| import base64 | |
| with open("imgs/background.jpg", "rb") as f: | |
| b64_img = base64.b64encode(f.read()).decode() | |
| def inference(id_image, hair_image): | |
| os.makedirs("gradio_inputs", exist_ok=True) | |
| os.makedirs("gradio_outputs", exist_ok=True) | |
| id_path = "gradio_inputs/id.png" | |
| hair_path = "gradio_inputs/hair.png" | |
| id_image.save(id_path) | |
| hair_image.save(hair_path) | |
| # ===== 图像对齐 ===== | |
| aligned_id = _maybe_align_image(id_path, output_size=1024, prefer_cuda=True) | |
| aligned_hair = _maybe_align_image(hair_path, output_size=1024, prefer_cuda=True) | |
| # 保存对齐结果(方便 Gradio 输出) | |
| aligned_id_path = "gradio_outputs/aligned_id.png" | |
| aligned_hair_path = "gradio_outputs/aligned_hair.png" | |
| cv2.imwrite(aligned_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR)) | |
| cv2.imwrite(aligned_hair_path, cv2.cvtColor(aligned_hair, cv2.COLOR_RGB2BGR)) | |
| # ===== 调用 HairMapper 秃头化 ===== | |
| bald_id_path = "gradio_outputs/bald_id.png" | |
| cv2.imwrite(bald_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR)) | |
| bald_head(bald_id_path, bald_id_path) | |
| # ===== 原本的 Args ===== | |
| class Args: | |
| pretrained_model_name_or_path = "./stable-diffusion-v1-5/stable-diffusion-v1-5" | |
| model_path = "./trained_model" | |
| image_encoder = "openai/clip-vit-large-patch14" | |
| controlnet_model_name_or_path = None | |
| revision = None | |
| output_dir = "gradio_outputs" | |
| seed = 42 | |
| num_validation_images = 1 | |
| validation_ids = [aligned_id_path] # 用对齐后的图像 | |
| validation_hairs = [aligned_hair_path] # 用对齐后的图像 | |
| use_fp16 = False | |
| args = Args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # 初始化 logger | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ===== 模型加载(和 main() 对齐) ===== | |
| tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", | |
| revision=args.revision) | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder, revision=args.revision).to(device) | |
| vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision).to( | |
| device, dtype=torch.float32) | |
| infer_config = OmegaConf.load('./configs/inference/inference_v2.yaml') | |
| unet2 = UNet2DConditionModel.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, torch_dtype=torch.float32 | |
| ).to(device) | |
| conv_in_8 = torch.nn.Conv2d(8, unet2.conv_in.out_channels, kernel_size=unet2.conv_in.kernel_size, | |
| padding=unet2.conv_in.padding) | |
| conv_in_8.requires_grad_(False) | |
| unet2.conv_in.requires_grad_(False) | |
| torch.nn.init.zeros_(conv_in_8.weight) | |
| conv_in_8.weight[:, :4, :, :].copy_(unet2.conv_in.weight) | |
| conv_in_8.bias.copy_(unet2.conv_in.bias) | |
| unet2.conv_in = conv_in_8 | |
| controlnet = ControlNetModel.from_unet(unet2).to(device) | |
| state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model.bin"), map_location="cpu") | |
| controlnet.load_state_dict(state_dict2, strict=False) | |
| prefix = "motion_module" | |
| ckpt_num = "4140000" | |
| save_path = os.path.join(args.model_path, f"{prefix}-{ckpt_num}.pth") | |
| denoising_unet = UNet3DConditionModel.from_pretrained_2d( | |
| args.pretrained_model_name_or_path, | |
| save_path, | |
| subfolder="unet", | |
| unet_additional_kwargs=infer_config.unet_additional_kwargs, | |
| ).to(device) | |
| cc_projection = CCProjection().to(device) | |
| state_dict3 = torch.load(os.path.join(args.model_path, "pytorch_model_1.bin"), map_location="cpu") | |
| cc_projection.load_state_dict(state_dict3, strict=False) | |
| from ref_encoder.reference_unet import ref_unet | |
| Hair_Encoder = ref_unet.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False, | |
| device_map=None, ignore_mismatched_sizes=True | |
| ).to(device) | |
| state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location="cpu") | |
| Hair_Encoder.load_state_dict(state_dict2, strict=False) | |
| # 推理 | |
| log_validation( | |
| vae, tokenizer, image_encoder, denoising_unet, | |
| args, device, logger, | |
| cc_projection, controlnet, Hair_Encoder | |
| ) | |
| output_video = os.path.join(args.output_dir, "validation", "generated_video_0.mp4") | |
| # 提取视频帧用于可拖动预览 | |
| frames_dir = os.path.join(args.output_dir, "frames", uuid.uuid4().hex) | |
| os.makedirs(frames_dir, exist_ok=True) | |
| cap = cv2.VideoCapture(output_video) | |
| frames_list = [] | |
| idx = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| fp = os.path.join(frames_dir, f"{idx:03d}.png") | |
| cv2.imwrite(fp, frame) | |
| frames_list.append(fp) | |
| idx += 1 | |
| cap.release() | |
| max_frames = len(frames_list) if frames_list else 1 | |
| first_frame = frames_list[0] if frames_list else None | |
| return aligned_id_path, aligned_hair_path, bald_id_path, output_video, frames_list, gr.update(minimum=1, | |
| maximum=max_frames, | |
| value=1, | |
| step=1), first_frame | |
| # Gradio 前端 | |
| # 原 Interface 版本(保留以便回退) | |
| # demo = gr.Interface( | |
| # fn=inference, | |
| # inputs=[ | |
| # gr.Image(type="pil", label="上传身份图(ID Image)"), | |
| # gr.Image(type="pil", label="上传发型图(Hair Reference Image)") | |
| # ], | |
| # outputs=[ | |
| # gr.Image(type="filepath", label="对齐后的身份图"), | |
| # gr.Image(type="filepath", label="对齐后的发型图"), | |
| # gr.Image(type="filepath", label="秃头化后的身份图"), | |
| # gr.Video(label="生成的视频") | |
| # ], | |
| # title="StableHairV2 多视角发型迁移", | |
| # description="上传身份图和发型参考图,查看对齐结果并生成多视角视频" | |
| # ) | |
| # if __name__ == "__main__": | |
| # demo.launch(server_name="0.0.0.0", server_port=7860) | |
| # Blocks 美化版 | |
| css = f""" | |
| html, body {{ | |
| height: 100%; | |
| margin: 0; | |
| padding: 0; | |
| }} | |
| .gradio-container {{ | |
| width: 100% !important; | |
| height: 100% !important; | |
| margin: 0 !important; | |
| padding: 0 !important; | |
| background-image: url("data:image/jpeg;base64,{b64_img}"); | |
| background-size: cover; | |
| background-position: center; | |
| background-attachment: fixed; /* 背景固定 */ | |
| }} | |
| #title-card {{ | |
| background: rgba(255, 255, 255, 0.8); | |
| border-radius: 12px; | |
| padding: 16px 24px; | |
| box-shadow: 0 2px 8px rgba(0,0,0,0.15); | |
| margin-bottom: 20px; | |
| }} | |
| #title-card h2 {{ | |
| text-align: center; | |
| margin: 4px 0 12px 0; | |
| font-size: 28px; | |
| }} | |
| #title-card p {{ | |
| text-align: center; | |
| font-size: 16px; | |
| color: #374151; | |
| }} | |
| .out-card {{ | |
| border:1px solid #e5e7eb; border-radius:10px; padding:10px; | |
| background: rgba(255,255,255,0.85); | |
| }} | |
| .two-col {{ | |
| display:grid !important; grid-template-columns: 360px minmax(680px, 1fr); gap:16px | |
| }} | |
| .left-pane {{min-width: 360px}} | |
| .right-pane {{min-width: 680px}} | |
| /* Tabs 美化 */ | |
| .tabs {{ | |
| background: rgba(255,255,255,0.88); | |
| border-radius: 12px; | |
| box-shadow: 0 8px 24px rgba(0,0,0,0.08); | |
| padding: 8px; | |
| border: 1px solid #e5e7eb; | |
| }} | |
| .tab-nav {{ | |
| display: flex; gap: 8px; margin-bottom: 8px; | |
| background: transparent; | |
| border-bottom: 1px solid #e5e7eb; | |
| padding-bottom: 6px; | |
| }} | |
| .tab-nav button {{ | |
| background: rgba(255,255,255,0.7); | |
| border: 1px solid #e5e7eb; | |
| backdrop-filter: blur(6px); | |
| border-radius: 8px; | |
| padding: 6px 12px; | |
| color: #111827; | |
| transition: all .2s ease; | |
| }} | |
| .tab-nav button:hover {{ | |
| transform: translateY(-1px); | |
| box-shadow: 0 4px 10px rgba(0,0,0,0.06); | |
| }} | |
| .tab-nav button[aria-selected="true"] {{ | |
| background: #4f46e5; | |
| color: #fff; | |
| border-color: #4f46e5; | |
| box-shadow: 0 6px 14px rgba(79,70,229,0.25); | |
| }} | |
| .tabitem {{ | |
| background: rgba(255,255,255,0.88); | |
| border-radius: 10px; | |
| padding: 8px; | |
| }} | |
| /* 发型库滚动限制容器:固定260px高度,内部可滚动 */ | |
| #hair_gallery_wrap {{ | |
| height: 260px !important; | |
| overflow-y: scroll !important; | |
| overflow-x: auto !important; | |
| }} | |
| #hair_gallery_wrap .grid, #hair_gallery_wrap .wrap {{ | |
| height: 100% !important; | |
| overflow-y: scroll !important; | |
| }} | |
| /* 确保画廊本体占满容器高度,避免滚动条落到页面底部 */ | |
| #hair_gallery {{ | |
| height: 100% !important; | |
| }} | |
| """ | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"), | |
| css=css | |
| ) as demo: | |
| # ==== 顶部 Panel ==== | |
| with gr.Group(elem_id="title-card"): | |
| gr.Markdown(""" | |
| <h2 id='title'>StableHairV2 多视角发型迁移</h2> | |
| <p>上传身份图与发型参考图,系统将自动完成 <b>对齐 → 秃头化 → 视频生成</b>。</p> | |
| """) | |
| with gr.Row(elem_classes=["two-col"]): | |
| with gr.Column(scale=5, min_width=260, elem_classes=["left-pane"]): | |
| id_input = gr.Image(type="pil", label="身份图", height=200) | |
| hair_input = gr.Image(type="pil", label="发型参考图", height=200) | |
| with gr.Row(): | |
| run_btn = gr.Button("开始生成", variant="primary") | |
| clear_btn = gr.Button("清空") | |
| # ========= 发型库(点击即填充到“发型参考图”) ========= | |
| def _list_imgs(dir_path: str): | |
| exts = (".png", ".jpg", ".jpeg", ".webp") | |
| # exts = (".jpg") | |
| try: | |
| files = [os.path.join(dir_path, f) for f in sorted(os.listdir(dir_path)) | |
| if f.lower().endswith(exts)] | |
| return files | |
| except Exception: | |
| return [] | |
| hair_list = _list_imgs("hair_resposity") | |
| with gr.Accordion("发型库(点击选择后自动填充)", open=True): | |
| with gr.Group(elem_id="hair_gallery_wrap"): | |
| gallery = gr.Gallery( | |
| value=hair_list, | |
| columns=4, rows=2, allow_preview=True, label="发型库", | |
| elem_id="hair_gallery" | |
| ) | |
| def _pick_hair(evt: gr.SelectData): # type: ignore[name-defined] | |
| i = evt.index if hasattr(evt, 'index') else 0 | |
| i = 0 if i is None else int(i) | |
| if 0 <= i < len(hair_list): | |
| return gr.update(value=hair_list[i]) | |
| return gr.update() | |
| gallery.select(_pick_hair, inputs=None, outputs=hair_input) | |
| with gr.Column(scale=7, min_width=520, elem_classes=["right-pane"]): | |
| with gr.Tabs(): | |
| with gr.TabItem("生成视频"): | |
| with gr.Group(elem_classes=["out-card"]): | |
| video_out = gr.Video(label="生成的视频", height=340) | |
| with gr.Row(): | |
| frame_slider = gr.Slider(1, 21, value=1, step=1, label="多视角预览(拖动查看帧)") | |
| frame_preview = gr.Image(type="filepath", label="预览帧", height=260) | |
| frames_state = gr.State([]) | |
| with gr.TabItem("归一化对齐结果"): | |
| with gr.Group(elem_classes=["out-card"]): | |
| with gr.Row(): | |
| aligned_id_out = gr.Image(type="filepath", label="对齐后的身份图", height=240) | |
| aligned_hair_out = gr.Image(type="filepath", label="对齐后的发型图", height=240) | |
| with gr.TabItem("秃头化结果"): | |
| with gr.Group(elem_classes=["out-card"]): | |
| bald_id_out = gr.Image(type="filepath", label="秃头化后的身份图", height=260) | |
| # 逻辑保持不变 | |
| run_btn.click(fn=inference, | |
| inputs=[id_input, hair_input], | |
| outputs=[aligned_id_out, aligned_hair_out, bald_id_out, | |
| video_out, frames_state, frame_slider, frame_preview]) | |
| def _on_slide(frames, idx): | |
| if not frames: | |
| return gr.update() | |
| i = int(idx) - 1 | |
| i = max(0, min(i, len(frames) - 1)) | |
| return gr.update(value=frames[i]) | |
| frame_slider.change(_on_slide, inputs=[frames_state, frame_slider], outputs=frame_preview) | |
| def _clear(): | |
| return None, None, None, None, None | |
| clear_btn.click(_clear, None, | |
| [id_input, hair_input, aligned_id_out, aligned_hair_out, bald_id_out]) | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) | |