razmars commited on
Commit
266c1c3
·
verified ·
1 Parent(s): 203f7df

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +12 -17
modeling_super_linear.py CHANGED
@@ -200,7 +200,7 @@ class RLinear(nn.Module):
200
  self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
201
  self.zero_shot_Linear = None
202
 
203
- def transform_model(self,new_lookback,mode):
204
  if mode == 1:
205
  W = self.Linear.weight.detach()
206
  new_W = W[:, -new_lookback:]
@@ -226,25 +226,19 @@ class RLinear(nn.Module):
226
  else:
227
  W = self.Linear.weight.detach() # (out_features, seq_len)
228
 
229
- # 1️⃣ keep the last `new_lookback` columns
230
- W_tail = W[:, -new_lookback:] # (out_features, new_lookback)
231
-
232
- # 2️⃣ resize those columns back to `seq_len`
233
- # ─ use 4-D shape (N=1, C=1, H=out, W=new_lookback)
234
- W4d = W_tail.unsqueeze(0).unsqueeze(0)
235
-
236
  new_W = F.interpolate(
237
  W4d,
238
- size=(new_lookback, self.seq_len), # keep rows = out_features,
239
- # stretch cols to seq_len
240
  mode='bilinear',
241
  align_corners=False
242
- )[0, 0] # (out_features, seq_len)
243
-
244
- # 3️⃣ concatenate on the column axis
245
- W_now = torch.cat((W_tail, new_W), dim=1) # (out_features, new_lookback + seq_len)
246
-
247
- self.zero_shot_Linear = W_now
248
 
249
 
250
 
@@ -254,7 +248,8 @@ class RLinear(nn.Module):
254
  if x.shape[1] < self.seq_len:
255
  #if self.zero_shot_Linear is None:
256
  #print(F"new Lookkback : {x.shape[1]}")
257
- self.transform_model(x.shape[1],3)
 
258
 
259
  x = x.clone()
260
  #x = x * (x.shape[1]/512)
 
200
  self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
201
  self.zero_shot_Linear = None
202
 
203
+ def transform_model(self,x,new_lookback,mode):
204
  if mode == 1:
205
  W = self.Linear.weight.detach()
206
  new_W = W[:, -new_lookback:]
 
226
  else:
227
  W = self.Linear.weight.detach() # (out_features, seq_len)
228
 
229
+ W4d = W.unsqueeze(0).unsqueeze(0) # (1, 1, out, in)
230
+
231
+ # resize H → self.horizon and W → new_lookback
 
 
 
 
232
  new_W = F.interpolate(
233
  W4d,
234
+ size=(self.horizon, new_lookback), # (H_out, W_out)
 
235
  mode='bilinear',
236
  align_corners=False
237
+ )[0, 0] # drop the two singleton dims
238
+
239
+ x = F.linear(x, new_W)
240
+ self.zero_shot_Linear = W
241
+ return x
 
242
 
243
 
244
 
 
248
  if x.shape[1] < self.seq_len:
249
  #if self.zero_shot_Linear is None:
250
  #print(F"new Lookkback : {x.shape[1]}")
251
+
252
+ x = self.transform_model(x,x.shape[1],3)
253
 
254
  x = x.clone()
255
  #x = x * (x.shape[1]/512)