cfli commited on
Commit
cbcf91e
·
verified ·
1 Parent(s): 183b636

Update modeling_minicpm.py

Browse files
Files changed (1) hide show
  1. modeling_minicpm.py +172 -111
modeling_minicpm.py CHANGED
@@ -36,7 +36,8 @@ from transformers.modeling_attn_mask_utils import (
36
  _prepare_4d_causal_attention_mask,
37
  _prepare_4d_causal_attention_mask_for_sdpa,
38
  )
39
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
 
40
  from transformers.modeling_utils import PreTrainedModel
41
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
42
  from transformers.utils import (
@@ -57,7 +58,6 @@ try:
57
  except:
58
  pass
59
 
60
-
61
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
62
  # It means that the function will not be traced through and simply appear as a node in the graph.
63
  if is_torch_fx_available():
@@ -66,7 +66,6 @@ if is_torch_fx_available():
66
 
67
  _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
68
 
69
-
70
  logger = logging.get_logger(__name__)
71
 
72
  _CONFIG_FOR_DOC = "MiniCPMConfig"
@@ -92,7 +91,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
92
 
93
 
94
  def _make_causal_mask(
95
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
96
  ):
97
  warnings.warn(
98
  "Calling `transformers.models.minicpm.modeling_minicpm._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.minicpm.modeling_minicpm.AttentionMaskConverter._make_causal_mask"
@@ -101,6 +100,7 @@ def _make_causal_mask(
101
  input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
102
  )
103
 
 
104
  # @torch.jit.script # type: ignore
105
  def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
106
  old_dtype = hidden.dtype
@@ -193,7 +193,7 @@ class MiniCPMDynamicNTKScalingRotaryEmbedding(MiniCPMRotaryEmbedding):
193
 
194
  if seq_len > self.max_position_embeddings:
195
  base = self.base * (
196
- (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
197
  ) ** (self.dim / (self.dim - 2))
198
  inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
199
  self.register_buffer("inv_freq", inv_freq, persistent=False)
@@ -211,7 +211,7 @@ class MiniCPMDynamicNTKScalingRotaryEmbedding(MiniCPMRotaryEmbedding):
211
  def rotate_half(x):
212
  """Rotates half the hidden dims of the input."""
213
  x1 = x[..., : x.shape[-1] // 2]
214
- x2 = x[..., x.shape[-1] // 2 :]
215
  return torch.cat((-x2, x1), dim=-1)
216
 
217
 
@@ -249,6 +249,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
249
  k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
250
  return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)
251
 
 
252
  class MiniCPMMLP(nn.Module):
253
  def __init__(self, config):
254
  super().__init__()
@@ -295,7 +296,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
295
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
296
 
297
 
298
-
299
  class MiniCPMAttention(nn.Module):
300
  """Multi-headed attention from 'Attention Is All You Need' paper"""
301
 
@@ -363,14 +363,14 @@ class MiniCPMAttention(nn.Module):
363
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
364
 
365
  def forward(
366
- self,
367
- hidden_states: torch.Tensor,
368
- attention_mask: Optional[torch.Tensor] = None,
369
- position_ids: Optional[torch.LongTensor] = None,
370
- past_key_value: Optional[Cache] = None,
371
- output_attentions: bool = False,
372
- use_cache: bool = False,
373
- **kwargs,
374
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
375
  if "padding_mask" in kwargs:
376
  warnings.warn(
@@ -463,7 +463,7 @@ class MiniCPMAttention(nn.Module):
463
 
464
  if not output_attentions:
465
  attn_weights = None
466
-
467
  return attn_output, attn_weights, past_key_value
468
 
469
 
@@ -483,14 +483,14 @@ class MiniCPMFlashAttention2(MiniCPMAttention):
483
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
484
 
485
  def forward(
486
- self,
487
- hidden_states: torch.Tensor,
488
- attention_mask: Optional[torch.LongTensor] = None,
489
- position_ids: Optional[torch.LongTensor] = None,
490
- past_key_value: Optional[Cache] = None,
491
- output_attentions: bool = False,
492
- use_cache: bool = False,
493
- **kwargs,
494
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
495
  # MiniCPMFlashAttention2 attention does not support output_attentions
496
  if "padding_mask" in kwargs:
@@ -571,7 +571,7 @@ class MiniCPMFlashAttention2(MiniCPMAttention):
571
  return attn_output, attn_weights, past_key_value
572
 
573
  def _flash_attention_forward(
574
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
575
  ):
576
  """
577
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@@ -675,13 +675,13 @@ class MiniCPMSdpaAttention(MiniCPMAttention):
675
 
676
  # Adapted from MiniCPMAttention.forward
677
  def forward(
678
- self,
679
- hidden_states: torch.Tensor,
680
- attention_mask: Optional[torch.Tensor] = None,
681
- position_ids: Optional[torch.LongTensor] = None,
682
- past_key_value: Optional[Cache] = None,
683
- output_attentions: bool = False,
684
- use_cache: bool = False,
685
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
686
  if output_attentions:
687
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -774,14 +774,14 @@ class MiniCPMDecoderLayer(nn.Module):
774
  self.num_hidden_layers = config.num_hidden_layers
775
 
776
  def forward(
777
- self,
778
- hidden_states: torch.Tensor,
779
- attention_mask: Optional[torch.Tensor] = None,
780
- position_ids: Optional[torch.LongTensor] = None,
781
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
782
- output_attentions: Optional[bool] = False,
783
- use_cache: Optional[bool] = False,
784
- **kwargs,
785
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
786
  """
787
  Args:
@@ -814,7 +814,7 @@ class MiniCPMDecoderLayer(nn.Module):
814
  use_cache=use_cache,
815
  **kwargs,
816
  )
817
-
818
  hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
819
 
820
  # Fully Connected
@@ -952,7 +952,7 @@ MINICPM_INPUTS_DOCSTRING = r"""
952
  "The bare MiniCPM Model outputting raw hidden-states without any specific head on top.",
953
  MINICPM_START_DOCSTRING,
954
  )
955
- class MiniCPMModel(MiniCPMPreTrainedModel):
956
  """
957
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
958
 
@@ -986,17 +986,17 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
986
 
987
  @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
988
  def forward(
989
- self,
990
- input_ids: torch.LongTensor = None,
991
- attention_mask: Optional[torch.Tensor] = None,
992
- position_ids: Optional[torch.LongTensor] = None,
993
- past_key_values: Optional[List[torch.FloatTensor]] = None,
994
- inputs_embeds: Optional[torch.FloatTensor] = None,
995
- use_cache: Optional[bool] = None,
996
- output_attentions: Optional[bool] = None,
997
- output_hidden_states: Optional[bool] = None,
998
- return_dict: Optional[bool] = None,
999
- layer_cutoff: Optional[int] = None,
1000
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1001
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1002
  output_hidden_states = (
@@ -1066,11 +1066,21 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
1066
  all_self_attns = () if output_attentions else None
1067
  next_decoder_cache = None
1068
 
 
 
 
 
 
 
 
 
 
1069
  for idx, decoder_layer in enumerate(self.layers):
1070
- if layer_cutoff is not None and idx == layer_cutoff:
 
 
 
1071
  break
1072
- if output_hidden_states:
1073
- all_hidden_states += (hidden_states,)
1074
 
1075
  if self.gradient_checkpointing and self.training:
1076
  layer_outputs = self._gradient_checkpointing_func(
@@ -1103,7 +1113,7 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
1103
  hidden_states = self.norm(hidden_states)
1104
 
1105
  # add hidden states from the last decoder layer
1106
- if output_hidden_states:
1107
  all_hidden_states += (hidden_states,)
1108
 
1109
  next_cache = None
@@ -1119,14 +1129,21 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
1119
  )
1120
 
1121
 
1122
- class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1123
  _tied_weights_keys = ["lm_head.weight"]
1124
 
1125
  def __init__(self, config):
1126
  super().__init__(config)
1127
- self.model = MiniCPMModel(config)
1128
  self.vocab_size = config.vocab_size
1129
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
 
 
 
 
 
1130
 
1131
  # Initialize weights and apply final processing
1132
  self.post_init()
@@ -1152,18 +1169,19 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1152
  @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
1153
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1154
  def forward(
1155
- self,
1156
- input_ids: torch.LongTensor = None,
1157
- attention_mask: Optional[torch.Tensor] = None,
1158
- position_ids: Optional[torch.LongTensor] = None,
1159
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1160
- inputs_embeds: Optional[torch.FloatTensor] = None,
1161
- labels: Optional[torch.LongTensor] = None,
1162
- use_cache: Optional[bool] = None,
1163
- output_attentions: Optional[bool] = None,
1164
- output_hidden_states: Optional[bool] = None,
1165
- return_dict: Optional[bool] = None,
1166
- layer_cutoff: Optional[int] = None,
 
1167
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1168
  r"""
1169
  Args:
@@ -1196,6 +1214,19 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1196
  )
1197
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1199
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1200
  outputs = self.model(
1201
  input_ids=input_ids,
@@ -1205,32 +1236,62 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1205
  inputs_embeds=inputs_embeds,
1206
  use_cache=use_cache,
1207
  output_attentions=output_attentions,
1208
- output_hidden_states=output_hidden_states,
1209
  return_dict=return_dict,
1210
- layer_cutoff=layer_cutoff
1211
  )
1212
 
1213
  hidden_states = outputs[0]
1214
- if self.config.pretraining_tp > 1:
1215
- lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1216
- logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1217
- logits = torch.cat(logits, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1218
  else:
1219
- logits = self.lm_head(hidden_states / (self.config.hidden_size / self.config.dim_model_base))
1220
- logits = logits.float()
 
 
 
 
 
 
 
 
 
 
1221
 
1222
  loss = None
1223
- if labels is not None:
1224
  # Shift so that tokens < n predict n
1225
- shift_logits = logits[..., :-1, :].contiguous()
1226
- shift_labels = labels[..., 1:].contiguous()
1227
- # Flatten the tokens
1228
- loss_fct = CrossEntropyLoss()
1229
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1230
- shift_labels = shift_labels.view(-1)
1231
- # Enable model parallelism
1232
- shift_labels = shift_labels.to(shift_logits.device)
1233
- loss = loss_fct(shift_logits, shift_labels)
 
 
 
 
1234
 
1235
  if not return_dict:
1236
  output = (logits,) + outputs[1:]
@@ -1245,7 +1306,7 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1245
  )
1246
 
1247
  def prepare_inputs_for_generation(
1248
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1249
  ):
1250
  if past_key_values is not None:
1251
  if isinstance(past_key_values, Cache):
@@ -1261,7 +1322,7 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1261
  # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1262
  # input)
1263
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1264
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1265
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1266
  # input_ids based on the past_length.
1267
  elif past_length < input_ids.shape[1]:
@@ -1270,9 +1331,9 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1270
 
1271
  # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1272
  if (
1273
- max_cache_length is not None
1274
- and attention_mask is not None
1275
- and cache_length + input_ids.shape[1] > max_cache_length
1276
  ):
1277
  attention_mask = attention_mask[:, -max_cache_length:]
1278
 
@@ -1282,7 +1343,7 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1282
  position_ids = attention_mask.long().cumsum(-1) - 1
1283
  position_ids.masked_fill_(attention_mask == 0, 1)
1284
  if past_key_values:
1285
- position_ids = position_ids[:, -input_ids.shape[1] :]
1286
 
1287
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1288
  if inputs_embeds is not None and past_key_values is None:
@@ -1308,7 +1369,7 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1308
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1309
  )
1310
  return reordered_past
1311
-
1312
  @torch.inference_mode()
1313
  def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1314
  max_length: int = 4096, num_beams=1, do_sample=True, top_p=0.8, temperature=0.3, logits_processor=None,
@@ -1317,11 +1378,11 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1317
  history = []
1318
  if logits_processor:
1319
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1320
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1321
  else:
1322
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1323
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1324
-
1325
  history.append({"role": role, "content": query})
1326
  history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=False)
1327
  inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
@@ -1369,17 +1430,17 @@ class MiniCPMForSequenceClassification(MiniCPMPreTrainedModel):
1369
 
1370
  @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
1371
  def forward(
1372
- self,
1373
- input_ids: torch.LongTensor = None,
1374
- attention_mask: Optional[torch.Tensor] = None,
1375
- position_ids: Optional[torch.LongTensor] = None,
1376
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1377
- inputs_embeds: Optional[torch.FloatTensor] = None,
1378
- labels: Optional[torch.LongTensor] = None,
1379
- use_cache: Optional[bool] = None,
1380
- output_attentions: Optional[bool] = None,
1381
- output_hidden_states: Optional[bool] = None,
1382
- return_dict: Optional[bool] = None,
1383
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1384
  r"""
1385
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
36
  _prepare_4d_causal_attention_mask,
37
  _prepare_4d_causal_attention_mask_for_sdpa,
38
  )
39
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \
40
+ SequenceClassifierOutputWithPast
41
  from transformers.modeling_utils import PreTrainedModel
42
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
43
  from transformers.utils import (
 
58
  except:
59
  pass
60
 
 
61
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
62
  # It means that the function will not be traced through and simply appear as a node in the graph.
63
  if is_torch_fx_available():
 
66
 
67
  _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
68
 
 
69
  logger = logging.get_logger(__name__)
70
 
71
  _CONFIG_FOR_DOC = "MiniCPMConfig"
 
91
 
92
 
93
  def _make_causal_mask(
94
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
95
  ):
96
  warnings.warn(
97
  "Calling `transformers.models.minicpm.modeling_minicpm._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.minicpm.modeling_minicpm.AttentionMaskConverter._make_causal_mask"
 
100
  input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
101
  )
102
 
103
+
104
  # @torch.jit.script # type: ignore
105
  def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
106
  old_dtype = hidden.dtype
 
193
 
194
  if seq_len > self.max_position_embeddings:
195
  base = self.base * (
196
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
197
  ) ** (self.dim / (self.dim - 2))
198
  inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
199
  self.register_buffer("inv_freq", inv_freq, persistent=False)
 
211
  def rotate_half(x):
212
  """Rotates half the hidden dims of the input."""
213
  x1 = x[..., : x.shape[-1] // 2]
214
+ x2 = x[..., x.shape[-1] // 2:]
215
  return torch.cat((-x2, x1), dim=-1)
216
 
217
 
 
249
  k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
250
  return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)
251
 
252
+
253
  class MiniCPMMLP(nn.Module):
254
  def __init__(self, config):
255
  super().__init__()
 
296
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
297
 
298
 
 
299
  class MiniCPMAttention(nn.Module):
300
  """Multi-headed attention from 'Attention Is All You Need' paper"""
301
 
 
363
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
364
 
365
  def forward(
366
+ self,
367
+ hidden_states: torch.Tensor,
368
+ attention_mask: Optional[torch.Tensor] = None,
369
+ position_ids: Optional[torch.LongTensor] = None,
370
+ past_key_value: Optional[Cache] = None,
371
+ output_attentions: bool = False,
372
+ use_cache: bool = False,
373
+ **kwargs,
374
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
375
  if "padding_mask" in kwargs:
376
  warnings.warn(
 
463
 
464
  if not output_attentions:
465
  attn_weights = None
466
+
467
  return attn_output, attn_weights, past_key_value
468
 
469
 
 
483
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
484
 
485
  def forward(
486
+ self,
487
+ hidden_states: torch.Tensor,
488
+ attention_mask: Optional[torch.LongTensor] = None,
489
+ position_ids: Optional[torch.LongTensor] = None,
490
+ past_key_value: Optional[Cache] = None,
491
+ output_attentions: bool = False,
492
+ use_cache: bool = False,
493
+ **kwargs,
494
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
495
  # MiniCPMFlashAttention2 attention does not support output_attentions
496
  if "padding_mask" in kwargs:
 
571
  return attn_output, attn_weights, past_key_value
572
 
573
  def _flash_attention_forward(
574
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
575
  ):
576
  """
577
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
 
675
 
676
  # Adapted from MiniCPMAttention.forward
677
  def forward(
678
+ self,
679
+ hidden_states: torch.Tensor,
680
+ attention_mask: Optional[torch.Tensor] = None,
681
+ position_ids: Optional[torch.LongTensor] = None,
682
+ past_key_value: Optional[Cache] = None,
683
+ output_attentions: bool = False,
684
+ use_cache: bool = False,
685
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
686
  if output_attentions:
687
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
 
774
  self.num_hidden_layers = config.num_hidden_layers
775
 
776
  def forward(
777
+ self,
778
+ hidden_states: torch.Tensor,
779
+ attention_mask: Optional[torch.Tensor] = None,
780
+ position_ids: Optional[torch.LongTensor] = None,
781
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
782
+ output_attentions: Optional[bool] = False,
783
+ use_cache: Optional[bool] = False,
784
+ **kwargs,
785
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
786
  """
787
  Args:
 
814
  use_cache=use_cache,
815
  **kwargs,
816
  )
817
+
818
  hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
819
 
820
  # Fully Connected
 
952
  "The bare MiniCPM Model outputting raw hidden-states without any specific head on top.",
953
  MINICPM_START_DOCSTRING,
954
  )
955
+ class LayerWiseMiniCPMModel(MiniCPMPreTrainedModel):
956
  """
957
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
958
 
 
986
 
987
  @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
988
  def forward(
989
+ self,
990
+ input_ids: torch.LongTensor = None,
991
+ attention_mask: Optional[torch.Tensor] = None,
992
+ position_ids: Optional[torch.LongTensor] = None,
993
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
994
+ inputs_embeds: Optional[torch.FloatTensor] = None,
995
+ use_cache: Optional[bool] = None,
996
+ output_attentions: Optional[bool] = None,
997
+ output_hidden_states: Optional[bool] = None,
998
+ return_dict: Optional[bool] = None,
999
+ cutoff_layers: Optional[Union[int, List]] = None,
1000
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1001
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1002
  output_hidden_states = (
 
1066
  all_self_attns = () if output_attentions else None
1067
  next_decoder_cache = None
1068
 
1069
+ if cutoff_layers is None:
1070
+ max_layer = self.config.num_hidden_layers
1071
+ cutoff_layers = [max_layer]
1072
+ if isinstance(cutoff_layers, int):
1073
+ max_layer = cutoff_layers
1074
+ cutoff_layers = [cutoff_layers]
1075
+ else:
1076
+ max_layer = max(cutoff_layers)
1077
+
1078
  for idx, decoder_layer in enumerate(self.layers):
1079
+ if idx in cutoff_layers and output_hidden_states:
1080
+ all_hidden_states += (self.norm(hidden_states),)
1081
+
1082
+ if idx == max_layer:
1083
  break
 
 
1084
 
1085
  if self.gradient_checkpointing and self.training:
1086
  layer_outputs = self._gradient_checkpointing_func(
 
1113
  hidden_states = self.norm(hidden_states)
1114
 
1115
  # add hidden states from the last decoder layer
1116
+ if output_hidden_states and self.config.num_hidden_layers == max_layer:
1117
  all_hidden_states += (hidden_states,)
1118
 
1119
  next_cache = None
 
1129
  )
1130
 
1131
 
1132
+ class LayerWiseMiniCPMForCausalLM(MiniCPMPreTrainedModel):
1133
  _tied_weights_keys = ["lm_head.weight"]
1134
 
1135
  def __init__(self, config):
1136
  super().__init__(config)
1137
+ self.model = LayerWiseMiniCPMModel(config)
1138
  self.vocab_size = config.vocab_size
1139
+
1140
+ if not self.config.classifier_multi:
1141
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1142
+ else:
1143
+ self.lm_head = nn.ModuleList([nn.Linear(
1144
+ config.hidden_size, config.vocab_size, bias=False) for _ in range(
1145
+ self.config.start_layer,
1146
+ self.model.config.num_hidden_layers)])
1147
 
1148
  # Initialize weights and apply final processing
1149
  self.post_init()
 
1169
  @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
1170
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1171
  def forward(
1172
+ self,
1173
+ input_ids: torch.LongTensor = None,
1174
+ attention_mask: Optional[torch.Tensor] = None,
1175
+ position_ids: Optional[torch.LongTensor] = None,
1176
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1177
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1178
+ labels: Optional[torch.LongTensor] = None,
1179
+ use_cache: Optional[bool] = None,
1180
+ output_attentions: Optional[bool] = None,
1181
+ output_hidden_states: Optional[bool] = None,
1182
+ return_dict: Optional[bool] = None,
1183
+ cutoff_layers: Optional[Union[int, List]] = None,
1184
+ only_for_one_logit: Optional[int] = None
1185
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1186
  r"""
1187
  Args:
 
1214
  )
1215
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1216
 
1217
+ if cutoff_layers is None:
1218
+ cutoff_layers = [self.config.num_hidden_layers]
1219
+ elif isinstance(cutoff_layers, int):
1220
+ cutoff_layers = [cutoff_layers]
1221
+
1222
+ remove_layers = [i for i in cutoff_layers if self.config.start_layer > i or i > self.config.num_hidden_layers]
1223
+ if len(remove_layers) > 0:
1224
+ logger.warning_once(
1225
+ f"layers {remove_layers} is incompatible with the setting. They will be removed..."
1226
+ )
1227
+
1228
+ cutoff_layers = [i for i in cutoff_layers if i not in remove_layers]
1229
+
1230
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1231
  outputs = self.model(
1232
  input_ids=input_ids,
 
1236
  inputs_embeds=inputs_embeds,
1237
  use_cache=use_cache,
1238
  output_attentions=output_attentions,
1239
+ output_hidden_states=True,
1240
  return_dict=return_dict,
1241
+ cutoff_layers=cutoff_layers
1242
  )
1243
 
1244
  hidden_states = outputs[0]
1245
+
1246
+ all_logits = ()
1247
+ if only_for_one_logit is None:
1248
+ for i in range(len(outputs.hidden_states)):
1249
+ if self.config.classifier_multi == False:
1250
+ if self.config.pretraining_tp > 1:
1251
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1252
+ logits = [F.linear(outputs.hidden_states[i], lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1253
+ logits = torch.cat(logits, dim=-1)
1254
+ else:
1255
+ logits = self.lm_head(outputs.hidden_states[i] / (self.config.hidden_size / self.config.dim_model_base))
1256
+ else:
1257
+ if self.config.pretraining_tp > 1:
1258
+ lm_head_slices = self.lm_head[cutoff_layers[i] - self.config.start_layer].weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1259
+ logits = [F.linear(outputs.hidden_states[i], lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1260
+ logits = torch.cat(logits, dim=-1)
1261
+ else:
1262
+ logits = self.lm_head[cutoff_layers[i] - self.config.start_layer](outputs.hidden_states[i] / (self.config.hidden_size / self.config.dim_model_base))
1263
+ logits = logits.float()
1264
+ all_logits = all_logits + (logits, )
1265
  else:
1266
+ if self.config.classifier_multi == False:
1267
+ lm_head_slices = self.lm_head.weight.split(1, dim=0)
1268
+ for i in range(len(outputs.hidden_states)):
1269
+ logits = F.linear(outputs.hidden_states[i], lm_head_slices[only_for_one_logit])
1270
+ logits = logits.float()
1271
+ all_logits = all_logits + (logits,)
1272
+ else:
1273
+ for i in range(len(outputs.hidden_states)):
1274
+ lm_head_slices = self.lm_head[cutoff_layers[i] - self.config.start_layer].weight.split(1, dim=0)
1275
+ logits = F.linear(outputs.hidden_states[i], lm_head_slices[only_for_one_logit])
1276
+ logits = logits.float()
1277
+ all_logits = all_logits + (logits, )
1278
 
1279
  loss = None
1280
+ if labels is not None and not only_for_one_logit:
1281
  # Shift so that tokens < n predict n
1282
+ loss = 0
1283
+ for logits in all_logits:
1284
+ shift_logits = logits[..., :-1, :].contiguous()
1285
+ shift_labels = labels[..., 1:].contiguous()
1286
+ # Flatten the tokens
1287
+ loss_fct = CrossEntropyLoss()
1288
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1289
+ shift_labels = shift_labels.view(-1)
1290
+ # Enable model parallelism
1291
+ shift_labels = shift_labels.to(shift_logits.device)
1292
+ loss += loss_fct(shift_logits, shift_labels)
1293
+
1294
+ outputs.hidden_states = None if not output_hidden_states else outputs.hidden_states
1295
 
1296
  if not return_dict:
1297
  output = (logits,) + outputs[1:]
 
1306
  )
1307
 
1308
  def prepare_inputs_for_generation(
1309
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1310
  ):
1311
  if past_key_values is not None:
1312
  if isinstance(past_key_values, Cache):
 
1322
  # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1323
  # input)
1324
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1325
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
1326
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1327
  # input_ids based on the past_length.
1328
  elif past_length < input_ids.shape[1]:
 
1331
 
1332
  # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1333
  if (
1334
+ max_cache_length is not None
1335
+ and attention_mask is not None
1336
+ and cache_length + input_ids.shape[1] > max_cache_length
1337
  ):
1338
  attention_mask = attention_mask[:, -max_cache_length:]
1339
 
 
1343
  position_ids = attention_mask.long().cumsum(-1) - 1
1344
  position_ids.masked_fill_(attention_mask == 0, 1)
1345
  if past_key_values:
1346
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1347
 
1348
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1349
  if inputs_embeds is not None and past_key_values is None:
 
1369
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1370
  )
1371
  return reordered_past
1372
+
1373
  @torch.inference_mode()
1374
  def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1375
  max_length: int = 4096, num_beams=1, do_sample=True, top_p=0.8, temperature=0.3, logits_processor=None,
 
1378
  history = []
1379
  if logits_processor:
1380
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1381
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1382
  else:
1383
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1384
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1385
+
1386
  history.append({"role": role, "content": query})
1387
  history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=False)
1388
  inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
 
1430
 
1431
  @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
1432
  def forward(
1433
+ self,
1434
+ input_ids: torch.LongTensor = None,
1435
+ attention_mask: Optional[torch.Tensor] = None,
1436
+ position_ids: Optional[torch.LongTensor] = None,
1437
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1438
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1439
+ labels: Optional[torch.LongTensor] = None,
1440
+ use_cache: Optional[bool] = None,
1441
+ output_attentions: Optional[bool] = None,
1442
+ output_hidden_states: Optional[bool] = None,
1443
+ return_dict: Optional[bool] = None,
1444
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1445
  r"""
1446
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):