File size: 2,412 Bytes
034f4b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    from .arch_model import NAFBlock
except:
    from arch_model import NAFBlock

class MoEBlock(nn.Module):
    def __init__(self, c, n=5, used=3):
        super().__init__()
        self.used = int(used)
        self.num_experts = n
        self.experts = nn.ModuleList([NAFBlock(c=c) for _ in range(n)])

    # Sparse implementation for large n
    def forward(self, feat, weights):
        B, _, _, _ = feat.shape
        k = self.used
        # Get top-k weights and indices
        topk_weights, topk_indices = torch.topk(weights, k, dim=1)  # (B, k)
        expert_counts = torch.bincount(topk_indices.flatten(), minlength=self.num_experts)
        # Apply l1 normalization to keep the sum to 1 and maintain aspect relation between weights
        topk_weights = topk_weights / topk_weights.sum(dim=1, keepdim=True)  # (B, k)
        mask = torch.zeros(B, self.num_experts, dtype=torch.float32, device=feat.device)
        mask.scatter_(1, topk_indices, 1.0)  # Set 1.0 for used experts
        
        # Initialize output tensor
        outputs = torch.zeros_like(feat)
        
        # Process only used experts
        for expert_idx in range(self.num_experts):
            batch_mask = mask[:, expert_idx].bool()  # Convert to boolean mask
            if batch_mask.any():
                # Get the weights for this expert
                expert_weights = topk_weights[batch_mask, (topk_indices[batch_mask] == expert_idx).nonzero()[:, 1]]
                expert_out = self.experts[expert_idx](feat[batch_mask])
                outputs[batch_mask] += expert_out * expert_weights.view(-1, 1, 1, 1)
        
        return outputs, expert_counts, weights

# 
#----------------------------------------------------------------------------------------------
if __name__ == '__main__':
    
    img_channel = 3
    width = 32

    enc_blks = [1, 2, 3]
    middle_blk_num = 3
    dec_blks = [3, 1, 1]
    dilations = [1, 4, 9]
    extra_depth_wise = True
    
    net  = MoEBlock(c = img_channel, 
                            n=5,
                            used=3)

    inp_shape = (3, 256, 256)

    from ptflops import get_model_complexity_info

    macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
    output = net(torch.randn((4, 3, 256, 256)), F.softmax(torch.randn((4,5))))