Update modeling_super_linear.py
Browse files- modeling_super_linear.py +1 -6
modeling_super_linear.py
CHANGED
|
@@ -457,9 +457,6 @@ class superLinear(nn.Module):
|
|
| 457 |
V = 1
|
| 458 |
|
| 459 |
x = x.reshape(B * V, L)
|
| 460 |
-
print("RAZ")
|
| 461 |
-
print(x.shape)
|
| 462 |
-
|
| 463 |
expert_probs = None
|
| 464 |
|
| 465 |
if get_prob:
|
|
@@ -468,6 +465,7 @@ class superLinear(nn.Module):
|
|
| 468 |
out, self.moe_loss = self.moe(x)
|
| 469 |
|
| 470 |
if self.auto_regressive and self.max_horizon < self.inf_pred_len:
|
|
|
|
| 471 |
outputs = [out]
|
| 472 |
ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
|
| 473 |
for i in range(0, self.inf_pred_len, self.max_horizon):
|
|
@@ -478,13 +476,10 @@ class superLinear(nn.Module):
|
|
| 478 |
|
| 479 |
if len(x_enc.shape) > 2:
|
| 480 |
out = out.reshape(B, V, out.shape[-1])
|
| 481 |
-
print(F"out2 :{out.shape}")
|
| 482 |
result = out.permute(0, 2, 1)
|
| 483 |
else:
|
| 484 |
-
print(F"out1 :{out.shape}")
|
| 485 |
result = out
|
| 486 |
|
| 487 |
-
|
| 488 |
if get_prob:
|
| 489 |
expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
|
| 490 |
return result, expert_probs
|
|
|
|
| 457 |
V = 1
|
| 458 |
|
| 459 |
x = x.reshape(B * V, L)
|
|
|
|
|
|
|
|
|
|
| 460 |
expert_probs = None
|
| 461 |
|
| 462 |
if get_prob:
|
|
|
|
| 465 |
out, self.moe_loss = self.moe(x)
|
| 466 |
|
| 467 |
if self.auto_regressive and self.max_horizon < self.inf_pred_len:
|
| 468 |
+
print("bitch")
|
| 469 |
outputs = [out]
|
| 470 |
ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
|
| 471 |
for i in range(0, self.inf_pred_len, self.max_horizon):
|
|
|
|
| 476 |
|
| 477 |
if len(x_enc.shape) > 2:
|
| 478 |
out = out.reshape(B, V, out.shape[-1])
|
|
|
|
| 479 |
result = out.permute(0, 2, 1)
|
| 480 |
else:
|
|
|
|
| 481 |
result = out
|
| 482 |
|
|
|
|
| 483 |
if get_prob:
|
| 484 |
expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
|
| 485 |
return result, expert_probs
|