autoprogrammer commited on
Commit
e290920
·
verified ·
1 Parent(s): aa36b25

Upload DeepSeekV2Lite model

Browse files
Files changed (1) hide show
  1. modeling_deepseek.py +28 -28
modeling_deepseek.py CHANGED
@@ -566,7 +566,7 @@ class DeepseekV2MoE(nn.Module):
566
  )
567
 
568
  def forward_original(self, hidden_states):
569
- """原始的forward方法,保留用于对比和回滚"""
570
  identity = hidden_states
571
  orig_shape = hidden_states.shape
572
  topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
@@ -590,20 +590,20 @@ class DeepseekV2MoE(nn.Module):
590
 
591
  def forward(self, hidden_states):
592
  """
593
- 实现dense backward功能的前向传播:
594
- 前向输出保持与官方相同(稀疏计算结果),但在反向传播时通过dense计算的梯度传递回来
595
 
596
- Dense Backward机制说明:
597
- 1. 前向传播时,只有top-k专家参与计算(稀疏前向)
598
- 2. 反向传播时,所有专家都接收梯度(密集反向)
599
- 3. 使用直通梯度技术:sparse_output.detach() + (dense_output - dense_output.detach())
600
- 4. 通过register_hook确保只有被激活的专家才真正更新参数
601
 
602
  Args:
603
- hidden_states: 输入张量,形状为 (batch_size, seq_length, hidden_dim)
604
 
605
  Returns:
606
- 输出张量,形状为 (batch_size, seq_length, hidden_dim)
607
  """
608
  batch_size, seq_length, hidden_dim = hidden_states.shape
609
  dtype = hidden_states.dtype
@@ -612,14 +612,14 @@ class DeepseekV2MoE(nn.Module):
612
  identity = hidden_states
613
  orig_shape = hidden_states.shape
614
 
615
- # Step 1: 计算路由逻辑,选择top-k专家
616
  topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
617
  flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
618
  N_tokens = flat_hidden.size(0)
619
  flat_topk_idx = topk_idx.view(-1)
620
 
621
- # Step 2: 计算完整的路由权重(用于dense backward
622
- # 注意V2版本强制使用float32计算
623
  router_logits = F.linear(
624
  flat_hidden.type(torch.float32),
625
  self.gate.weight.type(torch.float32),
@@ -632,49 +632,49 @@ class DeepseekV2MoE(nn.Module):
632
 
633
  routing_weights = routing_weights.to(dtype=dtype)
634
 
635
- # Step 3: 准备输出累加器
636
  dense_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
637
  sparse_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
638
 
639
- # Step 4: 对每个专家进行计算
640
  if self.training:
641
- # 训练模式:实现dense backward
642
  for expert_idx in range(self.config.n_routed_experts):
643
- # V2版本的专家可能为None(在EP模式下)
644
  if self.experts[expert_idx] is None:
645
  continue
646
 
647
  expert_layer = self.experts[expert_idx]
648
 
649
- # 为所有tokens计算当前专家的输出(dense计算)
650
  expert_output = expert_layer(flat_hidden) # (N_tokens, hidden_dim)
651
 
652
- # 创建激活掩码:标记哪些token选择了该专家
653
  activation_mask = (topk_idx == expert_idx).any(dim=1).float().unsqueeze(-1).to(dtype)
654
 
655
- # 注册hook:只有被选中的token才能向该专家传递梯度
656
  if expert_output.requires_grad:
657
  expert_output.register_hook(lambda grad, mask=activation_mask: grad * mask)
658
 
659
  expert_output = expert_output.to(dtype=dtype)
660
 
661
- # Dense accumulation: 使用完整的路由权重
662
  weight_full = routing_weights[:, expert_idx].unsqueeze(-1) # (N_tokens, 1)
663
  dense_outputs = dense_outputs + expert_output * weight_full
664
 
665
- # Sparse accumulation: 只累加被选中的专家输出
666
  matches = (topk_idx == expert_idx)
667
  if matches.any():
668
  token_indices, k_indices = torch.where(matches)
669
  weights_topk = topk_weight[token_indices, k_indices].unsqueeze(-1).to(sparse_outputs.dtype) # (num_matches, 1)
670
  sparse_outputs[token_indices] = sparse_outputs[token_indices] + expert_output[token_indices] * weights_topk
671
  else:
672
- # 推理模式:使用原始的稀疏计算逻辑
673
  sparse_outputs = self.moe_infer(flat_hidden, topk_idx, topk_weight)
674
- # 推理时不需要dense_outputs
675
  dense_outputs = sparse_outputs
676
 
677
- # Step 5: 添加共享专家(如果有)
678
  if self.config.n_shared_experts is not None:
679
  shared_expert_output = self.shared_experts(identity)
680
  sparse_outputs = sparse_outputs.view(*orig_shape) + shared_expert_output
@@ -683,11 +683,11 @@ class DeepseekV2MoE(nn.Module):
683
  sparse_outputs = sparse_outputs.view(*orig_shape)
684
  dense_outputs = dense_outputs.view(*orig_shape)
685
 
686
- # Step 6: 使用直通梯度技术组合sparse前向和dense反向
687
  if self.training:
688
- # 前向用sparse,反向用dense
689
  final_output = sparse_outputs.detach() + (dense_outputs - dense_outputs.detach())
690
- # 添加辅助损失
691
  final_output = AddAuxiliaryLoss.apply(final_output, aux_loss)
692
  else:
693
  final_output = sparse_outputs
 
566
  )
