matthewkram commited on
Commit
c6011cd
·
verified ·
1 Parent(s): e93faf1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -83
app.py CHANGED
@@ -1,104 +1,195 @@
1
  import os
2
  import sys
 
 
3
  import time
 
4
  import torch
 
 
5
  import numpy as np
 
 
6
  import tempfile
7
- from PIL import Image
8
- from datetime import datetime
9
- import gradio as gr
10
- from torch import autocast
11
- from pytorch_lightning import seed_everything
12
- import torchvision.transforms as T
13
- from diffusers import StableVideoDiffusionPipeline
14
- from diffusers.utils import load_image, export_to_video
15
 
16
- class WorldAnimate:
17
  def __init__(self):
18
- model_id = "stabilityai/stable-video-diffusion-img2vid-xt"
19
  self.pipe = StableVideoDiffusionPipeline.from_pretrained(
20
- model_id, torch_dtype=torch.float16, variant="fp16"
 
 
 
21
  )
22
- self.pipe.enable_model_cpu_offload()
23
- self.pipe.enable_vae_slicing()
24
- self.pipe.unet.enable_forward_chunking(chunk_size=1, dim=1)
25
- self.pipe.to("cuda" if torch.cuda.is_available() else "cpu")
26
- torch.backends.cuda.matmul.allow_tf32 = True
27
-
28
- def process_input(self, image, seed, num_frames, fps, decode_chunk_size, motion_bucket_id, noise_aug_strength):
29
- if seed == -1:
30
- seed = int.from_bytes(os.urandom(2), "big")
31
- seed_everything(seed)
32
-
33
- if isinstance(image, str):
34
- image = load_image(image)
35
- image = image.resize((1024, 576))
36
-
37
- generator = torch.manual_seed(seed)
38
- frames = self.pipe(
39
- image,
40
- num_frames=num_frames,
41
- fps=fps,
42
- decode_chunk_size=decode_chunk_size,
43
- motion_bucket_id=motion_bucket_id,
44
- noise_aug_strength=noise_aug_strength,
45
- generator=generator,
46
- ).frames[0]
47
-
48
- return frames
49
-
50
- def app():
51
- with gr.Blocks(title="World 2.2 Animate (Local No API)") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  gr.HTML("""
53
- <h1 style="text-align: center; font-family: Arial; color: white;">World 2.2 Animate</h1>
54
- <p style="text-align: center; font-family: Arial; color: white;">
55
- This is a local processing app for image-to-video conversion using Stable Video Diffusion.<br>
56
- Upload an image, adjust parameters, and generate a video with smooth motion.<br>
57
- Parameters:<br>
58
- - Seed: Random seed for reproducibility (-1 for random).<br>
59
- - Num Frames: Number of frames in the video (default 25).<br>
60
- - FPS: Frames per second (default 7).<br>
61
- - Decode Chunk Size: For memory optimization (default 8).<br>
62
- - Motion Bucket ID: Controls motion amount (1-255, default 127).<br>
63
- - Noise Aug Strength: Adds noise for variation (0-1, default 0.02).
64
- </p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  """)
66
 
67
  with gr.Row():
68
- with gr.Column():
69
- input_image = gr.Image(label="Upload Image", type="pil")
70
- seed = gr.Number(label="Seed", value=-1)
71
- num_frames = gr.Slider(label="Num Frames", minimum=1, maximum=25, value=25, step=1)
72
- fps = gr.Slider(label="FPS", minimum=1, maximum=30, value=7, step=1)
73
- decode_chunk_size = gr.Slider(label="Decode Chunk Size", minimum=1, maximum=16, value=8, step=1)
74
- motion_bucket_id = gr.Slider(label="Motion Bucket ID", minimum=1, maximum=255, value=127, step=1)
75
- noise_aug_strength = gr.Slider(label="Noise Aug Strength", minimum=0.0, maximum=1.0, value=0.02, step=0.01)
76
- generate_btn = gr.Button(value="Generate Video")
 
 
 
 
 
 
 
 
 
 
77
 
78
- with gr.Column():
79
- output_video = gr.Video(label="Generated Video")
80
- status = gr.Textbox(label="Status")
81
-
82
- generate_btn.click(
83
- fn=process,
84
- inputs=[input_image, seed, num_frames, fps, decode_chunk_size, motion_bucket_id, noise_aug_strength],
85
- outputs=[output_video, status]
86
- )
87
 
88
- return demo # Важно: возвращаем demo!
89
 
