Update vision.py
#9
by
Keerthi23
- opened
vision.py
CHANGED
|
@@ -34,10 +34,16 @@ from .configuration_vmistral import VMistralVisionConfig
|
|
| 34 |
|
| 35 |
logger = logging.get_logger(__name__)
|
| 36 |
|
| 37 |
-
|
| 38 |
if is_flash_attn_2_available():
|
| 39 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 40 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
|
|
|
| 34 |
|
| 35 |
logger = logging.get_logger(__name__)
|
| 36 |
|
| 37 |
+
# is_flash_attn_2_available() checks if flash_attn is available, if yes, it imports it
|
| 38 |
if is_flash_attn_2_available():
|
| 39 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 40 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 41 |
+
# if flash_attn is not available, it raises an ImportError, intimating the user to install flash_attn
|
| 42 |
+
else:
|
| 43 |
+
raise ImportError(
|
| 44 |
+
"Flash Attention 2.0 is not available. Please install flash-attn>=2.1.0: "
|
| 45 |
+
"`pip install flash-attn>=2.1.0` or use an alternative attention implementation."
|
| 46 |
+
)
|
| 47 |
|
| 48 |
|
| 49 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|