sdmklgdfmkl / app.py
matthewkram's picture
Update app.py
4b86371 verified
import os
import torch
import gradio as gr
from diffusers import StableVideoDiffusionPipeline
from PIL import Image
import numpy as np
import cv2
import tempfile
from diffusers.utils import export_to_video
pipe = None
def load():
global pipe
if pipe is None:
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid-xt",
torch_dtype=torch.float16,
variant="fp16"
)
pipe.to("cuda")
gr.Info("Модель на GPU — генерация 30–60 сек")
return pipe
def run(ref_img, video, mode, quality, prog=gr.Progress()):
pipe = load()
prog(0, desc="Подготовка...")
img = Image.fromarray(ref_img).convert("RGB").resize((576, 320))
cap = cv2.VideoCapture(video)
n = int(cap.get(7)); cap.release()
hint = f" ({n} кадров)"
steps = 25 if quality == "wan-pro" else 15
frames = 25 if quality == "wan-pro" else 14
noise = 0.1 if mode == "wan2.2-animate-mix" else 0.02
def cb(step, *_):
prog((step+1)/steps, desc=f"Шаг {step+1}/{steps}")
prog(0.1, desc="Генерация...")
out = pipe(
img,
num_inference_steps=steps,
num_frames=frames,
decode_chunk_size=2,
noise_aug_strength=noise,
callback_on_step_end=cb
).frames[0]
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
export_to_video(out, tmp.name, fps=7)
return tmp.name, "Готово!" + hint
with gr.Blocks() as demo:
gr.Markdown("# Wan2.2-Animate (GPU)")
with gr.Accordion("Инструкция", open=False):
gr.Markdown("Загрузи фото + видео → выбери режим → жми Generate")
with gr.Row():
with gr.Column():
img = gr.Image(label="Фото", type="numpy")
vid = gr.Video(label="Видео")
with gr.Row():
mode = gr.Dropdown(["wan2.2-animate-move", "wan2.2-animate-mix"],
label="Режим", value="wan2.2-animate-move")
qual = gr.Dropdown(["wan-pro", "wan-std"], label="Качество", value="wan-pro")
btn = gr.Button("Generate Video")
with gr.Column():
out = gr.Video(label="Результат")
stat = gr.Textbox(label="Статус")
btn.click(run, [img, vid, mode, qual], [out, stat])
demo.queue(max_size=2).launch(
server_name="0.0.0.0",
server_port=7860,
share=True, # ← ФИКС 1
enable_queue=True
)