|
|
import spaces |
|
|
import torch |
|
|
import os |
|
|
import time |
|
|
import datetime |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
from moviepy.editor import VideoFileClip |
|
|
import subprocess |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
default_weights_root = Path("/data/weights") if Path("/data").exists() else Path("./weights") |
|
|
WEIGHTS_ROOT = Path(os.environ.get("DIFFUERASER_WEIGHTS_ROOT", default_weights_root)) |
|
|
|
|
|
|
|
|
subfolders = [ |
|
|
"diffuEraser", |
|
|
"stable-diffusion-v1-5", |
|
|
"PCM_Weights", |
|
|
"propainter", |
|
|
"sd-vae-ft-mse", |
|
|
] |
|
|
for subfolder in subfolders: |
|
|
(WEIGHTS_ROOT / subfolder).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
workspace_weights = Path("./weights") |
|
|
try: |
|
|
if WEIGHTS_ROOT.resolve() != workspace_weights.resolve(): |
|
|
if workspace_weights.exists(): |
|
|
if workspace_weights.is_symlink() or workspace_weights.is_file(): |
|
|
workspace_weights.unlink() |
|
|
else: |
|
|
shutil.rmtree(workspace_weights) |
|
|
workspace_weights.symlink_to(WEIGHTS_ROOT, target_is_directory=True) |
|
|
except FileNotFoundError: |
|
|
|
|
|
pass |
|
|
|
|
|
snapshot_download(repo_id="lixiaowen/diffuEraser", local_dir=str(WEIGHTS_ROOT / "diffuEraser")) |
|
|
snapshot_download(repo_id="stable-diffusion-v1-5/stable-diffusion-v1-5", local_dir=str(WEIGHTS_ROOT / "stable-diffusion-v1-5")) |
|
|
snapshot_download(repo_id="wangfuyun/PCM_Weights", local_dir=str(WEIGHTS_ROOT / "PCM_Weights")) |
|
|
snapshot_download(repo_id="camenduru/ProPainter", local_dir=str(WEIGHTS_ROOT / "propainter")) |
|
|
snapshot_download(repo_id="stabilityai/sd-vae-ft-mse", local_dir=str(WEIGHTS_ROOT / "sd-vae-ft-mse")) |
|
|
|
|
|
|
|
|
from diffueraser.diffueraser import DiffuEraser |
|
|
from propainter.inference import Propainter, get_device |
|
|
|
|
|
base_model_path = str(WEIGHTS_ROOT / "stable-diffusion-v1-5") |
|
|
vae_path = str(WEIGHTS_ROOT / "sd-vae-ft-mse") |
|
|
diffueraser_path = str(WEIGHTS_ROOT / "diffuEraser") |
|
|
propainter_model_dir = str(WEIGHTS_ROOT / "propainter") |
|
|
|
|
|
|
|
|
device = get_device() |
|
|
ckpt = "2-Step" |
|
|
video_inpainting_sd = DiffuEraser(device, base_model_path, vae_path, diffueraser_path, ckpt=ckpt) |
|
|
propainter = Propainter(propainter_model_dir, device=device) |
|
|
|
|
|
|
|
|
def trim_video(input_path, output_path, max_duration=120): |
|
|
clip = VideoFileClip(input_path) |
|
|
duration = min(max_duration, clip.duration) |
|
|
clip.close() |
|
|
|
|
|
|
|
|
subprocess.run( |
|
|
[ |
|
|
"ffmpeg", |
|
|
"-hide_banner", |
|
|
"-loglevel", |
|
|
"error", |
|
|
"-y", |
|
|
"-ss", |
|
|
"0", |
|
|
"-i", |
|
|
input_path, |
|
|
"-t", |
|
|
f"{duration:.6f}", |
|
|
"-c", |
|
|
"copy", |
|
|
output_path, |
|
|
], |
|
|
check=True, |
|
|
) |
|
|
|
|
|
@spaces.GPU(duration=1200) |
|
|
def infer(input_video, input_mask): |
|
|
|
|
|
save_path = "results" |
|
|
mask_dilation_iter = 8 |
|
|
max_img_size = 1280 |
|
|
ref_stride = 10 |
|
|
neighbor_length = 10 |
|
|
subvideo_length = 50 |
|
|
|
|
|
if not os.path.exists(save_path): |
|
|
os.makedirs(save_path) |
|
|
|
|
|
|
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
trimmed_video_path = os.path.join(save_path, f"trimmed_video_{timestamp}.mp4") |
|
|
trimmed_mask_path = os.path.join(save_path, f"trimmed_mask_{timestamp}.mp4") |
|
|
priori_path = os.path.join(save_path, f"priori_{timestamp}.mp4") |
|
|
output_path = os.path.join(save_path, f"diffueraser_result_{timestamp}.mp4") |
|
|
|
|
|
|
|
|
trim_video(input_video, trimmed_video_path) |
|
|
trim_video(input_mask, trimmed_mask_path) |
|
|
|
|
|
|
|
|
clip = VideoFileClip(trimmed_video_path) |
|
|
video_duration = clip.duration |
|
|
clip.close() |
|
|
video_length = int(video_duration * 30) |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
propainter.forward(trimmed_video_path, trimmed_mask_path, priori_path, |
|
|
video_length=video_length, ref_stride=ref_stride, |
|
|
neighbor_length=neighbor_length, subvideo_length=subvideo_length, |
|
|
mask_dilation=mask_dilation_iter) |
|
|
|
|
|
|
|
|
guidance_scale = None |
|
|
video_inpainting_sd.forward(trimmed_video_path, trimmed_mask_path, priori_path, output_path, |
|
|
max_img_size=max_img_size, video_length=video_length, |
|
|
mask_dilation_iter=mask_dilation_iter, |
|
|
guidance_scale=guidance_scale) |
|
|
|
|
|
end_time = time.time() |
|
|
print(f"DiffuEraser inference time: {end_time - start_time:.2f} seconds") |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
return output_path |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
with gr.Column(): |
|
|
gr.Markdown("# DiffuEraser: A Diffusion Model for Video Inpainting") |
|
|
gr.Markdown("DiffuEraser is a diffusion model for video inpainting, which outperforms state-of-the-art model ProPainter in both content completeness and temporal consistency while maintaining acceptable efficiency.") |
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="display:flex;column-gap:4px;"> |
|
|
<a href="https://github.com/lixiaowen-xw/DiffuEraser"> |
|
|
<img src='https://img.shields.io/badge/GitHub-Repo-blue'> |
|
|
</a> |
|
|
<a href="https://lixiaowen-xw.github.io/DiffuEraser-page"> |
|
|
<img src='https://img.shields.io/badge/Project-Page-green'> |
|
|
</a> |
|
|
<a href="https://lixiaowen-xw.github.io/DiffuEraser-page"> |
|
|
<img src='https://img.shields.io/badge/ArXiv-Paper-red'> |
|
|
</a> |
|
|
<a href="https://huggingface.co/spaces/fffiloni/DiffuEraser-demo?duplicate=true"> |
|
|
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space"> |
|
|
</a> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_video = gr.Video(label="Input Video (MP4 ONLY)") |
|
|
input_mask = gr.Video(label="Input Mask Video (MP4 ONLY)") |
|
|
submit_btn = gr.Button("Submit") |
|
|
|
|
|
with gr.Column(): |
|
|
video_result = gr.Video(label="Result") |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["./examples/example1/video.mp4", "./examples/example1/mask.mp4"], |
|
|
["./examples/example2/video.mp4", "./examples/example2/mask.mp4"], |
|
|
["./examples/example3/video.mp4", "./examples/example3/mask.mp4"], |
|
|
], |
|
|
inputs=[input_video, input_mask] |
|
|
) |
|
|
|
|
|
submit_btn.click(fn=infer, inputs=[input_video, input_mask], outputs=[video_result]) |
|
|
|
|
|
demo.queue().launch(show_api=True, show_error=True, ssr_mode=False) |
|
|
|