Update modeling_minicpm.py
Browse files- modeling_minicpm.py +93 -42
modeling_minicpm.py
CHANGED
|
@@ -70,13 +70,14 @@ from functools import lru_cache
|
|
| 70 |
def compressed_attention(
|
| 71 |
q: torch.Tensor,
|
| 72 |
k: torch.Tensor,
|
| 73 |
-
|
| 74 |
kernel_size: int,
|
| 75 |
kernel_stride: int,
|
| 76 |
block_size: int,
|
| 77 |
topk: int,
|
| 78 |
cu_seqlens_q: torch.Tensor,
|
| 79 |
cu_seqlens_k: torch.Tensor,
|
|
|
|
| 80 |
max_seqlen_q: int,
|
| 81 |
max_seqlen_k: int,
|
| 82 |
sm_scale: float = None,
|
|
@@ -106,9 +107,10 @@ def compressed_attention(
|
|
| 106 |
score = infllmv2_attn_stage1(
|
| 107 |
q.contiguous(),
|
| 108 |
k.contiguous(),
|
| 109 |
-
|
| 110 |
cu_seqlens_q=cu_seqlens_q,
|
| 111 |
cu_seqlens_k=cu_seqlens_k,
|
|
|
|
| 112 |
max_seqlen_q=max_seqlen_q,
|
| 113 |
max_seqlen_k=max_seqlen_k,
|
| 114 |
causal=is_prefilling
|
|
@@ -142,12 +144,10 @@ def compressed_attention(
|
|
| 142 |
def calc_chunks_with_stride(cu_seqlen, chunk_size, kernel_stride):
|
| 143 |
"""
|
| 144 |
Compute the chunks that require Sparse attention, with stride support.
|
| 145 |
-
|
| 146 |
Args:
|
| 147 |
cu_seqlen (torch.Tensor): Cumulative sequence lengths for each sample.
|
| 148 |
chunk_size (int): Chunk size used for Sparse attention.
|
| 149 |
kernel_stride (int): Stride size when sliding over the sequence.
|
| 150 |
-
|
| 151 |
Returns:
|
| 152 |
filtered_indices (torch.Tensor): Indices used to directly index into the key/value tensors.
|
| 153 |
cu_seqlens_compressed (torch.Tensor): Cumulative sequence lengths after compression.
|
|
@@ -190,7 +190,6 @@ class CompressK(torch.nn.Module):
|
|
| 190 |
def __init__(self, head_num_k, head_dim, kernel_size, kernel_stride=16):
|
| 191 |
"""
|
| 192 |
Module for compressing key (K) representations.
|
| 193 |
-
|
| 194 |
Args:
|
| 195 |
head_num_k (int): Number of key attention heads.
|
| 196 |
head_dim (int): Dimension of each attention head.
|
|
@@ -206,15 +205,12 @@ class CompressK(torch.nn.Module):
|
|
| 206 |
def forward(self, k: torch.Tensor, cu_seqlens):
|
| 207 |
"""
|
| 208 |
Forward pass for compressing the key (K) tensor.
|
| 209 |
-
|
| 210 |
Args:
|
| 211 |
k (torch.Tensor): Input key tensor of shape (total_seq_len, num_heads, head_dim).
|
| 212 |
cu_seqlens (torch.Tensor): Cumulative sequence lengths for each sample in the batch, typically used for handling variable-length sequences.
|
| 213 |
-
|
| 214 |
Returns:
|
| 215 |
compress_k (torch.Tensor): Compressed key tensor.
|
| 216 |
cu_seqlens_compressed (torch.Tensor): Updated cumulative sequence lengths after compression.
|
| 217 |
-
|
| 218 |
"""
|
| 219 |
# Compute chunk-related metadata, with stride support
|
| 220 |
filtered_k_indices, cu_seqlens_compressed = calc_chunks_with_stride(
|
|
@@ -241,6 +237,11 @@ class InfLLMv2CacheLayer(DynamicLayer):
|
|
| 241 |
self.no_compress_k_cache = []
|
| 242 |
self.cached_compressed_cu_seqlens = torch.tensor([], dtype=torch.int32)
|
| 243 |
self.compress_k_cache_varlen = torch.tensor([], dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
def update_no_rope_key(self, key_states):
|
| 246 |
if self.no_rope_keys.numel() == 0:
|
|
@@ -282,6 +283,39 @@ class InfLLMv2CacheLayer(DynamicLayer):
|
|
| 282 |
k_chunk_list.append(None)
|
| 283 |
return k_chunk_list
|
| 284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
class InfLLMv2Cache(DynamicCache):
|
| 286 |
def __init__(self, config,num_hidden_layers: Optional[int] = None) -> None:
|
| 287 |
super().__init__(config=config)
|
|
@@ -303,6 +337,12 @@ class InfLLMv2Cache(DynamicCache):
|
|
| 303 |
def update_no_compress_k(self, key_states, layer_idx, kernel_size=32, kernel_stride=16, cache_kwargs=None):
|
| 304 |
return self.layers[layer_idx].update_no_compress_k(key_states, kernel_size, kernel_stride)
|
| 305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
def crop(self, max_length):
|
| 307 |
for layer in self.layers:
|
| 308 |
layer.crop(max_length)
|
|
@@ -489,7 +529,6 @@ def rotate_half(x):
|
|
| 489 |
|
| 490 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
| 491 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 492 |
-
|
| 493 |
Args:
|
| 494 |
q (`torch.Tensor`): The query tensor.
|
| 495 |
k (`torch.Tensor`): The key tensor.
|
|
@@ -860,7 +899,6 @@ class MiniCPMFlashAttention2(MiniCPMAttention):
|
|
| 860 |
"""
|
| 861 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 862 |
first unpad the input, then computes the attention scores and pad the final attention scores.
|
| 863 |
-
|
| 864 |
Args:
|
| 865 |
query_states (`torch.Tensor`):
|
| 866 |
Input query states to be passed to Flash Attention API
|
|
@@ -976,7 +1014,9 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 976 |
self.local_blocks = self.window_size // self.block_size # local_blocks
|
| 977 |
self.topk = self.config.sparse_config.get('topk', 64) + (self.window_size//self.block_size)
|
| 978 |
self.use_nope = self.config.sparse_config.get('use_nope', False)
|
|
|
|
| 979 |
self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride)
|
|
|
|
| 980 |
|
| 981 |
def forward(
|
| 982 |
self,
|
|
@@ -1088,7 +1128,6 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1088 |
"""
|
| 1089 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 1090 |
first unpad the input, then computes the attention scores and pad the final attention scores.
|
| 1091 |
-
|
| 1092 |
Args:
|
| 1093 |
query_states (`torch.Tensor`):
|
| 1094 |
Input query states to be passed to Flash Attention API
|
|
@@ -1114,7 +1153,7 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1114 |
batch_size = query_states.shape[0]
|
| 1115 |
# assert batch_size == 1, 'Only batch_size=1 is supported at the moment.'
|
| 1116 |
if past_key_value!=None:
|
| 1117 |
-
compressed_k, compressed_cu_seqlens = self.get_compress_k(
|
| 1118 |
key_states=key_states if self.use_nope ==False else no_rope_param['key_states_no_rope'], # This can be optimized a bit;
|
| 1119 |
attention_mask=attention_mask,
|
| 1120 |
past_key_value=past_key_value,
|
|
@@ -1135,6 +1174,10 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1135 |
if past_key_value==None:
|
| 1136 |
# compress_k use varlen form
|
| 1137 |
compressed_k, compressed_cu_seqlens = self.compress_k(key_states,cu_seqlens_k)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1138 |
|
| 1139 |
|
| 1140 |
attn_output_unpad = self.sparse_forward(
|
|
@@ -1146,7 +1189,8 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1146 |
max_seqlen_in_batch_q,
|
| 1147 |
max_seqlen_in_batch_k,
|
| 1148 |
no_rope_param=no_rope_param,
|
| 1149 |
-
compressed_k=compressed_k, compressed_cu_seqlens=compressed_cu_seqlens
|
|
|
|
| 1150 |
)
|
| 1151 |
|
| 1152 |
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
|
@@ -1166,7 +1210,7 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1166 |
no_rope_param: Optional parameter containing key states without rope
|
| 1167 |
|
| 1168 |
Returns:
|
| 1169 |
-
Tuple of (compressed_k, compressed_cu_seqlens)
|
| 1170 |
"""
|
| 1171 |
|
| 1172 |
# Check if this is prefilling or initial compression condition
|
|
@@ -1182,9 +1226,12 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1182 |
unpadded_key_states, indices, cu_seqlens, max_seqlen_in_batch = _unpad_one_tensor(key_states,attention_mask=attention_mask)
|
| 1183 |
# Compress the keys
|
| 1184 |
compressed_k, compressed_cu_seqlens = self.compress_k(unpadded_key_states, cu_seqlens)
|
|
|
|
| 1185 |
|
| 1186 |
past_key_value.update_compress_k(
|
| 1187 |
compressed_k, self.layer_idx, compressed_cu_seqlens)
|
|
|
|
|
|
|
| 1188 |
|
| 1189 |
no_compress_k_list = []
|
| 1190 |
# Compute and update no_compress_k
|
|
@@ -1196,6 +1243,17 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1196 |
past_key_value.update_no_compress_k(
|
| 1197 |
no_compress_k_list, self.layer_idx,kernel_stride=self.kernel_stride,
|
| 1198 |
kernel_size=self.kernel_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1199 |
|
| 1200 |
else:
|
| 1201 |
# Decode case: incremental update
|
|
@@ -1220,8 +1278,23 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1220 |
else:
|
| 1221 |
new_compressed_k_list.append(None)
|
| 1222 |
compressed_k, compressed_cu_seqlens = past_key_value.update_compress_k(new_compressed_k_list, self.layer_idx,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1223 |
|
| 1224 |
-
return compressed_k, compressed_cu_seqlens
|
| 1225 |
def sparse_forward(self,
|
| 1226 |
query_layer,
|
| 1227 |
key_layer,
|
|
@@ -1231,7 +1304,8 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1231 |
max_seqlen_in_batch_q,
|
| 1232 |
max_seqlen_in_batch_k,
|
| 1233 |
no_rope_param=None,
|
| 1234 |
-
compressed_k=None, compressed_cu_seqlens=None
|
|
|
|
| 1235 |
compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
|
| 1236 |
cache_lens = None
|
| 1237 |
if max_seqlen_in_batch_q==1 and max_seqlen_in_batch_k>1: #decoding
|
|
@@ -1241,13 +1315,14 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1241 |
topk_idx = compressed_attention(
|
| 1242 |
query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'],
|
| 1243 |
compressed_k,
|
| 1244 |
-
|
| 1245 |
self.kernel_size,
|
| 1246 |
self.kernel_stride,
|
| 1247 |
self.block_size,
|
| 1248 |
self.topk,
|
| 1249 |
cu_seqlens_q,
|
| 1250 |
compressed_cu_seqlens,
|
|
|
|
| 1251 |
max_seqlen_in_batch_q,
|
| 1252 |
compressed_seqlens.max().item(),
|
| 1253 |
None,
|
|
@@ -1280,7 +1355,6 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1280 |
"""
|
| 1281 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 1282 |
first unpad the input, then computes the attention scores and pad the final attention scores.
|
| 1283 |
-
|
| 1284 |
Args:
|
| 1285 |
query_states (`torch.Tensor`):
|
| 1286 |
Input query states to be passed to Flash Attention API
|
|
@@ -1544,11 +1618,9 @@ MINICPM_START_DOCSTRING = r"""
|
|
| 1544 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1545 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 1546 |
etc.)
|
| 1547 |
-
|
| 1548 |
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 1549 |
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 1550 |
and behavior.
|
| 1551 |
-
|
| 1552 |
Parameters:
|
| 1553 |
config ([`MiniCPMConfig`]):
|
| 1554 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
|
@@ -1588,50 +1660,38 @@ MINICPM_INPUTS_DOCSTRING = r"""
|
|
| 1588 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1589 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 1590 |
it.
|
| 1591 |
-
|
| 1592 |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1593 |
[`PreTrainedTokenizer.__call__`] for details.
|
| 1594 |
-
|
| 1595 |
[What are input IDs?](../glossary#input-ids)
|
| 1596 |
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1597 |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 1598 |
-
|
| 1599 |
- 1 for tokens that are **not masked**,
|
| 1600 |
- 0 for tokens that are **masked**.
|
| 1601 |
-
|
| 1602 |
[What are attention masks?](../glossary#attention-mask)
|
| 1603 |
-
|
| 1604 |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1605 |
[`PreTrainedTokenizer.__call__`] for details.
|
| 1606 |
-
|
| 1607 |
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
| 1608 |
`past_key_values`).
|
| 1609 |
-
|
| 1610 |
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 1611 |
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 1612 |
information on the default strategy.
|
| 1613 |
-
|
| 1614 |
- 1 indicates the head is **not masked**,
|
| 1615 |
- 0 indicates the head is **masked**.
|
| 1616 |
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1617 |
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 1618 |
config.n_positions - 1]`.
|
| 1619 |
-
|
| 1620 |
[What are position IDs?](../glossary#position-ids)
|
| 1621 |
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
| 1622 |
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 1623 |
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 1624 |
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 1625 |
-
|
| 1626 |
Two formats are allowed:
|
| 1627 |
- a [`~cache_utils.Cache`] instance;
|
| 1628 |
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 1629 |
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
| 1630 |
cache format.
|
| 1631 |
-
|
| 1632 |
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
| 1633 |
legacy cache format will be returned.
|
| 1634 |
-
|
| 1635 |
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 1636 |
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 1637 |
of shape `(batch_size, sequence_length)`.
|
|
@@ -1660,7 +1720,6 @@ MINICPM_INPUTS_DOCSTRING = r"""
|
|
| 1660 |
class MiniCPMModel(MiniCPMPreTrainedModel):
|
| 1661 |
"""
|
| 1662 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
|
| 1663 |
-
|
| 1664 |
Args:
|
| 1665 |
config: MiniCPMConfig
|
| 1666 |
"""
|
|
@@ -1887,20 +1946,14 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
|
|
| 1887 |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1888 |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 1889 |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1890 |
-
|
| 1891 |
Returns:
|
| 1892 |
-
|
| 1893 |
Example:
|
| 1894 |
-
|
| 1895 |
```python
|
| 1896 |
>>> from transformers import AutoTokenizer, MiniCPMForCausalLM
|
| 1897 |
-
|
| 1898 |
>>> model = MiniCPMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
| 1899 |
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
| 1900 |
-
|
| 1901 |
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 1902 |
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 1903 |
-
|
| 1904 |
>>> # Generate
|
| 1905 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1906 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
@@ -2080,10 +2133,8 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
|
|
| 2080 |
@add_start_docstrings(
|
| 2081 |
"""
|
| 2082 |
The MiniCPM Model transformer with a sequence classification head on top (linear layer).
|
| 2083 |
-
|
| 2084 |
[`MiniCPMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
| 2085 |
(e.g. GPT-2) do.
|
| 2086 |
-
|
| 2087 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
| 2088 |
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
| 2089 |
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
|
@@ -2196,4 +2247,4 @@ class MiniCPMForSequenceClassification(MiniCPMPreTrainedModel):
|
|
| 2196 |
past_key_values=transformer_outputs.past_key_values,
|
| 2197 |
hidden_states=transformer_outputs.hidden_states,
|
| 2198 |
attentions=transformer_outputs.attentions,
|
| 2199 |
-
)
|
|
|
|
| 70 |
def compressed_attention(
|
| 71 |
q: torch.Tensor,
|
| 72 |
k: torch.Tensor,
|
| 73 |
+
k2: torch.Tensor,
|
| 74 |
kernel_size: int,
|
| 75 |
kernel_stride: int,
|
| 76 |
block_size: int,
|
| 77 |
topk: int,
|
| 78 |
cu_seqlens_q: torch.Tensor,
|
| 79 |
cu_seqlens_k: torch.Tensor,
|
| 80 |
+
cu_seqlens_k2: torch.Tensor,
|
| 81 |
max_seqlen_q: int,
|
| 82 |
max_seqlen_k: int,
|
| 83 |
sm_scale: float = None,
|
|
|
|
| 107 |
score = infllmv2_attn_stage1(
|
| 108 |
q.contiguous(),
|
| 109 |
k.contiguous(),
|
| 110 |
+
k2.contiguous(),
|
| 111 |
cu_seqlens_q=cu_seqlens_q,
|
| 112 |
cu_seqlens_k=cu_seqlens_k,
|
| 113 |
+
cu_seqlens_v=cu_seqlens_k2,
|
| 114 |
max_seqlen_q=max_seqlen_q,
|
| 115 |
max_seqlen_k=max_seqlen_k,
|
| 116 |
causal=is_prefilling
|
|
|
|
| 144 |
def calc_chunks_with_stride(cu_seqlen, chunk_size, kernel_stride):
|
| 145 |
"""
|
| 146 |
Compute the chunks that require Sparse attention, with stride support.
|
|
|
|
| 147 |
Args:
|
| 148 |
cu_seqlen (torch.Tensor): Cumulative sequence lengths for each sample.
|
| 149 |
chunk_size (int): Chunk size used for Sparse attention.
|
| 150 |
kernel_stride (int): Stride size when sliding over the sequence.
|
|
|
|
| 151 |
Returns:
|
| 152 |
filtered_indices (torch.Tensor): Indices used to directly index into the key/value tensors.
|
| 153 |
cu_seqlens_compressed (torch.Tensor): Cumulative sequence lengths after compression.
|
|
|
|
| 190 |
def __init__(self, head_num_k, head_dim, kernel_size, kernel_stride=16):
|
| 191 |
"""
|
| 192 |
Module for compressing key (K) representations.
|
|
|
|
| 193 |
Args:
|
| 194 |
head_num_k (int): Number of key attention heads.
|
| 195 |
head_dim (int): Dimension of each attention head.
|
|
|
|
| 205 |
def forward(self, k: torch.Tensor, cu_seqlens):
|
| 206 |
"""
|
| 207 |
Forward pass for compressing the key (K) tensor.
|
|
|
|
| 208 |
Args:
|
| 209 |
k (torch.Tensor): Input key tensor of shape (total_seq_len, num_heads, head_dim).
|
| 210 |
cu_seqlens (torch.Tensor): Cumulative sequence lengths for each sample in the batch, typically used for handling variable-length sequences.
|
|
|
|
| 211 |
Returns:
|
| 212 |
compress_k (torch.Tensor): Compressed key tensor.
|
| 213 |
cu_seqlens_compressed (torch.Tensor): Updated cumulative sequence lengths after compression.
|
|
|
|
| 214 |
"""
|
| 215 |
# Compute chunk-related metadata, with stride support
|
| 216 |
filtered_k_indices, cu_seqlens_compressed = calc_chunks_with_stride(
|
|
|
|
| 237 |
self.no_compress_k_cache = []
|
| 238 |
self.cached_compressed_cu_seqlens = torch.tensor([], dtype=torch.int32)
|
| 239 |
self.compress_k_cache_varlen = torch.tensor([], dtype=torch.float32)
|
| 240 |
+
# Add support for compress_k2
|
| 241 |
+
self.compress_k2_cache = []
|
| 242 |
+
self.cached_compressed_cu_seqlens2 = torch.tensor([], dtype=torch.int32)
|
| 243 |
+
self.compress_k2_cache_varlen = torch.tensor([], dtype=torch.float32)
|
| 244 |
+
self.no_compress_k2_cache = []
|
| 245 |
|
| 246 |
def update_no_rope_key(self, key_states):
|
| 247 |
if self.no_rope_keys.numel() == 0:
|
|
|
|
| 283 |
k_chunk_list.append(None)
|
| 284 |
return k_chunk_list
|
| 285 |
|
| 286 |
+
def update_compress_k2(self, key_states, cu_seqlens=None):
|
| 287 |
+
if len(self.compress_k2_cache) == 0:
|
| 288 |
+
if cu_seqlens is not None:
|
| 289 |
+
self.cached_compressed_cu_seqlens2 = cu_seqlens.clone()
|
| 290 |
+
self.compress_k2_cache_varlen = key_states
|
| 291 |
+
split_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
| 292 |
+
self.compress_k2_cache = list(torch.split(key_states, split_sizes))
|
| 293 |
+
else:
|
| 294 |
+
for index, k in enumerate(key_states):
|
| 295 |
+
if k is not None:
|
| 296 |
+
self.compress_k2_cache[index] = torch.cat([self.compress_k2_cache[index], k], dim=0)
|
| 297 |
+
new_seq_lens = torch.tensor([tensor.shape[0] for tensor in self.compress_k2_cache], dtype=torch.int32)
|
| 298 |
+
new_cumsum = torch.cumsum(new_seq_lens, dim=0, dtype=torch.int32)
|
| 299 |
+
|
| 300 |
+
self.compress_k2_cache_varlen = torch.cat(self.compress_k2_cache, dim=0)
|
| 301 |
+
self.cached_compressed_cu_seqlens2 = torch.cat([torch.tensor([0], dtype=torch.int32), new_cumsum]).to(self.compress_k2_cache_varlen.device)
|
| 302 |
+
return self.compress_k2_cache_varlen, self.cached_compressed_cu_seqlens2
|
| 303 |
+
|
| 304 |
+
def update_no_compress_k2(self, key_states, kernel_size=128, kernel_stride=64):
|
| 305 |
+
k_chunk_list = []
|
| 306 |
+
for index, k in enumerate(key_states):
|
| 307 |
+
if len(self.no_compress_k2_cache) <= index:
|
| 308 |
+
self.no_compress_k2_cache.append(k)
|
| 309 |
+
else:
|
| 310 |
+
self.no_compress_k2_cache[index] = torch.cat([self.no_compress_k2_cache[index], k], dim=0)
|
| 311 |
+
current_len = self.no_compress_k2_cache[index].shape[0]
|
| 312 |
+
if current_len >= kernel_size:
|
| 313 |
+
k_chunk_list.append(self.no_compress_k2_cache[index][:kernel_size])
|
| 314 |
+
self.no_compress_k2_cache[index] = self.no_compress_k2_cache[index][kernel_stride:]
|
| 315 |
+
else:
|
| 316 |
+
k_chunk_list.append(None)
|
| 317 |
+
return k_chunk_list
|
| 318 |
+
|
| 319 |
class InfLLMv2Cache(DynamicCache):
|
| 320 |
def __init__(self, config,num_hidden_layers: Optional[int] = None) -> None:
|
| 321 |
super().__init__(config=config)
|
|
|
|
| 337 |
def update_no_compress_k(self, key_states, layer_idx, kernel_size=32, kernel_stride=16, cache_kwargs=None):
|
| 338 |
return self.layers[layer_idx].update_no_compress_k(key_states, kernel_size, kernel_stride)
|
| 339 |
|
| 340 |
+
def update_compress_k2(self, key_states, layer_idx, cu_seqlens=None, cache_kwargs=None):
|
| 341 |
+
return self.layers[layer_idx].update_compress_k2(key_states, cu_seqlens)
|
| 342 |
+
|
| 343 |
+
def update_no_compress_k2(self, key_states, layer_idx, kernel_size=128, kernel_stride=64, cache_kwargs=None):
|
| 344 |
+
return self.layers[layer_idx].update_no_compress_k2(key_states, kernel_size, kernel_stride)
|
| 345 |
+
|
| 346 |
def crop(self, max_length):
|
| 347 |
for layer in self.layers:
|
| 348 |
layer.crop(max_length)
|
|
|
|
| 529 |
|
| 530 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
| 531 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
|
|
| 532 |
Args:
|
| 533 |
q (`torch.Tensor`): The query tensor.
|
| 534 |
k (`torch.Tensor`): The key tensor.
|
|
|
|
| 899 |
"""
|
| 900 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 901 |
first unpad the input, then computes the attention scores and pad the final attention scores.
|
|
|
|
| 902 |
Args:
|
| 903 |
query_states (`torch.Tensor`):
|
| 904 |
Input query states to be passed to Flash Attention API
|
|
|
|
| 1014 |
self.local_blocks = self.window_size // self.block_size # local_blocks
|
| 1015 |
self.topk = self.config.sparse_config.get('topk', 64) + (self.window_size//self.block_size)
|
| 1016 |
self.use_nope = self.config.sparse_config.get('use_nope', False)
|
| 1017 |
+
|
| 1018 |
self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride)
|
| 1019 |
+
self.compress_k2 = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size*4, kernel_stride=self.kernel_stride*4)
|
| 1020 |
|
| 1021 |
def forward(
|
| 1022 |
self,
|
|
|
|
| 1128 |
"""
|
| 1129 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 1130 |
first unpad the input, then computes the attention scores and pad the final attention scores.
|
|
|
|
| 1131 |
Args:
|
| 1132 |
query_states (`torch.Tensor`):
|
| 1133 |
Input query states to be passed to Flash Attention API
|
|
|
|
| 1153 |
batch_size = query_states.shape[0]
|
| 1154 |
# assert batch_size == 1, 'Only batch_size=1 is supported at the moment.'
|
| 1155 |
if past_key_value!=None:
|
| 1156 |
+
compressed_k, compressed_cu_seqlens, compressed_k2, compressed_cu_seqlens2 = self.get_compress_k(
|
| 1157 |
key_states=key_states if self.use_nope ==False else no_rope_param['key_states_no_rope'], # This can be optimized a bit;
|
| 1158 |
attention_mask=attention_mask,
|
| 1159 |
past_key_value=past_key_value,
|
|
|
|
| 1174 |
if past_key_value==None:
|
| 1175 |
# compress_k use varlen form
|
| 1176 |
compressed_k, compressed_cu_seqlens = self.compress_k(key_states,cu_seqlens_k)
|
| 1177 |
+
compressed_k2, compressed_cu_seqlens2 = self.compress_k2(key_states,cu_seqlens_k)
|
| 1178 |
+
else:
|
| 1179 |
+
# compressed_k and compressed_k2 already retrieved from get_compress_k above
|
| 1180 |
+
pass
|
| 1181 |
|
| 1182 |
|
| 1183 |
attn_output_unpad = self.sparse_forward(
|
|
|
|
| 1189 |
max_seqlen_in_batch_q,
|
| 1190 |
max_seqlen_in_batch_k,
|
| 1191 |
no_rope_param=no_rope_param,
|
| 1192 |
+
compressed_k=compressed_k, compressed_cu_seqlens=compressed_cu_seqlens,
|
| 1193 |
+
compressed_k2=compressed_k2, compressed_cu_seqlens2=compressed_cu_seqlens2
|
| 1194 |
)
|
| 1195 |
|
| 1196 |
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
|
|
|
| 1210 |
no_rope_param: Optional parameter containing key states without rope
|
| 1211 |
|
| 1212 |
Returns:
|
| 1213 |
+
Tuple of (compressed_k, compressed_cu_seqlens, compressed_k2, compressed_cu_seqlens2)
|
| 1214 |
"""
|
| 1215 |
|
| 1216 |
# Check if this is prefilling or initial compression condition
|
|
|
|
| 1226 |
unpadded_key_states, indices, cu_seqlens, max_seqlen_in_batch = _unpad_one_tensor(key_states,attention_mask=attention_mask)
|
| 1227 |
# Compress the keys
|
| 1228 |
compressed_k, compressed_cu_seqlens = self.compress_k(unpadded_key_states, cu_seqlens)
|
| 1229 |
+
compressed_k2, compressed_cu_seqlens2 = self.compress_k2(unpadded_key_states, cu_seqlens)
|
| 1230 |
|
| 1231 |
past_key_value.update_compress_k(
|
| 1232 |
compressed_k, self.layer_idx, compressed_cu_seqlens)
|
| 1233 |
+
past_key_value.update_compress_k2(
|
| 1234 |
+
compressed_k2, self.layer_idx, compressed_cu_seqlens2)
|
| 1235 |
|
| 1236 |
no_compress_k_list = []
|
| 1237 |
# Compute and update no_compress_k
|
|
|
|
| 1243 |
past_key_value.update_no_compress_k(
|
| 1244 |
no_compress_k_list, self.layer_idx,kernel_stride=self.kernel_stride,
|
| 1245 |
kernel_size=self.kernel_size)
|
| 1246 |
+
|
| 1247 |
+
# Also update no_compress_k2
|
| 1248 |
+
no_compress_k2_list = []
|
| 1249 |
+
for i in range(len(compressed_cu_seqlens2)-1):
|
| 1250 |
+
no_compress_k2_start = (compressed_cu_seqlens2[i+1]- compressed_cu_seqlens2[i]) * self.kernel_stride * 4
|
| 1251 |
+
|
| 1252 |
+
no_compress_k2_list.append(unpadded_key_states[cu_seqlens[i]+no_compress_k2_start:cu_seqlens[i+1]].clone())
|
| 1253 |
+
|
| 1254 |
+
past_key_value.update_no_compress_k2(
|
| 1255 |
+
no_compress_k2_list, self.layer_idx,kernel_stride=self.kernel_stride*4,
|
| 1256 |
+
kernel_size=self.kernel_size*4)
|
| 1257 |
|
| 1258 |
else:
|
| 1259 |
# Decode case: incremental update
|
|
|
|
| 1278 |
else:
|
| 1279 |
new_compressed_k_list.append(None)
|
| 1280 |
compressed_k, compressed_cu_seqlens = past_key_value.update_compress_k(new_compressed_k_list, self.layer_idx,)
|
| 1281 |
+
|
| 1282 |
+
# For compress_k2, update no_compress_k2 buffer and compress when ready
|
| 1283 |
+
no_compress_k2_list = past_key_value.update_no_compress_k2(
|
| 1284 |
+
key_states_split, self.layer_idx,
|
| 1285 |
+
kernel_stride=self.kernel_stride*4,
|
| 1286 |
+
kernel_size=self.kernel_size*4)
|
| 1287 |
+
new_compressed_k2_list = []
|
| 1288 |
+
for no_compress_k2 in no_compress_k2_list:
|
| 1289 |
+
if no_compress_k2 is not None:
|
| 1290 |
+
# We have enough tokens to compress for k2
|
| 1291 |
+
new_compressed_k2 = no_compress_k2.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
|
| 1292 |
+
new_compressed_k2_list.append(new_compressed_k2)
|
| 1293 |
+
else:
|
| 1294 |
+
new_compressed_k2_list.append(None)
|
| 1295 |
+
compressed_k2, compressed_cu_seqlens2 = past_key_value.update_compress_k2(new_compressed_k2_list, self.layer_idx,)
|
| 1296 |
|
| 1297 |
+
return compressed_k, compressed_cu_seqlens, compressed_k2, compressed_cu_seqlens2
|
| 1298 |
def sparse_forward(self,
|
| 1299 |
query_layer,
|
| 1300 |
key_layer,
|
|
|
|
| 1304 |
max_seqlen_in_batch_q,
|
| 1305 |
max_seqlen_in_batch_k,
|
| 1306 |
no_rope_param=None,
|
| 1307 |
+
compressed_k=None, compressed_cu_seqlens=None,
|
| 1308 |
+
compressed_k2=None, compressed_cu_seqlens2=None):
|
| 1309 |
compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
|
| 1310 |
cache_lens = None
|
| 1311 |
if max_seqlen_in_batch_q==1 and max_seqlen_in_batch_k>1: #decoding
|
|
|
|
| 1315 |
topk_idx = compressed_attention(
|
| 1316 |
query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'],
|
| 1317 |
compressed_k,
|
| 1318 |
+
compressed_k2,
|
| 1319 |
self.kernel_size,
|
| 1320 |
self.kernel_stride,
|
| 1321 |
self.block_size,
|
| 1322 |
self.topk,
|
| 1323 |
cu_seqlens_q,
|
| 1324 |
compressed_cu_seqlens,
|
| 1325 |
+
compressed_cu_seqlens2,
|
| 1326 |
max_seqlen_in_batch_q,
|
| 1327 |
compressed_seqlens.max().item(),
|
| 1328 |
None,
|
|
|
|
| 1355 |
"""
|
| 1356 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 1357 |
first unpad the input, then computes the attention scores and pad the final attention scores.
|
|
|
|
| 1358 |
Args:
|
| 1359 |
query_states (`torch.Tensor`):
|
| 1360 |
Input query states to be passed to Flash Attention API
|
|
|
|
| 1618 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1619 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 1620 |
etc.)
|
|
|
|
| 1621 |
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 1622 |
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 1623 |
and behavior.
|
|
|
|
| 1624 |
Parameters:
|
| 1625 |
config ([`MiniCPMConfig`]):
|
| 1626 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
|
|
|
| 1660 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1661 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 1662 |
it.
|
|
|
|
| 1663 |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1664 |
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
| 1665 |
[What are input IDs?](../glossary#input-ids)
|
| 1666 |
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1667 |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
| 1668 |
- 1 for tokens that are **not masked**,
|
| 1669 |
- 0 for tokens that are **masked**.
|
|
|
|
| 1670 |
[What are attention masks?](../glossary#attention-mask)
|
|
|
|
| 1671 |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1672 |
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
| 1673 |
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
| 1674 |
`past_key_values`).
|
|
|
|
| 1675 |
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 1676 |
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 1677 |
information on the default strategy.
|
|
|
|
| 1678 |
- 1 indicates the head is **not masked**,
|
| 1679 |
- 0 indicates the head is **masked**.
|
| 1680 |
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1681 |
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 1682 |
config.n_positions - 1]`.
|
|
|
|
| 1683 |
[What are position IDs?](../glossary#position-ids)
|
| 1684 |
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
| 1685 |
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 1686 |
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 1687 |
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
|
|
|
| 1688 |
Two formats are allowed:
|
| 1689 |
- a [`~cache_utils.Cache`] instance;
|
| 1690 |
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 1691 |
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
| 1692 |
cache format.
|
|
|
|
| 1693 |
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
| 1694 |
legacy cache format will be returned.
|
|
|
|
| 1695 |
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 1696 |
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 1697 |
of shape `(batch_size, sequence_length)`.
|
|
|
|
| 1720 |
class MiniCPMModel(MiniCPMPreTrainedModel):
|
| 1721 |
"""
|
| 1722 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
|
|
|
|
| 1723 |
Args:
|
| 1724 |
config: MiniCPMConfig
|
| 1725 |
"""
|
|
|
|
| 1946 |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1947 |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 1948 |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
|
| 1949 |
Returns:
|
|
|
|
| 1950 |
Example:
|
|
|
|
| 1951 |
```python
|
| 1952 |
>>> from transformers import AutoTokenizer, MiniCPMForCausalLM
|
|
|
|
| 1953 |
>>> model = MiniCPMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
| 1954 |
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
|
|
|
| 1955 |
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 1956 |
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
|
| 1957 |
>>> # Generate
|
| 1958 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1959 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
|
|
| 2133 |
@add_start_docstrings(
|
| 2134 |
"""
|
| 2135 |
The MiniCPM Model transformer with a sequence classification head on top (linear layer).
|
|
|
|
| 2136 |
[`MiniCPMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
| 2137 |
(e.g. GPT-2) do.
|
|
|
|
| 2138 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
| 2139 |
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
| 2140 |
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
|
|
|
| 2247 |
past_key_values=transformer_outputs.past_key_values,
|
| 2248 |
hidden_states=transformer_outputs.hidden_states,
|
| 2249 |
attentions=transformer_outputs.attentions,
|
| 2250 |
+
)
|