Update build/torch-universal/scattermoe/layers.py
#1
by
winglian
- opened
build/torch-universal/scattermoe/layers.py
CHANGED
|
@@ -50,3 +50,113 @@ class ScatterMoEGatedMLP(nn.Module):
|
|
| 50 |
layer_output = layer_output.view(bsz, length, emb_size)
|
| 51 |
return layer_output
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
layer_output = layer_output.view(bsz, length, emb_size)
|
| 51 |
return layer_output
|
| 52 |
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class HFScatterMoEGatedMLP(nn.Module):
|
| 57 |
+
"""
|
| 58 |
+
ScatterMoE-accelerated forward pass for HF MoEs based on Qwen2MoE.
|
| 59 |
+
|
| 60 |
+
This class adapts the ScatterMoE kernel to work with standard Qwen2MoE parameter names:
|
| 61 |
+
- Uses existing `gate_up_proj` and `down_proj` parameters
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def forward(
|
| 66 |
+
self: nn.Module,
|
| 67 |
+
layer_input: torch.Tensor
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Forward pass using ScatterMoE kernels with standard Qwen2MoE parameter names.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
module: The XXXMoeSparseMoeBlock module containing:
|
| 74 |
+
- module.gate: Router with .weight parameter [num_experts, hidden_size]
|
| 75 |
+
- module.experts: Qwen2MoeExperts with:
|
| 76 |
+
- .gate_up_proj: [num_experts, 2*intermediate_size, hidden_size]
|
| 77 |
+
- .down_proj: [num_experts, hidden_size, intermediate_size]
|
| 78 |
+
- module.shared_expert: Optional shared expert
|
| 79 |
+
- module.shared_expert_gate: Optional shared expert gate
|
| 80 |
+
hidden_states: Input tensor [batch_size, seq_len, hidden_size]
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
tuple: (output, router_logits) where:
|
| 84 |
+
- output: [batch_size, seq_len, hidden_size]
|
| 85 |
+
- router_logits: [batch_size * seq_len, num_experts]
|
| 86 |
+
"""
|
| 87 |
+
batch_size, sequence_length, hidden_dim = layer_input.shape
|
| 88 |
+
hidden_states_flat = layer_input.view(-1, hidden_dim)
|
| 89 |
+
|
| 90 |
+
# ============================================================================
|
| 91 |
+
# Shared Expert (if present)
|
| 92 |
+
# ============================================================================
|
| 93 |
+
if hasattr(self, 'shared_expert') and self.shared_expert is not None:
|
| 94 |
+
shared_expert_output = self.shared_expert(hidden_states_flat)
|
| 95 |
+
shared_expert_gate_output = F.sigmoid(
|
| 96 |
+
self.shared_expert_gate(hidden_states_flat)
|
| 97 |
+
)
|
| 98 |
+
shared_expert_output = shared_expert_output * shared_expert_gate_output
|
| 99 |
+
else:
|
| 100 |
+
shared_expert_output = None
|
| 101 |
+
|
| 102 |
+
# ============================================================================
|
| 103 |
+
# Router Computation
|
| 104 |
+
# ============================================================================
|
| 105 |
+
# Standard Qwen2MoE router: self.gate.weight is [num_experts, hidden_size]
|
| 106 |
+
router_logits = F.linear(hidden_states_flat, self.gate.weight) # [num_tokens, num_experts]
|
| 107 |
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 108 |
+
|
| 109 |
+
# Get top-k experts
|
| 110 |
+
top_k = self.gate.top_k
|
| 111 |
+
num_experts = self.gate.num_experts
|
| 112 |
+
routing_weights, selected_experts = torch.topk(
|
| 113 |
+
routing_weights, top_k, dim=-1
|
| 114 |
+
) # [num_tokens, top_k]
|
| 115 |
+
|
| 116 |
+
# Normalize top-k weights if required
|
| 117 |
+
if self.gate.norm_topk_prob:
|
| 118 |
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
| 119 |
+
routing_weights = routing_weights.to(hidden_states_flat.dtype)
|
| 120 |
+
|
| 121 |
+
# Flatten and sort for ScatterMoE kernel
|
| 122 |
+
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
|
| 123 |
+
selected_experts, num_experts=num_experts
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# compute experts - Input linear (gate + up projections)
|
| 127 |
+
gates, h = parallel_linear(
|
| 128 |
+
hidden_states_flat,
|
| 129 |
+
self.experts.gate_up_proj.transpose(2, 1), # [num_experts, hidden, 2*intermediate]
|
| 130 |
+
top_k,
|
| 131 |
+
sorted_expert_idxs,
|
| 132 |
+
sorted_scattered_idxs,
|
| 133 |
+
expert_offsets,
|
| 134 |
+
grouped_in=False,
|
| 135 |
+
grouped_out=True,
|
| 136 |
+
).chunk(2, dim=-1)
|
| 137 |
+
|
| 138 |
+
# Activation
|
| 139 |
+
h = self.experts.act_fn(gates) * h
|
| 140 |
+
|
| 141 |
+
# experts - Output linear
|
| 142 |
+
expert_output = parallel_linear(
|
| 143 |
+
h,
|
| 144 |
+
self.experts.down_proj.transpose(2, 1), # [num_experts, intermediate, hidden]
|
| 145 |
+
1, # Each token goes to 1 expert for the output (already routed)
|
| 146 |
+
sorted_expert_idxs,
|
| 147 |
+
sorted_scattered_idxs,
|
| 148 |
+
expert_offsets,
|
| 149 |
+
grouped_in=True,
|
| 150 |
+
grouped_out=False,
|
| 151 |
+
gates=routing_weights,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Combine with Shared Expert (if present)
|
| 155 |
+
if shared_expert_output is not None:
|
| 156 |
+
expert_output = expert_output + shared_expert_output
|
| 157 |
+
|
| 158 |
+
# Reshape to original dimensions
|
| 159 |
+
expert_output = expert_output.view(batch_size, sequence_length, hidden_dim)
|
| 160 |
+
|
| 161 |
+
return expert_output # , router_logits # HF MoE modeling doesn't return router_logits
|
| 162 |
+
|