danifei commited on
Commit
034f4b8
·
1 Parent(s): 266758d

basic functionality

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
DeMoE.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b0aef148ffcb1a5572b4da6cbd33f86ed2e18639db28b46838aae46bcd011a5
3
+ size 80848778
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: DeMoE
3
- emoji: 🌍
4
  colorFrom: gray
5
  colorTo: green
6
  sdk: gradio
 
1
  ---
2
  title: DeMoE
3
+ emoji: 🌪️​
4
  colorFrom: gray
5
  colorTo: green
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ import torchvision.transforms as transforms
5
+ import torch.nn.functional as F
6
+
7
+ from archs import create_model, resume_model
8
+
9
+ PATH_MODEL = './DeMoE.pt'
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model_opt = {
12
+ 'name': 'DeMoE',
13
+ 'img_channels': 3,
14
+ 'width': 32,
15
+ 'middle_blk_num': 2,
16
+ 'enc_blk_nums': [2, 2, 2, 2],
17
+ 'dec_blk_nums': [2, 2, 2, 2],
18
+ 'num_experts': 5,
19
+ 'k_used': 1
20
+ }
21
+
22
+ pil_to_tensor = transforms.ToTensor()
23
+ tensor_to_pil = transforms.ToPILImage()
24
+
25
+ model = create_model(model_opt, device)
26
+
27
+ checkpoints = torch.load(PATH_MODEL, map_location=device, weights_only=False)
28
+ model = resume_model(model, PATH_MODEL, device)
29
+
30
+ def pad_tensor(tensor, multiple = 16):
31
+ '''pad the tensor to be multiple of some number'''
32
+ multiple = multiple
33
+ _, _, H, W = tensor.shape
34
+ pad_h = (multiple - H % multiple) % multiple
35
+ pad_w = (multiple - W % multiple) % multiple
36
+ tensor = F.pad(tensor, (0, pad_w, 0, pad_h), value = 0)
37
+
38
+ return tensor
39
+
40
+ def process_img(image, task = 'auto'):
41
+ tensor = pil_to_tensor(image).unsqueeze(0).to(device)
42
+ _, _, H, W = tensor.shape
43
+
44
+ tensor = pad_tensor(tensor)
45
+
46
+ with torch.no_grad():
47
+ output = model(tensor, task)
48
+
49
+ output = torch.clamp(output, 0., 1.)
50
+ output = output[:,:, :H, :W].squeeze(0)
51
+ return tensor_to_pil(output)
52
+
53
+ title = 'DeMoE 🌪️​'
54
+ description = ''' >**Abstract**: Image deblurring, removing blurring artifacts from images, is a fundamental task in computational photography and low-level computer vision. Existing approaches focus on specialized solutions tailored to particular blur types, thus, these solutions lack generalization. This limitation in current methods implies requiring multiple models to cover several blur types, which is not practical in many real scenarios. In this paper, we introduce the first all-in-one deblurring method capable of efficiently restoring images affected by diverse blur degradations, including global motion, local motion, blur in low-light conditions, and defocus blur. We propose a mixture-of-experts (MoE) decoding module, which dynamically routes image features based on the recognized blur degradation, enabling precise and efficient restoration in an end-to-end manner. Our unified approach not only achieves performance comparable to dedicated task-specific models, but also shows promising generalization to unseen blur scenarios, particularly when leveraging appropriate expert selection.
55
+
56
+ [Daniel Feijoo](https://github.com/danifei), Paula Garrido-Mellado, Jaesung Rim, Álvaro García, Marcos V. Conde
57
+
58
+ [Fundación Cidaut](https://cidaut.ai/)
59
+
60
+
61
+ Available code at [github](https://github.com/cidautai/DeMoE). More information on the [Arxiv paper](https://arxiv.org/pdf/2508.06228).
62
+
63
+ > **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
64
+ **This demo expects an image with some Low-Light degradations.**
65
+
66
+ <br>
67
+ '''
68
+
69
+ examples = [['examples/1POA1811.png'],
70
+ ['examples/12_blur.png'],
71
+ ['examples/0031.png'],
72
+ ['examples/000143.png'],
73
+ ['examples/blur_4.png']]
74
+
75
+ css = """
76
+ .image-frame img, .image-container img {
77
+ width: auto;
78
+ height: auto;
79
+ max-width: none;
80
+ }
81
+ """
82
+
83
+ demo = gr.Interface(
84
+ fn = process_img,
85
+ inputs = [
86
+ gr.Image(type = 'pil', label = 'input')
87
+ ],
88
+ outputs = [gr.Image(type='pil', label = 'output')],
89
+ title = title,
90
+ description = description,
91
+ examples = examples,
92
+ css = css
93
+ )
94
+
95
+ if __name__ == '__main__':
96
+ demo.launch()
archs/DeMoE.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ try:
6
+ from .arch_util import CustomSequential
7
+ from .arch_model import EfficientClassificationHead,NAFBlock
8
+ from .moeblocks import MoEBlock
9
+ except:
10
+ from arch_util import CustomSequential
11
+ from arch_model import EfficientClassificationHead, NAFBlock
12
+ from moeblocks import MoEBlock
13
+
14
+
15
+ TASKS = {'defocus': [1.0, 0, 0, 0, 0],
16
+ 'global_motion': [0, 1.0, 0, 0, 0],
17
+ 'local_motion': [0, 0, 1.0, 0, 0],
18
+ 'synth_global_motion': [0, 0, 0, 1.0, 0],
19
+ 'low_light': [0, 0, 0, 0, 1.0]}
20
+
21
+ class DeMoE(nn.Module):
22
+
23
+ def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], num_exp=5, k_used=3):
24
+ super().__init__()
25
+
26
+ self.num_experts = num_exp
27
+ self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
28
+ bias=True)
29
+ self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
30
+ bias=True)
31
+
32
+ self.encoders = nn.ModuleList()
33
+ self.decoders = nn.ModuleList()
34
+ self.middle_blks = nn.ModuleList()
35
+ self.ups = nn.ModuleList()
36
+ self.downs = nn.ModuleList()
37
+ self.experts = nn.ModuleList()
38
+
39
+ chan = width
40
+ for num in enc_blk_nums:
41
+ self.encoders.append(
42
+ CustomSequential(
43
+ *[NAFBlock(chan) if i==0 else NAFBlock(chan) for i in range(num)]
44
+ )
45
+ )
46
+ self.downs.append(
47
+ nn.Conv2d(chan, 2*chan, 2, 2)
48
+ )
49
+ chan = chan * 2
50
+
51
+ self.middle_blks = \
52
+ CustomSequential(
53
+ *[NAFBlock(chan) if i==0 else NAFBlock(chan) for i in range(middle_blk_num)]
54
+ )
55
+ self.experts.append(MoEBlock(c=chan, n=num_exp, used=k_used))
56
+
57
+ for num in dec_blk_nums:
58
+ self.ups.append(
59
+ nn.Sequential(
60
+ nn.Conv2d(chan, chan * 2, 1, bias=False),
61
+ nn.PixelShuffle(2)
62
+ )
63
+ )
64
+ chan = chan // 2
65
+ self.decoders.append(
66
+ CustomSequential(
67
+ *[NAFBlock(chan) if i==0 else NAFBlock(chan) for i in range(num)]
68
+ )
69
+ )
70
+ self.experts.append(MoEBlock(c=chan, n=num_exp, used=k_used))
71
+
72
+
73
+ self.mlp_branch = EfficientClassificationHead(in_channels=width*2**len(enc_blk_nums), num_classes=num_exp)
74
+
75
+
76
+
77
+ self.padder_size = 2 ** len(self.encoders)
78
+
79
+ def forward(self, inp, task = 'auto'):
80
+ B, C, H, W = inp.shape
81
+ inp = self.check_image_size(inp)
82
+
83
+ x = self.intro(inp)
84
+
85
+ encs = []
86
+ bins = []
87
+ weights = []
88
+ for encoder, down in zip(self.encoders, self.downs):
89
+ x = encoder(x)
90
+ encs.append(x)
91
+ x = down(x)
92
+ class_weights_0 = self.mlp_branch(x)
93
+ class_weights = F.softmax(class_weights_0)
94
+ # if the task is selected manually
95
+ if task != 'auto':
96
+ class_weights = torch.tensor(TASKS[task], device=x.device).unsqueeze(0).expand(B, -1)
97
+ x = self.middle_blks(x)
98
+ x, expert_bins, weight = self.experts[0].forward(x, class_weights)
99
+ bins.append(expert_bins)
100
+ weights.append(weight)
101
+ for decoder, up, enc_skip, expert in zip(self.decoders, self.ups, encs[::-1], self.experts[1::1]):
102
+ x = up(x)
103
+ x = x + enc_skip
104
+ x = decoder(x)
105
+ x, expert_bins, weight= expert.forward(x, class_weights)
106
+ bins.append(expert_bins)
107
+ weights.append(weight)
108
+ x = self.ending(x)
109
+ x = x + inp
110
+
111
+ return {'output': x[:, :, :H, :W],
112
+ 'bin_counts': torch.stack(bins, dim=0),
113
+ 'pred_labels': class_weights,
114
+ 'weights': weights}
115
+
116
+ def check_image_size(self, x):
117
+ _, _, h, w = x.size()
118
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
119
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
120
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), value = 0)
121
+ return x
122
+
123
+ if __name__=='__main__':
124
+
125
+ from ptflops import get_model_complexity_info
126
+
127
+ net = DeMoE(img_channel=3, width=32,
128
+ middle_blk_num=2, enc_blk_nums=[2,2,2,2], dec_blk_nums=[2,2,2,2],k_used=1)
129
+ print('State dict: ',len(net.state_dict().keys()))
130
+ macs, params = get_model_complexity_info(net, input_res=(3, 256, 256), print_per_layer_stat=False, verbose=False)
131
+ print(macs, params)
132
+
133
+
archs/__init__.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .DeMoE import DeMoE
4
+
5
+ def create_model(opt, device):
6
+ '''
7
+ Creates the model.
8
+ opt: a dictionary from the yaml config key network
9
+ '''
10
+ name = opt['name']
11
+
12
+ if name == 'DeMoE':
13
+ model = DeMoE(img_channel=opt['img_channels'],
14
+ width=opt['width'],
15
+ middle_blk_num=opt['middle_blk_num'],
16
+ enc_blk_nums=opt['enc_blk_nums'],
17
+ dec_blk_nums=opt['dec_blk_nums'],
18
+ num_exp=opt['num_experts'],
19
+ k_used=opt['k_used'])
20
+
21
+ else:
22
+ raise NotImplementedError('This network is not implemented')
23
+
24
+ model.to(device)
25
+
26
+ return model
27
+
28
+ def load_weights(model, model_weights):
29
+ '''
30
+ Loads the weights of a pretrained model, picking only the weights that are
31
+ in the new model.
32
+ '''
33
+ new_weights = model.state_dict()
34
+ new_weights.update({k: v for k, v in model_weights.items() if k in new_weights})
35
+
36
+ model.load_state_dict(new_weights)
37
+
38
+ return model
39
+
40
+ def resume_model(model,
41
+ path_model,
42
+ device):
43
+
44
+ '''
45
+ Returns the loaded weights of model and optimizer if resume flag is True
46
+ '''
47
+
48
+ checkpoints = torch.load(path_model, map_location=device, weights_only=False)
49
+ weights = checkpoints['params']
50
+ model = load_weights(model, model_weights=weights)
51
+
52
+ return model
53
+
54
+
55
+ __all__ = ['create_model', 'resume_model', 'load_weights']
56
+
57
+
58
+
59
+
archs/arch_model.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ try:
5
+ from .arch_util import LayerNorm2d
6
+ except:
7
+ from arch_util import LayerNorm2d
8
+
9
+ # ------------------------------------------------------------------------
10
+ # Modified from NAFNet (https://github.com/megvii-research/NAFNet)
11
+ # ------------------------------------------------------------------------
12
+
13
+
14
+ class SimpleGate(nn.Module):
15
+ def forward(self, x):
16
+ x1, x2 = x.chunk(2, dim=1)
17
+ return x1 * x2
18
+
19
+ class NAFBlock(nn.Module):
20
+ def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
21
+ super().__init__()
22
+ dw_channel = c * DW_Expand
23
+ self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
24
+ self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
25
+ bias=True) # the dconv
26
+ self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
27
+
28
+ # Simplified Channel Attention
29
+ self.sca = nn.Sequential(
30
+ nn.AdaptiveAvgPool2d(1),
31
+ nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
32
+ groups=1, bias=True),
33
+ )
34
+
35
+ # SimpleGate
36
+ self.sg = SimpleGate()
37
+
38
+ ffn_channel = FFN_Expand * c
39
+ self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
40
+ self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
41
+
42
+ self.norm1 = LayerNorm2d(c)
43
+ self.norm2 = LayerNorm2d(c)
44
+
45
+ self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
46
+ self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
47
+
48
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
49
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
50
+
51
+
52
+ def forward(self, inp):
53
+ x = inp # size [B, C, H, W]
54
+
55
+ x = self.norm1(x) # size [B, C, H, W]
56
+
57
+ x = self.conv1(x) # size [B, 2*C, H, W]
58
+ x = self.conv2(x) # size [B, 2*C, H, W]
59
+ x = self.sg(x) # size [B, C, H, W]
60
+ x = x * self.sca(x) # size [B, C, H, W]
61
+ x = self.conv3(x) # size [B, C, H, W]
62
+
63
+ x = self.dropout1(x)
64
+
65
+ y = inp + x * self.beta # size [B, C, H, W]
66
+
67
+ x = self.conv4(self.norm2(y)) # size [B, 2*C, H, W]
68
+ x = self.sg(x) # size [B, C, H, W]
69
+ x = self.conv5(x) # size [B, C, H, W]
70
+
71
+ x = self.dropout2(x)
72
+
73
+ x = y + x * self.gamma
74
+
75
+
76
+ return x
77
+
78
+ class EfficientClassificationHead(nn.Module):
79
+
80
+ def __init__(self, in_channels, num_classes=5):
81
+ super().__init__()
82
+ self.conv_bottleneck = nn.Sequential(
83
+ nn.Conv2d(in_channels, 256, kernel_size=1), # Channel reduction
84
+ nn.BatchNorm2d(256),
85
+ nn.ReLU(inplace=True),
86
+ nn.Dropout2d(0.2))
87
+
88
+ self.attention = nn.Sequential(
89
+ nn.Conv2d(256, 1, kernel_size=1),
90
+ nn.Sigmoid())
91
+
92
+ self.classifier = nn.Sequential(
93
+ nn.AdaptiveAvgPool2d(1),
94
+ nn.Flatten(),
95
+ nn.Linear(256, 128),
96
+ nn.BatchNorm1d(128),
97
+ nn.ReLU(inplace=True),
98
+ nn.Dropout(0.3),
99
+ nn.Linear(128, num_classes))
100
+
101
+ def forward(self, x):
102
+ x = self.conv_bottleneck(x)
103
+ attention_mask = self.attention(x)
104
+ x = x * attention_mask # Spatial attention
105
+ return self.classifier(x)
archs/arch_util.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch import nn as nn
4
+ from torch.nn import init as init
5
+
6
+ # ------------------------------------------------------------------------
7
+ # Modified from NAFNet (https://github.com/megvii-research/NAFNet)
8
+ # ------------------------------------------------------------------------
9
+
10
+ class LayerNormFunction(torch.autograd.Function):
11
+
12
+ @staticmethod
13
+ def forward(ctx, x, weight, bias, eps):
14
+ ctx.eps = eps
15
+ N, C, H, W = x.size()
16
+ mu = x.mean(1, keepdim=True)
17
+ var = (x - mu).pow(2).mean(1, keepdim=True)
18
+ y = (x - mu) / (var + eps).sqrt()
19
+ ctx.save_for_backward(y, var, weight)
20
+ y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
21
+ return y
22
+
23
+ @staticmethod
24
+ def backward(ctx, grad_output):
25
+ eps = ctx.eps
26
+
27
+ N, C, H, W = grad_output.size()
28
+ y, var, weight = ctx.saved_variables
29
+ g = grad_output * weight.view(1, C, 1, 1)
30
+ mean_g = g.mean(dim=1, keepdim=True)
31
+
32
+ mean_gy = (g * y).mean(dim=1, keepdim=True)
33
+ gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
34
+ return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
35
+ dim=0), None
36
+
37
+ class LayerNorm2d(nn.Module):
38
+
39
+ def __init__(self, channels, eps=1e-6):
40
+ super(LayerNorm2d, self).__init__()
41
+ self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
42
+ self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
43
+ self.eps = eps
44
+
45
+ def forward(self, x):
46
+ return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
47
+
48
+ def calc_mean_std(feat, eps=1e-5):
49
+ """
50
+ Calculate mean and std for the given feature map.
51
+ feat: Tensor of shape [B, C, H, W]
52
+ eps: small value to avoid division by zero
53
+ """
54
+ B, C, _, _ = feat.size()
55
+
56
+ # Compute mean and std for the feature map across spatial dimensions.
57
+ feat_mean = feat.mean(dim=1, keepdim=True)
58
+ feat_std = feat.var(dim=1, keepdim=True) + eps
59
+ feat_std = feat_std.sqrt()
60
+
61
+ return feat_mean, feat_std
62
+
63
+ class CustomSequential(nn.Module):
64
+ '''
65
+ Similar to nn.Sequential, but it lets us introduce a second argument in the forward method
66
+ so adaptors can be considered in the inference.
67
+ '''
68
+ def __init__(self, *args):
69
+ super(CustomSequential, self).__init__()
70
+ self.modules_list = nn.ModuleList(args)
71
+
72
+ def forward(self, x):
73
+ for module in self.modules_list:
74
+ x = module(x)
75
+ return x
76
+
77
+ if __name__ == '__main__':
78
+
79
+ pass
archs/moeblocks.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ try:
6
+ from .arch_model import NAFBlock
7
+ except:
8
+ from arch_model import NAFBlock
9
+
10
+ class MoEBlock(nn.Module):
11
+ def __init__(self, c, n=5, used=3):
12
+ super().__init__()
13
+ self.used = int(used)
14
+ self.num_experts = n
15
+ self.experts = nn.ModuleList([NAFBlock(c=c) for _ in range(n)])
16
+
17
+ # Sparse implementation for large n
18
+ def forward(self, feat, weights):
19
+ B, _, _, _ = feat.shape
20
+ k = self.used
21
+ # Get top-k weights and indices
22
+ topk_weights, topk_indices = torch.topk(weights, k, dim=1) # (B, k)
23
+ expert_counts = torch.bincount(topk_indices.flatten(), minlength=self.num_experts)
24
+ # Apply l1 normalization to keep the sum to 1 and maintain aspect relation between weights
25
+ topk_weights = topk_weights / topk_weights.sum(dim=1, keepdim=True) # (B, k)
26
+ mask = torch.zeros(B, self.num_experts, dtype=torch.float32, device=feat.device)
27
+ mask.scatter_(1, topk_indices, 1.0) # Set 1.0 for used experts
28
+
29
+ # Initialize output tensor
30
+ outputs = torch.zeros_like(feat)
31
+
32
+ # Process only used experts
33
+ for expert_idx in range(self.num_experts):
34
+ batch_mask = mask[:, expert_idx].bool() # Convert to boolean mask
35
+ if batch_mask.any():
36
+ # Get the weights for this expert
37
+ expert_weights = topk_weights[batch_mask, (topk_indices[batch_mask] == expert_idx).nonzero()[:, 1]]
38
+ expert_out = self.experts[expert_idx](feat[batch_mask])
39
+ outputs[batch_mask] += expert_out * expert_weights.view(-1, 1, 1, 1)
40
+
41
+ return outputs, expert_counts, weights
42
+
43
+ #
44
+ #----------------------------------------------------------------------------------------------
45
+ if __name__ == '__main__':
46
+
47
+ img_channel = 3
48
+ width = 32
49
+
50
+ enc_blks = [1, 2, 3]
51
+ middle_blk_num = 3
52
+ dec_blks = [3, 1, 1]
53
+ dilations = [1, 4, 9]
54
+ extra_depth_wise = True
55
+
56
+ net = MoEBlock(c = img_channel,
57
+ n=5,
58
+ used=3)
59
+
60
+ inp_shape = (3, 256, 256)
61
+
62
+ from ptflops import get_model_complexity_info
63
+
64
+ macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
65
+ output = net(torch.randn((4, 3, 256, 256)), F.softmax(torch.randn((4,5))))
check_file.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch
2
+
3
+ pt_dict = torch.load('DeMoE.pt', map_location='cpu')
4
+ print(pt_dict['params'].keys())
5
+ print(len(pt_dict['params'].keys()))
examples/000143.png ADDED

