Update modeling_super_linear.py
Browse files- modeling_super_linear.py +8 -3
modeling_super_linear.py
CHANGED
|
@@ -448,9 +448,14 @@ class superLinear(nn.Module):
|
|
| 448 |
|
| 449 |
|
| 450 |
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None, freq=[None], get_prob=False):
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
x = x.reshape(B * V, L)
|
| 455 |
|
| 456 |
expert_probs = None
|
|
|
|
| 448 |
|
| 449 |
|
| 450 |
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None, freq=[None], get_prob=False):
|
| 451 |
+
if len(x_enc.shape) > 2:
|
| 452 |
+
x = x_enc.permute(0, 2, 1)
|
| 453 |
+
B, V, L = x.shape
|
| 454 |
+
else:
|
| 455 |
+
x = x_enc
|
| 456 |
+
B, L = x.shape
|
| 457 |
+
V = 1
|
| 458 |
+
|
| 459 |
x = x.reshape(B * V, L)
|
| 460 |
|
| 461 |
expert_probs = None
|