GeeeekExplorer commited on
Commit
1938e9d
·
1 Parent(s): 9d2f599

fix indexer rope

Browse files
Files changed (1) hide show
  1. inference/model.py +29 -16
inference/model.py CHANGED
@@ -2,7 +2,6 @@ import math
2
  from dataclasses import dataclass
3
  from typing import Tuple, Optional, Literal
4
 
5
- from einops import rearrange
6
  import torch
7
  from torch import nn
8
  import torch.nn.functional as F
@@ -282,6 +281,7 @@ class RMSNorm(nn.Module):
282
  super().__init__()
283
  self.dim = dim
284
  self.eps = eps
 
285
  self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
286
 
287
  def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
@@ -315,6 +315,7 @@ class LayerNorm(nn.Module):
315
  super().__init__()
316
  self.dim = dim
317
  self.eps = eps
 
318
  self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
319
  self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
320
 
@@ -403,7 +404,7 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
403
  return freqs_cis
404
 
405
 
406
- def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
407
  """
408
  Applies rotary positional embeddings to the input tensor.
409
 
@@ -415,9 +416,14 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
415
  torch.Tensor: Tensor with rotary embeddings applied.
416
  """
417
  dtype = x.dtype
418
- x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
 
 
 
419
  freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
420
  y = torch.view_as_real(x * freqs_cis).flatten(3)
 
 
421
  return y.to(dtype)
422
 
423
 
@@ -441,7 +447,8 @@ class Indexer(torch.nn.Module):
441
  self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim)
442
  self.wk = Linear(self.dim, self.head_dim)
443
  self.k_norm = LayerNorm(self.head_dim)
444
- self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.get_default_dtype())
 
445
  self.softmax_scale = self.head_dim ** -0.5
446
  self.scale_fmt = args.scale_fmt
447
 
@@ -453,14 +460,16 @@ class Indexer(torch.nn.Module):
453
  bsz, seqlen, _ = x.size()
454
  end_pos = start_pos + seqlen
455
  q = self.wq_b(qr)
456
- q = rearrange(q, 'b s (h d) -> b s h d', d=self.head_dim)
457
  q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
458
- q_pe = apply_rotary_emb(q_pe, freqs_cis)
 
459
  q = torch.cat([q_pe, q_nope], dim=-1)
460
  k = self.wk(x)
461
  k = self.k_norm(k)
462
  k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
463
- k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis).squeeze(2)
 
464
  k = torch.cat([k_pe, k_nope], dim=-1)
465
  q = rotate_activation(q)
466
  k = rotate_activation(k)
@@ -468,7 +477,7 @@ class Indexer(torch.nn.Module):
468
  k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
469
  self.k_cache[:bsz, start_pos:end_pos] = k_fp8
470
  self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
471
- weights = self.weights_proj(x) * self.n_heads ** -0.5
472
  weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
473
  index_score = fp8_index(q_fp8.contiguous(), weights, self.k_cache[:bsz, :end_pos].contiguous(), self.k_scale_cache[:bsz, :end_pos].contiguous())
474
  if mask is not None:
@@ -524,6 +533,7 @@ class MLA(nn.Module):
524
  self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
525
  self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
526
  self.softmax_scale = self.qk_head_dim ** -0.5
 
527
  if args.max_seq_len > args.original_seq_len:
528
  mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
529
  self.softmax_scale = self.softmax_scale * mscale * mscale
@@ -558,6 +568,9 @@ class MLA(nn.Module):
558
  kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
559
  kv = self.kv_norm(kv)
560
  k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
 
 
 
561
  self.kv_cache[:bsz, start_pos:end_pos] = kv
562
  self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
563
  if mask is not None: # MHA prefill
@@ -566,7 +579,7 @@ class MLA(nn.Module):
566
  kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
567
  k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
568
  k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
569
- scores = torch.einsum("bshd,bthd->bsht", q.float(), k.float()) * self.softmax_scale
570
 
571
  # indexer
572
  topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
@@ -574,24 +587,24 @@ class MLA(nn.Module):
574
  index_mask += mask
575
  scores += index_mask.unsqueeze(2)
576
 
577
- scores = scores.softmax(dim=-1, dtype=torch.float32)
578
- x = torch.einsum("bsht,bthd->bshd", scores.type_as(x), v)
579
- else: # MHA decode
580
  if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
581
  self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
582
  wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
583
  wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
584
  q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
585
- scores = (torch.einsum("bshc,btc->bsht", q_nope.float(), self.kv_cache[:bsz, :end_pos].float()) +
586
- torch.einsum("bshr,btr->bsht", q_pe.float(), self.pe_cache[:bsz, :end_pos].float())) * self.softmax_scale
587
 
588
  # indexer
589
  topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
590
  index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
591
  scores += index_mask.unsqueeze(2)
592
 
