p2o6e100 commited on
Commit
541e21b
·
1 Parent(s): b6dcc99

fix split error for attention

Browse files
Files changed (1) hide show
  1. modeling_llama.py +1 -1
modeling_llama.py CHANGED
@@ -456,7 +456,7 @@ class LlamaAttention(nn.Module):
456
  bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim
457
  ).transpose(1, 2)
458
  query_states, key_states, value_states = torch.split(
459
- qkv_states, [self.num_heads, self.num_heads + self.num_key_value_heads], dim=1
460
  )
461
 
462
  if position_embeddings is None:
 
456
  bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim
457
  ).transpose(1, 2)
458
  query_states, key_states, value_states = torch.split(
459
+ qkv_states, [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=1
460
  )
461
 
462
  if position_embeddings is None: