DeMoE / archs /moeblocks.py
danifei's picture
basic functionality
034f4b8
raw
history blame
2.41 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from .arch_model import NAFBlock
except:
from arch_model import NAFBlock
class MoEBlock(nn.Module):
def __init__(self, c, n=5, used=3):
super().__init__()
self.used = int(used)
self.num_experts = n
self.experts = nn.ModuleList([NAFBlock(c=c) for _ in range(n)])
# Sparse implementation for large n
def forward(self, feat, weights):
B, _, _, _ = feat.shape
k = self.used
# Get top-k weights and indices
topk_weights, topk_indices = torch.topk(weights, k, dim=1) # (B, k)
expert_counts = torch.bincount(topk_indices.flatten(), minlength=self.num_experts)
# Apply l1 normalization to keep the sum to 1 and maintain aspect relation between weights
topk_weights = topk_weights / topk_weights.sum(dim=1, keepdim=True) # (B, k)
mask = torch.zeros(B, self.num_experts, dtype=torch.float32, device=feat.device)
mask.scatter_(1, topk_indices, 1.0) # Set 1.0 for used experts
# Initialize output tensor
outputs = torch.zeros_like(feat)
# Process only used experts
for expert_idx in range(self.num_experts):
batch_mask = mask[:, expert_idx].bool() # Convert to boolean mask
if batch_mask.any():
# Get the weights for this expert
expert_weights = topk_weights[batch_mask, (topk_indices[batch_mask] == expert_idx).nonzero()[:, 1]]
expert_out = self.experts[expert_idx](feat[batch_mask])
outputs[batch_mask] += expert_out * expert_weights.view(-1, 1, 1, 1)
return outputs, expert_counts, weights
#
#----------------------------------------------------------------------------------------------
if __name__ == '__main__':
img_channel = 3
width = 32
enc_blks = [1, 2, 3]
middle_blk_num = 3
dec_blks = [3, 1, 1]
dilations = [1, 4, 9]
extra_depth_wise = True
net = MoEBlock(c = img_channel,
n=5,
used=3)
inp_shape = (3, 256, 256)
from ptflops import get_model_complexity_info
macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
output = net(torch.randn((4, 3, 256, 256)), F.softmax(torch.randn((4,5))))