|
|
import os |
|
|
import cv2 |
|
|
import torch |
|
|
import spaces |
|
|
import imageio |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
torch.jit.script = lambda f: f |
|
|
|
|
|
import argparse |
|
|
from utils.batch_inference import ( |
|
|
BSRInferenceLoop, BIDInferenceLoop |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
def get_example(task): |
|
|
case = { |
|
|
"dn": [ |
|
|
['examples/bus.mp4',], |
|
|
['examples/koala.mp4',], |
|
|
['examples/flamingo.mp4',], |
|
|
['examples/rhino.mp4',], |
|
|
['examples/elephant.mp4',], |
|
|
['examples/sheep.mp4',], |
|
|
['examples/dog-agility.mp4',], |
|
|
|
|
|
], |
|
|
"sr": [ |
|
|
['examples/bus_sr.mp4',], |
|
|
['examples/koala_sr.mp4',], |
|
|
['examples/flamingo_sr.mp4',], |
|
|
['examples/rhino_sr.mp4',], |
|
|
['examples/elephant_sr.mp4',], |
|
|
['examples/sheep_sr.mp4',], |
|
|
['examples/dog-agility_sr.mp4',], |
|
|
|
|
|
] |
|
|
|
|
|
} |
|
|
return case[task] |
|
|
|
|
|
|
|
|
|
|
|
def update_prompt(input_video): |
|
|
video_name = input_video.split('/')[-1] |
|
|
return set_default_prompt(video_name) |
|
|
|
|
|
|
|
|
|
|
|
video_to_image = { |
|
|
'bus.mp4': ['examples_frames/bus'], |
|
|
'koala.mp4': ['examples_frames/koala'], |
|
|
'dog-gooses.mp4': ['examples_frames/dog-gooses'], |
|
|
'flamingo.mp4': ['examples_frames/flamingo'], |
|
|
'rhino.mp4': ['examples_frames/rhino'], |
|
|
'elephant.mp4': ['examples_frames/elephant'], |
|
|
'sheep.mp4': ['examples_frames/sheep'], |
|
|
'dog-agility.mp4': ['examples_frames/dog-agility'], |
|
|
|
|
|
'bus_sr.mp4': ['examples_frames/bus_sr'], |
|
|
'koala_sr.mp4': ['examples_frames/koala_sr'], |
|
|
'dog-gooses_sr.mp4': ['examples_frames/dog_gooses_sr'], |
|
|
'flamingo_sr.mp4': ['examples_frames/flamingo_sr'], |
|
|
'rhino_sr.mp4': ['examples_frames/rhino_sr'], |
|
|
'elephant_sr.mp4': ['examples_frames/elephant_sr'], |
|
|
'sheep_sr.mp4': ['examples_frames/sheep_sr'], |
|
|
'dog-agility_sr.mp4': ['examples_frames/dog-agility_sr'], |
|
|
} |
|
|
|
|
|
|
|
|
def images_to_video(image_list, output_path, fps=10): |
|
|
|
|
|
frames = [np.array(img).astype(np.uint8) for img in image_list] |
|
|
frames = frames[:20] |
|
|
|
|
|
|
|
|
writer = imageio.get_writer(output_path, fps=fps, codec='libx264') |
|
|
|
|
|
for frame in frames: |
|
|
writer.append_data(frame) |
|
|
|
|
|
writer.close() |
|
|
|
|
|
def video2frames(video_path): |
|
|
|
|
|
video = cv2.VideoCapture(video_path) |
|
|
|
|
|
img_path = video_path[:-4] |
|
|
|
|
|
frame_count = 0 |
|
|
os.makedirs(img_path, exist_ok=True) |
|
|
|
|
|
while True: |
|
|
|
|
|
ret, frame = video.read() |
|
|
|
|
|
|
|
|
if not ret: |
|
|
break |
|
|
|
|
|
|
|
|
frame_file = f"{img_path}/{frame_count:05}.jpg" |
|
|
cv2.imwrite(frame_file, frame) |
|
|
|
|
|
|
|
|
frame_count += 1 |
|
|
|
|
|
|
|
|
video.release() |
|
|
|
|
|
return img_path |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def DiffBIR_restore(input_video, prompt, sr_ratio, n_frames, n_steps, guidance_scale, seed, n_prompt, task): |
|
|
|
|
|
video_name = input_video.split('/')[-1] |
|
|
if video_name in video_to_image: |
|
|
frames_path = video_to_image[video_name][0] |
|
|
else: |
|
|
frames_path = video2frames(input_video) |
|
|
|
|
|
print(f"[INFO] input_video: {input_video}") |
|
|
print(f"[INFO] Frames path: {frames_path}") |
|
|
args = argparse.Namespace() |
|
|
|
|
|
|
|
|
args.task = task |
|
|
args.upscale = sr_ratio |
|
|
|
|
|
|
|
|
args.steps = n_steps |
|
|
args.better_start = True |
|
|
args.tiled = False |
|
|
args.tile_size = 512 |
|
|
args.tile_stride = 256 |
|
|
args.pos_prompt = prompt |
|
|
args.neg_prompt = n_prompt |
|
|
args.cfg_scale = guidance_scale |
|
|
|
|
|
args.input = frames_path |
|
|
args.n_samples = 1 |
|
|
args.batch_size = 10 |
|
|
args.final_size = (480, 854) |
|
|
args.config = "configs/inference/my_cldm.yaml" |
|
|
|
|
|
args.guidance = False |
|
|
args.g_loss = "w_mse" |
|
|
args.g_scale = 0.0 |
|
|
args.g_start = 1001 |
|
|
args.g_stop = -1 |
|
|
args.g_space = "latent" |
|
|
args.g_repeat = 1 |
|
|
|
|
|
args.output = " " |
|
|
|
|
|
args.seed = seed |
|
|
args.device = "cuda" |
|
|
|
|
|
args.n_frames = n_frames |
|
|
|
|
|
args.warp_period = [0, 0.1] |
|
|
args.merge_period = [0, 0] |
|
|
args.ToMe_period = [0, 1] |
|
|
args.merge_ratio = [0.6, 0] |
|
|
|
|
|
if args.task == "sr": |
|
|
restored_vid_path = BSRInferenceLoop(args).run() |
|
|
elif args.task == "dn": |
|
|
restored_vid_path = BIDInferenceLoop(args).run() |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
return restored_vid_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
intro = """ |
|
|
<div style="text-align:center"> |
|
|
<h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;"> |
|
|
DiffIR2VR - <small>Zero-Shot Video Restoration</small> |
|
|
</h1> |
|
|
<span>[<a target="_blank" href="https://jimmycv07.github.io/DiffIR2VR_web/">Project page</a>] [<a target="_blank" href="https://huggingface.co/papers/2406.06523">arXiv</a>]</span> |
|
|
<div style="display:flex; justify-content: center;margin-top: 0.5em">Note that this page is a limited demo of DiffIR2VR. |
|
|
For more configurations, please visit our GitHub page. The code will be released soon!</div> |
|
|
<div style="display:flex; justify-content: center;margin-top: 0.5em; color: red;">For super-resolution, |
|
|
it is recommended that the final frame size (original size * upscale ratio) be around 480x854, |
|
|
else the demo may fail due to lengthy inference times.</div> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(css="style.css") as demo: |
|
|
|
|
|
gr.HTML(intro) |
|
|
|
|
|
|
|
|
with gr.Tab(label="Super-resolution with DiffBIR"): |
|
|
with gr.Row(): |
|
|
input_video = gr.Video(label="Input Video") |
|
|
output_video = gr.Video(label="Restored Video", interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
|
run_button = gr.Button("Restore your video !", visible=True) |
|
|
|
|
|
with gr.Accordion('Advanced options', open=False): |
|
|
prompt = gr.Textbox( |
|
|
label="Prompt", |
|
|
max_lines=1, |
|
|
placeholder="describe your video content" |
|
|
|
|
|
) |
|
|
sr_ratio = gr.Slider(label='Upscale ratio', |
|
|
minimum=1, |
|
|
maximum=16, |
|
|
value=4, |
|
|
step=0.5) |
|
|
n_frames = gr.Slider(label='Frames', |
|
|
minimum=1, |
|
|
maximum=60, |
|
|
value=10, |
|
|
step=1) |
|
|
n_steps = gr.Slider(label='Steps', |
|
|
minimum=1, |
|
|
maximum=100, |
|
|
value=5, |
|
|
step=1) |
|
|
guidance_scale = gr.Slider(label='Guidance Scale', |
|
|
minimum=0.1, |
|
|
maximum=30.0, |
|
|
value=4.0, |
|
|
step=0.1) |
|
|
seed = gr.Slider(label='Seed', |
|
|
minimum=-1, |
|
|
maximum=1000, |
|
|
step=1, |
|
|
randomize=True) |
|
|
n_prompt = gr.Textbox( |
|
|
label='Negative Prompt', |
|
|
value="low quality, blurry, low-resolution, noisy, unsharp, weird textures" |
|
|
) |
|
|
task = gr.Textbox(value="sr", visible=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_button.click(fn = DiffBIR_restore, |
|
|
inputs = [input_video, |
|
|
prompt, |
|
|
sr_ratio, |
|
|
n_frames, |
|
|
n_steps, |
|
|
guidance_scale, |
|
|
seed, |
|
|
n_prompt, |
|
|
task |
|
|
], |
|
|
outputs = [output_video] |
|
|
) |
|
|
gr.Examples( |
|
|
examples=get_example("sr"), |
|
|
label='Examples', |
|
|
inputs=[input_video], |
|
|
outputs=[output_video], |
|
|
examples_per_page=7 |
|
|
) |
|
|
|
|
|
with gr.Tab(label="Denoise with DiffBIR"): |
|
|
with gr.Row(): |
|
|
input_video = gr.Video(label="Input Video") |
|
|
output_video = gr.Video(label="Restored Video", interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
|
run_button = gr.Button("Restore your video !", visible=True) |
|
|
|
|
|
with gr.Accordion('Advanced options', open=False): |
|
|
prompt = gr.Textbox( |
|
|
label="Prompt", |
|
|
max_lines=1, |
|
|
placeholder="describe your video content" |
|
|
|
|
|
) |
|
|
n_frames = gr.Slider(label='Frames', |
|
|
minimum=1, |
|
|
maximum=60, |
|
|
value=10, |
|
|
step=1) |
|
|
n_steps = gr.Slider(label='Steps', |
|
|
minimum=1, |
|
|
maximum=100, |
|
|
value=5, |
|
|
step=1) |
|
|
guidance_scale = gr.Slider(label='Guidance Scale', |
|
|
minimum=0.1, |
|
|
maximum=30.0, |
|
|
value=4.0, |
|
|
step=0.1) |
|
|
seed = gr.Slider(label='Seed', |
|
|
minimum=-1, |
|
|
maximum=1000, |
|
|
step=1, |
|
|
randomize=True) |
|
|
n_prompt = gr.Textbox( |
|
|
label='Negative Prompt', |
|
|
value="low quality, blurry, low-resolution, noisy, unsharp, weird textures" |
|
|
) |
|
|
task = gr.Textbox(value="dn", visible=False) |
|
|
sr_ratio = gr.Number(value=1, visible=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_button.click(fn = DiffBIR_restore, |
|
|
inputs = [input_video, |
|
|
prompt, |
|
|
sr_ratio, |
|
|
n_frames, |
|
|
n_steps, |
|
|
guidance_scale, |
|
|
seed, |
|
|
n_prompt, |
|
|
task |
|
|
], |
|
|
outputs = [output_video] |
|
|
) |
|
|
gr.Examples( |
|
|
examples=get_example("dn"), |
|
|
label='Examples', |
|
|
inputs=[input_video], |
|
|
outputs=[output_video], |
|
|
examples_per_page=7 |
|
|
) |
|
|
|
|
|
demo.queue() |
|
|
|
|
|
demo.launch() |