Update modeling_llama.py
Browse files- modeling_llama.py +6 -6
modeling_llama.py
CHANGED
|
@@ -60,14 +60,10 @@ def is_flash_attn_available():
|
|
| 60 |
return False
|
| 61 |
|
| 62 |
# Let's add an extra check to see if cuda is available
|
| 63 |
-
import torch
|
| 64 |
|
| 65 |
return _is_package_available("flash_attn") and torch.cuda.is_available()
|
| 66 |
|
| 67 |
-
|
| 68 |
-
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func, flash_attn_with_kvcache
|
| 69 |
-
# from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
| 70 |
-
from flash_attn.bert_padding import unpad_input, pad_input
|
| 71 |
|
| 72 |
|
| 73 |
|
|
@@ -232,7 +228,10 @@ class LlamaAttention(nn.Module):
|
|
| 232 |
|
| 233 |
attention_mask: [bsz, q_len]
|
| 234 |
"""
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
| 236 |
bsz, q_len, *_ = qkv.size()
|
| 237 |
|
| 238 |
if key_padding_mask is None:
|
|
@@ -342,6 +341,7 @@ class LlamaAttention(nn.Module):
|
|
| 342 |
return attn_output, attn_weights, past_key_value
|
| 343 |
# use flash attention
|
| 344 |
elif past_key_value is not None:
|
|
|
|
| 345 |
output = flash_attn_with_kvcache(
|
| 346 |
query_states.transpose(1, 2),
|
| 347 |
key_states.transpose(1, 2),
|
|
|
|
| 60 |
return False
|
| 61 |
|
| 62 |
# Let's add an extra check to see if cuda is available
|
|
|
|
| 63 |
|
| 64 |
return _is_package_available("flash_attn") and torch.cuda.is_available()
|
| 65 |
|
| 66 |
+
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
|
|
|
|
| 228 |
|
| 229 |
attention_mask: [bsz, q_len]
|
| 230 |
"""
|
| 231 |
+
if is_flash_attn_available():
|
| 232 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func, flash_attn_with_kvcache
|
| 233 |
+
# from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
| 234 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
| 235 |
bsz, q_len, *_ = qkv.size()
|
| 236 |
|
| 237 |
if key_padding_mask is None:
|
|
|
|
| 341 |
return attn_output, attn_weights, past_key_value
|
| 342 |
# use flash attention
|
| 343 |
elif past_key_value is not None:
|
| 344 |
+
from flash_attn.flash_attn_interface import flash_attn_with_kvcache
|
| 345 |
output = flash_attn_with_kvcache(
|
| 346 |
query_states.transpose(1, 2),
|
| 347 |
key_states.transpose(1, 2),
|