Files changed (1) hide show
  1. vision.py +7 -1
vision.py CHANGED
@@ -34,10 +34,16 @@ from .configuration_vmistral import VMistralVisionConfig
34
 
35
  logger = logging.get_logger(__name__)
36
 
37
-
38
  if is_flash_attn_2_available():
39
  from flash_attn import flash_attn_func, flash_attn_varlen_func
40
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
 
 
 
 
 
41
 
42
 
43
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
34
 
35
  logger = logging.get_logger(__name__)
36
 
37
+ # is_flash_attn_2_available() checks if flash_attn is available, if yes, it imports it
38
  if is_flash_attn_2_available():
39
  from flash_attn import flash_attn_func, flash_attn_varlen_func
40
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
41
+ # if flash_attn is not available, it raises an ImportError, intimating the user to install flash_attn
42
+ else:
43
+ raise ImportError(
44
+ "Flash Attention 2.0 is not available. Please install flash-attn>=2.1.0: "
45
+ "`pip install flash-attn>=2.1.0` or use an alternative attention implementation."
46
+ )
47
 
48
 
49
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data