fix split error for attention
Browse files- modeling_llama.py +1 -1
modeling_llama.py
CHANGED
|
@@ -456,7 +456,7 @@ class LlamaAttention(nn.Module):
|
|
| 456 |
bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim
|
| 457 |
).transpose(1, 2)
|
| 458 |
query_states, key_states, value_states = torch.split(
|
| 459 |
-
qkv_states, [self.num_heads, self.
|
| 460 |
)
|
| 461 |
|
| 462 |
if position_embeddings is None:
|
|
|
|
| 456 |
bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim
|
| 457 |
).transpose(1, 2)
|
| 458 |
query_states, key_states, value_states = torch.split(
|
| 459 |
+
qkv_states, [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=1
|
| 460 |
)
|
| 461 |
|
| 462 |
if position_embeddings is None:
|