razmars commited on
Commit
391b2c1
·
verified ·
1 Parent(s): a63396c

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +1 -6
modeling_super_linear.py CHANGED
@@ -457,9 +457,6 @@ class superLinear(nn.Module):
457
  V = 1
458
 
459
  x = x.reshape(B * V, L)
460
- print("RAZ")
461
- print(x.shape)
462
-
463
  expert_probs = None
464
 
465
  if get_prob:
@@ -468,6 +465,7 @@ class superLinear(nn.Module):
468
  out, self.moe_loss = self.moe(x)
469
 
470
  if self.auto_regressive and self.max_horizon < self.inf_pred_len:
 
471
  outputs = [out]
472
  ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
473
  for i in range(0, self.inf_pred_len, self.max_horizon):
@@ -478,13 +476,10 @@ class superLinear(nn.Module):
478
 
479
  if len(x_enc.shape) > 2:
480
  out = out.reshape(B, V, out.shape[-1])
481
- print(F"out2 :{out.shape}")
482
  result = out.permute(0, 2, 1)
483
  else:
484
- print(F"out1 :{out.shape}")
485
  result = out
486
 
487
-
488
  if get_prob:
489
  expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
490
  return result, expert_probs
 
457
  V = 1
458
 
459
  x = x.reshape(B * V, L)
 
 
 
460
  expert_probs = None
461
 
462
  if get_prob:
 
465
  out, self.moe_loss = self.moe(x)
466
 
467
  if self.auto_regressive and self.max_horizon < self.inf_pred_len:
468
+ print("bitch")
469
  outputs = [out]
470
  ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
471
  for i in range(0, self.inf_pred_len, self.max_horizon):
 
476
 
477
  if len(x_enc.shape) > 2:
478
  out = out.reshape(B, V, out.shape[-1])
 
479
  result = out.permute(0, 2, 1)
480
  else:
 
481
  result = out
482
 
 
483
  if get_prob:
484
  expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
485
  return result, expert_probs