alex
WAN animate PRO
ea97ae7
raw
history blame
2.46 kB
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import os
import argparse
from process_pipepline import ProcessPipeline
# simple args container like before
class _Args:
pass
def _parse_args():
args = _Args()
# general paths
args.ckpt_path = "./Wan2.2-Animate-14B/process_checkpoint"
args.video_path = None
args.refer_path = None
args.save_path = None
# processing parameters
args.resolution_area = [1280, 720]
args.fps = 30
# feature flags
args.replace_flag = True
args.retarget_flag = False
args.use_flux = False
# mask strategy parameters (replacement mode)
args.iterations = 3
args.k = 7
args.w_len = 1
args.h_len = 1
return args
def load_preprocess_models():
ckpt_path = "./Wan2.2-Animate-14B/process_checkpoint"
pose2d_checkpoint_path = os.path.join(ckpt_path, 'pose2d/vitpose_h_wholebody.onnx')
det_checkpoint_path = os.path.join(ckpt_path, 'det/yolov10m.onnx')
sam2_checkpoint_path = os.path.join(ckpt_path, 'sam2/sam2_hiera_large.pt')
flux_kontext_path = None
process_pipeline = ProcessPipeline(det_checkpoint_path=det_checkpoint_path, pose2d_checkpoint_path=pose2d_checkpoint_path, sam_checkpoint_path=sam2_checkpoint_path, flux_kontext_path=flux_kontext_path)
return process_pipeline
def run(process_pipeline, input_video, edited_frame, preprocess_dir, w, h, tag_string,
pts_by_frame: dict, lbs_by_frame: dict):
args = _parse_args()
if tag_string == "retarget_flag":
retarget_flag = True
replace_flag = False
else:
retarget_flag = False
replace_flag = True
os.makedirs(preprocess_dir, exist_ok=True)
process_pipeline(video_path=input_video,
refer_image_path=edited_frame,
output_path=preprocess_dir,
resolution_area=[w, h],
fps=args.fps,
iterations=args.iterations,
k=args.k,
w_len=args.w_len,
h_len=args.h_len,
retarget_flag=retarget_flag,
use_flux=args.use_flux,
replace_flag=replace_flag,
pts_by_frame=pts_by_frame,
lbs_by_frame=lbs_by_frame)