DiffuEraser-Creative / gradio_app.py
casperarmani
changed helper function trim cap to 120s from 10s
b98a9d3
import spaces
import torch
import os
import time
import datetime
import shutil
from pathlib import Path
from moviepy.editor import VideoFileClip
import subprocess
import gradio as gr
# Download Weights
from huggingface_hub import snapshot_download
# Decide where to cache weights. Default to persistent disk (/data) when available, fall back to workspace.
default_weights_root = Path("/data/weights") if Path("/data").exists() else Path("./weights")
WEIGHTS_ROOT = Path(os.environ.get("DIFFUERASER_WEIGHTS_ROOT", default_weights_root))
# List of subdirectories to create inside the weights root
subfolders = [
"diffuEraser",
"stable-diffusion-v1-5",
"PCM_Weights",
"propainter",
"sd-vae-ft-mse",
]
for subfolder in subfolders:
(WEIGHTS_ROOT / subfolder).mkdir(parents=True, exist_ok=True)
# Make sure legacy code that references ./weights still works by linking it to the persistent cache root.
workspace_weights = Path("./weights")
try:
if WEIGHTS_ROOT.resolve() != workspace_weights.resolve():
if workspace_weights.exists():
if workspace_weights.is_symlink() or workspace_weights.is_file():
workspace_weights.unlink()
else:
shutil.rmtree(workspace_weights)
workspace_weights.symlink_to(WEIGHTS_ROOT, target_is_directory=True)
except FileNotFoundError:
# resolve() can raise if the symlink target is missing; ignore until directories exist
pass
snapshot_download(repo_id="lixiaowen/diffuEraser", local_dir=str(WEIGHTS_ROOT / "diffuEraser"))
snapshot_download(repo_id="stable-diffusion-v1-5/stable-diffusion-v1-5", local_dir=str(WEIGHTS_ROOT / "stable-diffusion-v1-5"))
snapshot_download(repo_id="wangfuyun/PCM_Weights", local_dir=str(WEIGHTS_ROOT / "PCM_Weights"))
snapshot_download(repo_id="camenduru/ProPainter", local_dir=str(WEIGHTS_ROOT / "propainter"))
snapshot_download(repo_id="stabilityai/sd-vae-ft-mse", local_dir=str(WEIGHTS_ROOT / "sd-vae-ft-mse"))
# Import model classes
from diffueraser.diffueraser import DiffuEraser
from propainter.inference import Propainter, get_device
base_model_path = str(WEIGHTS_ROOT / "stable-diffusion-v1-5")
vae_path = str(WEIGHTS_ROOT / "sd-vae-ft-mse")
diffueraser_path = str(WEIGHTS_ROOT / "diffuEraser")
propainter_model_dir = str(WEIGHTS_ROOT / "propainter")
# Model setup
device = get_device()
ckpt = "2-Step"
video_inpainting_sd = DiffuEraser(device, base_model_path, vae_path, diffueraser_path, ckpt=ckpt)
propainter = Propainter(propainter_model_dir, device=device)
# Helper function to trim videos (cap at 120s so longer clips still pass through)
def trim_video(input_path, output_path, max_duration=120):
clip = VideoFileClip(input_path)
duration = min(max_duration, clip.duration)
clip.close()
# Preserve original encoding to avoid reintroducing H.264 macroblocking artefacts
subprocess.run(
[
"ffmpeg",
"-hide_banner",
"-loglevel",
"error",
"-y",
"-ss",
"0",
"-i",
input_path,
"-t",
f"{duration:.6f}",
"-c",
"copy",
output_path,
],
check=True,
)
@spaces.GPU(duration=1200)
def infer(input_video, input_mask):
# Setup paths and parameters
save_path = "results"
mask_dilation_iter = 8
max_img_size = 1280
ref_stride = 10
neighbor_length = 10
subvideo_length = 50
if not os.path.exists(save_path):
os.makedirs(save_path)
# Timestamp for unique filenames
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
trimmed_video_path = os.path.join(save_path, f"trimmed_video_{timestamp}.mp4")
trimmed_mask_path = os.path.join(save_path, f"trimmed_mask_{timestamp}.mp4")
priori_path = os.path.join(save_path, f"priori_{timestamp}.mp4")
output_path = os.path.join(save_path, f"diffueraser_result_{timestamp}.mp4")
# Trim input videos
trim_video(input_video, trimmed_video_path)
trim_video(input_mask, trimmed_mask_path)
# Dynamically compute video_length (in frames) assuming 30 fps
clip = VideoFileClip(trimmed_video_path)
video_duration = clip.duration
clip.close()
video_length = int(video_duration * 30)
# Run models
start_time = time.time()
# ProPainter (priori)
propainter.forward(trimmed_video_path, trimmed_mask_path, priori_path,
video_length=video_length, ref_stride=ref_stride,
neighbor_length=neighbor_length, subvideo_length=subvideo_length,
mask_dilation=mask_dilation_iter)
# DiffuEraser
guidance_scale = None
video_inpainting_sd.forward(trimmed_video_path, trimmed_mask_path, priori_path, output_path,
max_img_size=max_img_size, video_length=video_length,
mask_dilation_iter=mask_dilation_iter,
guidance_scale=guidance_scale)
end_time = time.time()
print(f"DiffuEraser inference time: {end_time - start_time:.2f} seconds")
torch.cuda.empty_cache()
return output_path
# Gradio interface
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# DiffuEraser: A Diffusion Model for Video Inpainting")
gr.Markdown("DiffuEraser is a diffusion model for video inpainting, which outperforms state-of-the-art model ProPainter in both content completeness and temporal consistency while maintaining acceptable efficiency.")
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href="https://github.com/lixiaowen-xw/DiffuEraser">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
<a href="https://lixiaowen-xw.github.io/DiffuEraser-page">
<img src='https://img.shields.io/badge/Project-Page-green'>
</a>
<a href="https://lixiaowen-xw.github.io/DiffuEraser-page">
<img src='https://img.shields.io/badge/ArXiv-Paper-red'>
</a>
<a href="https://huggingface.co/spaces/fffiloni/DiffuEraser-demo?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
</a>
</div>
""")
with gr.Row():
with gr.Column():
input_video = gr.Video(label="Input Video (MP4 ONLY)")
input_mask = gr.Video(label="Input Mask Video (MP4 ONLY)")
submit_btn = gr.Button("Submit")
with gr.Column():
video_result = gr.Video(label="Result")
gr.Examples(
examples=[
["./examples/example1/video.mp4", "./examples/example1/mask.mp4"],
["./examples/example2/video.mp4", "./examples/example2/mask.mp4"],
["./examples/example3/video.mp4", "./examples/example3/mask.mp4"],
],
inputs=[input_video, input_mask]
)
submit_btn.click(fn=infer, inputs=[input_video, input_mask], outputs=[video_result])
demo.queue().launch(show_api=True, show_error=True, ssr_mode=False)