razmars commited on
Commit
302f34b
·
verified ·
1 Parent(s): aa6df5b

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +7 -3
modeling_super_linear.py CHANGED
@@ -475,10 +475,14 @@ class superLinear(nn.Module):
475
  outputs.append(ar_out)
476
  ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
477
  out = torch.cat(outputs, dim=1)[:, :self.inf_pred_len]
 
478
  print(F"out1 :{out.shape}")
479
- out = out.reshape(B, V, out.shape[-1])
480
- print(F"out2 :{out.shape}")
481
- result = out.permute(0, 2, 1)
 
 
 
482
 
483
  if get_prob:
484
  expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
 
475
  outputs.append(ar_out)
476
  ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
477
  out = torch.cat(outputs, dim=1)[:, :self.inf_pred_len]
478
+
479
  print(F"out1 :{out.shape}")
480
+ if len(x_enc.shape) > 2:
481
+ result = out
482
+ else:
483
+ out = out.reshape(B, V, out.shape[-1])
484
+ print(F"out2 :{out.shape}")
485
+ result = out.permute(0, 2, 1)
486
 
487
  if get_prob:
488
  expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])