Update modeling_openelm.py
Browse files- modeling_openelm.py +3 -1
modeling_openelm.py
CHANGED
|
@@ -779,7 +779,9 @@ class OpenELMModel(OpenELMPreTrainedModel):
|
|
| 779 |
:, None, None, :
|
| 780 |
].eq(0.0)
|
| 781 |
causal_mask = causal_mask.clone()
|
| 782 |
-
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
|
|
|
|
|
|
|
| 783 |
#causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
|
| 784 |
# padding_mask, min_dtype
|
| 785 |
#)
|
|
|
|
| 779 |
:, None, None, :
|
| 780 |
].eq(0.0)
|
| 781 |
causal_mask = causal_mask.clone()
|
| 782 |
+
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
|
| 783 |
+
padding_mask, min_dtype
|
| 784 |
+
)
|
| 785 |
#causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
|
| 786 |
# padding_mask, min_dtype
|
| 787 |
#)
|