import math import warnings from collections.abc import Callable from typing import Optional import torch import torch.nn.functional as F from torch import nn from transformers import initialization as init from transformers.cache_utils import Cache from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_layers import GenericForSequenceClassification, GenericForTokenClassification from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import logging from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( DeepseekV3Attention, DeepseekV3DecoderLayer, DeepseekV3ForCausalLM, DeepseekV3MLP, DeepseekV3Model, DeepseekV3MoE, DeepseekV3PreTrainedModel, DeepseekV3RMSNorm, DeepseekV3RotaryEmbedding, apply_rotary_pos_emb_interleave, yarn_get_mscale, ) from transformers.models.llama.modeling_llama import ( apply_rotary_pos_emb, eager_attention_forward, ) from configuration_deepseek_v32 import DeepseekV32Config logger = logging.get_logger(__name__) class DeepseekV32RMSNorm(DeepseekV3RMSNorm): pass class DeepseekV32RotaryEmbedding(DeepseekV3RotaryEmbedding): pass class DeepseekV32MLP(DeepseekV3MLP): pass class DeepseekV32MoE(DeepseekV3MoE): pass class DeepseekV32SparseAttention(nn.Module): """ DeepSeek V3.2 sparse attention mechanism with indexer. This implements the native sparse attention from DeepSeek V3.2 which uses an indexer to select top-k tokens for attention computation, making it more efficient for long sequences. """ def __init__(self, config: DeepseekV32Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.attention_dropout = config.attention_dropout self.num_heads = config.num_attention_heads self.q_lora_rank = config.q_lora_rank self.qk_rope_head_dim = config.qk_rope_head_dim self.kv_lora_rank = config.kv_lora_rank self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim self.qk_head_dim = config.qk_head_dim self.index_topk = config.index_topk self.is_causal = True # Query projection if self.q_lora_rank is None: self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) else: self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) self.q_a_layernorm = DeepseekV32RMSNorm(config.q_lora_rank) self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) # Key-Value projections self.kv_a_proj_with_mqa = nn.Linear( config.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=config.attention_bias, ) self.kv_a_layernorm = DeepseekV32RMSNorm(self.kv_lora_rank) self.kv_b_proj = nn.Linear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, ) # Output projection self.o_proj = nn.Linear( self.num_heads * self.v_head_dim, config.hidden_size, bias=config.attention_bias, ) # Indexer components for sparse attention self.wq_b = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) self.wk = nn.Linear(config.hidden_size, self.qk_head_dim, bias=config.attention_bias) self.k_norm = DeepseekV32RMSNorm(self.qk_head_dim) self.weights_proj = nn.Linear(config.hidden_size, self.num_heads, bias=False) self.scaling = self.qk_head_dim ** (-0.5) if self.config.rope_parameters.get("rope_type", "default") != "default": mscale_all_dim = self.config.rope_parameters.get("mscale_all_dim", 0) scaling_factor = self.config.rope_parameters["factor"] if mscale_all_dim: mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) self.scaling = self.scaling * mscale * mscale def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: batch_size, seq_length = hidden_states.shape[:-1] # For training or when index_topk is not effective, fall back to standard attention # This is a simplified implementation - in practice, you'd implement the full sparse indexer if self.training or seq_length <= self.index_topk: warnings.warn( "DeepSeek V3.2 sparse attention is not fully implemented in this version. " "Falling back to standard attention. For production use, please use vLLM or " "other optimized inference engines.", UserWarning, ) return self._standard_attention( hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs ) # Sparse attention implementation would go here # This requires custom CUDA kernels for efficient top-k selection and indexing return self._standard_attention( hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs ) def _standard_attention( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Standard attention fallback (same as DeepSeek V3)""" batch_size, seq_length = hidden_states.shape[:-1] query_shape = (batch_size, seq_length, -1, self.qk_head_dim) key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) if self.q_lora_rank is None: q_states = self.q_proj(hidden_states) else: q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q_states = q_states.view(query_shape).transpose(1, 2) q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) cos, sin = position_embeddings if self.config.rope_interleave: q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) else: q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) k_rot = k_rot.expand(*k_pass.shape[:-1], -1) query_states = torch.cat((q_pass, q_rot), dim=-1) key_states = torch.cat((k_pass, k_rot), dim=-1) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: attn_output = attn_output[:, :, :, : self.v_head_dim] attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class DeepseekV32DecoderLayer(nn.Module): def __init__(self, config: DeepseekV32Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size # Use sparse attention for V3.2 self.self_attn = DeepseekV32SparseAttention(config=config, layer_idx=layer_idx) if layer_idx >= config.first_k_dense_replace: self.mlp = DeepseekV32MoE(config) else: self.mlp = DeepseekV32MLP(config) self.input_layernorm = DeepseekV32RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = DeepseekV32RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class DeepseekV32PreTrainedModel(DeepseekV3PreTrainedModel): config_class = DeepseekV32Config _can_compile_fullgraph = False _keep_in_fp32_modules_strict = ["e_score_correction_bias"] class DeepseekV32Model(DeepseekV3Model): """ DeepSeek V3.2 Model with native sparse attention. This model extends DeepSeek V3 with an efficient sparse attention mechanism that uses an indexer to select top-k tokens for attention computation. """ config_class = DeepseekV32Config _keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"] def __init__(self, config: DeepseekV32Config): # Skip DeepseekV3Model.__init__ and go directly to PreTrainedModel DeepseekV3PreTrainedModel.__init__(self, config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) # Use V3.2-specific decoder layers self.layers = nn.ModuleList( [DeepseekV32DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = DeepseekV32RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = DeepseekV32RotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() class DeepseekV32ForCausalLM(DeepseekV3ForCausalLM): """ DeepSeek V3.2 Model for causal language modeling with sparse attention. """ config_class = DeepseekV32Config _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super(DeepseekV3ForCausalLM, self).__init__(config) self.model = DeepseekV32Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() class DeepseekV32ForSequenceClassification(GenericForSequenceClassification, DeepseekV32PreTrainedModel): pass class DeepseekV32ForTokenClassification(GenericForTokenClassification, DeepseekV32PreTrainedModel): pass __all__ = [ "DeepseekV32PreTrainedModel", "DeepseekV32Model", "DeepseekV32ForCausalLM", "DeepseekV32ForSequenceClassification", "DeepseekV32ForTokenClassification", ]