Git LFS Details

  • SHA256: 2522a72a4f83f6cf363848628914007efe6f2d98490a11a37985c934ec746e85
  • Pointer size: 131 Bytes
  • Size of remote file: 839 kB
examples/0031.png ADDED

Git LFS Details

  • SHA256: c566978dafd3282daa3f00e3366444f5acd96eedd4855d88f8b109289e7a7d31
  • Pointer size: 131 Bytes
  • Size of remote file: 246 kB
examples/12_blur.png ADDED

Git LFS Details

  • SHA256: a873478349366796559b7d9c934d66bfa24a496197cf23b465c55d7a39efddd3
  • Pointer size: 132 Bytes
  • Size of remote file: 3.47 MB
examples/1P0A1811.png ADDED

Git LFS Details

  • SHA256: 0f15ee1e9e132a78121d30807ac736ef2002bdfca77d3b6d6f12f7c917477b1b
  • Pointer size: 132 Bytes
  • Size of remote file: 7.22 MB
examples/blur_4.png ADDED

Git LFS Details

  • SHA256: 4daac3165f76b91c48f80562196d5c357f849d34a0db1024c264142331c216b3
  • Pointer size: 131 Bytes
  • Size of remote file: 553 kB
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops==0.8.0
2
+ gradio==5.49.0
3
+ kornia==0.7.2
4
+ lpips==0.1.4
5
+ numpy==2.0.0
6
+ opencv-python==4.10.0.84
7
+ pandas==2.2.2
8
+ pillow==10.3.0
9
+ ptflops==0.7.3
10
+ pyiqa==0.1.13
11
+ pytorch-msssim==1.0.0
12
+ PyYAML==6.0.1
13
+ scikit-image==0.24.0
14
+ scipy==1.13.1
15
+ torch==2.5.1
16
+ torchaudio==2.5.1
17
+ torchvision==0.20.1
18
+ tqdm==4.66.4