Update modeling_super_linear.py
Browse files- modeling_super_linear.py +4 -3
modeling_super_linear.py
CHANGED
|
@@ -476,13 +476,14 @@ class superLinear(nn.Module):
|
|
| 476 |
ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
|
| 477 |
out = torch.cat(outputs, dim=1)[:, :self.inf_pred_len]
|
| 478 |
|
| 479 |
-
print(F"out1 :{out.shape}")
|
| 480 |
if len(x_enc.shape) > 2:
|
| 481 |
-
result = out
|
| 482 |
-
else:
|
| 483 |
out = out.reshape(B, V, out.shape[-1])
|
| 484 |
print(F"out2 :{out.shape}")
|
| 485 |
result = out.permute(0, 2, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
|
| 487 |
if get_prob:
|
| 488 |
expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
|
|
|
|
| 476 |
ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
|
| 477 |
out = torch.cat(outputs, dim=1)[:, :self.inf_pred_len]
|
| 478 |
|
|
|
|
| 479 |
if len(x_enc.shape) > 2:
|
|
|
|
|
|
|
| 480 |
out = out.reshape(B, V, out.shape[-1])
|
| 481 |
print(F"out2 :{out.shape}")
|
| 482 |
result = out.permute(0, 2, 1)
|
| 483 |
+
else:
|
| 484 |
+
print(F"out1 :{out.shape}")
|
| 485 |
+
result = out
|
| 486 |
+
|
| 487 |
|
| 488 |
if get_prob:
|
| 489 |
expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
|