Update modeling_super_linear.py
Browse files- modeling_super_linear.py +11 -3
modeling_super_linear.py
CHANGED
|
@@ -211,9 +211,17 @@ class RLinear(nn.Module):
|
|
| 211 |
|
| 212 |
self.zero_shot_Linear = new_W
|
| 213 |
else:
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
|
| 219 |
|
|
|
|
| 211 |
|
| 212 |
self.zero_shot_Linear = new_W
|
| 213 |
else:
|
| 214 |
+
W4d = W.unsqueeze(0).unsqueeze(0) # (1, 1, out, in)
|
| 215 |
+
|
| 216 |
+
# resize H → self.horizon and W → new_lookback
|
| 217 |
+
new_W = F.interpolate(
|
| 218 |
+
W4d,
|
| 219 |
+
size=(self.horizon, new_lookback), # (H_out, W_out)
|
| 220 |
+
mode='bilinear',
|
| 221 |
+
align_corners=False
|
| 222 |
+
)[0, 0] # drop the two singleton dims
|
| 223 |
+
|
| 224 |
+
self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
|
| 225 |
|
| 226 |
|
| 227 |
|