|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
try: |
|
|
from .arch_model import NAFBlock |
|
|
except: |
|
|
from arch_model import NAFBlock |
|
|
|
|
|
class MoEBlock(nn.Module): |
|
|
def __init__(self, c, n=5, used=3): |
|
|
super().__init__() |
|
|
self.used = int(used) |
|
|
self.num_experts = n |
|
|
self.experts = nn.ModuleList([NAFBlock(c=c) for _ in range(n)]) |
|
|
|
|
|
|
|
|
def forward(self, feat, weights): |
|
|
B, _, _, _ = feat.shape |
|
|
k = self.used |
|
|
|
|
|
topk_weights, topk_indices = torch.topk(weights, k, dim=1) |
|
|
expert_counts = torch.bincount(topk_indices.flatten(), minlength=self.num_experts) |
|
|
|
|
|
topk_weights = topk_weights / topk_weights.sum(dim=1, keepdim=True) |
|
|
mask = torch.zeros(B, self.num_experts, dtype=torch.float32, device=feat.device) |
|
|
mask.scatter_(1, topk_indices, 1.0) |
|
|
|
|
|
|
|
|
outputs = torch.zeros_like(feat) |
|
|
|
|
|
|
|
|
for expert_idx in range(self.num_experts): |
|
|
batch_mask = mask[:, expert_idx].bool() |
|
|
if batch_mask.any(): |
|
|
|
|
|
expert_weights = topk_weights[batch_mask, (topk_indices[batch_mask] == expert_idx).nonzero()[:, 1]] |
|
|
expert_out = self.experts[expert_idx](feat[batch_mask]) |
|
|
outputs[batch_mask] += expert_out * expert_weights.view(-1, 1, 1, 1) |
|
|
|
|
|
return outputs, expert_counts, weights |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
img_channel = 3 |
|
|
width = 32 |
|
|
|
|
|
enc_blks = [1, 2, 3] |
|
|
middle_blk_num = 3 |
|
|
dec_blks = [3, 1, 1] |
|
|
dilations = [1, 4, 9] |
|
|
extra_depth_wise = True |
|
|
|
|
|
net = MoEBlock(c = img_channel, |
|
|
n=5, |
|
|
used=3) |
|
|
|
|
|
inp_shape = (3, 256, 256) |
|
|
|
|
|
from ptflops import get_model_complexity_info |
|
|
|
|
|
macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False) |
|
|
output = net(torch.randn((4, 3, 256, 256)), F.softmax(torch.randn((4,5)))) |
|
|
|