razmars commited on
Commit
fe49d07
·
verified ·
1 Parent(s): b95c453

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +68 -58
modeling_super_linear.py CHANGED
@@ -192,71 +192,81 @@ class NLinear(nn.Module):
192
 
193
 
194
  class RLinear(nn.Module):
195
- def __init__(self, input_len, output_len):
196
- super(RLinear, self).__init__()
197
- self.Linear = nn.Linear(input_len, output_len)
198
- self.seq_len = input_len
199
- self.horizon = output_len
200
- self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
201
- self.zero_shot_Linear = None
202
-
203
- def transform_model(self,new_lookback,mode):
204
- if mode == 1:
205
- W = self.Linear.weight.detach()
206
- new_W = W[:, -new_lookback:]
207
- original_norm = torch.norm(W, p=2)
208
- new_norm = torch.norm(new_W, p=2)
209
- final_scaling = original_norm / new_norm if new_norm.item() != 0 else 1.0
210
- new_W = new_W * final_scaling
211
-
212
- self.zero_shot_Linear = new_W
213
- else:
214
- W = self.Linear.weight.detach()
215
- W4d = W.unsqueeze(0).unsqueeze(0) # (1, 1, out, in)
216
-
217
- # resize H → self.horizon and W → new_lookback
218
- new_W = F.interpolate(
219
- W4d,
220
- size=(self.horizon, new_lookback), # (H_out, W_out)
221
- mode='bilinear',
222
- align_corners=False
223
- )[0, 0] # drop the two singleton dims
224
-
225
- self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
226
-
 
 
 
 
 
 
 
 
227
 
 
228
 
229
-
 
230
 
231
- def forward(self, x):
232
- # x: [Batch, Input length,Channel]
233
- x_shape = x.shape
234
- if x.shape[1] < self.seq_len:
235
- #if self.zero_shot_Linear is None:
236
- #print(F"new Lookkback : {x.shape[1]}")
237
-
238
- self.transform_model(x.shape[1],3)
239
 
240
- x = x.clone()
241
- #x = x * (x.shape[1]/512)
242
- x = self.revin_layer(x, 'norm')
243
- x = F.linear(x, self.zero_shot_Linear)
244
- x = self.revin_layer(x, 'denorm')
245
- #x = x * (512/x.shape[1])
246
- return x
247
 
248
-
249
- if len(x_shape) == 2:
250
- x = x.unsqueeze(-1)
251
 
252
- x = x.clone()
253
- x = self.revin_layer(x, 'norm')
254
- x = self.Linear(x.permute(0,2,1)).permute(0,2,1).clone()
255
- x = self.revin_layer(x, 'denorm')
 
 
256
 
257
- if len(x_shape) == 2:
258
- x = x.squeeze(-1)
259
- return x # to [Batch, Output length, Channel]
 
 
 
 
260
 
261
 
262
  "-------------------------------------------------------------------------------------------------------------------"
 
192
 
193
 
194
  class RLinear(nn.Module):
195
+ """
196
+ Resizable linear projection from variable input length L to fixed horizon.
197
+ Each channel is projected independently (no mixing across channels).
198
+ """
199
+ def __init__(self, input_len: int, output_len: int):
200
+ super().__init__()
201
+ self.seq_len = input_len # “design” length
202
+ self.horizon = output_len
203
+ # plain weight + bias – we will *interpolate* weight when L ≠ seq_len
204
+ self.linear = nn.Linear(input_len, output_len, bias=True)
205
+
206
+ # RevIN your own implementation; keep it stateless for safety
207
+ self.revin = RevIN(num_features=None, affine=False,
208
+ norm_type=None, subtract_last=False)
209
+
210
+ def _resize_weight(self, weight: torch.Tensor, new_in: int) -> torch.Tensor:
211
+ """
212
+ Bilinearly interpolate the *columns* of `weight` so that
213
+ weight.shape becomes (horizon, new_in).
214
+ """
215
+ if new_in == weight.shape[1]:
216
+ return weight # no resizing needed
217
+
218
+ # weight: (out, in) → (1,1,out,in) so we can use interpolate
219
+ w4d = weight.unsqueeze(0).unsqueeze(0)
220
+ w_resized = F.interpolate(
221
+ w4d, size=(self.horizon, new_in), mode="bilinear",
222
+ align_corners=False
223
+ )[0, 0] # back to (out, new_in)
224
+ return w_resized
225
+
226
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
227
+ """
228
+ x: (B, L, C) or (B, L) ⇒ (B, horizon, C) or (B, horizon)
229
+ """
230
+ # make sure x is 3-D
231
+ squeeze_last_dim = False
232
+ if x.dim() == 2: # (B,L)
233
+ x = x.unsqueeze(-1) # (B,L,1)
234
+ squeeze_last_dim = True
235
 
236
+ B, L, C = x.shape
237
 
238
+ # RevIN normalisation (over time dimension)
239
+ x = self.revin(x, "norm")
240
 
241
+ if L == self.seq_len:
242
+ # fast path built-in linear works
243
+ # reshape so that each channel is treated independently
244
+ x = x.permute(0, 2, 1) # (B,C,L)
245
+ x = self.linear(x) # (B,C,horizon)
246
+ x = x.permute(0, 2, 1) # (B,horizon,C)
 
 
247
 
248
+ else:
249
+ # slow path resize the weight to match L
250
+ # freeze current weight & bias
251
+ W = self.linear.weight.detach() # (out,in)
252
+ b = self.linear.bias.detach() # (out)
 
 
253
 
254
+ W_resized = self._resize_weight(W, L) # (out,L)
 
 
255
 
256
+ # apply per channel
257
+ # x: (B,L,C) → (B,C,L) so that last dim is L
258
+ x = x.permute(0, 2, 1)
259
+ # out = x @ W_resized.T + b
260
+ out = torch.einsum("bcl,ol->bco", x, W_resized) + b
261
+ x = out.permute(0, 2, 1) # back to (B,horizon,C)
262
 
263
+ # RevIN denorm
264
+ x = self.revin(x, "denorm")
265
+
266
+ if squeeze_last_dim:
267
+ x = x.squeeze(-1) # (B,horizon)
268
+
269
+ return x
270
 
271
 
272
  "-------------------------------------------------------------------------------------------------------------------"