File size: 7,857 Bytes
034f4b8
 
 
 
 
a490245
 
034f4b8
5d01aa8
034f4b8
a490245
 
 
5d01aa8
a490245
 
 
 
 
 
 
5d01aa8
 
a490245
 
 
 
 
 
5d01aa8
a490245
 
 
 
 
 
 
 
5d01aa8
a490245
 
 
 
 
034f4b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d01aa8
 
 
034f4b8
 
 
 
 
 
 
 
 
 
 
a490245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
034f4b8
 
5d01aa8
034f4b8
 
 
5d01aa8
034f4b8
5d01aa8
 
034f4b8
 
 
 
a490245
 
 
034f4b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d01aa8
 
 
 
 
034f4b8
 
 
a490245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d01aa8
a490245
 
 
 
 
 
 
 
5d01aa8
a490245
5d01aa8
a490245
 
5d01aa8
a490245
034f4b8
 
a490245
 
 
 
 
 
 
 
 
 
 
 
 
 
5d01aa8
a490245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
034f4b8
2712eab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import gradio as gr
from PIL import Image
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
import os
import glob

from archs import create_model, load_model

# -------- Detect folders & images (assets/<folder>) --------
IMG_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".webp")

def list_subfolders(base="examples"):
    """Return a sorted list of immediate subfolders inside base."""
    if not os.path.isdir(base):
        return []
    subs = [d for d in sorted(os.listdir(base)) if os.path.isdir(os.path.join(base, d))]
    return subs

def list_images(folder):
    """Return full paths of images inside examples/<folder>."""
    paths = sorted(glob.glob(os.path.join("examples", folder, "*")))
    return [p for p in paths if p.lower().endswith(IMG_EXTS)]

# -------- Folder/Gallery interactions --------
def update_gallery(folder):
    """Given a folder name, return the gallery items (list of image paths) and store the same list in state."""
    files = list_images(folder)
    print(files)
    return gr.update(value=files, visible=True), files 

def load_from_gallery(evt: gr.SelectData, current_files):
    """On gallery click, load the clicked image path into the input image."""
    idx = evt.index
    if not current_files or idx is None or idx >= len(current_files):
        return gr.update()
    path = current_files[idx]
    print(path)
    return Image.open(path)


# Model

PATH_MODEL = './DeMoE.pt'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_opt = {
    'name': 'DeMoE',
    'img_channels': 3,
    'width': 32,
    'middle_blk_num': 2,
    'enc_blk_nums': [2, 2, 2, 2],
    'dec_blk_nums': [2, 2, 2, 2],
    'num_experts': 5,
    'k_used': 1
}

pil_to_tensor = transforms.ToTensor()
tensor_to_pil = transforms.ToPILImage()

model = create_model(model_opt, device)

checkpoints = torch.load(PATH_MODEL, map_location=device, weights_only=False)
model = load_model(model, PATH_MODEL, device)

model.eval()

def pad_tensor(tensor, multiple = 16):
    '''pad the tensor to be multiple of some number'''
    multiple = multiple
    _, _, H, W = tensor.shape
    pad_h = (multiple - H % multiple) % multiple
    pad_w = (multiple - W % multiple) % multiple
    tensor = F.pad(tensor, (0, pad_w, 0, pad_h), value = 0)
    
    return tensor

TASK_LABELS = ["Auto", "Defocus", "Low-Light", "Global-Motion", "Synth-Global-Motion", "Local-Motion"]

# Map pretty label -> internal task code used by the model
LABEL_TO_TASK = {
    "Auto": "auto",     
    "Low-Light": "low_light",
    "Global-Motion": "global_motion",  
    "Defocus": "defocus",  
    "Synth-Global-Motion": "synth_global_motion",
    "Local-Motion": "local_motion",          
}

def process_img(image, task_label = 'auto'):
    """Main inference: converts PIL -> tensor, pads, runs the model with selected task, clamps, crops, returns PIL."""
    task_label = LABEL_TO_TASK.get(task_label, 'auto')
    tensor = pil_to_tensor(image).unsqueeze(0).to(device)
    _, _, H, W = tensor.shape
    print('Using task:', task_label)
    tensor = pad_tensor(tensor)

    with torch.no_grad():
        output_dict = model(tensor, task_label)

    output = output_dict['output']
    # print(output.shape)
    output = torch.clamp(output, 0., 1.)
    output = output[:,:, :H, :W].squeeze(0)    
    return tensor_to_pil(output)




