Spaces:
Runtime error
Runtime error
Antoni Bigata
commited on
Commit
Β·
1982e66
1
Parent(s):
d9970e1
requirements
Browse files- WavLM_modules.py +1 -3
WavLM_modules.py
CHANGED
|
@@ -450,9 +450,7 @@ class MultiheadAttention(nn.Module):
|
|
| 450 |
relative_position_bucket = self._relative_positions_bucket(
|
| 451 |
relative_position, bidirectional=True
|
| 452 |
)
|
| 453 |
-
|
| 454 |
-
# self.relative_attention_bias.weight.device
|
| 455 |
-
# )
|
| 456 |
values = self.relative_attention_bias(relative_position_bucket)
|
| 457 |
values = values.permute([2, 0, 1])
|
| 458 |
return values
|
|
|
|
| 450 |
relative_position_bucket = self._relative_positions_bucket(
|
| 451 |
relative_position, bidirectional=True
|
| 452 |
)
|
| 453 |
+
relative_position_bucket.cuda()
|
|
|
|
|
|
|
| 454 |
values = self.relative_attention_bias(relative_position_bucket)
|
| 455 |
values = values.permute([2, 0, 1])
|
| 456 |
return values
|