Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,469 Bytes
1caa0d9 76cd760 1caa0d9 76cd760 1caa0d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import logging
import os
import sys
import warnings
from datetime import datetime
warnings.filterwarnings('ignore')
import random
import torch
import torch.distributed as dist
from PIL import Image
import wan
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.distributed.util import init_distributed_group
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import merge_video_audio, save_video, str2bool
EXAMPLE_PROMPT = {
"t2v-A14B": {
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"i2v-A14B": {
"prompt":
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"image":
"examples/i2v_input.JPG",
},
"ti2v-5B": {
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"animate-14B": {
"prompt": "视频中的人在做动作",
"video": "",
"pose": "",
"mask": "",
},
"s2v-14B": {
"prompt":
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"image":
"examples/i2v_input.JPG",
"audio":
"examples/talk.wav",
"tts_prompt_audio":
"examples/zero_shot_prompt.wav",
"tts_prompt_text":
"希望你以后能够做的比我还好呦。",
"tts_text":
"收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。"
},
}
def _validate_args(args):
# Basic check
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
args.image = EXAMPLE_PROMPT[args.task]["image"]
if args.audio is None and args.enable_tts is False and "audio" in EXAMPLE_PROMPT[args.task]:
args.audio = EXAMPLE_PROMPT[args.task]["audio"]
if (args.tts_prompt_audio is None or args.tts_text is None) and args.enable_tts is True and "audio" in EXAMPLE_PROMPT[args.task]:
args.tts_prompt_audio = EXAMPLE_PROMPT[args.task]["tts_prompt_audio"]
args.tts_prompt_text = EXAMPLE_PROMPT[args.task]["tts_prompt_text"]
args.tts_text = EXAMPLE_PROMPT[args.task]["tts_text"]
if args.task == "i2v-A14B":
assert args.image is not None, "Please specify the image path for i2v."
cfg = WAN_CONFIGS[args.task]
if args.sample_steps is None:
args.sample_steps = cfg.sample_steps
if args.sample_shift is None:
args.sample_shift = cfg.sample_shift
if args.sample_guide_scale is None:
args.sample_guide_scale = cfg.sample_guide_scale
if args.frame_num is None:
args.frame_num = cfg.frame_num
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
0, sys.maxsize)
# Size check
if not 's2v' in args.task:
assert args.size in SUPPORTED_SIZES[
args.
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
class _Args:
pass
def _parse_args():
args = _Args()
# core generation options
args.task = "animate-14B"
# args.size = "1280*720"
args.size = "720*1280"
args.frame_num = None
args.ckpt_dir = "./Wan2.2-Animate-14B/"
args.offload_model = False
args.ulysses_size = 1
args.t5_fsdp = False
args.t5_cpu = False
args.dit_fsdp = False
args.prompt = None
args.use_prompt_extend = False
args.prompt_extend_method = "local_qwen" # ["dashscope", "local_qwen"]
args.prompt_extend_model = None
args.prompt_extend_target_lang = "zh" # ["zh", "en"]
args.base_seed = 1234
args.image = None
args.sample_solver = "unipc" # ['unipc', 'dpm++']
args.sample_steps = None
args.sample_shift = None
args.sample_guide_scale = None
args.convert_model_dtype = True
# animate
args.refert_num = 1
# s2v-only
args.num_clip = None
args.audio = None
args.enable_tts = False
args.tts_prompt_audio = None
args.tts_prompt_text = None
args.tts_text = None
args.pose_video = None
args.start_from_ref = False
args.infer_frames = 80
_validate_args(args)
return args
def _init_logging(rank):
# logging
if rank == 0:
# set format
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)])
else:
logging.basicConfig(level=logging.ERROR)
def load_model(use_relighting_lora = False):
cfg = WAN_CONFIGS["animate-14B"]
return wan.WanAnimate(
config=cfg,
checkpoint_dir="./Wan2.2-Animate-14B/",
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_sp=False,
t5_cpu=False,
convert_model_dtype=False,
use_relighting_lora=use_relighting_lora
)
def generate(wan_animate, preprocess_dir, save_file, replace_flag = False):
args = _parse_args()
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
local_rank = int(os.getenv("LOCAL_RANK", 0))
device = local_rank
_init_logging(rank)
cfg = WAN_CONFIGS[args.task]
logging.info(f"Input prompt: {args.prompt}")
img = None
if args.image is not None:
img = Image.open(args.image).convert("RGB")
logging.info(f"Input image: {args.image}")
print(f'rank:{rank}')
logging.info(f"Generating video ...")
video = wan_animate.generate(
src_root_path=preprocess_dir,
replace_flag=replace_flag,
refert_num = args.refert_num,
clip_len=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
if rank == 0:
save_video(
tensor=video[None],
save_file=save_file,
fps=cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1))
# if "s2v" in args.task:
# if args.enable_tts is False:
# merge_video_audio(video_path=args.save_file, audio_path=args.audio)
# else:
# merge_video_audio(video_path=args.save_file, audio_path="tts.wav")
del video
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
logging.info("Finished.")
|