90
- def process(image, seed, num_frames, fps, decode_chunk_size, motion_bucket_id, noise_aug_strength):
91
- try:
92
- animator = WorldAnimate()
93
- frames = animator.process_input(image, seed, num_frames, fps, decode_chunk_size, motion_bucket_id, noise_aug_strength)
94
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video:
95
- export_to_video(frames, temp_video.name, fps=fps)
96
- return temp_video.name, "Success!"
97
- except Exception as e:
98
- return None, f"Failed: {str(e)}"
 
 
 
 
 
99
 
100
- def start_app():
101
- app().launch()
 
 
 
102
 
103
  if __name__ == "__main__":
104
  start_app()
 
1
  import os
2
  import sys
3
+ import uuid
4
+ import shutil
5
  import time
6
+ import gradio as gr
7
  import torch
8
+ from diffusers import StableVideoDiffusionPipeline
9
+ from PIL import Image
10
  import numpy as np
11
+ import cv2
12
+ import subprocess
13
  import tempfile
 
 
 
 
 
 
 
 
14
 
15
+ class WanAnimateApp:
16
  def __init__(self):
17
+ model_name = "stabilityai/stable-video-diffusion-img2vid-xt"
18
  self.pipe = StableVideoDiffusionPipeline.from_pretrained(
19
+ model_name,
20
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
21
+ variant="fp16",
22
+ device_map="cpu"
23
  )
24
+
25
+ def predict(
26
+ self,
27
+ ref_img,
28
+ video,
29
+ model_id,
30
+ model,
31
+ ):
32
+ if ref_img is None or video is None:
33
+ return None, "Upload both image and video."
34
+
35
+ try:
36
+ # Local processing — PIL for image (no open for type="pil")
37
+ if isinstance(ref_img, Image.Image):
38
+ ref_image = ref_img.convert("RGB").resize((576, 320))
39
+ else:
40
+ ref_image = Image.open(ref_img).convert("RGB").resize((576, 320))
41
+
42
+ cap = cv2.VideoCapture(video)
43
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
44
+ cap.release()
45
+ motion_hint = f" with dynamic motion from {frame_count} frames"
46
+
47
+ # Prompt based on mode
48
+ if model_id == "wan2.2-animate-move":
49
+ prompt = f"Animate the character in the reference image{motion_hint}, high quality, smooth movements."
50
+ else:
51
+ prompt = f"Replace the character in the video with the reference image{motion_hint}, seamless, detailed."
52
+
53
+ # Parameters
54
+ num_frames = 25 if model == "wan-pro" else 14
55
+ num_steps = 25 if model == "wan-pro" else 15
56
+
57
+ # Local generation
58
+ generator = torch.Generator(device="cpu").manual_seed(42)
59
+ output = self.pipe(
60
+ ref_image,
61
+ num_inference_steps=num_steps,
62
+ num_frames=num_frames,
63
+ generator=generator,
64
+ decode_chunk_size=2
65
+ ).frames[0]
66
+
67
+ # Save MP4 with ffmpeg
68
+ temp_dir = tempfile.mkdtemp()
69
+ for i, frame in enumerate(output):
70
+ frame.save(f"{temp_dir}/frame_{i:04d}.png")
71
+ temp_video = f"/tmp/output_{uuid.uuid4()}.mp4"
72
+ subprocess.run([
73
+ 'ffmpeg', '-y', '-framerate', '7', '-i', f"{temp_dir}/frame_%04d.png",
74
+ '-c:v', 'libx264', '-pix_fmt', 'yuv420p', temp_video
75
+ ], check=True)
76
+ shutil.rmtree(temp_dir)
77
+
78
+ return temp_video, "SUCCEEDED"
79
+
80
+ except Exception as e:
81
+ return None, f"Failed: {str(e)}"
82
+
83
+ def start_app():
84
+ app = WanAnimateApp()
85
+
86
+ with gr.Blocks(title="Wan2.2-Animate (Local No API)") as demo:
87
  gr.HTML("""
88
+ <div style="padding: 2rem; text-align: center; max-width: 1200px; margin: 0 auto; font-family: Arial, sans-serif;">
89
+ <h1 style="font-size: 2.5rem; font-weight: bold; margin-bottom: 0.5rem; color: #333;">
90
+ Wan2.2-Animate: Unified Character Animation and Replacement with Holistic Replication
91
+ </h1>
92
+ <h3 style="font-size: 1.5rem; font-weight: bold; margin-bottom: 0.5rem; color: #333;">
93
+ Local version without API (SVD Proxy)
94
+ </h3>
95
+ <div style="font-size: 1.25rem; margin-bottom: 1.5rem; color: #555;">
96
+ Tongyi Lab, Alibaba
97
+ </div>
98
+ <div style="display: flex; flex-wrap: wrap; justify-content: center; gap: 1rem; margin-bottom: 1.5rem;">
99
+ <a href="https://arxiv.org/abs/2509.14055" target="_blank" style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; color: #333; text-decoration: none; border-radius: 9999px; font-weight: 500;">
100
+ <span style="margin-right: 0.5rem;">📄</span>Paper
101
+ </a>
102
+ <a href="https://github.com/Wan-Video/Wan2.2" target="_blank" style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; color: #333; text-decoration: none; border-radius: 9999px; font-weight: 500;">
103
+ <span style="margin-right: 0.5rem;">💻</span>GitHub
104
+ </a>
105
+ <a href="https://huggingface.co/Wan-AI/Wan2.2-Animate-14B" target="_blank" style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; color: #333; text-decoration: none; border-radius: 9999px; font-weight: 500;">
106
+ <span style="margin-right: 0.5rem;">🤗</span>HF Model
107
+ </a>
108
+ </div>
109
+ </div>
110
+ """)
111
+
112
+ gr.HTML("""
113
+ <details>
114
+ <summary>‼️Usage (использования)</summary>
115
+ Wan-Animate supports two modes:
116
+ <ul>
117
+ <li>Move Mode: animate the character in input image with movements from the input video</li>
118
+ <li>Mix Mode: replace the character in input video with the character in input image</li>
119
+ </ul>
120
+ Wan-Animate supports two modes:
121
+ <ul>
122
+ <li>Move Mode: Use the movements extracted from the input video to drive the character in the input image</li>
123
+ <li>Mix Mode: Use the character in the input image to replace the character in the input video</li>
124
+ </ul>
125
+ Currently, the following restrictions apply to inputs:
126
+ <ul>
127
+ <li>Video file size: Less than 200MB</li>
128
+ <li>Video resolution: The shorter side must be greater than 200, and the longer side must be less than 2048</li>
129
+ <li>Video duration: 2s to 30s</li>
130
+ <li>Video aspect ratio: 1:3 to 3:1</li>
131
+ <li>Video formats: mp4, avi, mov</li>
132
+ <li>Image file size: Less than 5MB</li>
133
+ <li>Image resolution: The shorter side must be greater than 200, and the longer side must be less than 4096</li>
134
+ <li>Image formats: jpg, png, jpeg, webp, bmp</li>
135
+ </ul>
136
+ Current, the inference quality has two variants. You can use our open-source code for more flexible configuration.
137
+ <ul>
138
+ <li>wan-pro: 25fps, 720p</li>
139
+ <li>wan-std: 15fps, 720p</li>
140
+ </ul>
141
+ </details>
142
  """)
