razmars commited on
Commit
399c990
·
verified ·
1 Parent(s): 25a4f7f

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +16 -10
modeling_super_linear.py CHANGED
@@ -224,21 +224,27 @@ class RLinear(nn.Module):
224
 
225
  self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
226
  else:
227
- W = self.Linear.weight.detach()
228
- W = W[:, -new_lookback:]
229
- W4d = W.unsqueeze(0).unsqueeze(0) # (1, 1, out, in)
230
-
231
- # resize H → self.horizon and W → new_lookback
 
 
 
 
232
  new_W = F.interpolate(
233
  W4d,
234
- size=( new_lookback,self.seq_len), # (H_out, W_out)
 
235
  mode='bilinear',
236
  align_corners=False
237
- )[0, 0] # drop the two singleton dims
238
-
239
- W_now = torch.cat((W, new_W), dim=1)
240
- self.zero_shot_Linear = new_W
241
 
 
242
 
243
 
244
 
 
224
 
225
  self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
226
  else:
227
+ W = self.Linear.weight.detach() # (out_features, seq_len)
228
+
229
+ # 1️⃣ keep the last `new_lookback` columns
230
+ W_tail = W[:, -new_lookback:] # (out_features, new_lookback)
231
+
232
+ # 2️⃣ resize those columns back to `seq_len`
233
+ # ─ use 4-D shape (N=1, C=1, H=out, W=new_lookback)
234
+ W4d = W_tail.unsqueeze(0).unsqueeze(0)
235
+
236
  new_W = F.interpolate(
237
  W4d,
238
+ size=(out_features, self.seq_len), # keep rows = out_features,
239
+ # stretch cols to seq_len
240
  mode='bilinear',
241
  align_corners=False
242
+ )[0, 0] # (out_features, seq_len)
243
+
244
+ # 3️⃣ concatenate on the column axis
245
+ W_now = torch.cat((W_tail, new_W), dim=1) # (out_features, new_lookback + seq_len)
246
 
247
+ self.zero_shot_Linear = W_now
248
 
249
 
250