razmars commited on
Commit
0b03843
·
verified ·
1 Parent(s): 762a522

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +41 -55
modeling_super_linear.py CHANGED
@@ -192,62 +192,48 @@ class NLinear(nn.Module):
192
 
193
 
194
  class RLinear(nn.Module):
195
- """
196
- Resizable linear projection from variable input length L to fixed horizon,
197
- applied independently to every channel, **without bias**.
198
- """
199
- def __init__(self, input_len: int, output_len: int):
200
- super().__init__()
201
- self.seq_len = input_len
202
- self.horizon = output_len
203
-
204
- # ★ bias removed → bias=False
205
- self.linear = nn.Linear(input_len, output_len)
206
-
207
- self.revin = RevIN(num_features=None, affine=False,
208
- norm_type=None, subtract_last=False)
209
-
210
- @staticmethod
211
- def _resize_weight(weight: torch.Tensor, new_in: int, horizon: int) -> torch.Tensor:
212
- """Interpolate columns so weight becomes (horizon, new_in)."""
213
- if new_in == weight.shape[1]:
214
- return weight
215
- w4d = weight.unsqueeze(0).unsqueeze(0) # (1,1,out,in)
216
- w_resized = F.interpolate(
217
- w4d, size=(horizon, new_in), mode="bilinear",
218
- align_corners=False
219
- )[0, 0] # (out,new_in)
220
- return w_resized
221
-
222
- def forward(self, x: torch.Tensor) -> torch.Tensor:
223
- """
224
- x: (B,L,C) or (B,L) ➜ (B,horizon,C) or (B,horizon)
225
- """
226
- squeeze_last = False
227
- if x.dim() == 2: # (B,L)
228
- x = x.unsqueeze(-1) # (B,L,1)
229
- squeeze_last = True
230
-
231
- B, L, C = x.shape
232
- x = self.revin(x, "norm")
233
-
234
- if L == self.seq_len: # fast path
235
- x = self.linear(x.permute(0, 2, 1)) # (B,C,horizon)
236
- x = x.permute(0, 2, 1) # (B,horizon,C)
237
- else: # resize path
238
- W = self.linear.weight.detach() # (out,in)
239
- W_resized = self._resize_weight(W, L, self.horizon)
240
-
241
- # ★ bias removed → no "+ b"
242
- x = x.permute(0, 2, 1) # (B,C,L)
243
- x = torch.einsum("bcl,ol->bco", x, W_resized) # (B,C,out)
244
- x = x.permute(0, 2, 1) # (B,horizon,C)
245
-
246
- x = self.revin(x, "denorm")
247
- if squeeze_last:
248
- x = x.squeeze(-1)
249
- return x
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  "-------------------------------------------------------------------------------------------------------------------"
252
  class SparseNoisyMoE(nn.Module):
253
  def __init__(self, configs, experts=None):
 
192
 
193
 
194
  class RLinear(nn.Module):
195
+ def __init__(self, input_len, output_len):
196
+ super(RLinear, self).__init__()
197
+ self.Linear = nn.Linear(input_len, output_len)
198
+ self.seq_len = input_len
199
+ self.horizon = output_len
200
+ self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+
203
+ def forward(self, x):
204
+ # x: [Batch, Input length,Channel]
205
+ x_shape = x.shape
206
+ if len(x_shape) == 2:
207
+ x = x.unsqueeze(-1)
208
+
209
+ B,L,V = x.shape
210
+ if L < self.seq_len:
211
+ in_features = L
212
+ W = self.Linear.weight.detach()
213
+ fixed_weights = self.weights[:, :L]
214
+ dynamic_weights = self.weights[:, L:]
215
+
216
+ if in_features != self.weights.size(1) or out_features != self.weights.size(0):
217
+ dynamic_weights = F.interpolate(dynamic_weights.unsqueeze(0).unsqueeze(0), size=(self.horizon, in_features-self.seq_len), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
218
+ if self.fixed_in != 0:
219
+ fixed_weights = F.interpolate(fixed_weights.unsqueeze(0).unsqueeze(0), size=(self.horizon, self.fixed_in), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
220
+
221
+ x = self.revin_layer(x, 'norm')
222
+ x = F.linear(x, torch.cat((fixed_weights, dynamic_weights), dim=1))
223
+ x = self.revin_layer(x, 'denorm')
224
+ if len(x_shape) == 2:
225
+ x = x.squeeze(-1)
226
+ return x
227
+
228
+
229
+ x = x.clone()
230
+ x = self.revin_layer(x, 'norm')
231
+ x = self.Linear(x.permute(0,2,1)).permute(0,2,1).clone()
232
+ x = self.revin_layer(x, 'denorm')
233
+
234
+ if len(x_shape) == 2:
235
+ x = x.squeeze(-1)
236
+ return x # to [Batch, Output length, Channel]
237
  "-------------------------------------------------------------------------------------------------------------------"
238
  class SparseNoisyMoE(nn.Module):
239
  def __init__(self, configs, experts=None):