| | |
| | import math |
| | from dataclasses import dataclass |
| | from typing import Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import triton |
| | import triton.language as tl |
| | from torch import nn |
| | from torch.library import triton_op, wrap_triton |
| | from transformers.activations import ACT2FN |
| | from transformers.cache_utils import Cache, DynamicCache |
| | from transformers.generation.utils import GenerationMixin |
| | from transformers.modeling_outputs import MoeModelOutputWithPast |
| | from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| | from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13 |
| | from transformers.utils import ( |
| | ModelOutput, |
| | add_start_docstrings, |
| | add_start_docstrings_to_model_forward, |
| | ) |
| | from transformers.utils import logging as hf_logging |
| | from transformers.utils.import_utils import is_torch_fx_available |
| |
|
| | from .configuration_bailing_moe_v2 import BailingMoeV2Config |
| |
|
| |
|
| | logger = hf_logging.get_logger(__name__) |
| | _CONFIG_FOR_DOC = "BailingMoeV2Config" |
| |
|
| |
|
| | |
| | if is_torch_fx_available(): |
| | if not is_torch_greater_or_equal_than_1_13: |
| | import torch.fx |
| |
|
| |
|
| | |
| | def twn_torch_ref(W): |
| | W_fp = W.float() |
| | dim = -1 |
| | absW = W_fp.abs() |
| | th = absW.mean(dim, keepdim=True) * 0.7 |
| | mask = absW > th |
| | mask_f = mask.float() |
| | alpha = (absW * mask_f).sum(dim, keepdim=True) / mask_f.sum(dim, keepdim=True).clamp(min=1.0) |
| | out = W_fp.sign() * mask_f * alpha |
| | return out.to(W.dtype) |
| |
|
| |
|
| | twn_torch_compiled = torch.compile(twn_torch_ref, mode="max-autotune") |
| |
|
| |
|
| | @triton.autotune( |
| | configs=[ |
| | triton.Config({"BLOCK_SIZE": 128}, num_warps=4, num_stages=3), |
| | triton.Config({"BLOCK_SIZE": 256}, num_warps=4, num_stages=3), |
| | triton.Config({"BLOCK_SIZE": 512}, num_warps=8, num_stages=3), |
| | triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=3), |
| | triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=3), |
| | ], |
| | key=["N"], |
| | ) |
| | @triton.jit |
| | def twn_quant_row_merged_bf16_kernel( |
| | w_ptr, |
| | out_ptr, |
| | M, |
| | N, |
| | stride_wm, |
| | stride_wn, |
| | stride_om, |
| | stride_on, |
| | BLOCK_SIZE: tl.constexpr, |
| | ): |
| | pid = tl.program_id(0) |
| | if pid >= M: |
| | return |
| |
|
| | row_w_ptr = w_ptr + pid * stride_wm |
| | row_out_ptr = out_ptr + pid * stride_om |
| |
|
| | |
| | sum_abs = 0.0 |
| | count = 0.0 |
| | for off in range(0, N, BLOCK_SIZE): |
| | cols = off + tl.arange(0, BLOCK_SIZE) |
| | mask = cols < N |
| | val = tl.load(row_w_ptr + cols * stride_wn, mask=mask, other=0.0).to(tl.float32) |
| | val_abs = tl.abs(val) |
| | sum_abs += tl.sum(val_abs, axis=0) |
| | count += tl.sum(mask.to(tl.float32), axis=0) |
| |
|
| | th = (sum_abs / tl.maximum(count, 1.0)) * 0.7 |
| |
|
| | |
| | masked_sum = 0.0 |
| | masked_count = 0.0 |
| | for off in range(0, N, BLOCK_SIZE): |
| | cols = off + tl.arange(0, BLOCK_SIZE) |
| | mask = cols < N |
| | val = tl.load(row_w_ptr + cols * stride_wn, mask=mask, other=0.0).to(tl.float32) |
| | val_abs = tl.abs(val) |
| | is_selected = (val_abs > th).to(tl.float32) |
| | masked_sum += tl.sum(val_abs * is_selected, axis=0) |
| | masked_count += tl.sum(is_selected, axis=0) |
| |
|
| | alpha = masked_sum / tl.maximum(masked_count, 1.0) |
| |
|
| | |
| | for off in range(0, N, BLOCK_SIZE): |
| | cols = off + tl.arange(0, BLOCK_SIZE) |
| | mask = cols < N |
| | val = tl.load(row_w_ptr + cols * stride_wn, mask=mask, other=0.0).to(tl.float32) |
| | is_selected = tl.abs(val) > th |
| |
|
| | |
| | sign = tl.where(val >= 0, alpha, -alpha) |
| | out_val = tl.where(is_selected, sign, 0.0) |
| |
|
| | tl.store(row_out_ptr + cols * stride_on, out_val.to(tl.bfloat16), mask=mask) |
| |
|
| |
|
| | @triton_op("grove_kernels::twn_triton", mutates_args={}) |
| | def twn_triton(W: torch.Tensor) -> torch.Tensor: |
| | M, N = W.shape |
| | out = torch.empty_like(W, dtype=torch.bfloat16) |
| | grid = (M,) |
| | wrap_triton(twn_quant_row_merged_bf16_kernel)[grid]( |
| | W, |
| | out, |
| | M, |
| | N, |
| | W.stride(0), |
| | W.stride(1), |
| | out.stride(0), |
| | out.stride(1), |
| | ) |
| | return out |
| |
|
| |
|
| | class QuantizeTernary(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, input): |
| | |
| | if len(input.shape) == 3: |
| | return twn_torch_ref(input) |
| | else: |
| | return twn_triton(input) |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | |
| | return grad_output, None |
| |
|
| |
|
| | def quantize(input: torch.Tensor) -> torch.Tensor: |
| | return QuantizeTernary.apply(input) |
| |
|
| |
|
| | def conditionally_quantize(input: torch.Tensor, do_quantize: bool) -> torch.Tensor: |
| | if do_quantize: |
| | return quantize(input) |
| | else: |
| | return input |
| |
|
| |
|
| | def quantize_weight_inplace(owner: nn.Module, weight_name: str, enabled: bool = True) -> bool: |
| | """ |
| | Quantize `owner.<weight_name>` once and write back to the same Parameter storage. |
| | Returns True if this call performed quantization, False if skipped/already-done. |
| | """ |
| | if not enabled: |
| | return False |
| | done_attr = f"__inplace_quantized_{weight_name}" |
| | if bool(getattr(owner, done_attr, False)): |
| | return False |
| |
|
| | weight = getattr(owner, weight_name) |
| | with torch.no_grad(): |
| | quantized = quantize(weight).to(device=weight.device, dtype=weight.dtype) |
| | weight.data.copy_(quantized) |
| | setattr(owner, done_attr, True) |
| | return True |
| |
|
| |
|
| | def conditionally_quantize_inplace_on_prefill( |
| | owner: nn.Module, |
| | weight_name: str, |
| | do_quantize: bool, |
| | *, |
| | quantize_inplace_now: bool = False, |
| | ) -> torch.Tensor: |
| | """ |
| | In eval mode, quantize the target weight once (during prefill) and write it back in-place. |
| | This avoids storing duplicate cached tensors while removing per-token quantization overhead. |
| | """ |
| | weight = getattr(owner, weight_name) |
| | if not do_quantize: |
| | return weight |
| | if owner.training: |
| | return quantize(weight) |
| |
|
| | if not quantize_inplace_now: |
| | return weight |
| | quantize_weight_inplace(owner, weight_name, enabled=True) |
| | return weight |
| |
|
| |
|
| | @dataclass |
| | class MoEV2CausalLMOutputWithPast(ModelOutput): |
| | loss: Optional[torch.FloatTensor] = None |
| | logits: Optional[torch.FloatTensor] = None |
| | past_key_values: Optional[Cache] = None |
| | hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None |
| | attentions: Optional[tuple[torch.FloatTensor, ...]] = None |
| | z_loss: Optional[torch.FloatTensor] = None |
| | aux_loss: Optional[torch.FloatTensor] = None |
| | router_logits: Optional[tuple[torch.FloatTensor]] = None |
| | mtp_loss: Optional[torch.FloatTensor] = None |
| | mtp_logits: Optional[tuple[torch.FloatTensor, ...]] = None |
| |
|
| |
|
| | class MoeV2ModelOutputWithPast(MoeModelOutputWithPast): |
| | def __init__(self, mtp_hidden_states=None, aux_loss=0.0, **kwargs): |
| | super().__init__(**kwargs) |
| | self.mtp_hidden_states = mtp_hidden_states |
| | self.aux_loss = aux_loss |
| |
|
| |
|
| | class BailingMoeV2RotaryEmbedding(nn.Module): |
| | def __init__(self, config: BailingMoeV2Config, device=None): |
| | super().__init__() |
| | if hasattr(config, "rope_scaling") and config.rope_scaling is not None: |
| | self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
| | else: |
| | self.rope_type = "default" |
| | self.max_seq_len_cached = config.max_position_embeddings |
| | self.original_max_seq_len = config.max_position_embeddings |
| |
|
| | self.config = config |
| | self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
| |
|
| | inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
| | self.register_buffer("inv_freq", inv_freq, persistent=False) |
| | self.original_inv_freq = self.inv_freq |
| |
|
| | @torch.no_grad() |
| | @dynamic_rope_update |
| | def forward(self, x, position_ids): |
| | inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
| | position_ids_expanded = position_ids[:, None, :].float() |
| |
|
| | device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| | with torch.autocast(device_type=device_type, enabled=False): |
| | freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| | emb = torch.cat((freqs, freqs), dim=-1) |
| | cos = emb.cos() * self.attention_scaling |
| | sin = emb.sin() * self.attention_scaling |
| | freqs = torch.cat([freqs, freqs], dim=-1) |
| | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype), freqs.float() |
| |
|
| |
|
| | def rotate_half(x): |
| | x1 = x[..., : x.shape[-1] // 2] |
| | x2 = x[..., x.shape[-1] // 2 :] |
| | return torch.cat((-x2, x1), dim=-1) |
| |
|
| |
|
| | def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): |
| | cos = cos.unsqueeze(unsqueeze_dim) |
| | sin = sin.unsqueeze(unsqueeze_dim) |
| |
|
| | rotary_dim = cos.shape[-1] |
| | q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] |
| | k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] |
| |
|
| | q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) |
| | k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) |
| |
|
| | q_embed = torch.cat([q_embed, q_pass], dim=-1) |
| | k_embed = torch.cat([k_embed, k_pass], dim=-1) |
| | return q_embed, k_embed |
| |
|
| |
|
| | class BailingMoeV2MLP(nn.Module): |
| | def __init__(self, config: BailingMoeV2Config, intermediate_size: int): |
| | super().__init__() |
| | self.config = config |
| | self.hidden_size = config.hidden_size |
| | self.intermediate_size = intermediate_size |
| |
|
| | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
| | self.act_fn = ACT2FN[config.hidden_act] |
| |
|
| | def forward(self, x, quantize_inplace_now: bool = False): |
| | down_weight, gate_weight, up_weight = ( |
| | conditionally_quantize_inplace_on_prefill( |
| | self.down_proj, |
| | "weight", |
| | self.config.quantize, |
| | quantize_inplace_now=quantize_inplace_now, |
| | ), |
| | conditionally_quantize_inplace_on_prefill( |
| | self.gate_proj, |
| | "weight", |
| | self.config.quantize, |
| | quantize_inplace_now=quantize_inplace_now, |
| | ), |
| | conditionally_quantize_inplace_on_prefill( |
| | self.up_proj, |
| | "weight", |
| | self.config.quantize, |
| | quantize_inplace_now=quantize_inplace_now, |
| | ), |
| | ) |
| | return torch.nn.functional.linear( |
| | self.act_fn(torch.nn.functional.linear(x, gate_weight)) * torch.nn.functional.linear(x, up_weight), |
| | down_weight, |
| | ) |
| |
|
| |
|
| | class BailingMoeV2RMSNorm(nn.Module): |
| | def __init__(self, hidden_size, eps=1e-6): |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.ones(hidden_size)) |
| | self.variance_epsilon = eps |
| |
|
| | def forward(self, hidden_states): |
| | input_dtype = hidden_states.dtype |
| | hidden_states = hidden_states.to(torch.float32) |
| | variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| | return self.weight * hidden_states.to(input_dtype) |
| |
|
| |
|
| | try: |
| | from liger_kernel.transformers.rms_norm import LigerRMSNorm |
| |
|
| | BailingMoeV2RMSNorm = LigerRMSNorm |
| | except: |
| | print("no liger kernel") |
| |
|
| |
|
| | class BailingMoeV2Gate(nn.Module): |
| | expert_bias: torch.Tensor |
| |
|
| | def __init__(self, config: BailingMoeV2Config): |
| | super().__init__() |
| | self.config = config |
| | self.top_k = config.num_experts_per_tok |
| | self.num_experts = config.num_experts |
| |
|
| | self.n_group = config.n_group |
| | self.topk_group = config.topk_group |
| |
|
| | self.gating_dim = config.hidden_size |
| | self.weight = nn.Parameter(torch.empty((self.num_experts, self.gating_dim))) |
| | |
| | self.routed_scaling_factor = config.routed_scaling_factor |
| |
|
| | self.register_buffer("expert_bias", torch.zeros((self.num_experts))) |
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self) -> None: |
| | import torch.nn.init as init |
| |
|
| | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
| |
|
| | def group_limited_topk(self, scores: torch.Tensor): |
| | num_tokens, _ = scores.size() |
| | group_scores = scores.view(num_tokens, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) |
| | group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] |
| | group_mask = torch.zeros_like(group_scores) |
| | group_mask.scatter_(1, group_idx, 1) |
| |
|
| | score_mask = ( |
| | group_mask.unsqueeze(-1) |
| | .expand(num_tokens, self.n_group, self.num_experts // self.n_group) |
| | .reshape(num_tokens, -1) |
| | ) |
| |
|
| | masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) |
| | probs, top_indices = torch.topk(masked_scores, k=self.top_k, dim=-1) |
| | return probs, top_indices |
| |
|
| | def forward(self, hidden_states: torch.Tensor): |
| | hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) |
| | logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) |
| |
|
| | scores = torch.sigmoid(logits.float()).type_as(logits) |
| | scores_for_routing = scores + self.expert_bias |
| | _, topk_idx = self.group_limited_topk(scores_for_routing) |
| |
|
| | scores = torch.gather(scores, dim=1, index=topk_idx).type_as(logits) |
| | topk_weight = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.top_k > 1 else scores |
| | topk_weight = topk_weight * self.routed_scaling_factor |
| |
|
| | return topk_idx, topk_weight, logits |
| |
|
| |
|
| | class BailingMoeV2SparseMoeBlock(nn.Module): |
| | """ |
| | Unfused MoE block matching Ling-mini HF layout (ModuleList experts). |
| | """ |
| |
|
| | def __init__(self, config) -> None: |
| | super().__init__() |
| | self.config = config |
| | self.num_experts_per_tok = config.num_experts_per_tok |
| | self._setup_experts() |
| | self.gate = BailingMoeV2Gate(config) |
| | if config.num_shared_experts is not None: |
| | self.shared_experts = BailingMoeV2MLP( |
| | config=config, |
| | intermediate_size=config.moe_intermediate_size * config.num_shared_experts, |
| | ) |
| |
|
| | def _setup_experts(self): |
| | self.experts = nn.ModuleList( |
| | [ |
| | BailingMoeV2MLP( |
| | config=self.config, |
| | intermediate_size=self.config.moe_intermediate_size, |
| | ) |
| | for _ in range(self.config.num_experts) |
| | ] |
| | ) |
| |
|
| | def forward( |
| | self, hidden_states: torch.Tensor, quantize_inplace_now: bool = False |
| | ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: |
| | original_shape = hidden_states.shape |
| | identity = hidden_states |
| |
|
| | bsz, seq_len, h = hidden_states.shape |
| | topk_idx, topk_weight, router_logits = self.gate(hidden_states) |
| | hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) |
| | flat_topk_idx = topk_idx.view(-1) |
| |
|
| | if self.training: |
| | hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0) |
| | y = torch.empty_like(hidden_states) |
| | for i, expert in enumerate(self.experts): |
| | y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) |
| | y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) |
| | y = y.to(hidden_states.dtype).view(bsz, seq_len, h) |
| | else: |
| | y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(bsz, seq_len, h) |
| |
|
| | if self.config.num_shared_experts is not None: |
| | y = y + self.shared_experts(identity) |
| |
|
| | return y |
| |
|
| | @torch.no_grad() |
| | def moe_infer(self, x, topk_ids, topk_weight): |
| | cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) |
| | cnts.scatter_(1, topk_ids, 1) |
| | tokens_per_expert = cnts.sum(dim=0) |
| | idxs = topk_ids.view(-1).argsort() |
| | sorted_tokens = x[idxs // topk_ids.shape[1]] |
| | tokens_per_expert = tokens_per_expert.cpu().numpy() |
| | outputs = [] |
| | start_idx = 0 |
| | for i, num_tokens in enumerate(tokens_per_expert): |
| | end_idx = start_idx + num_tokens |
| | if num_tokens == 0: |
| | continue |
| | expert = self.experts[i] |
| | tokens_for_this_expert = sorted_tokens[start_idx:end_idx] |
| | expert_out = expert(tokens_for_this_expert) |
| | outputs.append(expert_out.to(x.device)) |
| | start_idx = end_idx |
| |
|
| | outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) |
| | new_x = torch.empty_like(outs) |
| | new_x[idxs] = outs |
| | final_out = ( |
| | new_x.view(*topk_ids.shape, -1) |
| | .type(topk_weight.dtype) |
| | .mul_(topk_weight.unsqueeze(dim=-1)) |
| | .sum(dim=1) |
| | .type(new_x.dtype) |
| | ) |
| | return final_out |
| |
|
| |
|
| | class BailingMoeV2Attention(nn.Module): |
| | """Fixed wiring for modern HF attention APIs: uses prepared causal_mask + cache_position + Cache.update().""" |
| |
|
| | def __init__(self, config: BailingMoeV2Config, layer_idx: Optional[int] = None): |
| | super().__init__() |
| | self.config = config |
| | self.layer_idx = layer_idx |
| | if layer_idx is None: |
| | logger.warning_once( |
| | f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " |
| | "lead to errors during the forward call if caching is used. Please pass `layer_idx`." |
| | ) |
| |
|
| | self.attention_dropout = config.attention_dropout |
| | self.hidden_size = config.hidden_size |
| | self.num_heads = config.num_attention_heads |
| | self.head_dim = config.head_dim or self.hidden_size // self.num_heads |
| | self.scaling = self.head_dim**-0.5 |
| |
|
| | partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) |
| | self.rope_dim = int(self.head_dim * partial_rotary_factor) |
| |
|
| | self.num_key_value_heads = config.num_key_value_heads |
| | self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| | self.is_causal = True |
| |
|
| | self.sliding_window = None |
| | self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None |
| | self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None |
| |
|
| | self.query_key_value = nn.Linear( |
| | self.hidden_size, |
| | (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, |
| | bias=config.use_qkv_bias, |
| | ) |
| |
|
| | if self.config.use_qk_norm: |
| | self.query_layernorm = BailingMoeV2RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| | self.key_layernorm = BailingMoeV2RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| |
|
| | self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_value: Optional[Cache] = None, |
| | output_attentions: bool = False, |
| | use_cache: bool = False, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, |
| | **kwargs, |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: |
| | quantize_inplace_now = bool(kwargs.pop("quantize_inplace_now", False)) |
| | bsz, q_len, _ = hidden_states.size() |
| | qkv_weight = conditionally_quantize_inplace_on_prefill( |
| | self.query_key_value, |
| | "weight", |
| | self.config.quantize, |
| | quantize_inplace_now=quantize_inplace_now, |
| | ) |
| | out_qkv = torch.nn.functional.linear(hidden_states, qkv_weight) |
| | cos, sin, _freqs = position_embeddings |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | qkv = out_qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) |
| |
|
| | query_states, key_states, value_states = qkv.split( |
| | [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2 |
| | ) |
| | query_states = query_states.transpose(1, 2) |
| | key_states = key_states.transpose(1, 2) |
| | value_states = value_states.transpose(1, 2) |
| |
|
| | if self.config.use_qk_norm: |
| | query_states = self.query_layernorm(query_states) |
| | key_states = self.key_layernorm(key_states) |
| | if self.sliding_window is not None: |
| | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
| |
|
| | |
| | if use_cache and past_key_value is not None: |
| | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
| |
|
| | |
| | attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
| | attn_output, attn_weights = attention_interface( |
| | self, |
| | query_states, |
| | key_states, |
| | value_states, |
| | attention_mask, |
| | dropout=0.0, |
| | position_ids=position_ids, |
| | scaling=self.scaling, |
| | sliding_window=self.sliding_window, |
| | **kwargs, |
| | ) |
| |
|
| | attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() |
| | dense_weight = conditionally_quantize_inplace_on_prefill( |
| | self.dense, |
| | "weight", |
| | self.config.quantize, |
| | quantize_inplace_now=quantize_inplace_now, |
| | ) |
| | attn_output = torch.nn.functional.linear(attn_output, dense_weight) |
| |
|
| | if not output_attentions: |
| | attn_weights = None |
| |
|
| | return attn_output, attn_weights, past_key_value |
| |
|
| |
|
| | class BailingMoeV2DecoderLayer(nn.Module): |
| | def __init__(self, config: BailingMoeV2Config, layer_idx: int): |
| | super().__init__() |
| | self.hidden_size = config.hidden_size |
| | self.layer_idx = layer_idx |
| |
|
| | self.attention = BailingMoeV2Attention(config=config, layer_idx=layer_idx) |
| |
|
| | self.mlp = ( |
| | BailingMoeV2SparseMoeBlock(config) |
| | if (config.num_experts is not None and layer_idx >= config.first_k_dense_replace) |
| | else BailingMoeV2MLP(config=config, intermediate_size=config.intermediate_size) |
| | ) |
| |
|
| | self.input_layernorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| | self.post_attention_layernorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_value: Optional[Cache] = None, |
| | output_attentions: Optional[bool] = False, |
| | output_router_logits: Optional[bool] = False, |
| | use_cache: Optional[bool] = False, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, |
| | **kwargs, |
| | ) -> Tuple[ |
| | torch.Tensor, |
| | Optional[torch.Tensor], |
| | Optional[Cache], |
| | torch.Tensor, |
| | Optional[torch.Tensor], |
| | ]: |
| | quantize_inplace_now = bool(kwargs.get("quantize_inplace_now", False)) |
| | residual = hidden_states |
| | hidden_states = self.input_layernorm(hidden_states) |
| |
|
| | attn_out, self_attn_weights, present_key_value = self.attention( |
| | hidden_states=hidden_states, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_value=past_key_value, |
| | output_attentions=bool(output_attentions), |
| | use_cache=bool(use_cache), |
| | cache_position=cache_position, |
| | position_embeddings=position_embeddings, |
| | **kwargs, |
| | ) |
| | hidden_states = residual + attn_out |
| |
|
| | residual = hidden_states |
| | hidden_states = self.post_attention_layernorm(hidden_states) |
| |
|
| | mlp_out = self.mlp(hidden_states, quantize_inplace_now=quantize_inplace_now) |
| | if isinstance(mlp_out, tuple): |
| | hidden_states, aux_loss = mlp_out |
| | else: |
| | hidden_states, aux_loss = mlp_out, 0.0 |
| |
|
| | hidden_states = residual + hidden_states.to(residual.device) |
| |
|
| | |
| | router_logits = None |
| |
|
| | return ( |
| | hidden_states, |
| | self_attn_weights, |
| | present_key_value, |
| | aux_loss, |
| | router_logits, |
| | ) |
| |
|
| |
|
| | BAILINGMOEV2_START_DOCSTRING = r""" |
| | This model inherits from [`PreTrainedModel`]. |
| | """ |
| |
|
| |
|
| | @add_start_docstrings( |
| | "The bare BailingMoeV2 Model outputting raw hidden-states without any specific head on top.", |
| | BAILINGMOEV2_START_DOCSTRING, |
| | ) |
| | class BailingMoeV2PreTrainedModel(PreTrainedModel): |
| | config_class = BailingMoeV2Config |
| | base_model_prefix = "model" |
| | supports_gradient_checkpointing = True |
| | _no_split_modules = ["BailingMoeV2DecoderLayer"] |
| | _skip_keys_device_placement = "past_key_values" |
| | _supports_attention_backend = True |
| | _supports_flash_attn_2 = True |
| | _supports_sdpa = True |
| | _supports_cache_class = True |
| |
|
| | def _init_weights(self, module): |
| | std = self.config.initializer_range |
| | if isinstance(module, nn.Linear): |
| | module.weight.data.normal_(mean=0.0, std=std) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.Embedding): |
| | module.weight.data.normal_(mean=0.0, std=std) |
| | if module.padding_idx is not None: |
| | module.weight.data[module.padding_idx].zero_() |
| |
|
| |
|
| | BAILINGMOEV2_INPUTS_DOCSTRING = r"""NA""" |
| |
|
| |
|
| | @add_start_docstrings( |
| | "The bare BailingMoeV2 Model outputting raw hidden-states without any specific head on top.", |
| | BAILINGMOEV2_START_DOCSTRING, |
| | ) |
| | class BailingMoeV2Model(BailingMoeV2PreTrainedModel): |
| | def __init__(self, config: BailingMoeV2Config): |
| | super().__init__(config) |
| | self.padding_idx = config.pad_token_id |
| | self.vocab_size = config.vocab_size |
| | self.num_nextn_predict_layers = getattr(config, "num_nextn_predict_layers", 0) |
| |
|
| | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| |
|
| | layers = [] |
| | for layer_idx in range(config.num_hidden_layers + self.num_nextn_predict_layers): |
| | |
| | |
| | if layer_idx < config.num_hidden_layers: |
| | layers.append(BailingMoeV2DecoderLayer(config, layer_idx)) |
| | else: |
| | raise NotImplementedError("BailingMoeV2MTPLayer not included in this prototype file.") |
| | self.layers = nn.ModuleList(layers) |
| | self.config = config |
| |
|
| | self.norm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| | self.rotary_emb = BailingMoeV2RotaryEmbedding(config=config) |
| | self.gradient_checkpointing = False |
| | self._cache_debug_calls = 0 |
| | self.post_init() |
| |
|
| | def get_input_embeddings(self): |
| | return self.word_embeddings |
| |
|
| | def set_input_embeddings(self, value): |
| | self.word_embeddings = value |
| |
|
| | def quantize_inplace(self, verbose: bool = False) -> int: |
| | """ |
| | Quantize this base model in-place once (for inference). |
| | Returns the number of tensors newly quantized in this call. |
| | """ |
| | if not self.config.quantize: |
| | if verbose: |
| | print("[quantize-inplace] config.quantize is False; nothing to do") |
| | return 0 |
| |
|
| | quantized_count = 0 |
| | for layer in self.layers: |
| | |
| | quantized_count += int(quantize_weight_inplace(layer.attention.query_key_value, "weight", enabled=True)) |
| | quantized_count += int(quantize_weight_inplace(layer.attention.dense, "weight", enabled=True)) |
| |
|
| | |
| | if isinstance(layer.mlp, BailingMoeV2MLP): |
| | quantized_count += int(quantize_weight_inplace(layer.mlp.down_proj, "weight", enabled=True)) |
| | quantized_count += int(quantize_weight_inplace(layer.mlp.gate_proj, "weight", enabled=True)) |
| | quantized_count += int(quantize_weight_inplace(layer.mlp.up_proj, "weight", enabled=True)) |
| | elif isinstance(layer.mlp, LingSonicMoe): |
| | quantized_count += int(quantize_weight_inplace(layer.mlp.experts, "gate_up_proj", enabled=True)) |
| | quantized_count += int(quantize_weight_inplace(layer.mlp.experts, "down_proj", enabled=True)) |
| | if hasattr(layer.mlp, "shared_experts"): |
| | quantized_count += int( |
| | quantize_weight_inplace(layer.mlp.shared_experts.down_proj, "weight", enabled=True) |
| | ) |
| | quantized_count += int( |
| | quantize_weight_inplace(layer.mlp.shared_experts.gate_proj, "weight", enabled=True) |
| | ) |
| | quantized_count += int( |
| | quantize_weight_inplace(layer.mlp.shared_experts.up_proj, "weight", enabled=True) |
| | ) |
| |
|
| | if verbose: |
| | print(f"[quantize-inplace] newly quantized tensors: {quantized_count}") |
| | return quantized_count |
| |
|
| | def prepare_fa2_from_position_ids(self, position_ids: torch.Tensor): |
| | position_ids = position_ids.flatten() |
| | T = position_ids.numel() |
| | indices_q = torch.arange(T, device=position_ids.device, dtype=torch.int32) |
| |
|
| | starts = indices_q[position_ids == 0] |
| |
|
| | |
| | |
| | if starts.numel() == 0: |
| | cu_seq_lens = torch.tensor([0, T], device=position_ids.device, dtype=torch.int32) |
| | else: |
| | |
| | if starts[0].item() != 0: |
| | starts = torch.cat([starts.new_zeros(1), starts], dim=0) |
| | if starts[-1].item() != T: |
| | starts = torch.cat([starts, starts.new_tensor([T])], dim=0) |
| | cu_seq_lens = starts |
| |
|
| | max_length = (cu_seq_lens[1:] - cu_seq_lens[:-1]).max().item() |
| | return (indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | @add_start_docstrings_to_model_forward(BAILINGMOEV2_INPUTS_DOCSTRING) |
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Cache] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | output_router_logits: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | **kwargs, |
| | ) -> Union[Tuple, MoeV2ModelOutputWithPast]: |
| | debug_cache = bool(kwargs.pop("debug_cache", False)) |
| | if debug_cache: |
| | print(f"Debug cache enabled for call {self._cache_debug_calls}") |
| | debug_call_id = self._cache_debug_calls |
| | self._cache_debug_calls += 1 |
| |
|
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_router_logits = ( |
| | output_router_logits if output_router_logits is not None else self.config.output_router_logits |
| | ) |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | use_cache = use_cache if use_cache is not None else self.config.use_cache |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | |
| | if (input_ids is None) == (inputs_embeds is None): |
| | raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
| |
|
| | if self.gradient_checkpointing and self.training and use_cache: |
| | logger.warning_once( |
| | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
| | ) |
| | use_cache = False |
| |
|
| | if use_cache and past_key_values is None: |
| | past_key_values = DynamicCache() |
| |
|
| | if inputs_embeds is None: |
| | inputs_embeds = self.word_embeddings(input_ids) |
| |
|
| | |
| | |
| | |
| | forward_batch = kwargs.get("forward_batch", None) |
| | is_decode_step = False |
| | forward_mode = getattr(forward_batch, "forward_mode", None) if forward_batch is not None else None |
| | if forward_mode is not None: |
| | for mode_name in ( |
| | "is_decode", |
| | "is_decode_or_idle", |
| | "is_target_verify", |
| | "is_draft_decode", |
| | ): |
| | mode_fn = getattr(forward_mode, mode_name, None) |
| | if callable(mode_fn) and bool(mode_fn()): |
| | is_decode_step = True |
| | break |
| |
|
| | |
| | |
| | kwargs["quantize_inplace_now"] = bool( |
| | self.config.quantize and (not self.training) and (not is_decode_step) and inputs_embeds.shape[1] > 1 |
| | ) |
| |
|
| | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| |
|
| | if cache_position is None: |
| | cache_position = torch.arange( |
| | past_seen_tokens, |
| | past_seen_tokens + inputs_embeds.shape[1], |
| | device=inputs_embeds.device, |
| | ) |
| |
|
| | if position_ids is not None: |
| | |
| | batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] |
| | if position_ids.shape[0] != batch_size: |
| | position_ids = position_ids.expand(batch_size, -1) |
| |
|
| | |
| | |
| | if (not is_decode_step) and inputs_embeds.shape[1] > 1: |
| | _, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = self.prepare_fa2_from_position_ids( |
| | position_ids |
| | ) |
| | kwargs["cu_seq_lens_q"] = cu_seq_lens_q |
| | kwargs["cu_seq_lens_k"] = cu_seq_lens_k |
| | kwargs["max_length_q"] = max_length_q |
| | kwargs["max_length_k"] = max_length_k |
| |
|
| | if position_ids is None: |
| | position_ids = cache_position.unsqueeze(0) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | causal_mask = None |
| |
|
| | hidden_states = inputs_embeds |
| | position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| | |
| | |
| |
|
| | all_hidden_states = () if output_hidden_states else None |
| | all_self_attns = () if output_attentions else None |
| | all_router_logits = () if output_router_logits else None |
| |
|
| | aux_loss_sum = 0.0 |
| |
|
| | for decoder_layer in self.layers: |
| | if output_hidden_states: |
| | all_hidden_states += (hidden_states,) |
| |
|
| | if self.gradient_checkpointing and self.training: |
| | layer_outputs = self._gradient_checkpointing_func( |
| | decoder_layer.__call__, |
| | hidden_states, |
| | causal_mask, |
| | position_ids, |
| | past_key_values, |
| | output_attentions, |
| | output_router_logits, |
| | use_cache, |
| | cache_position, |
| | position_embeddings, |
| | **kwargs, |
| | ) |
| | else: |
| | layer_outputs = decoder_layer( |
| | hidden_states, |
| | attention_mask=causal_mask, |
| | position_ids=position_ids, |
| | past_key_value=past_key_values, |
| | output_attentions=output_attentions, |
| | output_router_logits=output_router_logits, |
| | use_cache=use_cache, |
| | cache_position=cache_position, |
| | position_embeddings=position_embeddings, |
| | **kwargs, |
| | ) |
| |
|
| | hidden_states = layer_outputs[0] |
| |
|
| | if output_attentions: |
| | all_self_attns += (layer_outputs[1],) |
| |
|
| | |
| | aux_loss_sum = aux_loss_sum + layer_outputs[3] |
| |
|
| | if output_router_logits: |
| | all_router_logits += (layer_outputs[4],) |
| |
|
| | hidden_states = self.norm(hidden_states) |
| |
|
| | if debug_cache: |
| | past_after_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| | cache_start = int(cache_position[0].item()) if cache_position.numel() > 0 else -1 |
| | cache_end = int(cache_position[-1].item()) if cache_position.numel() > 0 else -1 |
| | cache_hit = bool(use_cache and inputs_embeds.shape[1] == 1 and past_seen_tokens > 0) |
| | print( |
| | "[cache-debug] " |
| | f"call={debug_call_id} use_cache={use_cache} " |
| | f"input_len={inputs_embeds.shape[1]} " |
| | f"past_before={past_seen_tokens} past_after={past_after_tokens} " |
| | f"cache_pos=[{cache_start},{cache_end}] " |
| | f"cache_hit_expected={cache_hit}" |
| | ) |
| |
|
| | if output_hidden_states: |
| | all_hidden_states += (hidden_states,) |
| | moe_layer_count = len(self.layers) - 1 |
| | out = MoeV2ModelOutputWithPast( |
| | last_hidden_state=hidden_states, |
| | past_key_values=past_key_values if use_cache else None, |
| | hidden_states=all_hidden_states, |
| | attentions=all_self_attns, |
| | router_logits=all_router_logits, |
| | aux_loss=aux_loss_sum / moe_layer_count, |
| | ) |
| | return ( |
| | out |
| | if return_dict |
| | else ( |
| | out.last_hidden_state, |
| | out.past_key_values, |
| | out.hidden_states, |
| | out.attentions, |
| | ) |
| | ) |
| |
|
| |
|
| | class BailingMoeV2ForCausalLM(BailingMoeV2PreTrainedModel, GenerationMixin): |
| | _tied_weights_keys = ["lm_head.weight"] |
| |
|
| | def __init__(self, config: BailingMoeV2Config): |
| | super().__init__(config) |
| | self.model = BailingMoeV2Model(config) |
| | self.vocab_size = config.vocab_size |
| |
|
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| | self.router_aux_loss_coef = 0.001 |
| | self.post_init() |
| |
|
| | def get_input_embeddings(self): |
| | return self.model.word_embeddings |
| |
|
| | def set_input_embeddings(self, value): |
| | self.model.word_embeddings = value |
| |
|
| | def get_output_embeddings(self): |
| | return self.lm_head |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | self.lm_head = new_embeddings |
| |
|
| | def set_decoder(self, decoder): |
| | self.model = decoder |
| |
|
| | def get_decoder(self): |
| | return self.model |
| |
|
| | def quantize_inplace(self, verbose: bool = False) -> int: |
| | """ |
| | Quantize model (and lm_head) in-place once for inference. |
| | Returns the number of tensors newly quantized in this call. |
| | """ |
| | quantized_count = self.model.quantize_inplace(verbose=verbose) |
| | if self.config.quantize: |
| | quantized_count += int(quantize_weight_inplace(self.lm_head, "weight", enabled=True)) |
| | if verbose: |
| | print(f"[quantize-inplace] total newly quantized tensors (with lm_head): {quantized_count}") |
| | return quantized_count |
| |
|
| | |
| | |
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Cache] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | output_router_logits: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | logits_to_keep: Union[int, torch.Tensor] = 0, |
| | **kwargs, |
| | ) -> Union[Tuple, MoEV2CausalLMOutputWithPast]: |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | output_router_logits = ( |
| | output_router_logits if output_router_logits is not None else self.config.output_router_logits |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | outputs = self.model( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | output_router_logits=output_router_logits, |
| | return_dict=True, |
| | **kwargs, |
| | ) |
| |
|
| | hidden_states = outputs.last_hidden_state |
| | assert isinstance(hidden_states, torch.Tensor) |
| |
|
| | |
| | loss = None |
| | logits = None |
| | if labels is not None: |
| | loss, logits = self.loss_function(hidden_states, self.lm_head.weight, labels) |
| | else: |
| | logits = self.lm_head(hidden_states) |
| | out = MoEV2CausalLMOutputWithPast( |
| | loss=loss, |
| | aux_loss=getattr(outputs, "aux_loss", 0.0), |
| | logits=logits, |
| | past_key_values=outputs.past_key_values if hasattr(outputs, "past_key_values") else None, |
| | hidden_states=outputs.hidden_states if hasattr(outputs, "hidden_states") else None, |
| | attentions=outputs.attentions if hasattr(outputs, "attentions") else None, |
| | router_logits=outputs.router_logits if hasattr(outputs, "router_logits") else None, |
| | ) |
| | return out |
| |
|
| |
|
| | ModelClass = BailingMoeV2ForCausalLM |
| |
|
| | __all__ = [ |
| | "BailingMoeV2ForCausalLM", |
| | "BailingMoeV2Model", |
| | "BailingMoeV2PreTrainedModel", |
| | ] |
| |
|