razmars commited on
Commit
934cf7d
·
verified ·
1 Parent(s): 7ee3e50

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +8 -3
modeling_super_linear.py CHANGED
@@ -448,9 +448,14 @@ class superLinear(nn.Module):
448
 
449
 
450
  def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None, freq=[None], get_prob=False):
451
- print(x_enc.shape)
452
- x = x_enc.permute(0, 2, 1)
453
- B, V, L = x.shape
 
 
 
 
 
454
  x = x.reshape(B * V, L)
455
 
456
  expert_probs = None
 
448
 
449
 
450
  def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None, freq=[None], get_prob=False):
451
+ if len(x_enc.shape) > 2:
452
+ x = x_enc.permute(0, 2, 1)
453
+ B, V, L = x.shape
454
+ else:
455
+ x = x_enc
456
+ B, L = x.shape
457
+ V = 1
458
+
459
  x = x.reshape(B * V, L)
460
 
461
  expert_probs = None