Update modeling_super_linear.py
Browse files- modeling_super_linear.py +16 -10
modeling_super_linear.py
CHANGED
|
@@ -224,21 +224,27 @@ class RLinear(nn.Module):
|
|
| 224 |
|
| 225 |
self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
|
| 226 |
else:
|
| 227 |
-
W
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
new_W = F.interpolate(
|
| 233 |
W4d,
|
| 234 |
-
size=(
|
|
|
|
| 235 |
mode='bilinear',
|
| 236 |
align_corners=False
|
| 237 |
-
)[0, 0]
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
|
|
|
|
| 242 |
|
| 243 |
|
| 244 |
|
|
|
|
| 224 |
|
| 225 |
self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
|
| 226 |
else:
|
| 227 |
+
W = self.Linear.weight.detach() # (out_features, seq_len)
|
| 228 |
+
|
| 229 |
+
# 1️⃣ keep the last `new_lookback` columns
|
| 230 |
+
W_tail = W[:, -new_lookback:] # (out_features, new_lookback)
|
| 231 |
+
|
| 232 |
+
# 2️⃣ resize those columns back to `seq_len`
|
| 233 |
+
# ─ use 4-D shape (N=1, C=1, H=out, W=new_lookback)
|
| 234 |
+
W4d = W_tail.unsqueeze(0).unsqueeze(0)
|
| 235 |
+
|
| 236 |
new_W = F.interpolate(
|
| 237 |
W4d,
|
| 238 |
+
size=(out_features, self.seq_len), # keep rows = out_features,
|
| 239 |
+
# stretch cols to seq_len
|
| 240 |
mode='bilinear',
|
| 241 |
align_corners=False
|
| 242 |
+
)[0, 0] # → (out_features, seq_len)
|
| 243 |
+
|
| 244 |
+
# 3️⃣ concatenate on the column axis
|
| 245 |
+
W_now = torch.cat((W_tail, new_W), dim=1) # (out_features, new_lookback + seq_len)
|
| 246 |
|
| 247 |
+
self.zero_shot_Linear = W_now
|
| 248 |
|
| 249 |
|
| 250 |
|