razmars commited on
Commit
762a522
Β·
verified Β·
1 Parent(s): fe49d07

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +31 -52
modeling_super_linear.py CHANGED
@@ -193,82 +193,61 @@ class NLinear(nn.Module):
193
 
194
  class RLinear(nn.Module):
195
  """
196
- Resizable linear projection from variable input length L to fixed horizon.
197
- Each channel is projected independently (no mixing across channels).
198
  """
199
  def __init__(self, input_len: int, output_len: int):
200
  super().__init__()
201
- self.seq_len = input_len # β€œdesign” length
202
- self.horizon = output_len
203
- # plain weight + bias – we will *interpolate* weight when L β‰  seq_len
204
- self.linear = nn.Linear(input_len, output_len, bias=True)
 
205
 
206
- # RevIN β€” your own implementation; keep it stateless for safety
207
  self.revin = RevIN(num_features=None, affine=False,
208
  norm_type=None, subtract_last=False)
209
 
210
- def _resize_weight(self, weight: torch.Tensor, new_in: int) -> torch.Tensor:
211
- """
212
- Bilinearly interpolate the *columns* of `weight` so that
213
- weight.shape becomes (horizon, new_in).
214
- """
215
  if new_in == weight.shape[1]:
216
- return weight # no resizing needed
217
-
218
- # weight: (out, in) β†’ (1,1,out,in) so we can use interpolate
219
- w4d = weight.unsqueeze(0).unsqueeze(0)
220
  w_resized = F.interpolate(
221
- w4d, size=(self.horizon, new_in), mode="bilinear",
222
  align_corners=False
223
- )[0, 0] # back to (out, new_in)
224
  return w_resized
225
 
226
  def forward(self, x: torch.Tensor) -> torch.Tensor:
227
  """
228
- x: (B, L, C) or (B, L) β‡’ (B, horizon, C) or (B, horizon)
229
  """
230
- # make sure x is 3-D
231
- squeeze_last_dim = False
232
- if x.dim() == 2: # (B,L)
233
- x = x.unsqueeze(-1) # (B,L,1)
234
- squeeze_last_dim = True
235
 
236
  B, L, C = x.shape
237
-
238
- # RevIN normalisation (over time dimension)
239
  x = self.revin(x, "norm")
240
 
241
- if L == self.seq_len:
242
- # fast path β€” built-in linear works
243
- # reshape so that each channel is treated independently
244
- x = x.permute(0, 2, 1) # (B,C,L)
245
- x = self.linear(x) # (B,C,horizon)
246
- x = x.permute(0, 2, 1) # (B,horizon,C)
247
 
248
- else:
249
- # slow path β€” resize the weight to match L
250
- # freeze current weight & bias
251
- W = self.linear.weight.detach() # (out,in)
252
- b = self.linear.bias.detach() # (out)
253
 
254
- W_resized = self._resize_weight(W, L) # (out,L)
255
-
256
- # apply per channel
257
- # x: (B,L,C) β†’ (B,C,L) so that last dim is L
258
- x = x.permute(0, 2, 1)
259
- # out = x @ W_resized.T + b
260
- out = torch.einsum("bcl,ol->bco", x, W_resized) + b
261
- x = out.permute(0, 2, 1) # back to (B,horizon,C)
262
-
263
- # RevIN denorm
264
  x = self.revin(x, "denorm")
265
-
266
- if squeeze_last_dim:
267
- x = x.squeeze(-1) # (B,horizon)
268
-
269
  return x
270
 
271
-
272
  "-------------------------------------------------------------------------------------------------------------------"
273
  class SparseNoisyMoE(nn.Module):
274
  def __init__(self, configs, experts=None):
 
193
 
194
  class RLinear(nn.Module):
195
  """
196
+ Resizable linear projection from variable input length L to fixed horizon,
197
+ applied independently to every channel, **without bias**.
198
  """
199
  def __init__(self, input_len: int, output_len: int):
200
  super().__init__()
201
+ self.seq_len = input_len
202
+ self.horizon = output_len
203
+
204
+ # β˜… bias removed β†’ bias=False
205
+ self.linear = nn.Linear(input_len, output_len)
206
 
 
207
  self.revin = RevIN(num_features=None, affine=False,
208
  norm_type=None, subtract_last=False)
209
 
210
+ @staticmethod
211
+ def _resize_weight(weight: torch.Tensor, new_in: int, horizon: int) -> torch.Tensor:
212
+ """Interpolate columns so weight becomes (horizon, new_in)."""
 
 
213
  if new_in == weight.shape[1]:
214
+ return weight
215
+ w4d = weight.unsqueeze(0).unsqueeze(0) # (1,1,out,in)
 
 
216
  w_resized = F.interpolate(
217
+ w4d, size=(horizon, new_in), mode="bilinear",
218
  align_corners=False
219
+ )[0, 0] # (out,new_in)
220
  return w_resized
221
 
222
  def forward(self, x: torch.Tensor) -> torch.Tensor:
223
  """
224
+ x: (B,L,C) or (B,L) ➜ (B,horizon,C) or (B,horizon)
225
  """
226
+ squeeze_last = False
227
+ if x.dim() == 2: # (B,L)
228
+ x = x.unsqueeze(-1) # (B,L,1)
229
+ squeeze_last = True
 
230
 
231
  B, L, C = x.shape
 
 
232
  x = self.revin(x, "norm")
233
 
234
+ if L == self.seq_len: # fast path
235
+ x = self.linear(x.permute(0, 2, 1)) # (B,C,horizon)
236
+ x = x.permute(0, 2, 1) # (B,horizon,C)
237
+ else: # resize path
238
+ W = self.linear.weight.detach() # (out,in)
239
+ W_resized = self._resize_weight(W, L, self.horizon)
240
 
241
+ # β˜… bias removed β†’ no "+ b"
242
+ x = x.permute(0, 2, 1) # (B,C,L)
243
+ x = torch.einsum("bcl,ol->bco", x, W_resized) # (B,C,out)
244
+ x = x.permute(0, 2, 1) # (B,horizon,C)
 
245
 
 
 
 
 
 
 
 
 
 
 
246
  x = self.revin(x, "denorm")
247
+ if squeeze_last:
248
+ x = x.squeeze(-1)
 
 
249
  return x
250
 
 
251
  "-------------------------------------------------------------------------------------------------------------------"
252
  class SparseNoisyMoE(nn.Module):
253
  def __init__(self, configs, experts=None):