Commit
·
1938e9d
1
Parent(s):
9d2f599
fix indexer rope
Browse files- inference/model.py +29 -16
inference/model.py
CHANGED
|
@@ -2,7 +2,6 @@ import math
|
|
| 2 |
from dataclasses import dataclass
|
| 3 |
from typing import Tuple, Optional, Literal
|
| 4 |
|
| 5 |
-
from einops import rearrange
|
| 6 |
import torch
|
| 7 |
from torch import nn
|
| 8 |
import torch.nn.functional as F
|
|
@@ -282,6 +281,7 @@ class RMSNorm(nn.Module):
|
|
| 282 |
super().__init__()
|
| 283 |
self.dim = dim
|
| 284 |
self.eps = eps
|
|
|
|
| 285 |
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
| 286 |
|
| 287 |
def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
|
|
@@ -315,6 +315,7 @@ class LayerNorm(nn.Module):
|
|
| 315 |
super().__init__()
|
| 316 |
self.dim = dim
|
| 317 |
self.eps = eps
|
|
|
|
| 318 |
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
| 319 |
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
|
| 320 |
|
|
@@ -403,7 +404,7 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
|
| 403 |
return freqs_cis
|
| 404 |
|
| 405 |
|
| 406 |
-
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 407 |
"""
|
| 408 |
Applies rotary positional embeddings to the input tensor.
|
| 409 |
|
|
@@ -415,9 +416,14 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
|
| 415 |
torch.Tensor: Tensor with rotary embeddings applied.
|
| 416 |
"""
|
| 417 |
dtype = x.dtype
|
| 418 |
-
|
|
|
|
|
|
|
|
|
|
| 419 |
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
|
| 420 |
y = torch.view_as_real(x * freqs_cis).flatten(3)
|
|
|
|
|
|
|
| 421 |
return y.to(dtype)
|
| 422 |
|
| 423 |
|
|
@@ -441,7 +447,8 @@ class Indexer(torch.nn.Module):
|
|
| 441 |
self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim)
|
| 442 |
self.wk = Linear(self.dim, self.head_dim)
|
| 443 |
self.k_norm = LayerNorm(self.head_dim)
|
| 444 |
-
|
|
|
|
| 445 |
self.softmax_scale = self.head_dim ** -0.5
|
| 446 |
self.scale_fmt = args.scale_fmt
|
| 447 |
|
|
@@ -453,14 +460,16 @@ class Indexer(torch.nn.Module):
|
|
| 453 |
bsz, seqlen, _ = x.size()
|
| 454 |
end_pos = start_pos + seqlen
|
| 455 |
q = self.wq_b(qr)
|
| 456 |
-
q =
|
| 457 |
q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
|
| 458 |
-
|
|
|
|
| 459 |
q = torch.cat([q_pe, q_nope], dim=-1)
|
| 460 |
k = self.wk(x)
|
| 461 |
k = self.k_norm(k)
|
| 462 |
k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
|
| 463 |
-
|
|
|
|
| 464 |
k = torch.cat([k_pe, k_nope], dim=-1)
|
| 465 |
q = rotate_activation(q)
|
| 466 |
k = rotate_activation(k)
|
|
@@ -468,7 +477,7 @@ class Indexer(torch.nn.Module):
|
|
| 468 |
k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
|
| 469 |
self.k_cache[:bsz, start_pos:end_pos] = k_fp8
|
| 470 |
self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
|
| 471 |
-
weights = self.weights_proj(x) * self.n_heads ** -0.5
|
| 472 |
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
|
| 473 |
index_score = fp8_index(q_fp8.contiguous(), weights, self.k_cache[:bsz, :end_pos].contiguous(), self.k_scale_cache[:bsz, :end_pos].contiguous())
|
| 474 |
if mask is not None:
|
|
@@ -524,6 +533,7 @@ class MLA(nn.Module):
|
|
| 524 |
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
|
| 525 |
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
|
| 526 |
self.softmax_scale = self.qk_head_dim ** -0.5
|
|
|
|
| 527 |
if args.max_seq_len > args.original_seq_len:
|
| 528 |
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
|
| 529 |
self.softmax_scale = self.softmax_scale * mscale * mscale
|
|
@@ -558,6 +568,9 @@ class MLA(nn.Module):
|
|
| 558 |
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
| 559 |
kv = self.kv_norm(kv)
|
| 560 |
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
|
|
|
|
|
|
|
|
|
|
| 561 |
self.kv_cache[:bsz, start_pos:end_pos] = kv
|
| 562 |
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
|
| 563 |
if mask is not None: # MHA prefill
|
|
@@ -566,7 +579,7 @@ class MLA(nn.Module):
|
|
| 566 |
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
| 567 |
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 568 |
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
|
| 569 |
-
scores = torch.einsum("bshd,bthd->bsht", q
|
| 570 |
|
| 571 |
# indexer
|
| 572 |
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
|
|
@@ -574,24 +587,24 @@ class MLA(nn.Module):
|
|
| 574 |
index_mask += mask
|
| 575 |
scores += index_mask.unsqueeze(2)
|
| 576 |
|
| 577 |
-
scores = scores.softmax(dim=-1
|
| 578 |
-
x = torch.einsum("bsht,bthd->bshd", scores
|
| 579 |
-
else: #
|
| 580 |
if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
|
| 581 |
self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
|
| 582 |
wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
|
| 583 |
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
|
| 584 |
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
|
| 585 |
-
scores = (torch.einsum("bshc,btc->bsht", q_nope
|
| 586 |
-
torch.einsum("bshr,btr->bsht", q_pe
|
| 587 |
|
| 588 |
# indexer
|
| 589 |
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
|
| 590 |
index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
|
| 591 |
scores += index_mask.unsqueeze(2)
|
| 592 |
|
| 593 |
-
scores = scores.softmax(dim=-1
|
| 594 |
-
x = torch.einsum("bsht,btc->bshc", scores
|
| 595 |
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
|
| 596 |
x = self.wo(x.flatten(2))
|
| 597 |
return x
|
|
|
|
| 2 |
from dataclasses import dataclass
|
| 3 |
from typing import Tuple, Optional, Literal
|
| 4 |
|
|
|
|
| 5 |
import torch
|
| 6 |
from torch import nn
|
| 7 |
import torch.nn.functional as F
|
|
|
|
| 281 |
super().__init__()
|
| 282 |
self.dim = dim
|
| 283 |
self.eps = eps
|
| 284 |
+
# rmsnorm in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
|
| 285 |
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
| 286 |
|
| 287 |
def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
|
|
|
|
| 315 |
super().__init__()
|
| 316 |
self.dim = dim
|
| 317 |
self.eps = eps
|
| 318 |
+
# layernorm in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenient.
|
| 319 |
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
| 320 |
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
|
| 321 |
|
|
|
|
| 404 |
return freqs_cis
|
| 405 |
|
| 406 |
|
| 407 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, interleaved: bool = True) -> torch.Tensor:
|
| 408 |
"""
|
| 409 |
Applies rotary positional embeddings to the input tensor.
|
| 410 |
|
|
|
|
| 416 |
torch.Tensor: Tensor with rotary embeddings applied.
|
| 417 |
"""
|
| 418 |
dtype = x.dtype
|
| 419 |
+
shape = x.shape
|
| 420 |
+
if not interleaved:
|
| 421 |
+
x = x.view(*shape[:-1], 2, -1).transpose(-1, -2).contiguous()
|
| 422 |
+
x = torch.view_as_complex(x.float().view(*shape[:-1], -1, 2))
|
| 423 |
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
|
| 424 |
y = torch.view_as_real(x * freqs_cis).flatten(3)
|
| 425 |
+
if not interleaved:
|
| 426 |
+
y = torch.cat([y[..., 0::2], y[..., 1::2]], dim=-1)
|
| 427 |
return y.to(dtype)
|
| 428 |
|
| 429 |
|
|
|
|
| 447 |
self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim)
|
| 448 |
self.wk = Linear(self.dim, self.head_dim)
|
| 449 |
self.k_norm = LayerNorm(self.head_dim)
|
| 450 |
+
# weights_proj in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenient.
|
| 451 |
+
self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.float32)
|
| 452 |
self.softmax_scale = self.head_dim ** -0.5
|
| 453 |
self.scale_fmt = args.scale_fmt
|
| 454 |
|
|
|
|
| 460 |
bsz, seqlen, _ = x.size()
|
| 461 |
end_pos = start_pos + seqlen
|
| 462 |
q = self.wq_b(qr)
|
| 463 |
+
q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
|
| 464 |
q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
|
| 465 |
+
# rope in indexer is not interleaved
|
| 466 |
+
q_pe = apply_rotary_emb(q_pe, freqs_cis, False)
|
| 467 |
q = torch.cat([q_pe, q_nope], dim=-1)
|
| 468 |
k = self.wk(x)
|
| 469 |
k = self.k_norm(k)
|
| 470 |
k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
|
| 471 |
+
# rope in indexer is not interleaved
|
| 472 |
+
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis, False).squeeze(2)
|
| 473 |
k = torch.cat([k_pe, k_nope], dim=-1)
|
| 474 |
q = rotate_activation(q)
|
| 475 |
k = rotate_activation(k)
|
|
|
|
| 477 |
k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
|
| 478 |
self.k_cache[:bsz, start_pos:end_pos] = k_fp8
|
| 479 |
self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
|
| 480 |
+
weights = self.weights_proj(x.float()) * self.n_heads ** -0.5
|
| 481 |
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
|
| 482 |
index_score = fp8_index(q_fp8.contiguous(), weights, self.k_cache[:bsz, :end_pos].contiguous(), self.k_scale_cache[:bsz, :end_pos].contiguous())
|
| 483 |
if mask is not None:
|
|
|
|
| 533 |
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
|
| 534 |
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
|
| 535 |
self.softmax_scale = self.qk_head_dim ** -0.5
|
| 536 |
+
self.scale_fmt = args.scale_fmt
|
| 537 |
if args.max_seq_len > args.original_seq_len:
|
| 538 |
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
|
| 539 |
self.softmax_scale = self.softmax_scale * mscale * mscale
|
|
|
|
| 568 |
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
| 569 |
kv = self.kv_norm(kv)
|
| 570 |
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
|
| 571 |
+
# we use fp8 kv cache in actual deployment, so here we simulate the precision by casting kv to fp8 and then back to bf16.
|
| 572 |
+
kv_fp8, kv_scale = act_quant(kv, block_size, self.scale_fmt)
|
| 573 |
+
kv = (kv_fp8.view(-1, block_size).float() * kv_scale.view(-1, 1)).to(kv.dtype).view_as(kv)
|
| 574 |
self.kv_cache[:bsz, start_pos:end_pos] = kv
|
| 575 |
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
|
| 576 |
if mask is not None: # MHA prefill
|
|
|
|
| 579 |
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
| 580 |
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 581 |
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
|
| 582 |
+
scores = torch.einsum("bshd,bthd->bsht", q, k).mul_(self.softmax_scale)
|
| 583 |
|
| 584 |
# indexer
|
| 585 |
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
|
|
|
|
| 587 |
index_mask += mask
|
| 588 |
scores += index_mask.unsqueeze(2)
|
| 589 |
|
| 590 |
+
scores = scores.softmax(dim=-1)
|
| 591 |
+
x = torch.einsum("bsht,bthd->bshd", scores, v)
|
| 592 |
+
else: # MQA decode
|
| 593 |
if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
|
| 594 |
self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
|
| 595 |
wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
|
| 596 |
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
|
| 597 |
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
|
| 598 |
+
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
|
| 599 |
+
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
|
| 600 |
|
| 601 |
# indexer
|
| 602 |
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
|
| 603 |
index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
|
| 604 |
scores += index_mask.unsqueeze(2)
|
| 605 |
|
| 606 |
+
scores = scores.softmax(dim=-1)
|
| 607 |
+
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
|
| 608 |
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
|
| 609 |
x = self.wo(x.flatten(2))
|
| 610 |
return x
|