593
- scores = scores.softmax(dim=-1, dtype=torch.float32)
594
- x = torch.einsum("bsht,btc->bshc", scores.type_as(x), self.kv_cache[:bsz, :end_pos])
595
  x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
596
  x = self.wo(x.flatten(2))
597
  return x
 
2
  from dataclasses import dataclass
3
  from typing import Tuple, Optional, Literal
4
 
 
5
  import torch
6
  from torch import nn
7
  import torch.nn.functional as F
 
281
  super().__init__()
282
  self.dim = dim
283
  self.eps = eps
284
+ # rmsnorm in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
285
  self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
286
 
287
  def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
 
315
  super().__init__()
316
  self.dim = dim
317
  self.eps = eps
318
+ # layernorm in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenient.
319
  self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
320
  self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
321
 
 
404
  return freqs_cis
405
 
406
 
407
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, interleaved: bool = True) -> torch.Tensor:
408
  """
409
  Applies rotary positional embeddings to the input tensor.
410
 
 
416
  torch.Tensor: Tensor with rotary embeddings applied.
417
  """
418
  dtype = x.dtype
419
+ shape = x.shape
420
+ if not interleaved:
421
+ x = x.view(*shape[:-1], 2, -1).transpose(-1, -2).contiguous()
422
+ x = torch.view_as_complex(x.float().view(*shape[:-1], -1, 2))
423
  freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
424
  y = torch.view_as_real(x * freqs_cis).flatten(3)
425
+ if not interleaved:
426
+ y = torch.cat([y[..., 0::2], y[..., 1::2]], dim=-1)
427
  return y.to(dtype)
428
 
429
 
 
447
  self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim)
448
  self.wk = Linear(self.dim, self.head_dim)
449
  self.k_norm = LayerNorm(self.head_dim)
450
+ # weights_proj in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenient.
451
+ self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.float32)
452
  self.softmax_scale = self.head_dim ** -0.5
453
  self.scale_fmt = args.scale_fmt
454
 
 
460
  bsz, seqlen, _ = x.size()
461
  end_pos = start_pos + seqlen
462
  q = self.wq_b(qr)
463
+ q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
464
  q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
465
+ # rope in indexer is not interleaved
466
+ q_pe = apply_rotary_emb(q_pe, freqs_cis, False)
467
  q = torch.cat([q_pe, q_nope], dim=-1)
468
  k = self.wk(x)
469
  k = self.k_norm(k)
470
  k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
471
+ # rope in indexer is not interleaved
472
+ k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis, False).squeeze(2)
473
  k = torch.cat([k_pe, k_nope], dim=-1)
474
  q = rotate_activation(q)
475
  k = rotate_activation(k)
 
477
  k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
478
  self.k_cache[:bsz, start_pos:end_pos] = k_fp8
479
  self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
480
+ weights = self.weights_proj(x.float()) * self.n_heads ** -0.5
481
  weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
482
  index_score = fp8_index(q_fp8.contiguous(), weights, self.k_cache[:bsz, :end_pos].contiguous(), self.k_scale_cache[:bsz, :end_pos].contiguous())
483
  if mask is not None:
 
533
  self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
534
  self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
535
  self.softmax_scale = self.qk_head_dim ** -0.5
536
+ self.scale_fmt = args.scale_fmt
537
  if args.max_seq_len > args.original_seq_len:
538
  mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
539
  self.softmax_scale = self.softmax_scale * mscale * mscale
 
568
  kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
569
  kv = self.kv_norm(kv)
570
  k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
571
+ # we use fp8 kv cache in actual deployment, so here we simulate the precision by casting kv to fp8 and then back to bf16.
572
+ kv_fp8, kv_scale = act_quant(kv, block_size, self.scale_fmt)
573
+ kv = (kv_fp8.view(-1, block_size).float() * kv_scale.view(-1, 1)).to(kv.dtype).view_as(kv)
574
  self.kv_cache[:bsz, start_pos:end_pos] = kv
575
  self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
576
  if mask is not None: # MHA prefill
 
579
  kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
580
  k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
581
  k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
582
+ scores = torch.einsum("bshd,bthd->bsht", q, k).mul_(self.softmax_scale)
583
 
584
  # indexer
585
  topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
 
587
  index_mask += mask
588
  scores += index_mask.unsqueeze(2)
589
 
590
+ scores = scores.softmax(dim=-1)
591
+ x = torch.einsum("bsht,bthd->bshd", scores, v)
592
+ else: # MQA decode
593
  if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
594
  self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
595
  wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
596
  wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
597
  q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
598
+ scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
599
+ torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
600
 
601
  # indexer
602
  topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
603
  index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
604
  scores += index_mask.unsqueeze(2)
605
 
606
+ scores = scores.softmax(dim=-1)
607
+ x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
608
  x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
609
  x = self.wo(x.flatten(2))
610
  return x