567
 
568
  def forward_original(self, hidden_states):
569
+ """Original forward method, kept for comparison and rollback"""
570
  identity = hidden_states
571
  orig_shape = hidden_states.shape
572
  topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
 
590
 
591
  def forward(self, hidden_states):
592
  """
593
+ Forward pass implementing dense backward functionality:
594
+ Forward output remains the same as official (sparse computation result), but gradients flow back through dense computation during backward pass
595
 
596
+ Dense Backward mechanism explanation:
597
+ 1. During forward pass, only top-k experts participate in computation (sparse forward)
598
+ 2. During backward pass, all experts receive gradients (dense backward)
599
+ 3. Uses straight-through gradient technique: sparse_output.detach() + (dense_output - dense_output.detach())
600
+ 4. Uses register_hook to ensure only activated experts actually update parameters
601
 
602
  Args:
603
+ hidden_states: Input tensor, shape (batch_size, seq_length, hidden_dim)
604
 
605
  Returns:
606
+ Output tensor, shape (batch_size, seq_length, hidden_dim)
607
  """
608
  batch_size, seq_length, hidden_dim = hidden_states.shape
609
  dtype = hidden_states.dtype
 
612
  identity = hidden_states
613
  orig_shape = hidden_states.shape
614
 
615
+ # Step 1: Compute routing logic, select top-k experts
616
  topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
617
  flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
618
  N_tokens = flat_hidden.size(0)
619
  flat_topk_idx = topk_idx.view(-1)
620
 
621
+ # Step 2: Compute complete routing weights (for dense backward)
622
+ # Note: V2 version forces float32 computation
623
  router_logits = F.linear(
624
  flat_hidden.type(torch.float32),
625
  self.gate.weight.type(torch.float32),
 
632
 
633
  routing_weights = routing_weights.to(dtype=dtype)
634
 
635
+ # Step 3: Prepare output accumulators
636
  dense_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
637
  sparse_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
638
 
639
+ # Step 4: Compute for each expert
640
  if self.training:
641
+ # Training mode: implement dense backward
642
  for expert_idx in range(self.config.n_routed_experts):
643
+ # V2 version experts may be None (in EP mode)
644
  if self.experts[expert_idx] is None:
645
  continue
646
 
647
  expert_layer = self.experts[expert_idx]
648
 
649
+ # Compute current expert output for all tokens (dense computation)
650
  expert_output = expert_layer(flat_hidden) # (N_tokens, hidden_dim)
651
 
652
+ # Create activation mask: mark which tokens selected this expert
653
  activation_mask = (topk_idx == expert_idx).any(dim=1).float().unsqueeze(-1).to(dtype)
654
 
655
+ # Register hook: only selected tokens can pass gradients to this expert
656
  if expert_output.requires_grad:
657
  expert_output.register_hook(lambda grad, mask=activation_mask: grad * mask)
658
 
659
  expert_output = expert_output.to(dtype=dtype)
660
 
661
+ # Dense accumulation: use complete routing weights
662
  weight_full = routing_weights[:, expert_idx].unsqueeze(-1) # (N_tokens, 1)
663
  dense_outputs = dense_outputs + expert_output * weight_full
664
 
665
+ # Sparse accumulation: only accumulate selected expert outputs
666
  matches = (topk_idx == expert_idx)
667
  if matches.any():
668
  token_indices, k_indices = torch.where(matches)
669
  weights_topk = topk_weight[token_indices, k_indices].unsqueeze(-1).to(sparse_outputs.dtype) # (num_matches, 1)
670
  sparse_outputs[token_indices] = sparse_outputs[token_indices] + expert_output[token_indices] * weights_topk
671
  else:
672
+ # Inference mode: use original sparse computation logic
673
  sparse_outputs = self.moe_infer(flat_hidden, topk_idx, topk_weight)
674
+ # Dense outputs not needed during inference
675
  dense_outputs = sparse_outputs
676
 
677
+ # Step 5: Add shared experts (if any)
678
  if self.config.n_shared_experts is not None:
679
  shared_expert_output = self.shared_experts(identity)
680
  sparse_outputs = sparse_outputs.view(*orig_shape) + shared_expert_output
 
683
  sparse_outputs = sparse_outputs.view(*orig_shape)
684
  dense_outputs = dense_outputs.view(*orig_shape)
685
 
686
+ # Step 6: Use straight-through gradient technique to combine sparse forward and dense backward
687
  if self.training:
688
+ # Forward uses sparse, backward uses dense
689
  final_output = sparse_outputs.detach() + (dense_outputs - dense_outputs.detach())
690
+ # Add auxiliary loss
691
  final_output = AddAuxiliaryLoss.apply(final_output, aux_loss)
692
  else:
693
  final_output = sparse_outputs