Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +2 -3
modeling_esm_plusplus.py
CHANGED
|
@@ -316,12 +316,11 @@ class MultiHeadAttention(nn.Module):
|
|
| 316 |
query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
|
| 317 |
|
| 318 |
if output_attentions: # Manual attention computation
|
| 319 |
-
b,
|
| 320 |
scale = 1 / math.sqrt(d)
|
| 321 |
-
attn_bias = torch.zeros(b,
|
| 322 |
if attention_mask is not None:
|
| 323 |
attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
|
| 324 |
-
|
| 325 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
|
| 326 |
attn_weights += attn_bias
|
| 327 |
attn_weights = F.softmax(attn_weights, dim=-1)
|
|
|
|
| 316 |
query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
|
| 317 |
|
| 318 |
if output_attentions: # Manual attention computation
|
| 319 |
+
b, h, l, d = query_BHLD.shape
|
| 320 |
scale = 1 / math.sqrt(d)
|
| 321 |
+
attn_bias = torch.zeros(b, h, l, l, dtype=query_BLD.dtype, device=query_BLD.device)
|
| 322 |
if attention_mask is not None:
|
| 323 |
attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
|
|
|
|
| 324 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
|
| 325 |
attn_weights += attn_bias
|
| 326 |
attn_weights = F.softmax(attn_weights, dim=-1)
|