Upload DeepSeekV2Lite model
Browse files- modeling_deepseek.py +28 -28
modeling_deepseek.py
CHANGED
|
@@ -566,7 +566,7 @@ class DeepseekV2MoE(nn.Module):
|
|
| 566 |
)
|
| 567 |
|
| 568 |
def forward_original(self, hidden_states):
|
| 569 |
-
"""
|
| 570 |
identity = hidden_states
|
| 571 |
orig_shape = hidden_states.shape
|
| 572 |
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
|
@@ -590,20 +590,20 @@ class DeepseekV2MoE(nn.Module):
|
|
| 590 |
|
| 591 |
def forward(self, hidden_states):
|
| 592 |
"""
|
| 593 |
-
|
| 594 |
-
|
| 595 |
|
| 596 |
-
Dense Backward
|
| 597 |
-
1.
|
| 598 |
-
2.
|
| 599 |
-
3.
|
| 600 |
-
4.
|
| 601 |
|
| 602 |
Args:
|
| 603 |
-
hidden_states:
|
| 604 |
|
| 605 |
Returns:
|
| 606 |
-
|
| 607 |
"""
|
| 608 |
batch_size, seq_length, hidden_dim = hidden_states.shape
|
| 609 |
dtype = hidden_states.dtype
|
|
@@ -612,14 +612,14 @@ class DeepseekV2MoE(nn.Module):
|
|
| 612 |
identity = hidden_states
|
| 613 |
orig_shape = hidden_states.shape
|
| 614 |
|
| 615 |
-
# Step 1:
|
| 616 |
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
| 617 |
flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
|
| 618 |
N_tokens = flat_hidden.size(0)
|
| 619 |
flat_topk_idx = topk_idx.view(-1)
|
| 620 |
|
| 621 |
-
# Step 2:
|
| 622 |
-
#
|
| 623 |
router_logits = F.linear(
|
| 624 |
flat_hidden.type(torch.float32),
|
| 625 |
self.gate.weight.type(torch.float32),
|
|
@@ -632,49 +632,49 @@ class DeepseekV2MoE(nn.Module):
|
|
| 632 |
|
| 633 |
routing_weights = routing_weights.to(dtype=dtype)
|
| 634 |
|
| 635 |
-
# Step 3:
|
| 636 |
dense_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
|
| 637 |
sparse_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
|
| 638 |
|
| 639 |
-
# Step 4:
|
| 640 |
if self.training:
|
| 641 |
-
#
|
| 642 |
for expert_idx in range(self.config.n_routed_experts):
|
| 643 |
-
# V2
|
| 644 |
if self.experts[expert_idx] is None:
|
| 645 |
continue
|
| 646 |
|
| 647 |
expert_layer = self.experts[expert_idx]
|
| 648 |
|
| 649 |
-
#
|
| 650 |
expert_output = expert_layer(flat_hidden) # (N_tokens, hidden_dim)
|
| 651 |
|
| 652 |
-
#
|
| 653 |
activation_mask = (topk_idx == expert_idx).any(dim=1).float().unsqueeze(-1).to(dtype)
|
| 654 |
|
| 655 |
-
#
|
| 656 |
if expert_output.requires_grad:
|
| 657 |
expert_output.register_hook(lambda grad, mask=activation_mask: grad * mask)
|
| 658 |
|
| 659 |
expert_output = expert_output.to(dtype=dtype)
|
| 660 |
|
| 661 |
-
# Dense accumulation:
|
| 662 |
weight_full = routing_weights[:, expert_idx].unsqueeze(-1) # (N_tokens, 1)
|
| 663 |
dense_outputs = dense_outputs + expert_output * weight_full
|
| 664 |
|
| 665 |
-
# Sparse accumulation:
|
| 666 |
matches = (topk_idx == expert_idx)
|
| 667 |
if matches.any():
|
| 668 |
token_indices, k_indices = torch.where(matches)
|
| 669 |
weights_topk = topk_weight[token_indices, k_indices].unsqueeze(-1).to(sparse_outputs.dtype) # (num_matches, 1)
|
| 670 |
sparse_outputs[token_indices] = sparse_outputs[token_indices] + expert_output[token_indices] * weights_topk
|
| 671 |
else:
|
| 672 |
-
#
|
| 673 |
sparse_outputs = self.moe_infer(flat_hidden, topk_idx, topk_weight)
|
| 674 |
-
#
|
| 675 |
dense_outputs = sparse_outputs
|
| 676 |
|
| 677 |
-
# Step 5:
|
| 678 |
if self.config.n_shared_experts is not None:
|
| 679 |
shared_expert_output = self.shared_experts(identity)
|
| 680 |
sparse_outputs = sparse_outputs.view(*orig_shape) + shared_expert_output
|
|
@@ -683,11 +683,11 @@ class DeepseekV2MoE(nn.Module):
|
|
| 683 |
sparse_outputs = sparse_outputs.view(*orig_shape)
|
| 684 |
dense_outputs = dense_outputs.view(*orig_shape)
|
| 685 |
|
| 686 |
-
# Step 6:
|
| 687 |
if self.training:
|
| 688 |
-
#
|
| 689 |
final_output = sparse_outputs.detach() + (dense_outputs - dense_outputs.detach())
|
| 690 |
-
#
|
| 691 |
final_output = AddAuxiliaryLoss.apply(final_output, aux_loss)
|
| 692 |
else:
|
| 693 |
final_output = sparse_outputs
|
|
|
|
| 566 |
)
|
| 567 |
|
| 568 |
def forward_original(self, hidden_states):
|
| 569 |
+
"""Original forward method, kept for comparison and rollback"""
|
| 570 |
identity = hidden_states
|
| 571 |
orig_shape = hidden_states.shape
|
| 572 |
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
|
|
|
| 590 |
|
| 591 |
def forward(self, hidden_states):
|
| 592 |
"""
|
| 593 |
+
Forward pass implementing dense backward functionality:
|
| 594 |
+
Forward output remains the same as official (sparse computation result), but gradients flow back through dense computation during backward pass
|
| 595 |
|
| 596 |
+
Dense Backward mechanism explanation:
|
| 597 |
+
1. During forward pass, only top-k experts participate in computation (sparse forward)
|
| 598 |
+
2. During backward pass, all experts receive gradients (dense backward)
|
| 599 |
+
3. Uses straight-through gradient technique: sparse_output.detach() + (dense_output - dense_output.detach())
|
| 600 |
+
4. Uses register_hook to ensure only activated experts actually update parameters
|
| 601 |
|
| 602 |
Args:
|
| 603 |
+
hidden_states: Input tensor, shape (batch_size, seq_length, hidden_dim)
|
| 604 |
|
| 605 |
Returns:
|
| 606 |
+
Output tensor, shape (batch_size, seq_length, hidden_dim)
|
| 607 |
"""
|
| 608 |
batch_size, seq_length, hidden_dim = hidden_states.shape
|
| 609 |
dtype = hidden_states.dtype
|
|
|
|
| 612 |
identity = hidden_states
|
| 613 |
orig_shape = hidden_states.shape
|
| 614 |
|
| 615 |
+
# Step 1: Compute routing logic, select top-k experts
|
| 616 |
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
| 617 |
flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
|
| 618 |
N_tokens = flat_hidden.size(0)
|
| 619 |
flat_topk_idx = topk_idx.view(-1)
|
| 620 |
|
| 621 |
+
# Step 2: Compute complete routing weights (for dense backward)
|
| 622 |
+
# Note: V2 version forces float32 computation
|
| 623 |
router_logits = F.linear(
|
| 624 |
flat_hidden.type(torch.float32),
|
| 625 |
self.gate.weight.type(torch.float32),
|
|
|
|
| 632 |
|
| 633 |
routing_weights = routing_weights.to(dtype=dtype)
|
| 634 |
|
| 635 |
+
# Step 3: Prepare output accumulators
|
| 636 |
dense_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
|
| 637 |
sparse_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
|
| 638 |
|
| 639 |
+
# Step 4: Compute for each expert
|
| 640 |
if self.training:
|
| 641 |
+
# Training mode: implement dense backward
|
| 642 |
for expert_idx in range(self.config.n_routed_experts):
|
| 643 |
+
# V2 version experts may be None (in EP mode)
|
| 644 |
if self.experts[expert_idx] is None:
|
| 645 |
continue
|
| 646 |
|
| 647 |
expert_layer = self.experts[expert_idx]
|
| 648 |
|
| 649 |
+
# Compute current expert output for all tokens (dense computation)
|
| 650 |
expert_output = expert_layer(flat_hidden) # (N_tokens, hidden_dim)
|
| 651 |
|
| 652 |
+
# Create activation mask: mark which tokens selected this expert
|
| 653 |
activation_mask = (topk_idx == expert_idx).any(dim=1).float().unsqueeze(-1).to(dtype)
|
| 654 |
|
| 655 |
+
# Register hook: only selected tokens can pass gradients to this expert
|
| 656 |
if expert_output.requires_grad:
|
| 657 |
expert_output.register_hook(lambda grad, mask=activation_mask: grad * mask)
|
| 658 |
|
| 659 |
expert_output = expert_output.to(dtype=dtype)
|
| 660 |
|
| 661 |
+
# Dense accumulation: use complete routing weights
|
| 662 |
weight_full = routing_weights[:, expert_idx].unsqueeze(-1) # (N_tokens, 1)
|
| 663 |
dense_outputs = dense_outputs + expert_output * weight_full
|
| 664 |
|
| 665 |
+
# Sparse accumulation: only accumulate selected expert outputs
|
| 666 |
matches = (topk_idx == expert_idx)
|
| 667 |
if matches.any():
|
| 668 |
token_indices, k_indices = torch.where(matches)
|
| 669 |
weights_topk = topk_weight[token_indices, k_indices].unsqueeze(-1).to(sparse_outputs.dtype) # (num_matches, 1)
|
| 670 |
sparse_outputs[token_indices] = sparse_outputs[token_indices] + expert_output[token_indices] * weights_topk
|
| 671 |
else:
|
| 672 |
+
# Inference mode: use original sparse computation logic
|
| 673 |
sparse_outputs = self.moe_infer(flat_hidden, topk_idx, topk_weight)
|
| 674 |
+
# Dense outputs not needed during inference
|
| 675 |
dense_outputs = sparse_outputs
|
| 676 |
|
| 677 |
+
# Step 5: Add shared experts (if any)
|
| 678 |
if self.config.n_shared_experts is not None:
|
| 679 |
shared_expert_output = self.shared_experts(identity)
|
| 680 |
sparse_outputs = sparse_outputs.view(*orig_shape) + shared_expert_output
|
|
|
|
| 683 |
sparse_outputs = sparse_outputs.view(*orig_shape)
|
| 684 |
dense_outputs = dense_outputs.view(*orig_shape)
|
| 685 |
|
| 686 |
+
# Step 6: Use straight-through gradient technique to combine sparse forward and dense backward
|
| 687 |
if self.training:
|
| 688 |
+
# Forward uses sparse, backward uses dense
|
| 689 |
final_output = sparse_outputs.detach() + (dense_outputs - dense_outputs.detach())
|
| 690 |
+
# Add auxiliary loss
|
| 691 |
final_output = AddAuxiliaryLoss.apply(final_output, aux_loss)
|
| 692 |
else:
|
| 693 |
final_output = sparse_outputs
|