# coding=utf-8 import math from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.nn.functional as F import triton import triton.language as tl from torch import nn from torch.library import triton_op, wrap_triton from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.generation.utils import GenerationMixin from transformers.modeling_outputs import MoeModelOutputWithPast from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13 from transformers.utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, ) from transformers.utils import logging as hf_logging from transformers.utils.import_utils import is_torch_fx_available from .configuration_bailing_moe_v2 import BailingMoeV2Config logger = hf_logging.get_logger(__name__) _CONFIG_FOR_DOC = "BailingMoeV2Config" # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. if is_torch_fx_available(): if not is_torch_greater_or_equal_than_1_13: import torch.fx # noqa: F401 # quantizers def twn_torch_ref(W): W_fp = W.float() dim = -1 # Always last dim absW = W_fp.abs() th = absW.mean(dim, keepdim=True) * 0.7 mask = absW > th mask_f = mask.float() alpha = (absW * mask_f).sum(dim, keepdim=True) / mask_f.sum(dim, keepdim=True).clamp(min=1.0) out = W_fp.sign() * mask_f * alpha return out.to(W.dtype) twn_torch_compiled = torch.compile(twn_torch_ref, mode="max-autotune") @triton.autotune( configs=[ triton.Config({"BLOCK_SIZE": 128}, num_warps=4, num_stages=3), triton.Config({"BLOCK_SIZE": 256}, num_warps=4, num_stages=3), triton.Config({"BLOCK_SIZE": 512}, num_warps=8, num_stages=3), triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=3), triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=3), ], key=["N"], ) @triton.jit def twn_quant_row_merged_bf16_kernel( w_ptr, out_ptr, M, N, stride_wm, stride_wn, stride_om, stride_on, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) if pid >= M: return row_w_ptr = w_ptr + pid * stride_wm row_out_ptr = out_ptr + pid * stride_om # --- Pass 1: Threshold --- sum_abs = 0.0 count = 0.0 for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N val = tl.load(row_w_ptr + cols * stride_wn, mask=mask, other=0.0).to(tl.float32) val_abs = tl.abs(val) sum_abs += tl.sum(val_abs, axis=0) count += tl.sum(mask.to(tl.float32), axis=0) th = (sum_abs / tl.maximum(count, 1.0)) * 0.7 # --- Pass 2: Alpha --- masked_sum = 0.0 masked_count = 0.0 for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N val = tl.load(row_w_ptr + cols * stride_wn, mask=mask, other=0.0).to(tl.float32) val_abs = tl.abs(val) is_selected = (val_abs > th).to(tl.float32) masked_sum += tl.sum(val_abs * is_selected, axis=0) masked_count += tl.sum(is_selected, axis=0) alpha = masked_sum / tl.maximum(masked_count, 1.0) # --- Pass 3: Output --- for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N val = tl.load(row_w_ptr + cols * stride_wn, mask=mask, other=0.0).to(tl.float32) is_selected = tl.abs(val) > th # Output is -alpha, 0, or +alpha sign = tl.where(val >= 0, alpha, -alpha) out_val = tl.where(is_selected, sign, 0.0) tl.store(row_out_ptr + cols * stride_on, out_val.to(tl.bfloat16), mask=mask) @triton_op("grove_kernels::twn_triton", mutates_args={}) def twn_triton(W: torch.Tensor) -> torch.Tensor: M, N = W.shape out = torch.empty_like(W, dtype=torch.bfloat16) grid = (M,) wrap_triton(twn_quant_row_merged_bf16_kernel)[grid]( W, out, M, N, W.stride(0), W.stride(1), out.stride(0), out.stride(1), ) return out class QuantizeTernary(torch.autograd.Function): @staticmethod def forward(ctx, input): # with torch.no_grad(): if len(input.shape) == 3: return twn_torch_ref(input) # fatser when else: return twn_triton(input) @staticmethod def backward(ctx, grad_output): # Straight-Through Estimator: gradient is just passed through. return grad_output, None def quantize(input: torch.Tensor) -> torch.Tensor: return QuantizeTernary.apply(input) def conditionally_quantize(input: torch.Tensor, do_quantize: bool) -> torch.Tensor: if do_quantize: return quantize(input) else: return input def quantize_weight_inplace(owner: nn.Module, weight_name: str, enabled: bool = True) -> bool: """ Quantize `owner.` once and write back to the same Parameter storage. Returns True if this call performed quantization, False if skipped/already-done. """ if not enabled: return False done_attr = f"__inplace_quantized_{weight_name}" if bool(getattr(owner, done_attr, False)): return False weight = getattr(owner, weight_name) with torch.no_grad(): quantized = quantize(weight).to(device=weight.device, dtype=weight.dtype) weight.data.copy_(quantized) setattr(owner, done_attr, True) return True def conditionally_quantize_inplace_on_prefill( owner: nn.Module, weight_name: str, do_quantize: bool, *, quantize_inplace_now: bool = False, ) -> torch.Tensor: """ In eval mode, quantize the target weight once (during prefill) and write it back in-place. This avoids storing duplicate cached tensors while removing per-token quantization overhead. """ weight = getattr(owner, weight_name) if not do_quantize: return weight if owner.training: return quantize(weight) if not quantize_inplace_now: return weight quantize_weight_inplace(owner, weight_name, enabled=True) return weight @dataclass class MoEV2CausalLMOutputWithPast(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[Cache] = None hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None attentions: Optional[tuple[torch.FloatTensor, ...]] = None z_loss: Optional[torch.FloatTensor] = None aux_loss: Optional[torch.FloatTensor] = None router_logits: Optional[tuple[torch.FloatTensor]] = None mtp_loss: Optional[torch.FloatTensor] = None mtp_logits: Optional[tuple[torch.FloatTensor, ...]] = None class MoeV2ModelOutputWithPast(MoeModelOutputWithPast): def __init__(self, mtp_hidden_states=None, aux_loss=0.0, **kwargs): super().__init__(**kwargs) self.mtp_hidden_states = mtp_hidden_states self.aux_loss = aux_loss class BailingMoeV2RotaryEmbedding(nn.Module): def __init__(self, config: BailingMoeV2Config, device=None): super().__init__() if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling freqs = torch.cat([freqs, freqs], dim=-1) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype), freqs.float() def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) rotary_dim = cos.shape[-1] q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) q_embed = torch.cat([q_embed, q_pass], dim=-1) k_embed = torch.cat([k_embed, k_pass], dim=-1) return q_embed, k_embed class BailingMoeV2MLP(nn.Module): def __init__(self, config: BailingMoeV2Config, intermediate_size: int): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x, quantize_inplace_now: bool = False): down_weight, gate_weight, up_weight = ( conditionally_quantize_inplace_on_prefill( self.down_proj, "weight", self.config.quantize, quantize_inplace_now=quantize_inplace_now, ), conditionally_quantize_inplace_on_prefill( self.gate_proj, "weight", self.config.quantize, quantize_inplace_now=quantize_inplace_now, ), conditionally_quantize_inplace_on_prefill( self.up_proj, "weight", self.config.quantize, quantize_inplace_now=quantize_inplace_now, ), ) return torch.nn.functional.linear( self.act_fn(torch.nn.functional.linear(x, gate_weight)) * torch.nn.functional.linear(x, up_weight), down_weight, ) class BailingMoeV2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) try: from liger_kernel.transformers.rms_norm import LigerRMSNorm BailingMoeV2RMSNorm = LigerRMSNorm except: print("no liger kernel") class BailingMoeV2Gate(nn.Module): expert_bias: torch.Tensor def __init__(self, config: BailingMoeV2Config): super().__init__() self.config = config self.top_k = config.num_experts_per_tok self.num_experts = config.num_experts self.n_group = config.n_group self.topk_group = config.topk_group self.gating_dim = config.hidden_size self.weight = nn.Parameter(torch.empty((self.num_experts, self.gating_dim))) # self.bias = nn.Parameter(torch.zeros((self.num_experts))) self.routed_scaling_factor = config.routed_scaling_factor self.register_buffer("expert_bias", torch.zeros((self.num_experts))) self.reset_parameters() def reset_parameters(self) -> None: import torch.nn.init as init init.kaiming_uniform_(self.weight, a=math.sqrt(5)) def group_limited_topk(self, scores: torch.Tensor): num_tokens, _ = scores.size() group_scores = scores.view(num_tokens, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] group_mask = torch.zeros_like(group_scores) group_mask.scatter_(1, group_idx, 1) score_mask = ( group_mask.unsqueeze(-1) .expand(num_tokens, self.n_group, self.num_experts // self.n_group) .reshape(num_tokens, -1) ) masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) probs, top_indices = torch.topk(masked_scores, k=self.top_k, dim=-1) return probs, top_indices def forward(self, hidden_states: torch.Tensor): hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) scores = torch.sigmoid(logits.float()).type_as(logits) scores_for_routing = scores + self.expert_bias _, topk_idx = self.group_limited_topk(scores_for_routing) scores = torch.gather(scores, dim=1, index=topk_idx).type_as(logits) topk_weight = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.top_k > 1 else scores topk_weight = topk_weight * self.routed_scaling_factor return topk_idx, topk_weight, logits class BailingMoeV2SparseMoeBlock(nn.Module): """ Unfused MoE block matching Ling-mini HF layout (ModuleList experts). """ def __init__(self, config) -> None: super().__init__() self.config = config self.num_experts_per_tok = config.num_experts_per_tok self._setup_experts() self.gate = BailingMoeV2Gate(config) if config.num_shared_experts is not None: self.shared_experts = BailingMoeV2MLP( config=config, intermediate_size=config.moe_intermediate_size * config.num_shared_experts, ) def _setup_experts(self): self.experts = nn.ModuleList( [ BailingMoeV2MLP( config=self.config, intermediate_size=self.config.moe_intermediate_size, ) for _ in range(self.config.num_experts) ] ) def forward( self, hidden_states: torch.Tensor, quantize_inplace_now: bool = False ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: original_shape = hidden_states.shape identity = hidden_states bsz, seq_len, h = hidden_states.shape topk_idx, topk_weight, router_logits = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) flat_topk_idx = topk_idx.view(-1) if self.training: hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0) y = torch.empty_like(hidden_states) for i, expert in enumerate(self.experts): y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) y = y.to(hidden_states.dtype).view(bsz, seq_len, h) else: y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(bsz, seq_len, h) if self.config.num_shared_experts is not None: y = y + self.shared_experts(identity) return y @torch.no_grad() def moe_infer(self, x, topk_ids, topk_weight): cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) cnts.scatter_(1, topk_ids, 1) tokens_per_expert = cnts.sum(dim=0) idxs = topk_ids.view(-1).argsort() sorted_tokens = x[idxs // topk_ids.shape[1]] tokens_per_expert = tokens_per_expert.cpu().numpy() outputs = [] start_idx = 0 for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens if num_tokens == 0: continue expert = self.experts[i] tokens_for_this_expert = sorted_tokens[start_idx:end_idx] expert_out = expert(tokens_for_this_expert) outputs.append(expert_out.to(x.device)) start_idx = end_idx outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) new_x = torch.empty_like(outs) new_x[idxs] = outs final_out = ( new_x.view(*topk_ids.shape, -1) .type(topk_weight.dtype) .mul_(topk_weight.unsqueeze(dim=-1)) .sum(dim=1) .type(new_x.dtype) ) return final_out class BailingMoeV2Attention(nn.Module): """Fixed wiring for modern HF attention APIs: uses prepared causal_mask + cache_position + Cache.update().""" def __init__(self, config: BailingMoeV2Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please pass `layer_idx`." ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = config.head_dim or self.hidden_size // self.num_heads self.scaling = self.head_dim**-0.5 partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) self.rope_dim = int(self.head_dim * partial_rotary_factor) self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.is_causal = True self.sliding_window = None self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None self.query_key_value = nn.Linear( self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=config.use_qkv_bias, ) if self.config.use_qk_norm: self.query_layernorm = BailingMoeV2RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.key_layernorm = BailingMoeV2RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, # IMPORTANT: pass prepared causal_mask here position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, # IMPORTANT: needed for modern cache update position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: quantize_inplace_now = bool(kwargs.pop("quantize_inplace_now", False)) bsz, q_len, _ = hidden_states.size() qkv_weight = conditionally_quantize_inplace_on_prefill( self.query_key_value, "weight", self.config.quantize, quantize_inplace_now=quantize_inplace_now, ) out_qkv = torch.nn.functional.linear(hidden_states, qkv_weight) cos, sin, _freqs = position_embeddings # fused path # if self.sliding_window is not None: # query_states, key_states, value_states = functional_fused_split_transpose_rope_qknorm( # out_qkv, # self.query_layernorm.weight, # self.key_layernorm.weight, # _freqs.contiguous(), # self.config.num_attention_heads, # self.config.num_key_value_heads, # ) # else: # query_states, key_states, value_states = functional_fused_split_transpose_qknorm( # out_qkv, # self.query_layernorm.weight, # self.key_layernorm.weight, # _freqs.contiguous(), # self.config.num_attention_heads, # self.config.num_key_value_heads, # ) qkv = out_qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) query_states, key_states, value_states = qkv.split( [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2 ) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) if self.config.use_qk_norm: query_states = self.query_layernorm(query_states) key_states = self.key_layernorm(key_states) if self.sliding_window is not None: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # ---- Modern Cache update wiring (DynamicCache / StaticCache compatible) ---- if use_cache and past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # fa should transpose internally attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, # prepared causal mask (or None for varlen flash path) dropout=0.0, position_ids=position_ids, scaling=self.scaling, sliding_window=self.sliding_window, # keep your prototype behavior **kwargs, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() dense_weight = conditionally_quantize_inplace_on_prefill( self.dense, "weight", self.config.quantize, quantize_inplace_now=quantize_inplace_now, ) attn_output = torch.nn.functional.linear(attn_output, dense_weight) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class BailingMoeV2DecoderLayer(nn.Module): def __init__(self, config: BailingMoeV2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.layer_idx = layer_idx self.attention = BailingMoeV2Attention(config=config, layer_idx=layer_idx) self.mlp = ( BailingMoeV2SparseMoeBlock(config) if (config.num_experts is not None and layer_idx >= config.first_k_dense_replace) else BailingMoeV2MLP(config=config, intermediate_size=config.intermediate_size) ) self.input_layernorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, # prepared causal mask position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, # your MOE doesn't return router logits; kept for API use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[ torch.Tensor, Optional[torch.Tensor], Optional[Cache], torch.Tensor, Optional[torch.Tensor], ]: quantize_inplace_now = bool(kwargs.get("quantize_inplace_now", False)) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) attn_out, self_attn_weights, present_key_value = self.attention( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=bool(output_attentions), use_cache=bool(use_cache), cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + attn_out residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) mlp_out = self.mlp(hidden_states, quantize_inplace_now=quantize_inplace_now) if isinstance(mlp_out, tuple): hidden_states, aux_loss = mlp_out else: hidden_states, aux_loss = mlp_out, 0.0 hidden_states = residual + hidden_states.to(residual.device) # Your MOE path does not provide router logits; keep placeholder. router_logits = None return ( hidden_states, self_attn_weights, present_key_value, aux_loss, router_logits, ) BAILINGMOEV2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. """ @add_start_docstrings( "The bare BailingMoeV2 Model outputting raw hidden-states without any specific head on top.", BAILINGMOEV2_START_DOCSTRING, ) class BailingMoeV2PreTrainedModel(PreTrainedModel): config_class = BailingMoeV2Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["BailingMoeV2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_attention_backend = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() BAILINGMOEV2_INPUTS_DOCSTRING = r"""NA""" @add_start_docstrings( "The bare BailingMoeV2 Model outputting raw hidden-states without any specific head on top.", BAILINGMOEV2_START_DOCSTRING, ) class BailingMoeV2Model(BailingMoeV2PreTrainedModel): def __init__(self, config: BailingMoeV2Config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.num_nextn_predict_layers = getattr(config, "num_nextn_predict_layers", 0) self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) layers = [] for layer_idx in range(config.num_hidden_layers + self.num_nextn_predict_layers): # NOTE: your prototype referenced BailingMoeV2MTPLayer but didn't include it. # Keep behavior: only decoder layers here unless you add MTP layers yourself. if layer_idx < config.num_hidden_layers: layers.append(BailingMoeV2DecoderLayer(config, layer_idx)) else: raise NotImplementedError("BailingMoeV2MTPLayer not included in this prototype file.") self.layers = nn.ModuleList(layers) self.config = config self.norm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = BailingMoeV2RotaryEmbedding(config=config) self.gradient_checkpointing = False self._cache_debug_calls = 0 self.post_init() def get_input_embeddings(self): return self.word_embeddings def set_input_embeddings(self, value): self.word_embeddings = value def quantize_inplace(self, verbose: bool = False) -> int: """ Quantize this base model in-place once (for inference). Returns the number of tensors newly quantized in this call. """ if not self.config.quantize: if verbose: print("[quantize-inplace] config.quantize is False; nothing to do") return 0 quantized_count = 0 for layer in self.layers: # Attention projections. quantized_count += int(quantize_weight_inplace(layer.attention.query_key_value, "weight", enabled=True)) quantized_count += int(quantize_weight_inplace(layer.attention.dense, "weight", enabled=True)) # MLP path: dense MLP or MoE experts (+ optional shared experts). if isinstance(layer.mlp, BailingMoeV2MLP): quantized_count += int(quantize_weight_inplace(layer.mlp.down_proj, "weight", enabled=True)) quantized_count += int(quantize_weight_inplace(layer.mlp.gate_proj, "weight", enabled=True)) quantized_count += int(quantize_weight_inplace(layer.mlp.up_proj, "weight", enabled=True)) elif isinstance(layer.mlp, LingSonicMoe): quantized_count += int(quantize_weight_inplace(layer.mlp.experts, "gate_up_proj", enabled=True)) quantized_count += int(quantize_weight_inplace(layer.mlp.experts, "down_proj", enabled=True)) if hasattr(layer.mlp, "shared_experts"): quantized_count += int( quantize_weight_inplace(layer.mlp.shared_experts.down_proj, "weight", enabled=True) ) quantized_count += int( quantize_weight_inplace(layer.mlp.shared_experts.gate_proj, "weight", enabled=True) ) quantized_count += int( quantize_weight_inplace(layer.mlp.shared_experts.up_proj, "weight", enabled=True) ) if verbose: print(f"[quantize-inplace] newly quantized tensors: {quantized_count}") return quantized_count def prepare_fa2_from_position_ids(self, position_ids: torch.Tensor): position_ids = position_ids.flatten() T = position_ids.numel() indices_q = torch.arange(T, device=position_ids.device, dtype=torch.int32) starts = indices_q[position_ids == 0] # If no segment-start markers exist (common in decoding where pos ids are offset), # treat as a single sequence. if starts.numel() == 0: cu_seq_lens = torch.tensor([0, T], device=position_ids.device, dtype=torch.int32) else: # ensure boundaries valid if starts[0].item() != 0: starts = torch.cat([starts.new_zeros(1), starts], dim=0) if starts[-1].item() != T: starts = torch.cat([starts, starts.new_tensor([T])], dim=0) cu_seq_lens = starts max_length = (cu_seq_lens[1:] - cu_seq_lens[:-1]).max().item() return (indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) # def prepare_fa2_from_position_ids(self, position_ids: torch.Tensor): # position_ids = position_ids.flatten() # indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) # cu_seq_lens = torch.cat( # ( # indices_q[position_ids == 0], # torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), # ) # ) # # max_length在不同的model里面type不同 # # modeling_qwen3_moe_foundation/modeling_qwen2_5_omni里为tensor # # modeling_qwen2_vl的为int # # 此处采用有.item()的写法,在decoder layers之前拿到int type的max_length # # 否则在decoder里面仍然每一层都会触发.item() # max_length = cu_seq_lens.diff().max().item() # return (indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) @add_start_docstrings_to_model_forward(BAILINGMOEV2_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, # 2D padding mask (B, S) coming in position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, MoeV2ModelOutputWithPast]: debug_cache = bool(kwargs.pop("debug_cache", False)) if debug_cache: print(f"Debug cache enabled for call {self._cache_debug_calls}") debug_call_id = self._cache_debug_calls self._cache_debug_calls += 1 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # exactly one of input_ids / inputs_embeds if (input_ids is None) == (inputs_embeds is None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False if use_cache and past_key_values is None: past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) # SGLang transformers backend passes `forward_batch`; use it to identify # decode mode (can have S>1 tokens due to token packing) and avoid # decode-only dynamic metadata that harms CUDA graph capture. forward_batch = kwargs.get("forward_batch", None) is_decode_step = False forward_mode = getattr(forward_batch, "forward_mode", None) if forward_batch is not None else None if forward_mode is not None: for mode_name in ( "is_decode", "is_decode_or_idle", "is_target_verify", "is_draft_decode", ): mode_fn = getattr(forward_mode, mode_name, None) if callable(mode_fn) and bool(mode_fn()): is_decode_step = True break # Perform one-time in-place weight quantization during prefill (S > 1), # then reuse the mutated weights for decode without extra memory cache. kwargs["quantize_inplace_now"] = bool( self.config.quantize and (not self.training) and (not is_decode_step) and inputs_embeds.shape[1] > 1 ) past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device, ) if position_ids is not None: # For bsh cases, expand [1, S] position_ids to [B, S] before FA2 metadata prep. batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] if position_ids.shape[0] != batch_size: position_ids = position_ids.expand(batch_size, -1) # Decode does not need cu_seq_lens/max_length metadata and creating # them every step hurts CUDA graph capture stability. if (not is_decode_step) and inputs_embeds.shape[1] > 1: _, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = self.prepare_fa2_from_position_ids( position_ids ) kwargs["cu_seq_lens_q"] = cu_seq_lens_q kwargs["cu_seq_lens_k"] = cu_seq_lens_k kwargs["max_length_q"] = max_length_q kwargs["max_length_k"] = max_length_k if position_ids is None: position_ids = cache_position.unsqueeze(0) # IMPORTANT: build prepared causal_mask and pass it into layers (NOT raw attention_mask) # mask_function = create_causal_mask # swap to create_sliding_window_causal_mask if you enable sliding window # causal_mask = mask_function( # config=self.config, # input_embeds=inputs_embeds, # attention_mask=attention_mask, # cache_position=cache_position, # past_key_values=past_key_values, # position_ids=position_ids, # ) # TODO: Im just disabling causal mask right now idk fix this later when we need SWA causal_mask = None hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) # if self.config.hc: # hidden_states = self.expand_streams(hidden_states) all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None aux_loss_sum = 0.0 for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, output_router_logits, use_cache, cache_position, position_embeddings, **kwargs, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, # <-- FIXED position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, # <-- FIXED position_embeddings=position_embeddings, **kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) # aux loss is at index 3 in our layer return aux_loss_sum = aux_loss_sum + layer_outputs[3] if output_router_logits: all_router_logits += (layer_outputs[4],) hidden_states = self.norm(hidden_states) if debug_cache: past_after_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_start = int(cache_position[0].item()) if cache_position.numel() > 0 else -1 cache_end = int(cache_position[-1].item()) if cache_position.numel() > 0 else -1 cache_hit = bool(use_cache and inputs_embeds.shape[1] == 1 and past_seen_tokens > 0) print( "[cache-debug] " f"call={debug_call_id} use_cache={use_cache} " f"input_len={inputs_embeds.shape[1]} " f"past_before={past_seen_tokens} past_after={past_after_tokens} " f"cache_pos=[{cache_start},{cache_end}] " f"cache_hit_expected={cache_hit}" ) if output_hidden_states: all_hidden_states += (hidden_states,) moe_layer_count = len(self.layers) - 1 out = MoeV2ModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, aux_loss=aux_loss_sum / moe_layer_count, # keeping your prototype behavior ) return ( out if return_dict else ( out.last_hidden_state, out.past_key_values, out.hidden_states, out.attentions, ) ) class BailingMoeV2ForCausalLM(BailingMoeV2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: BailingMoeV2Config): super().__init__(config) self.model = BailingMoeV2Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.router_aux_loss_coef = 0.001 self.post_init() def get_input_embeddings(self): return self.model.word_embeddings def set_input_embeddings(self, value): self.model.word_embeddings = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def quantize_inplace(self, verbose: bool = False) -> int: """ Quantize model (and lm_head) in-place once for inference. Returns the number of tensors newly quantized in this call. """ quantized_count = self.model.quantize_inplace(verbose=verbose) if self.config.quantize: quantized_count += int(quantize_weight_inplace(self.lm_head, "weight", enabled=True)) if verbose: print(f"[quantize-inplace] total newly quantized tensors (with lm_head): {quantized_count}") return quantized_count # @add_start_docstrings_to_model_forward(BAILINGMOEV2_INPUTS_DOCSTRING) # @replace_return_docstrings(output_type=MoEV2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple, MoEV2CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=True, # ensure attribute access **kwargs, ) hidden_states = outputs.last_hidden_state assert isinstance(hidden_states, torch.Tensor) # slice logits if requested loss = None logits = None if labels is not None: loss, logits = self.loss_function(hidden_states, self.lm_head.weight, labels) else: logits = self.lm_head(hidden_states) out = MoEV2CausalLMOutputWithPast( loss=loss, aux_loss=getattr(outputs, "aux_loss", 0.0), logits=logits, past_key_values=outputs.past_key_values if hasattr(outputs, "past_key_values") else None, hidden_states=outputs.hidden_states if hasattr(outputs, "hidden_states") else None, attentions=outputs.attentions if hasattr(outputs, "attentions") else None, router_logits=outputs.router_logits if hasattr(outputs, "router_logits") else None, ) return out ModelClass = BailingMoeV2ForCausalLM __all__ = [ "BailingMoeV2ForCausalLM", "BailingMoeV2Model", "BailingMoeV2PreTrainedModel", ]