basic functionality
Browse files- .gitattributes +1 -0
- DeMoE.pt +3 -0
- README.md +1 -1
- app.py +96 -0
- archs/DeMoE.py +133 -0
- archs/__init__.py +59 -0
- archs/arch_model.py +105 -0
- archs/arch_util.py +79 -0
- archs/moeblocks.py +65 -0
- check_file.py +5 -0
- examples/000143.png +3 -0
- examples/0031.png +3 -0
- examples/12_blur.png +3 -0
- examples/1P0A1811.png +3 -0
- examples/blur_4.png +3 -0
- requirements.txt +18 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
DeMoE.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5b0aef148ffcb1a5572b4da6cbd33f86ed2e18639db28b46838aae46bcd011a5
|
| 3 |
+
size 80848778
|
README.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
---
|
| 2 |
title: DeMoE
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: gray
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
|
|
|
| 1 |
---
|
| 2 |
title: DeMoE
|
| 3 |
+
emoji: 🌪️
|
| 4 |
colorFrom: gray
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
app.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision.transforms as transforms
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from archs import create_model, resume_model
|
| 8 |
+
|
| 9 |
+
PATH_MODEL = './DeMoE.pt'
|
| 10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 11 |
+
model_opt = {
|
| 12 |
+
'name': 'DeMoE',
|
| 13 |
+
'img_channels': 3,
|
| 14 |
+
'width': 32,
|
| 15 |
+
'middle_blk_num': 2,
|
| 16 |
+
'enc_blk_nums': [2, 2, 2, 2],
|
| 17 |
+
'dec_blk_nums': [2, 2, 2, 2],
|
| 18 |
+
'num_experts': 5,
|
| 19 |
+
'k_used': 1
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
pil_to_tensor = transforms.ToTensor()
|
| 23 |
+
tensor_to_pil = transforms.ToPILImage()
|
| 24 |
+
|
| 25 |
+
model = create_model(model_opt, device)
|
| 26 |
+
|
| 27 |
+
checkpoints = torch.load(PATH_MODEL, map_location=device, weights_only=False)
|
| 28 |
+
model = resume_model(model, PATH_MODEL, device)
|
| 29 |
+
|
| 30 |
+
def pad_tensor(tensor, multiple = 16):
|
| 31 |
+
'''pad the tensor to be multiple of some number'''
|
| 32 |
+
multiple = multiple
|
| 33 |
+
_, _, H, W = tensor.shape
|
| 34 |
+
pad_h = (multiple - H % multiple) % multiple
|
| 35 |
+
pad_w = (multiple - W % multiple) % multiple
|
| 36 |
+
tensor = F.pad(tensor, (0, pad_w, 0, pad_h), value = 0)
|
| 37 |
+
|
| 38 |
+
return tensor
|
| 39 |
+
|
| 40 |
+
def process_img(image, task = 'auto'):
|
| 41 |
+
tensor = pil_to_tensor(image).unsqueeze(0).to(device)
|
| 42 |
+
_, _, H, W = tensor.shape
|
| 43 |
+
|
| 44 |
+
tensor = pad_tensor(tensor)
|
| 45 |
+
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
output = model(tensor, task)
|
| 48 |
+
|
| 49 |
+
output = torch.clamp(output, 0., 1.)
|
| 50 |
+
output = output[:,:, :H, :W].squeeze(0)
|
| 51 |
+
return tensor_to_pil(output)
|
| 52 |
+
|
| 53 |
+
title = 'DeMoE 🌪️'
|
| 54 |
+
description = ''' >**Abstract**: Image deblurring, removing blurring artifacts from images, is a fundamental task in computational photography and low-level computer vision. Existing approaches focus on specialized solutions tailored to particular blur types, thus, these solutions lack generalization. This limitation in current methods implies requiring multiple models to cover several blur types, which is not practical in many real scenarios. In this paper, we introduce the first all-in-one deblurring method capable of efficiently restoring images affected by diverse blur degradations, including global motion, local motion, blur in low-light conditions, and defocus blur. We propose a mixture-of-experts (MoE) decoding module, which dynamically routes image features based on the recognized blur degradation, enabling precise and efficient restoration in an end-to-end manner. Our unified approach not only achieves performance comparable to dedicated task-specific models, but also shows promising generalization to unseen blur scenarios, particularly when leveraging appropriate expert selection.
|
| 55 |
+
|
| 56 |
+
[Daniel Feijoo](https://github.com/danifei), Paula Garrido-Mellado, Jaesung Rim, Álvaro García, Marcos V. Conde
|
| 57 |
+
|
| 58 |
+
[Fundación Cidaut](https://cidaut.ai/)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
Available code at [github](https://github.com/cidautai/DeMoE). More information on the [Arxiv paper](https://arxiv.org/pdf/2508.06228).
|
| 62 |
+
|
| 63 |
+
> **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
|
| 64 |
+
**This demo expects an image with some Low-Light degradations.**
|
| 65 |
+
|
| 66 |
+
<br>
|
| 67 |
+
'''
|
| 68 |
+
|
| 69 |
+
examples = [['examples/1POA1811.png'],
|
| 70 |
+
['examples/12_blur.png'],
|
| 71 |
+
['examples/0031.png'],
|
| 72 |
+
['examples/000143.png'],
|
| 73 |
+
['examples/blur_4.png']]
|
| 74 |
+
|
| 75 |
+
css = """
|
| 76 |
+
.image-frame img, .image-container img {
|
| 77 |
+
width: auto;
|
| 78 |
+
height: auto;
|
| 79 |
+
max-width: none;
|
| 80 |
+
}
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
demo = gr.Interface(
|
| 84 |
+
fn = process_img,
|
| 85 |
+
inputs = [
|
| 86 |
+
gr.Image(type = 'pil', label = 'input')
|
| 87 |
+
],
|
| 88 |
+
outputs = [gr.Image(type='pil', label = 'output')],
|
| 89 |
+
title = title,
|
| 90 |
+
description = description,
|
| 91 |
+
examples = examples,
|
| 92 |
+
css = css
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
if __name__ == '__main__':
|
| 96 |
+
demo.launch()
|
archs/DeMoE.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from .arch_util import CustomSequential
|
| 7 |
+
from .arch_model import EfficientClassificationHead,NAFBlock
|
| 8 |
+
from .moeblocks import MoEBlock
|
| 9 |
+
except:
|
| 10 |
+
from arch_util import CustomSequential
|
| 11 |
+
from arch_model import EfficientClassificationHead, NAFBlock
|
| 12 |
+
from moeblocks import MoEBlock
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
TASKS = {'defocus': [1.0, 0, 0, 0, 0],
|
| 16 |
+
'global_motion': [0, 1.0, 0, 0, 0],
|
| 17 |
+
'local_motion': [0, 0, 1.0, 0, 0],
|
| 18 |
+
'synth_global_motion': [0, 0, 0, 1.0, 0],
|
| 19 |
+
'low_light': [0, 0, 0, 0, 1.0]}
|
| 20 |
+
|
| 21 |
+
class DeMoE(nn.Module):
|
| 22 |
+
|
| 23 |
+
def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], num_exp=5, k_used=3):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.num_experts = num_exp
|
| 27 |
+
self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
|
| 28 |
+
bias=True)
|
| 29 |
+
self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
|
| 30 |
+
bias=True)
|
| 31 |
+
|
| 32 |
+
self.encoders = nn.ModuleList()
|
| 33 |
+
self.decoders = nn.ModuleList()
|
| 34 |
+
self.middle_blks = nn.ModuleList()
|
| 35 |
+
self.ups = nn.ModuleList()
|
| 36 |
+
self.downs = nn.ModuleList()
|
| 37 |
+
self.experts = nn.ModuleList()
|
| 38 |
+
|
| 39 |
+
chan = width
|
| 40 |
+
for num in enc_blk_nums:
|
| 41 |
+
self.encoders.append(
|
| 42 |
+
CustomSequential(
|
| 43 |
+
*[NAFBlock(chan) if i==0 else NAFBlock(chan) for i in range(num)]
|
| 44 |
+
)
|
| 45 |
+
)
|
| 46 |
+
self.downs.append(
|
| 47 |
+
nn.Conv2d(chan, 2*chan, 2, 2)
|
| 48 |
+
)
|
| 49 |
+
chan = chan * 2
|
| 50 |
+
|
| 51 |
+
self.middle_blks = \
|
| 52 |
+
CustomSequential(
|
| 53 |
+
*[NAFBlock(chan) if i==0 else NAFBlock(chan) for i in range(middle_blk_num)]
|
| 54 |
+
)
|
| 55 |
+
self.experts.append(MoEBlock(c=chan, n=num_exp, used=k_used))
|
| 56 |
+
|
| 57 |
+
for num in dec_blk_nums:
|
| 58 |
+
self.ups.append(
|
| 59 |
+
nn.Sequential(
|
| 60 |
+
nn.Conv2d(chan, chan * 2, 1, bias=False),
|
| 61 |
+
nn.PixelShuffle(2)
|
| 62 |
+
)
|
| 63 |
+
)
|
| 64 |
+
chan = chan // 2
|
| 65 |
+
self.decoders.append(
|
| 66 |
+
CustomSequential(
|
| 67 |
+
*[NAFBlock(chan) if i==0 else NAFBlock(chan) for i in range(num)]
|
| 68 |
+
)
|
| 69 |
+
)
|
| 70 |
+
self.experts.append(MoEBlock(c=chan, n=num_exp, used=k_used))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
self.mlp_branch = EfficientClassificationHead(in_channels=width*2**len(enc_blk_nums), num_classes=num_exp)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
self.padder_size = 2 ** len(self.encoders)
|
| 78 |
+
|
| 79 |
+
def forward(self, inp, task = 'auto'):
|
| 80 |
+
B, C, H, W = inp.shape
|
| 81 |
+
inp = self.check_image_size(inp)
|
| 82 |
+
|
| 83 |
+
x = self.intro(inp)
|
| 84 |
+
|
| 85 |
+
encs = []
|
| 86 |
+
bins = []
|
| 87 |
+
weights = []
|
| 88 |
+
for encoder, down in zip(self.encoders, self.downs):
|
| 89 |
+
x = encoder(x)
|
| 90 |
+
encs.append(x)
|
| 91 |
+
x = down(x)
|
| 92 |
+
class_weights_0 = self.mlp_branch(x)
|
| 93 |
+
class_weights = F.softmax(class_weights_0)
|
| 94 |
+
# if the task is selected manually
|
| 95 |
+
if task != 'auto':
|
| 96 |
+
class_weights = torch.tensor(TASKS[task], device=x.device).unsqueeze(0).expand(B, -1)
|
| 97 |
+
x = self.middle_blks(x)
|
| 98 |
+
x, expert_bins, weight = self.experts[0].forward(x, class_weights)
|
| 99 |
+
bins.append(expert_bins)
|
| 100 |
+
weights.append(weight)
|
| 101 |
+
for decoder, up, enc_skip, expert in zip(self.decoders, self.ups, encs[::-1], self.experts[1::1]):
|
| 102 |
+
x = up(x)
|
| 103 |
+
x = x + enc_skip
|
| 104 |
+
x = decoder(x)
|
| 105 |
+
x, expert_bins, weight= expert.forward(x, class_weights)
|
| 106 |
+
bins.append(expert_bins)
|
| 107 |
+
weights.append(weight)
|
| 108 |
+
x = self.ending(x)
|
| 109 |
+
x = x + inp
|
| 110 |
+
|
| 111 |
+
return {'output': x[:, :, :H, :W],
|
| 112 |
+
'bin_counts': torch.stack(bins, dim=0),
|
| 113 |
+
'pred_labels': class_weights,
|
| 114 |
+
'weights': weights}
|
| 115 |
+
|
| 116 |
+
def check_image_size(self, x):
|
| 117 |
+
_, _, h, w = x.size()
|
| 118 |
+
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
|
| 119 |
+
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
|
| 120 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), value = 0)
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
if __name__=='__main__':
|
| 124 |
+
|
| 125 |
+
from ptflops import get_model_complexity_info
|
| 126 |
+
|
| 127 |
+
net = DeMoE(img_channel=3, width=32,
|
| 128 |
+
middle_blk_num=2, enc_blk_nums=[2,2,2,2], dec_blk_nums=[2,2,2,2],k_used=1)
|
| 129 |
+
print('State dict: ',len(net.state_dict().keys()))
|
| 130 |
+
macs, params = get_model_complexity_info(net, input_res=(3, 256, 256), print_per_layer_stat=False, verbose=False)
|
| 131 |
+
print(macs, params)
|
| 132 |
+
|
| 133 |
+
|
archs/__init__.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from .DeMoE import DeMoE
|
| 4 |
+
|
| 5 |
+
def create_model(opt, device):
|
| 6 |
+
'''
|
| 7 |
+
Creates the model.
|
| 8 |
+
opt: a dictionary from the yaml config key network
|
| 9 |
+
'''
|
| 10 |
+
name = opt['name']
|
| 11 |
+
|
| 12 |
+
if name == 'DeMoE':
|
| 13 |
+
model = DeMoE(img_channel=opt['img_channels'],
|
| 14 |
+
width=opt['width'],
|
| 15 |
+
middle_blk_num=opt['middle_blk_num'],
|
| 16 |
+
enc_blk_nums=opt['enc_blk_nums'],
|
| 17 |
+
dec_blk_nums=opt['dec_blk_nums'],
|
| 18 |
+
num_exp=opt['num_experts'],
|
| 19 |
+
k_used=opt['k_used'])
|
| 20 |
+
|
| 21 |
+
else:
|
| 22 |
+
raise NotImplementedError('This network is not implemented')
|
| 23 |
+
|
| 24 |
+
model.to(device)
|
| 25 |
+
|
| 26 |
+
return model
|
| 27 |
+
|
| 28 |
+
def load_weights(model, model_weights):
|
| 29 |
+
'''
|
| 30 |
+
Loads the weights of a pretrained model, picking only the weights that are
|
| 31 |
+
in the new model.
|
| 32 |
+
'''
|
| 33 |
+
new_weights = model.state_dict()
|
| 34 |
+
new_weights.update({k: v for k, v in model_weights.items() if k in new_weights})
|
| 35 |
+
|
| 36 |
+
model.load_state_dict(new_weights)
|
| 37 |
+
|
| 38 |
+
return model
|
| 39 |
+
|
| 40 |
+
def resume_model(model,
|
| 41 |
+
path_model,
|
| 42 |
+
device):
|
| 43 |
+
|
| 44 |
+
'''
|
| 45 |
+
Returns the loaded weights of model and optimizer if resume flag is True
|
| 46 |
+
'''
|
| 47 |
+
|
| 48 |
+
checkpoints = torch.load(path_model, map_location=device, weights_only=False)
|
| 49 |
+
weights = checkpoints['params']
|
| 50 |
+
model = load_weights(model, model_weights=weights)
|
| 51 |
+
|
| 52 |
+
return model
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
__all__ = ['create_model', 'resume_model', 'load_weights']
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
archs/arch_model.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
from .arch_util import LayerNorm2d
|
| 6 |
+
except:
|
| 7 |
+
from arch_util import LayerNorm2d
|
| 8 |
+
|
| 9 |
+
# ------------------------------------------------------------------------
|
| 10 |
+
# Modified from NAFNet (https://github.com/megvii-research/NAFNet)
|
| 11 |
+
# ------------------------------------------------------------------------
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SimpleGate(nn.Module):
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
x1, x2 = x.chunk(2, dim=1)
|
| 17 |
+
return x1 * x2
|
| 18 |
+
|
| 19 |
+
class NAFBlock(nn.Module):
|
| 20 |
+
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
|
| 21 |
+
super().__init__()
|
| 22 |
+
dw_channel = c * DW_Expand
|
| 23 |
+
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
| 24 |
+
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
|
| 25 |
+
bias=True) # the dconv
|
| 26 |
+
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
| 27 |
+
|
| 28 |
+
# Simplified Channel Attention
|
| 29 |
+
self.sca = nn.Sequential(
|
| 30 |
+
nn.AdaptiveAvgPool2d(1),
|
| 31 |
+
nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
|
| 32 |
+
groups=1, bias=True),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# SimpleGate
|
| 36 |
+
self.sg = SimpleGate()
|
| 37 |
+
|
| 38 |
+
ffn_channel = FFN_Expand * c
|
| 39 |
+
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
| 40 |
+
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
| 41 |
+
|
| 42 |
+
self.norm1 = LayerNorm2d(c)
|
| 43 |
+
self.norm2 = LayerNorm2d(c)
|
| 44 |
+
|
| 45 |
+
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
| 46 |
+
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
| 47 |
+
|
| 48 |
+
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
| 49 |
+
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def forward(self, inp):
|
| 53 |
+
x = inp # size [B, C, H, W]
|
| 54 |
+
|
| 55 |
+
x = self.norm1(x) # size [B, C, H, W]
|
| 56 |
+
|
| 57 |
+
x = self.conv1(x) # size [B, 2*C, H, W]
|
| 58 |
+
x = self.conv2(x) # size [B, 2*C, H, W]
|
| 59 |
+
x = self.sg(x) # size [B, C, H, W]
|
| 60 |
+
x = x * self.sca(x) # size [B, C, H, W]
|
| 61 |
+
x = self.conv3(x) # size [B, C, H, W]
|
| 62 |
+
|
| 63 |
+
x = self.dropout1(x)
|
| 64 |
+
|
| 65 |
+
y = inp + x * self.beta # size [B, C, H, W]
|
| 66 |
+
|
| 67 |
+
x = self.conv4(self.norm2(y)) # size [B, 2*C, H, W]
|
| 68 |
+
x = self.sg(x) # size [B, C, H, W]
|
| 69 |
+
x = self.conv5(x) # size [B, C, H, W]
|
| 70 |
+
|
| 71 |
+
x = self.dropout2(x)
|
| 72 |
+
|
| 73 |
+
x = y + x * self.gamma
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
return x
|
| 77 |
+
|
| 78 |
+
class EfficientClassificationHead(nn.Module):
|
| 79 |
+
|
| 80 |
+
def __init__(self, in_channels, num_classes=5):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.conv_bottleneck = nn.Sequential(
|
| 83 |
+
nn.Conv2d(in_channels, 256, kernel_size=1), # Channel reduction
|
| 84 |
+
nn.BatchNorm2d(256),
|
| 85 |
+
nn.ReLU(inplace=True),
|
| 86 |
+
nn.Dropout2d(0.2))
|
| 87 |
+
|
| 88 |
+
self.attention = nn.Sequential(
|
| 89 |
+
nn.Conv2d(256, 1, kernel_size=1),
|
| 90 |
+
nn.Sigmoid())
|
| 91 |
+
|
| 92 |
+
self.classifier = nn.Sequential(
|
| 93 |
+
nn.AdaptiveAvgPool2d(1),
|
| 94 |
+
nn.Flatten(),
|
| 95 |
+
nn.Linear(256, 128),
|
| 96 |
+
nn.BatchNorm1d(128),
|
| 97 |
+
nn.ReLU(inplace=True),
|
| 98 |
+
nn.Dropout(0.3),
|
| 99 |
+
nn.Linear(128, num_classes))
|
| 100 |
+
|
| 101 |
+
def forward(self, x):
|
| 102 |
+
x = self.conv_bottleneck(x)
|
| 103 |
+
attention_mask = self.attention(x)
|
| 104 |
+
x = x * attention_mask # Spatial attention
|
| 105 |
+
return self.classifier(x)
|
archs/arch_util.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from torch import nn as nn
|
| 4 |
+
from torch.nn import init as init
|
| 5 |
+
|
| 6 |
+
# ------------------------------------------------------------------------
|
| 7 |
+
# Modified from NAFNet (https://github.com/megvii-research/NAFNet)
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
class LayerNormFunction(torch.autograd.Function):
|
| 11 |
+
|
| 12 |
+
@staticmethod
|
| 13 |
+
def forward(ctx, x, weight, bias, eps):
|
| 14 |
+
ctx.eps = eps
|
| 15 |
+
N, C, H, W = x.size()
|
| 16 |
+
mu = x.mean(1, keepdim=True)
|
| 17 |
+
var = (x - mu).pow(2).mean(1, keepdim=True)
|
| 18 |
+
y = (x - mu) / (var + eps).sqrt()
|
| 19 |
+
ctx.save_for_backward(y, var, weight)
|
| 20 |
+
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
| 21 |
+
return y
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
def backward(ctx, grad_output):
|
| 25 |
+
eps = ctx.eps
|
| 26 |
+
|
| 27 |
+
N, C, H, W = grad_output.size()
|
| 28 |
+
y, var, weight = ctx.saved_variables
|
| 29 |
+
g = grad_output * weight.view(1, C, 1, 1)
|
| 30 |
+
mean_g = g.mean(dim=1, keepdim=True)
|
| 31 |
+
|
| 32 |
+
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
| 33 |
+
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
| 34 |
+
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
|
| 35 |
+
dim=0), None
|
| 36 |
+
|
| 37 |
+
class LayerNorm2d(nn.Module):
|
| 38 |
+
|
| 39 |
+
def __init__(self, channels, eps=1e-6):
|
| 40 |
+
super(LayerNorm2d, self).__init__()
|
| 41 |
+
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
|
| 42 |
+
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
|
| 43 |
+
self.eps = eps
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
| 47 |
+
|
| 48 |
+
def calc_mean_std(feat, eps=1e-5):
|
| 49 |
+
"""
|
| 50 |
+
Calculate mean and std for the given feature map.
|
| 51 |
+
feat: Tensor of shape [B, C, H, W]
|
| 52 |
+
eps: small value to avoid division by zero
|
| 53 |
+
"""
|
| 54 |
+
B, C, _, _ = feat.size()
|
| 55 |
+
|
| 56 |
+
# Compute mean and std for the feature map across spatial dimensions.
|
| 57 |
+
feat_mean = feat.mean(dim=1, keepdim=True)
|
| 58 |
+
feat_std = feat.var(dim=1, keepdim=True) + eps
|
| 59 |
+
feat_std = feat_std.sqrt()
|
| 60 |
+
|
| 61 |
+
return feat_mean, feat_std
|
| 62 |
+
|
| 63 |
+
class CustomSequential(nn.Module):
|
| 64 |
+
'''
|
| 65 |
+
Similar to nn.Sequential, but it lets us introduce a second argument in the forward method
|
| 66 |
+
so adaptors can be considered in the inference.
|
| 67 |
+
'''
|
| 68 |
+
def __init__(self, *args):
|
| 69 |
+
super(CustomSequential, self).__init__()
|
| 70 |
+
self.modules_list = nn.ModuleList(args)
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
for module in self.modules_list:
|
| 74 |
+
x = module(x)
|
| 75 |
+
return x
|
| 76 |
+
|
| 77 |
+
if __name__ == '__main__':
|
| 78 |
+
|
| 79 |
+
pass
|
archs/moeblocks.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from .arch_model import NAFBlock
|
| 7 |
+
except:
|
| 8 |
+
from arch_model import NAFBlock
|
| 9 |
+
|
| 10 |
+
class MoEBlock(nn.Module):
|
| 11 |
+
def __init__(self, c, n=5, used=3):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.used = int(used)
|
| 14 |
+
self.num_experts = n
|
| 15 |
+
self.experts = nn.ModuleList([NAFBlock(c=c) for _ in range(n)])
|
| 16 |
+
|
| 17 |
+
# Sparse implementation for large n
|
| 18 |
+
def forward(self, feat, weights):
|
| 19 |
+
B, _, _, _ = feat.shape
|
| 20 |
+
k = self.used
|
| 21 |
+
# Get top-k weights and indices
|
| 22 |
+
topk_weights, topk_indices = torch.topk(weights, k, dim=1) # (B, k)
|
| 23 |
+
expert_counts = torch.bincount(topk_indices.flatten(), minlength=self.num_experts)
|
| 24 |
+
# Apply l1 normalization to keep the sum to 1 and maintain aspect relation between weights
|
| 25 |
+
topk_weights = topk_weights / topk_weights.sum(dim=1, keepdim=True) # (B, k)
|
| 26 |
+
mask = torch.zeros(B, self.num_experts, dtype=torch.float32, device=feat.device)
|
| 27 |
+
mask.scatter_(1, topk_indices, 1.0) # Set 1.0 for used experts
|
| 28 |
+
|
| 29 |
+
# Initialize output tensor
|
| 30 |
+
outputs = torch.zeros_like(feat)
|
| 31 |
+
|
| 32 |
+
# Process only used experts
|
| 33 |
+
for expert_idx in range(self.num_experts):
|
| 34 |
+
batch_mask = mask[:, expert_idx].bool() # Convert to boolean mask
|
| 35 |
+
if batch_mask.any():
|
| 36 |
+
# Get the weights for this expert
|
| 37 |
+
expert_weights = topk_weights[batch_mask, (topk_indices[batch_mask] == expert_idx).nonzero()[:, 1]]
|
| 38 |
+
expert_out = self.experts[expert_idx](feat[batch_mask])
|
| 39 |
+
outputs[batch_mask] += expert_out * expert_weights.view(-1, 1, 1, 1)
|
| 40 |
+
|
| 41 |
+
return outputs, expert_counts, weights
|
| 42 |
+
|
| 43 |
+
#
|
| 44 |
+
#----------------------------------------------------------------------------------------------
|
| 45 |
+
if __name__ == '__main__':
|
| 46 |
+
|
| 47 |
+
img_channel = 3
|
| 48 |
+
width = 32
|
| 49 |
+
|
| 50 |
+
enc_blks = [1, 2, 3]
|
| 51 |
+
middle_blk_num = 3
|
| 52 |
+
dec_blks = [3, 1, 1]
|
| 53 |
+
dilations = [1, 4, 9]
|
| 54 |
+
extra_depth_wise = True
|
| 55 |
+
|
| 56 |
+
net = MoEBlock(c = img_channel,
|
| 57 |
+
n=5,
|
| 58 |
+
used=3)
|
| 59 |
+
|
| 60 |
+
inp_shape = (3, 256, 256)
|
| 61 |
+
|
| 62 |
+
from ptflops import get_model_complexity_info
|
| 63 |
+
|
| 64 |
+
macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
|
| 65 |
+
output = net(torch.randn((4, 3, 256, 256)), F.softmax(torch.randn((4,5))))
|
check_file.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
pt_dict = torch.load('DeMoE.pt', map_location='cpu')
|
| 4 |
+
print(pt_dict['params'].keys())
|
| 5 |
+
print(len(pt_dict['params'].keys()))
|
examples/000143.png
ADDED
|
Git LFS Details
|
examples/0031.png
ADDED
|
Git LFS Details
|
examples/12_blur.png
ADDED
|
Git LFS Details
|
examples/1P0A1811.png
ADDED
|
Git LFS Details
|
examples/blur_4.png
ADDED
|
Git LFS Details
|
requirements.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
einops==0.8.0
|
| 2 |
+
gradio==5.49.0
|
| 3 |
+
kornia==0.7.2
|
| 4 |
+
lpips==0.1.4
|
| 5 |
+
numpy==2.0.0
|
| 6 |
+
opencv-python==4.10.0.84
|
| 7 |
+
pandas==2.2.2
|
| 8 |
+
pillow==10.3.0
|
| 9 |
+
ptflops==0.7.3
|
| 10 |
+
pyiqa==0.1.13
|
| 11 |
+
pytorch-msssim==1.0.0
|
| 12 |
+
PyYAML==6.0.1
|
| 13 |
+
scikit-image==0.24.0
|
| 14 |
+
scipy==1.13.1
|
| 15 |
+
torch==2.5.1
|
| 16 |
+
torchaudio==2.5.1
|
| 17 |
+
torchvision==0.20.1
|
| 18 |
+
tqdm==4.66.4
|