title = 'DeMoE 🌪️​' 
description = ''' >**Abstract**: Image deblurring, removing blurring artifacts from images, is a fundamental task in computational photography and low-level computer vision. Existing approaches focus on specialized solutions tailored to particular blur types, thus, these solutions lack generalization. This limitation in current methods implies requiring multiple models to cover several blur types, which is not practical in many real scenarios. In this paper, we introduce the first all-in-one deblurring method capable of efficiently restoring images affected by diverse blur degradations, including global motion, local motion, blur in low-light conditions, and defocus blur. We propose a mixture-of-experts (MoE) decoding module, which dynamically routes image features based on the recognized blur degradation, enabling precise and efficient restoration in an end-to-end manner. Our unified approach not only achieves performance comparable to dedicated task-specific models, but also shows promising generalization to unseen blur scenarios, particularly when leveraging appropriate expert selection.

[Daniel Feijoo](https://github.com/danifei), Paula Garrido-Mellado, Jaesung Rim, Álvaro García, Marcos V. Conde

[Fundación Cidaut](https://cidaut.ai/)


Available code at [github](https://github.com/cidautai/DeMoE). More information on the [Arxiv paper](https://arxiv.org/pdf/2508.06228). 

> **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
**This demo expects an image with some Low-Light degradations.**

<br>
'''

css = """
.fitbox img,
    .fitbox canvas {
    width: 100% !important;
    height: 100% !important;
    object-fit: contain !important;
    }
"""


# Example lists per folder under ./assets (kept simple, no helpers)
exts = (".png", ".jpg", ".jpeg", ".bmp", ".webp")

def list_basenames(folder):
    """Return [[basename, task_label], ...] for gr.Examples using examples_dir."""
    paths = sorted(glob.glob(f"examples/{folder}/*"))
    basenames = [os.path.basename(p) for p in paths if p.lower().endswith(exts)]
    # Default task per folder (tweak as you like)
    default_task = "auto"
    return [[name, default_task] for name in basenames]


examples_low_light     = list_basenames("low_light")
examples_global_motion  = list_basenames("global_motion")
examples_synth_global_motion = list_basenames("synth_global_motion")
examples_local_motion      = list_basenames("local_motion")
examples_defocus    = list_basenames("defocus")

# print(examples_defocus, examples_global_motion, examples_low_light, examples_synth_global_motion, examples_local_motion)
# -----------------------------
# Gradio Blocks layout
# -----------------------------
with gr.Blocks(css=css, title=title) as demo:
    gr.Markdown(f"# {title}\n\n{description}")

    with gr.Row():
        # Input image and the task selector (Radio)
        inp_img = gr.Image(type='pil', label='input', height=320)
        # Output image and action button
        out_img = gr.Image(type='pil', label='output', height=320)
    task_selector = gr.Radio(
    choices=TASK_LABELS,
    value="Auto",
    label="Blur type"
)

    btn = gr.Button("Restore", variant="primary")

    # Connect the button to the inference function
    btn.click(
        fn=process_img,
        inputs=[inp_img, task_selector],
        outputs=[out_img]
    )

    # Examples grouped by folder (each item loads image + task automatically)
    gr.Markdown("## Examples")
    with gr.Row():
        # List folders found in ./assets
        folders = list_subfolders("examples")
        print(folders)
        folder_radio = gr.Radio(choices=folders, label="Examples Folders", interactive=True)

        gallery = gr.Gallery(
            label="Images from the selected folder",
            visible=False,
            allow_preview=True,
            columns=6,
            height=320,
        )

    # State holds the current file list shown in the gallery (to resolve clicks)
    current_files_state = gr.State([])

    # When changing folder -> update gallery and state
    folder_radio.change(
        fn=update_gallery,
        inputs=folder_radio,
        outputs=[gallery, current_files_state]
    )

    # When clicking a thumbnail -> load it into the input image
    gallery.select(
        fn=load_from_gallery,
        inputs=[current_files_state],
        outputs=inp_img
    )

if __name__ == '__main__':
    demo.launch(show_error = True)