143
 
144
  with gr.Row():
145
+ with gr.Column():
146
+ ref_img = gr.Image(
147
+ label="Reference Image (изображение)",
148
+ type="pil",
149
+ sources=["upload"],
150
+ )
151
+
152
+ video = gr.Video(
153
+ label="Template Video (шаблонное видео)",
154
+ sources=["upload"],
155
+ )
156
+
157
+ with gr.Row():
158
+ model_id = gr.Dropdown(
159
+ label="Mode (режим)",
160
+ choices=["wan2.2-animate-move", "wan2.2-animate-mix"],
161
+ value="wan2.2-animate-move",
162
+ info=""
163
+ )
164
 
165
+ model = gr.Dropdown(
166
+ label="Inference Quality (качество)",
167
+ choices=["wan-pro", "wan-std"],
168
+ value="wan-pro",
169
+ )
 
 
 
 
170
 
171
+ run_button = gr.Button("Generate Video (генерировать)")
172
 
173
+ with gr.Column():
174
+ output_video = gr.Video(label="Output Video (результат)")
175
+ output_status = gr.Textbox(label="Status (статус)")
176
+
177
+ run_button.click(
178
+ fn=app.predict,
179
+ inputs=[
180
+ ref_img,
181
+ video,
182
+ model_id,
183
+ model,
184
+ ],
185
+ outputs=[output_video, output_status],
186
+ )
187
 
188
+ demo.queue(default_concurrency_limit=1)
189
+ demo.launch(
190
+ server_name="0.0.0.0",
191
+ server_port=7860
192
+ )
193
 
194
  if __name__ == "__main__":
195
  start_app()