Update build/torch-universal/scattermoe/layers.py

#1
by winglian - opened
build/torch-universal/scattermoe/layers.py CHANGED
@@ -50,3 +50,113 @@ class ScatterMoEGatedMLP(nn.Module):
50
  layer_output = layer_output.view(bsz, length, emb_size)
51
  return layer_output
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  layer_output = layer_output.view(bsz, length, emb_size)
51
  return layer_output
52
 
53
+
54
+
55
+
56
+ class HFScatterMoEGatedMLP(nn.Module):
57
+ """
58
+ ScatterMoE-accelerated forward pass for HF MoEs based on Qwen2MoE.
59
+
60
+ This class adapts the ScatterMoE kernel to work with standard Qwen2MoE parameter names:
61
+ - Uses existing `gate_up_proj` and `down_proj` parameters
62
+ """
63
+
64
+ @staticmethod
65
+ def forward(
66
+ self: nn.Module,
67
+ layer_input: torch.Tensor
68
+ ):
69
+ """
70
+ Forward pass using ScatterMoE kernels with standard Qwen2MoE parameter names.
71
+
72
+ Args:
73
+ module: The XXXMoeSparseMoeBlock module containing:
74
+ - module.gate: Router with .weight parameter [num_experts, hidden_size]
75
+ - module.experts: Qwen2MoeExperts with:
76
+ - .gate_up_proj: [num_experts, 2*intermediate_size, hidden_size]
77
+ - .down_proj: [num_experts, hidden_size, intermediate_size]
78
+ - module.shared_expert: Optional shared expert
79
+ - module.shared_expert_gate: Optional shared expert gate
80
+ hidden_states: Input tensor [batch_size, seq_len, hidden_size]
81
+
82
+ Returns:
83
+ tuple: (output, router_logits) where:
84
+ - output: [batch_size, seq_len, hidden_size]
85
+ - router_logits: [batch_size * seq_len, num_experts]
86
+ """
87
+ batch_size, sequence_length, hidden_dim = layer_input.shape
88
+ hidden_states_flat = layer_input.view(-1, hidden_dim)
89
+
90
+ # ============================================================================
91
+ # Shared Expert (if present)
92
+ # ============================================================================
93
+ if hasattr(self, 'shared_expert') and self.shared_expert is not None:
94
+ shared_expert_output = self.shared_expert(hidden_states_flat)
95
+ shared_expert_gate_output = F.sigmoid(
96
+ self.shared_expert_gate(hidden_states_flat)
97
+ )
98
+ shared_expert_output = shared_expert_output * shared_expert_gate_output
99
+ else:
100
+ shared_expert_output = None
101
+
102
+ # ============================================================================
103
+ # Router Computation
104
+ # ============================================================================
105
+ # Standard Qwen2MoE router: self.gate.weight is [num_experts, hidden_size]
106
+ router_logits = F.linear(hidden_states_flat, self.gate.weight) # [num_tokens, num_experts]
107
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
108
+
109
+ # Get top-k experts
110
+ top_k = self.gate.top_k
111
+ num_experts = self.gate.num_experts
112
+ routing_weights, selected_experts = torch.topk(
113
+ routing_weights, top_k, dim=-1
114
+ ) # [num_tokens, top_k]
115
+
116
+ # Normalize top-k weights if required
117
+ if self.gate.norm_topk_prob:
118
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
119
+ routing_weights = routing_weights.to(hidden_states_flat.dtype)
120
+
121
+ # Flatten and sort for ScatterMoE kernel
122
+ sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
123
+ selected_experts, num_experts=num_experts
124
+ )
125
+
126
+ # compute experts - Input linear (gate + up projections)
127
+ gates, h = parallel_linear(
128
+ hidden_states_flat,
129
+ self.experts.gate_up_proj.transpose(2, 1), # [num_experts, hidden, 2*intermediate]
130
+ top_k,
131
+ sorted_expert_idxs,
132
+ sorted_scattered_idxs,
133
+ expert_offsets,
134
+ grouped_in=False,
135
+ grouped_out=True,
136
+ ).chunk(2, dim=-1)
137
+
138
+ # Activation
139
+ h = self.experts.act_fn(gates) * h
140
+
141
+ # experts - Output linear
142
+ expert_output = parallel_linear(
143
+ h,
144
+ self.experts.down_proj.transpose(2, 1), # [num_experts, intermediate, hidden]
145
+ 1, # Each token goes to 1 expert for the output (already routed)
146
+ sorted_expert_idxs,
147
+ sorted_scattered_idxs,
148
+ expert_offsets,
149
+ grouped_in=True,
150
+ grouped_out=False,
151
+ gates=routing_weights,
152
+ )
153
+
154
+ # Combine with Shared Expert (if present)
155
+ if shared_expert_output is not None:
156
+ expert_output = expert_output + shared_expert_output
157
+
158
+ # Reshape to original dimensions
159
+ expert_output = expert_output.view(batch_size, sequence_length, hidden_dim)
160
+
161
+ return expert_output # , router_logits # HF MoE modeling doesn't return router_logits
162
+