razmars commited on
Commit
a7dc532
·
verified ·
1 Parent(s): 38d596e

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +11 -3
modeling_super_linear.py CHANGED
@@ -211,9 +211,17 @@ class RLinear(nn.Module):
211
 
212
  self.zero_shot_Linear = new_W
213
  else:
214
- W = self.Linear.weight.detach()
215
- new_W = F.interpolate(W.unsqueeze(-1), size=(new_lookback, self.horizon ), mode='bilinear', align_corners=False).squeeze(-1)
216
- self.zero_shot_Linear = new_W
 
 
 
 
 
 
 
 
217
 
218
 
219
 
 
211
 
212
  self.zero_shot_Linear = new_W
213
  else:
214
+ W4d = W.unsqueeze(0).unsqueeze(0) # (1, 1, out, in)
215
+
216
+ # resize H → self.horizon and W → new_lookback
217
+ new_W = F.interpolate(
218
+ W4d,
219
+ size=(self.horizon, new_lookback), # (H_out, W_out)
220
+ mode='bilinear',
221
+ align_corners=False
222
+ )[0, 0] # drop the two singleton dims
223
+
224
+ self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
225
 
226
 
227