suhmily commited on
Commit
ec88d20
·
verified ·
1 Parent(s): 1ba2c83

Update modeling_minicpm.py

Browse files
Files changed (1) hide show
  1. modeling_minicpm.py +93 -42
modeling_minicpm.py CHANGED
@@ -70,13 +70,14 @@ from functools import lru_cache
70
  def compressed_attention(
71
  q: torch.Tensor,
72
  k: torch.Tensor,
73
- v: torch.Tensor,
74
  kernel_size: int,
75
  kernel_stride: int,
76
  block_size: int,
77
  topk: int,
78
  cu_seqlens_q: torch.Tensor,
79
  cu_seqlens_k: torch.Tensor,
 
80
  max_seqlen_q: int,
81
  max_seqlen_k: int,
82
  sm_scale: float = None,
@@ -106,9 +107,10 @@ def compressed_attention(
106
  score = infllmv2_attn_stage1(
107
  q.contiguous(),
108
  k.contiguous(),
109
- v.contiguous(),
110
  cu_seqlens_q=cu_seqlens_q,
111
  cu_seqlens_k=cu_seqlens_k,
 
112
  max_seqlen_q=max_seqlen_q,
113
  max_seqlen_k=max_seqlen_k,
114
  causal=is_prefilling
@@ -142,12 +144,10 @@ def compressed_attention(
142
  def calc_chunks_with_stride(cu_seqlen, chunk_size, kernel_stride):
143
  """
144
  Compute the chunks that require Sparse attention, with stride support.
145
-
146
  Args:
147
  cu_seqlen (torch.Tensor): Cumulative sequence lengths for each sample.
148
  chunk_size (int): Chunk size used for Sparse attention.
149
  kernel_stride (int): Stride size when sliding over the sequence.
150
-
151
  Returns:
152
  filtered_indices (torch.Tensor): Indices used to directly index into the key/value tensors.
153
  cu_seqlens_compressed (torch.Tensor): Cumulative sequence lengths after compression.
@@ -190,7 +190,6 @@ class CompressK(torch.nn.Module):
190
  def __init__(self, head_num_k, head_dim, kernel_size, kernel_stride=16):
191
  """
192
  Module for compressing key (K) representations.
193
-
194
  Args:
195
  head_num_k (int): Number of key attention heads.
196
  head_dim (int): Dimension of each attention head.
@@ -206,15 +205,12 @@ class CompressK(torch.nn.Module):
206
  def forward(self, k: torch.Tensor, cu_seqlens):
207
  """
208
  Forward pass for compressing the key (K) tensor.
209
-
210
  Args:
211
  k (torch.Tensor): Input key tensor of shape (total_seq_len, num_heads, head_dim).
212
  cu_seqlens (torch.Tensor): Cumulative sequence lengths for each sample in the batch, typically used for handling variable-length sequences.
213
-
214
  Returns:
215
  compress_k (torch.Tensor): Compressed key tensor.
216
  cu_seqlens_compressed (torch.Tensor): Updated cumulative sequence lengths after compression.
217
-
218
  """
219
  # Compute chunk-related metadata, with stride support
220
  filtered_k_indices, cu_seqlens_compressed = calc_chunks_with_stride(
@@ -241,6 +237,11 @@ class InfLLMv2CacheLayer(DynamicLayer):
241
  self.no_compress_k_cache = []
242
  self.cached_compressed_cu_seqlens = torch.tensor([], dtype=torch.int32)
243
  self.compress_k_cache_varlen = torch.tensor([], dtype=torch.float32)
 
 
 
 
 
244
 
245
  def update_no_rope_key(self, key_states):
246
  if self.no_rope_keys.numel() == 0:
@@ -282,6 +283,39 @@ class InfLLMv2CacheLayer(DynamicLayer):
282
  k_chunk_list.append(None)
283
  return k_chunk_list
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  class InfLLMv2Cache(DynamicCache):
286
  def __init__(self, config,num_hidden_layers: Optional[int] = None) -> None:
287
  super().__init__(config=config)
@@ -303,6 +337,12 @@ class InfLLMv2Cache(DynamicCache):
303
  def update_no_compress_k(self, key_states, layer_idx, kernel_size=32, kernel_stride=16, cache_kwargs=None):
304
  return self.layers[layer_idx].update_no_compress_k(key_states, kernel_size, kernel_stride)
305
 
 
 
 
 
 
 
306
  def crop(self, max_length):
307
  for layer in self.layers:
308
  layer.crop(max_length)
@@ -489,7 +529,6 @@ def rotate_half(x):
489
 
490
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
491
  """Applies Rotary Position Embedding to the query and key tensors.
492
-
493
  Args:
494
  q (`torch.Tensor`): The query tensor.
495
  k (`torch.Tensor`): The key tensor.
@@ -860,7 +899,6 @@ class MiniCPMFlashAttention2(MiniCPMAttention):
860
  """
861
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
862
  first unpad the input, then computes the attention scores and pad the final attention scores.
863
-
864
  Args:
865
  query_states (`torch.Tensor`):
866
  Input query states to be passed to Flash Attention API
@@ -976,7 +1014,9 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
976
  self.local_blocks = self.window_size // self.block_size # local_blocks
977
  self.topk = self.config.sparse_config.get('topk', 64) + (self.window_size//self.block_size)
978
  self.use_nope = self.config.sparse_config.get('use_nope', False)
 
979
  self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride)
 
980
 
981
  def forward(
982
  self,
@@ -1088,7 +1128,6 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1088
  """
1089
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1090
  first unpad the input, then computes the attention scores and pad the final attention scores.
1091
-
1092
  Args:
1093
  query_states (`torch.Tensor`):
1094
  Input query states to be passed to Flash Attention API
@@ -1114,7 +1153,7 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1114
  batch_size = query_states.shape[0]
1115
  # assert batch_size == 1, 'Only batch_size=1 is supported at the moment.'
1116
  if past_key_value!=None:
1117
- compressed_k, compressed_cu_seqlens = self.get_compress_k(
1118
  key_states=key_states if self.use_nope ==False else no_rope_param['key_states_no_rope'], # This can be optimized a bit;
1119
  attention_mask=attention_mask,
1120
  past_key_value=past_key_value,
@@ -1135,6 +1174,10 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1135
  if past_key_value==None:
1136
  # compress_k use varlen form
1137
  compressed_k, compressed_cu_seqlens = self.compress_k(key_states,cu_seqlens_k)
 
 
 
 
1138
 
1139
 
1140
  attn_output_unpad = self.sparse_forward(
@@ -1146,7 +1189,8 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1146
  max_seqlen_in_batch_q,
1147
  max_seqlen_in_batch_k,
1148
  no_rope_param=no_rope_param,
1149
- compressed_k=compressed_k, compressed_cu_seqlens=compressed_cu_seqlens
 
1150
  )
1151
 
1152
  attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
@@ -1166,7 +1210,7 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1166
  no_rope_param: Optional parameter containing key states without rope
1167
 
1168
  Returns:
1169
- Tuple of (compressed_k, compressed_cu_seqlens)
1170
  """
1171
 
1172
  # Check if this is prefilling or initial compression condition
@@ -1182,9 +1226,12 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1182
  unpadded_key_states, indices, cu_seqlens, max_seqlen_in_batch = _unpad_one_tensor(key_states,attention_mask=attention_mask)
1183
  # Compress the keys
1184
  compressed_k, compressed_cu_seqlens = self.compress_k(unpadded_key_states, cu_seqlens)
 
1185
 
1186
  past_key_value.update_compress_k(
1187
  compressed_k, self.layer_idx, compressed_cu_seqlens)
 
 
1188
 
1189
  no_compress_k_list = []
1190
  # Compute and update no_compress_k
@@ -1196,6 +1243,17 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1196
  past_key_value.update_no_compress_k(
1197
  no_compress_k_list, self.layer_idx,kernel_stride=self.kernel_stride,
1198
  kernel_size=self.kernel_size)
 
 
 
 
 
 
 
 
 
 
 
1199
 
1200
  else:
1201
  # Decode case: incremental update
@@ -1220,8 +1278,23 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1220
  else:
1221
  new_compressed_k_list.append(None)
1222
  compressed_k, compressed_cu_seqlens = past_key_value.update_compress_k(new_compressed_k_list, self.layer_idx,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1223
 
1224
- return compressed_k, compressed_cu_seqlens
1225
  def sparse_forward(self,
1226
  query_layer,
1227
  key_layer,
@@ -1231,7 +1304,8 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1231
  max_seqlen_in_batch_q,
1232
  max_seqlen_in_batch_k,
1233
  no_rope_param=None,
1234
- compressed_k=None, compressed_cu_seqlens=None):
 
1235
  compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
1236
  cache_lens = None
1237
  if max_seqlen_in_batch_q==1 and max_seqlen_in_batch_k>1: #decoding
@@ -1241,13 +1315,14 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1241
  topk_idx = compressed_attention(
1242
  query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'],
1243
  compressed_k,
1244
- compressed_k.clone(),
1245
  self.kernel_size,
1246
  self.kernel_stride,
1247
  self.block_size,
1248
  self.topk,
1249
  cu_seqlens_q,
1250
  compressed_cu_seqlens,
 
1251
  max_seqlen_in_batch_q,
1252
  compressed_seqlens.max().item(),
1253
  None,
@@ -1280,7 +1355,6 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1280
  """
1281
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1282
  first unpad the input, then computes the attention scores and pad the final attention scores.
1283
-
1284
  Args:
1285
  query_states (`torch.Tensor`):
1286
  Input query states to be passed to Flash Attention API
@@ -1544,11 +1618,9 @@ MINICPM_START_DOCSTRING = r"""
1544
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1545
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1546
  etc.)
1547
-
1548
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1549
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1550
  and behavior.
1551
-
1552
  Parameters:
1553
  config ([`MiniCPMConfig`]):
1554
  Model configuration class with all the parameters of the model. Initializing with a config file does not
@@ -1588,50 +1660,38 @@ MINICPM_INPUTS_DOCSTRING = r"""
1588
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1589
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1590
  it.
1591
-
1592
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1593
  [`PreTrainedTokenizer.__call__`] for details.
1594
-
1595
  [What are input IDs?](../glossary#input-ids)
1596
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1597
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1598
-
1599
  - 1 for tokens that are **not masked**,
1600
  - 0 for tokens that are **masked**.
1601
-
1602
  [What are attention masks?](../glossary#attention-mask)
1603
-
1604
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1605
  [`PreTrainedTokenizer.__call__`] for details.
1606
-
1607
  If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1608
  `past_key_values`).
1609
-
1610
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1611
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1612
  information on the default strategy.
1613
-
1614
  - 1 indicates the head is **not masked**,
1615
  - 0 indicates the head is **masked**.
1616
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1617
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1618
  config.n_positions - 1]`.
1619
-
1620
  [What are position IDs?](../glossary#position-ids)
1621
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1622
  Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1623
  blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1624
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1625
-
1626
  Two formats are allowed:
1627
  - a [`~cache_utils.Cache`] instance;
1628
  - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1629
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1630
  cache format.
1631
-
1632
  The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1633
  legacy cache format will be returned.
1634
-
1635
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1636
  have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1637
  of shape `(batch_size, sequence_length)`.
@@ -1660,7 +1720,6 @@ MINICPM_INPUTS_DOCSTRING = r"""
1660
  class MiniCPMModel(MiniCPMPreTrainedModel):
1661
  """
1662
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
1663
-
1664
  Args:
1665
  config: MiniCPMConfig
1666
  """
@@ -1887,20 +1946,14 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1887
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1888
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1889
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1890
-
1891
  Returns:
1892
-
1893
  Example:
1894
-
1895
  ```python
1896
  >>> from transformers import AutoTokenizer, MiniCPMForCausalLM
1897
-
1898
  >>> model = MiniCPMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1899
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1900
-
1901
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
1902
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1903
-
1904
  >>> # Generate
1905
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1906
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
@@ -2080,10 +2133,8 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
2080
  @add_start_docstrings(
2081
  """
2082
  The MiniCPM Model transformer with a sequence classification head on top (linear layer).
2083
-
2084
  [`MiniCPMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
2085
  (e.g. GPT-2) do.
2086
-
2087
  Since it does classification on the last token, it requires to know the position of the last token. If a
2088
  `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
2089
  no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
@@ -2196,4 +2247,4 @@ class MiniCPMForSequenceClassification(MiniCPMPreTrainedModel):
2196
  past_key_values=transformer_outputs.past_key_values,
2197
  hidden_states=transformer_outputs.hidden_states,
2198
  attentions=transformer_outputs.attentions,
2199
- )
 
70
  def compressed_attention(
71
  q: torch.Tensor,
72
  k: torch.Tensor,
73
+ k2: torch.Tensor,
74
  kernel_size: int,
75
  kernel_stride: int,
76
  block_size: int,
77
  topk: int,
78
  cu_seqlens_q: torch.Tensor,
79
  cu_seqlens_k: torch.Tensor,
80
+ cu_seqlens_k2: torch.Tensor,
81
  max_seqlen_q: int,
82
  max_seqlen_k: int,
83
  sm_scale: float = None,
 
107
  score = infllmv2_attn_stage1(
108
  q.contiguous(),
109
  k.contiguous(),
110
+ k2.contiguous(),
111
  cu_seqlens_q=cu_seqlens_q,
112
  cu_seqlens_k=cu_seqlens_k,
113
+ cu_seqlens_v=cu_seqlens_k2,
114
  max_seqlen_q=max_seqlen_q,
115
  max_seqlen_k=max_seqlen_k,
116
  causal=is_prefilling
 
144
  def calc_chunks_with_stride(cu_seqlen, chunk_size, kernel_stride):
145
  """
146
  Compute the chunks that require Sparse attention, with stride support.
 
147
  Args:
148
  cu_seqlen (torch.Tensor): Cumulative sequence lengths for each sample.
149
  chunk_size (int): Chunk size used for Sparse attention.
150
  kernel_stride (int): Stride size when sliding over the sequence.
 
151
  Returns:
152
  filtered_indices (torch.Tensor): Indices used to directly index into the key/value tensors.
153
  cu_seqlens_compressed (torch.Tensor): Cumulative sequence lengths after compression.
 
190
  def __init__(self, head_num_k, head_dim, kernel_size, kernel_stride=16):
191
  """
192
  Module for compressing key (K) representations.
 
193
  Args:
194
  head_num_k (int): Number of key attention heads.
195
  head_dim (int): Dimension of each attention head.
 
205
  def forward(self, k: torch.Tensor, cu_seqlens):
206
  """
207
  Forward pass for compressing the key (K) tensor.
 
208
  Args:
209
  k (torch.Tensor): Input key tensor of shape (total_seq_len, num_heads, head_dim).
210
  cu_seqlens (torch.Tensor): Cumulative sequence lengths for each sample in the batch, typically used for handling variable-length sequences.
 
211
  Returns:
212
  compress_k (torch.Tensor): Compressed key tensor.
213
  cu_seqlens_compressed (torch.Tensor): Updated cumulative sequence lengths after compression.
 
214
  """
215
  # Compute chunk-related metadata, with stride support
216
  filtered_k_indices, cu_seqlens_compressed = calc_chunks_with_stride(
 
237
  self.no_compress_k_cache = []
238
  self.cached_compressed_cu_seqlens = torch.tensor([], dtype=torch.int32)
239
  self.compress_k_cache_varlen = torch.tensor([], dtype=torch.float32)
240
+ # Add support for compress_k2
241
+ self.compress_k2_cache = []
242
+ self.cached_compressed_cu_seqlens2 = torch.tensor([], dtype=torch.int32)
243
+ self.compress_k2_cache_varlen = torch.tensor([], dtype=torch.float32)
244
+ self.no_compress_k2_cache = []
245
 
246
  def update_no_rope_key(self, key_states):
247
  if self.no_rope_keys.numel() == 0:
 
283
  k_chunk_list.append(None)
284
  return k_chunk_list
285
 
286
+ def update_compress_k2(self, key_states, cu_seqlens=None):
287
+ if len(self.compress_k2_cache) == 0:
288
+ if cu_seqlens is not None:
289
+ self.cached_compressed_cu_seqlens2 = cu_seqlens.clone()
290
+ self.compress_k2_cache_varlen = key_states
291
+ split_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
292
+ self.compress_k2_cache = list(torch.split(key_states, split_sizes))
293
+ else:
294
+ for index, k in enumerate(key_states):
295
+ if k is not None:
296
+ self.compress_k2_cache[index] = torch.cat([self.compress_k2_cache[index], k], dim=0)
297
+ new_seq_lens = torch.tensor([tensor.shape[0] for tensor in self.compress_k2_cache], dtype=torch.int32)
298
+ new_cumsum = torch.cumsum(new_seq_lens, dim=0, dtype=torch.int32)
299
+
300
+ self.compress_k2_cache_varlen = torch.cat(self.compress_k2_cache, dim=0)
301
+ self.cached_compressed_cu_seqlens2 = torch.cat([torch.tensor([0], dtype=torch.int32), new_cumsum]).to(self.compress_k2_cache_varlen.device)
302
+ return self.compress_k2_cache_varlen, self.cached_compressed_cu_seqlens2
303
+
304
+ def update_no_compress_k2(self, key_states, kernel_size=128, kernel_stride=64):
305
+ k_chunk_list = []
306
+ for index, k in enumerate(key_states):
307
+ if len(self.no_compress_k2_cache) <= index:
308
+ self.no_compress_k2_cache.append(k)
309
+ else:
310
+ self.no_compress_k2_cache[index] = torch.cat([self.no_compress_k2_cache[index], k], dim=0)
311
+ current_len = self.no_compress_k2_cache[index].shape[0]
312
+ if current_len >= kernel_size:
313
+ k_chunk_list.append(self.no_compress_k2_cache[index][:kernel_size])
314
+ self.no_compress_k2_cache[index] = self.no_compress_k2_cache[index][kernel_stride:]
315
+ else:
316
+ k_chunk_list.append(None)
317
+ return k_chunk_list
318
+
319
  class InfLLMv2Cache(DynamicCache):
320
  def __init__(self, config,num_hidden_layers: Optional[int] = None) -> None:
321
  super().__init__(config=config)
 
337
  def update_no_compress_k(self, key_states, layer_idx, kernel_size=32, kernel_stride=16, cache_kwargs=None):
338
  return self.layers[layer_idx].update_no_compress_k(key_states, kernel_size, kernel_stride)
339
 
340
+ def update_compress_k2(self, key_states, layer_idx, cu_seqlens=None, cache_kwargs=None):
341
+ return self.layers[layer_idx].update_compress_k2(key_states, cu_seqlens)
342
+
343
+ def update_no_compress_k2(self, key_states, layer_idx, kernel_size=128, kernel_stride=64, cache_kwargs=None):
344
+ return self.layers[layer_idx].update_no_compress_k2(key_states, kernel_size, kernel_stride)
345
+
346
  def crop(self, max_length):
347
  for layer in self.layers:
348
  layer.crop(max_length)
 
529
 
530
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
531
  """Applies Rotary Position Embedding to the query and key tensors.
 
532
  Args:
533
  q (`torch.Tensor`): The query tensor.
534
  k (`torch.Tensor`): The key tensor.
 
899
  """
900
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
901
  first unpad the input, then computes the attention scores and pad the final attention scores.
 
902
  Args:
903
  query_states (`torch.Tensor`):
904
  Input query states to be passed to Flash Attention API
 
1014
  self.local_blocks = self.window_size // self.block_size # local_blocks
1015
  self.topk = self.config.sparse_config.get('topk', 64) + (self.window_size//self.block_size)
1016
  self.use_nope = self.config.sparse_config.get('use_nope', False)
1017
+
1018
  self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride)
1019
+ self.compress_k2 = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size*4, kernel_stride=self.kernel_stride*4)
1020
 
1021
  def forward(
1022
  self,
 
1128
  """
1129
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1130
  first unpad the input, then computes the attention scores and pad the final attention scores.
 
1131
  Args:
1132
  query_states (`torch.Tensor`):
1133
  Input query states to be passed to Flash Attention API
 
1153
  batch_size = query_states.shape[0]
1154
  # assert batch_size == 1, 'Only batch_size=1 is supported at the moment.'
1155
  if past_key_value!=None:
1156
+ compressed_k, compressed_cu_seqlens, compressed_k2, compressed_cu_seqlens2 = self.get_compress_k(
1157
  key_states=key_states if self.use_nope ==False else no_rope_param['key_states_no_rope'], # This can be optimized a bit;
1158
  attention_mask=attention_mask,
1159
  past_key_value=past_key_value,
 
1174
  if past_key_value==None:
1175
  # compress_k use varlen form
1176
  compressed_k, compressed_cu_seqlens = self.compress_k(key_states,cu_seqlens_k)
1177
+ compressed_k2, compressed_cu_seqlens2 = self.compress_k2(key_states,cu_seqlens_k)
1178
+ else:
1179
+ # compressed_k and compressed_k2 already retrieved from get_compress_k above
1180
+ pass
1181
 
1182
 
1183
  attn_output_unpad = self.sparse_forward(
 
1189
  max_seqlen_in_batch_q,
1190
  max_seqlen_in_batch_k,
1191
  no_rope_param=no_rope_param,
1192
+ compressed_k=compressed_k, compressed_cu_seqlens=compressed_cu_seqlens,
1193
+ compressed_k2=compressed_k2, compressed_cu_seqlens2=compressed_cu_seqlens2
1194
  )
1195
 
1196
  attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
 
1210
  no_rope_param: Optional parameter containing key states without rope
1211
 
1212
  Returns:
1213
+ Tuple of (compressed_k, compressed_cu_seqlens, compressed_k2, compressed_cu_seqlens2)
1214
  """
1215
 
1216
  # Check if this is prefilling or initial compression condition
 
1226
  unpadded_key_states, indices, cu_seqlens, max_seqlen_in_batch = _unpad_one_tensor(key_states,attention_mask=attention_mask)
1227
  # Compress the keys
1228
  compressed_k, compressed_cu_seqlens = self.compress_k(unpadded_key_states, cu_seqlens)
1229
+ compressed_k2, compressed_cu_seqlens2 = self.compress_k2(unpadded_key_states, cu_seqlens)
1230
 
1231
  past_key_value.update_compress_k(
1232
  compressed_k, self.layer_idx, compressed_cu_seqlens)
1233
+ past_key_value.update_compress_k2(
1234
+ compressed_k2, self.layer_idx, compressed_cu_seqlens2)
1235
 
1236
  no_compress_k_list = []
1237
  # Compute and update no_compress_k
 
1243
  past_key_value.update_no_compress_k(
1244
  no_compress_k_list, self.layer_idx,kernel_stride=self.kernel_stride,
1245
  kernel_size=self.kernel_size)
1246
+
1247
+ # Also update no_compress_k2
1248
+ no_compress_k2_list = []
1249
+ for i in range(len(compressed_cu_seqlens2)-1):
1250
+ no_compress_k2_start = (compressed_cu_seqlens2[i+1]- compressed_cu_seqlens2[i]) * self.kernel_stride * 4
1251
+
1252
+ no_compress_k2_list.append(unpadded_key_states[cu_seqlens[i]+no_compress_k2_start:cu_seqlens[i+1]].clone())
1253
+
1254
+ past_key_value.update_no_compress_k2(
1255
+ no_compress_k2_list, self.layer_idx,kernel_stride=self.kernel_stride*4,
1256
+ kernel_size=self.kernel_size*4)
1257
 
1258
  else:
1259
  # Decode case: incremental update
 
1278
  else:
1279
  new_compressed_k_list.append(None)
1280
  compressed_k, compressed_cu_seqlens = past_key_value.update_compress_k(new_compressed_k_list, self.layer_idx,)
1281
+
1282
+ # For compress_k2, update no_compress_k2 buffer and compress when ready
1283
+ no_compress_k2_list = past_key_value.update_no_compress_k2(
1284
+ key_states_split, self.layer_idx,
1285
+ kernel_stride=self.kernel_stride*4,
1286
+ kernel_size=self.kernel_size*4)
1287
+ new_compressed_k2_list = []
1288
+ for no_compress_k2 in no_compress_k2_list:
1289
+ if no_compress_k2 is not None:
1290
+ # We have enough tokens to compress for k2
1291
+ new_compressed_k2 = no_compress_k2.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
1292
+ new_compressed_k2_list.append(new_compressed_k2)
1293
+ else:
1294
+ new_compressed_k2_list.append(None)
1295
+ compressed_k2, compressed_cu_seqlens2 = past_key_value.update_compress_k2(new_compressed_k2_list, self.layer_idx,)
1296
 
1297
+ return compressed_k, compressed_cu_seqlens, compressed_k2, compressed_cu_seqlens2
1298
  def sparse_forward(self,
1299
  query_layer,
1300
  key_layer,
 
1304
  max_seqlen_in_batch_q,
1305
  max_seqlen_in_batch_k,
1306
  no_rope_param=None,
1307
+ compressed_k=None, compressed_cu_seqlens=None,
1308
+ compressed_k2=None, compressed_cu_seqlens2=None):
1309
  compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
1310
  cache_lens = None
1311
  if max_seqlen_in_batch_q==1 and max_seqlen_in_batch_k>1: #decoding
 
1315
  topk_idx = compressed_attention(
1316
  query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'],
1317
  compressed_k,
1318
+ compressed_k2,
1319
  self.kernel_size,
1320
  self.kernel_stride,
1321
  self.block_size,
1322
  self.topk,
1323
  cu_seqlens_q,
1324
  compressed_cu_seqlens,
1325
+ compressed_cu_seqlens2,
1326
  max_seqlen_in_batch_q,
1327
  compressed_seqlens.max().item(),
1328
  None,
 
1355
  """
1356
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1357
  first unpad the input, then computes the attention scores and pad the final attention scores.
 
1358
  Args:
1359
  query_states (`torch.Tensor`):
1360
  Input query states to be passed to Flash Attention API
 
1618
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1619
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1620
  etc.)
 
1621
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1622
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1623
  and behavior.
 
1624
  Parameters:
1625
  config ([`MiniCPMConfig`]):
1626
  Model configuration class with all the parameters of the model. Initializing with a config file does not
 
1660
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1661
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1662
  it.
 
1663
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1664
  [`PreTrainedTokenizer.__call__`] for details.
 
1665
  [What are input IDs?](../glossary#input-ids)
1666
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1667
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
1668
  - 1 for tokens that are **not masked**,
1669
  - 0 for tokens that are **masked**.
 
1670
  [What are attention masks?](../glossary#attention-mask)
 
1671
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1672
  [`PreTrainedTokenizer.__call__`] for details.
 
1673
  If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1674
  `past_key_values`).
 
1675
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1676
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1677
  information on the default strategy.
 
1678
  - 1 indicates the head is **not masked**,
1679
  - 0 indicates the head is **masked**.
1680
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1681
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1682
  config.n_positions - 1]`.
 
1683
  [What are position IDs?](../glossary#position-ids)
1684
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1685
  Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1686
  blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1687
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
 
1688
  Two formats are allowed:
1689
  - a [`~cache_utils.Cache`] instance;
1690
  - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1691
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1692
  cache format.
 
1693
  The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1694
  legacy cache format will be returned.
 
1695
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1696
  have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1697
  of shape `(batch_size, sequence_length)`.
 
1720
  class MiniCPMModel(MiniCPMPreTrainedModel):
1721
  """
1722
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
 
1723
  Args:
1724
  config: MiniCPMConfig
1725
  """
 
1946
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1947
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1948
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 
1949
  Returns:
 
1950
  Example:
 
1951
  ```python
1952
  >>> from transformers import AutoTokenizer, MiniCPMForCausalLM
 
1953
  >>> model = MiniCPMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1954
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
 
1955
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
1956
  >>> inputs = tokenizer(prompt, return_tensors="pt")
 
1957
  >>> # Generate
1958
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1959
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
2133
  @add_start_docstrings(
2134
  """
2135
  The MiniCPM Model transformer with a sequence classification head on top (linear layer).
 
2136
  [`MiniCPMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
2137
  (e.g. GPT-2) do.
 
2138
  Since it does classification on the last token, it requires to know the position of the last token. If a
2139
  `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
2140
  no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
 
2247
  past_key_values=transformer_outputs.past_key_values,
2248
  hidden_states=transformer_outputs.hidden_states,
2249
  attentions=transformer_outputs.attentions,
2250
+ )