lirannoc commited on
Commit
5aef5e6
·
verified ·
1 Parent(s): 41441d8

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +2 -2
modeling_super_linear.py CHANGED
@@ -422,11 +422,11 @@ class Model(nn.Module):
422
 
423
  if self.train_pred_len < pred_len:
424
  outputs = [out]
425
- ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
426
  for i in range(0, inf_pred_len, self.train_pred_len):
427
  ar_out, _ = self.moe(ar_x)
428
  outputs.append(ar_out)
429
- ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
430
  out = torch.cat(outputs, dim=1)[:, :pred_len]
431
 
432
  # Reshape back
 
422
 
423
  if self.train_pred_len < pred_len:
424
  outputs = [out]
425
+ ar_x = torch.cat([x, out], dim=1)[:, -self.train_seq_len:]
426
  for i in range(0, inf_pred_len, self.train_pred_len):
427
  ar_out, _ = self.moe(ar_x)
428
  outputs.append(ar_out)
429
+ ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.train_seq_len:]
430
  out = torch.cat(outputs, dim=1)[:, :pred_len]
431
 
432
  # Reshape back