|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
import torch.nn.functional as F |
|
|
import os |
|
|
import glob |
|
|
|
|
|
from archs import create_model, load_model |
|
|
|
|
|
|
|
|
IMG_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".webp") |
|
|
|
|
|
def list_subfolders(base="examples"): |
|
|
"""Return a sorted list of immediate subfolders inside base.""" |
|
|
if not os.path.isdir(base): |
|
|
return [] |
|
|
subs = [d for d in sorted(os.listdir(base)) if os.path.isdir(os.path.join(base, d))] |
|
|
return subs |
|
|
|
|
|
def list_images(folder): |
|
|
"""Return full paths of images inside examples/<folder>.""" |
|
|
paths = sorted(glob.glob(os.path.join("examples", folder, "*"))) |
|
|
return [p for p in paths if p.lower().endswith(IMG_EXTS)] |
|
|
|
|
|
|
|
|
def update_gallery(folder): |
|
|
"""Given a folder name, return the gallery items (list of image paths) and store the same list in state.""" |
|
|
files = list_images(folder) |
|
|
print(files) |
|
|
return gr.update(value=files, visible=True), files |
|
|
|
|
|
def load_from_gallery(evt: gr.SelectData, current_files): |
|
|
"""On gallery click, load the clicked image path into the input image.""" |
|
|
idx = evt.index |
|
|
if not current_files or idx is None or idx >= len(current_files): |
|
|
return gr.update() |
|
|
path = current_files[idx] |
|
|
print(path) |
|
|
return Image.open(path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PATH_MODEL = './DeMoE.pt' |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model_opt = { |
|
|
'name': 'DeMoE', |
|
|
'img_channels': 3, |
|
|
'width': 32, |
|
|
'middle_blk_num': 2, |
|
|
'enc_blk_nums': [2, 2, 2, 2], |
|
|
'dec_blk_nums': [2, 2, 2, 2], |
|
|
'num_experts': 5, |
|
|
'k_used': 1 |
|
|
} |
|
|
|
|
|
pil_to_tensor = transforms.ToTensor() |
|
|
tensor_to_pil = transforms.ToPILImage() |
|
|
|
|
|
model = create_model(model_opt, device) |
|
|
|
|
|
checkpoints = torch.load(PATH_MODEL, map_location=device, weights_only=False) |
|
|
model = load_model(model, PATH_MODEL, device) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
def pad_tensor(tensor, multiple = 16): |
|
|
'''pad the tensor to be multiple of some number''' |
|
|
multiple = multiple |
|
|
_, _, H, W = tensor.shape |
|
|
pad_h = (multiple - H % multiple) % multiple |
|
|
pad_w = (multiple - W % multiple) % multiple |
|
|
tensor = F.pad(tensor, (0, pad_w, 0, pad_h), value = 0) |
|
|
|
|
|
return tensor |
|
|
|
|
|
TASK_LABELS = ["Auto", "Defocus", "Low-Light", "Global-Motion", "Synth-Global-Motion", "Local-Motion"] |
|
|
|
|
|
|
|
|
LABEL_TO_TASK = { |
|
|
"Auto": "auto", |
|
|
"Low-Light": "low_light", |
|
|
"Global-Motion": "global_motion", |
|
|
"Defocus": "defocus", |
|
|
"Synth-Global-Motion": "synth_global_motion", |
|
|
"Local-Motion": "local_motion", |
|
|
} |
|
|
|
|
|
def process_img(image, task_label = 'auto'): |
|
|
"""Main inference: converts PIL -> tensor, pads, runs the model with selected task, clamps, crops, returns PIL.""" |
|
|
task_label = LABEL_TO_TASK.get(task_label, 'auto') |
|
|
tensor = pil_to_tensor(image).unsqueeze(0).to(device) |
|
|
_, _, H, W = tensor.shape |
|
|
print('Using task:', task_label) |
|
|
tensor = pad_tensor(tensor) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output_dict = model(tensor, task_label) |
|
|
|
|
|
output = output_dict['output'] |
|
|
|
|
|
output = torch.clamp(output, 0., 1.) |
|
|
output = output[:,:, :H, :W].squeeze(0) |
|
|
return tensor_to_pil(output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
title = 'DeMoE 🌪️' |
|
|
description = ''' >**Abstract**: Image deblurring, removing blurring artifacts from images, is a fundamental task in computational photography and low-level computer vision. Existing approaches focus on specialized solutions tailored to particular blur types, thus, these solutions lack generalization. This limitation in current methods implies requiring multiple models to cover several blur types, which is not practical in many real scenarios. In this paper, we introduce the first all-in-one deblurring method capable of efficiently restoring images affected by diverse blur degradations, including global motion, local motion, blur in low-light conditions, and defocus blur. We propose a mixture-of-experts (MoE) decoding module, which dynamically routes image features based on the recognized blur degradation, enabling precise and efficient restoration in an end-to-end manner. Our unified approach not only achieves performance comparable to dedicated task-specific models, but also shows promising generalization to unseen blur scenarios, particularly when leveraging appropriate expert selection. |
|
|
|
|
|
[Daniel Feijoo](https://github.com/danifei), Paula Garrido-Mellado, Jaesung Rim, Álvaro García, Marcos V. Conde |
|
|
|
|
|
[Fundación Cidaut](https://cidaut.ai/) |
|
|
|
|
|
|
|
|
Available code at [github](https://github.com/cidautai/DeMoE). More information on the [Arxiv paper](https://arxiv.org/pdf/2508.06228). |
|
|
|
|
|
> **Disclaimer:** please remember this is not a product, thus, you will notice some limitations. |
|
|
**This demo expects an image with some Low-Light degradations.** |
|
|
|
|
|
<br> |
|
|
''' |
|
|
|
|
|
css = """ |
|
|
.fitbox img, |
|
|
.fitbox canvas { |
|
|
width: 100% !important; |
|
|
height: 100% !important; |
|
|
object-fit: contain !important; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
exts = (".png", ".jpg", ".jpeg", ".bmp", ".webp") |
|
|
|
|
|
def list_basenames(folder): |
|
|
"""Return [[basename, task_label], ...] for gr.Examples using examples_dir.""" |
|
|
paths = sorted(glob.glob(f"examples/{folder}/*")) |
|
|
basenames = [os.path.basename(p) for p in paths if p.lower().endswith(exts)] |
|
|
|
|
|
default_task = "auto" |
|
|
return [[name, default_task] for name in basenames] |
|
|
|
|
|
|
|
|
examples_low_light = list_basenames("low_light") |
|
|
examples_global_motion = list_basenames("global_motion") |
|
|
examples_synth_global_motion = list_basenames("synth_global_motion") |
|
|
examples_local_motion = list_basenames("local_motion") |
|
|
examples_defocus = list_basenames("defocus") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(css=css, title=title) as demo: |
|
|
gr.Markdown(f"# {title}\n\n{description}") |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
inp_img = gr.Image(type='pil', label='input', height=320) |
|
|
|
|
|
out_img = gr.Image(type='pil', label='output', height=320) |
|
|
task_selector = gr.Radio( |
|
|
choices=TASK_LABELS, |
|
|
value="Auto", |
|
|
label="Blur type" |
|
|
) |
|
|
|
|
|
btn = gr.Button("Restore", variant="primary") |
|
|
|
|
|
|
|
|
btn.click( |
|
|
fn=process_img, |
|
|
inputs=[inp_img, task_selector], |
|
|
outputs=[out_img] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("## Examples") |
|
|
with gr.Row(): |
|
|
|
|
|
folders = list_subfolders("examples") |
|
|
print(folders) |
|
|
folder_radio = gr.Radio(choices=folders, label="Examples Folders", interactive=True) |
|
|
|
|
|
gallery = gr.Gallery( |
|
|
label="Images from the selected folder", |
|
|
visible=False, |
|
|
allow_preview=True, |
|
|
columns=6, |
|
|
height=320, |
|
|
) |
|
|
|
|
|
|
|
|
current_files_state = gr.State([]) |
|
|
|
|
|
|
|
|
folder_radio.change( |
|
|
fn=update_gallery, |
|
|
inputs=folder_radio, |
|
|
outputs=[gallery, current_files_state] |
|
|
) |
|
|
|
|
|
|
|
|
gallery.select( |
|
|
fn=load_from_gallery, |
|
|
inputs=[current_files_state], |
|
|
outputs=inp_img |
|
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
demo.launch(show_error = True) |