razmars commited on
Commit
2d77aa9
·
verified ·
1 Parent(s): 53d6bba

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +18 -2
modeling_super_linear.py CHANGED
@@ -214,6 +214,22 @@ class RLinear(nn.Module):
214
  def forward(self, x):
215
  # x: [Batch, Input length,Channel]
216
  x_shape = x.shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  if len(x_shape) == 2:
218
  x = x.unsqueeze(-1)
219
 
@@ -575,8 +591,8 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
575
  # backbone expects (B, C, L)
576
  x_enc = inputs_embeds
577
 
578
- if x_enc.shape[1] < 512:
579
- x_enc = self.fourier_interp_dim1(x_enc)
580
 
581
 
582
  # backbone returns (B, pred_len, C)
 
214
  def forward(self, x):
215
  # x: [Batch, Input length,Channel]
216
  x_shape = x.shape
217
+ if x.shape[1] < self.seq_len:
218
+ if self.zero_shot_Linear is None:
219
+ print(F"new Lookkback : {x.shape[1]}")
220
+ self.transform_model(x.shape[1])
221
+
222
+ if len(x_shape) == 2:
223
+ x = x.unsqueeze(-1)
224
+ x = x.clone()
225
+ x = self.revin_layer(x, 'norm')
226
+ x = F.linear(x.permute(0,2,1), self.zero_shot_Linear).permute(0,2,1).clone()
227
+ x = self.revin_layer(x, 'denorm')
228
+ if len(x_shape) == 2:
229
+ x = x.squeeze(-1)
230
+ return x
231
+
232
+
233
  if len(x_shape) == 2:
234
  x = x.unsqueeze(-1)
235
 
 
591
  # backbone expects (B, C, L)
592
  x_enc = inputs_embeds
593
 
594
+ '''if x_enc.shape[1] < 512:
595
+ x_enc = self.fourier_interp_dim1(x_enc)'''
596
 
597
 
598
  # backbone returns (B, pred_len, C)