razmars commited on
Commit
b672046
·
verified ·
1 Parent(s): f9ff639

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +5 -23
modeling_super_linear.py CHANGED
@@ -200,40 +200,22 @@ class RLinear(nn.Module):
200
  self.zero_shot_Linear = None
201
 
202
  def transform_model(self,new_lookback,mode):
203
- if mode == 2:
204
  W = self.Linear.weight.detach()
205
  new_W = W[:, -new_lookback:]
206
  original_norm = torch.norm(W, p=2)
207
  new_norm = torch.norm(new_W, p=2)
208
  final_scaling = original_norm / new_norm if new_norm.item() != 0 else 1.0
209
- final_scaling = 1
210
  new_W = new_W * final_scaling
211
- self.zero_shot_Linear = new_W
212
 
 
213
  else:
214
- W = self.Linear.weight.detach()
215
- target_indices = torch.linspace(0, self.seq_len - 1, steps=new_lookback, device=W.device)
216
- source_indices = torch.arange(0, self.seq_len, device=W.device).float()
217
-
218
- # Initialize the new weight matrix
219
- new_W = torch.zeros((W.size(0), new_lookback), device=W.device)
220
-
221
- # Linear interpolation for each row
222
- for i in range(W.size(0)):
223
- new_W[i] = torch.tensor([torch.sum(W[i] * (1 - torch.abs(idx - source_indices) / self.seq_len).clamp(min=0))
224
- for idx in target_indices], device=W.device)
225
-
226
- # Maintain the same norm as the original weights
227
- original_norm = torch.norm(W, p=2)
228
- new_norm = torch.norm(new_W, p=2)
229
- final_scaling = original_norm / new_norm if new_norm.item() != 0 else 1.0
230
- new_W = new_W * final_scaling
231
 
232
- self.zero_shot_Linear = new_W
233
 
234
 
235
-
236
-
237
  def forward(self, x):
238
  # x: [Batch, Input length,Channel]
239
  x_shape = x.shape
 
200
  self.zero_shot_Linear = None
201
 
202
  def transform_model(self,new_lookback,mode):
203
+ if mode == 1:
204
  W = self.Linear.weight.detach()
205
  new_W = W[:, -new_lookback:]
206
  original_norm = torch.norm(W, p=2)
207
  new_norm = torch.norm(new_W, p=2)
208
  final_scaling = original_norm / new_norm if new_norm.item() != 0 else 1.0
 
209
  new_W = new_W * final_scaling
 
210
 
211
+ self.zero_shot_Linear = new_W
212
  else:
213
+ W = self.Linear.weight.detach()
214
+ new_W = F.interpolate(W.unsqueeze(0), size=(new_lookback, self.horizon ), mode='bilinear', align_corners=False).squeeze(0)
215
+ self.zero_shot_Linear = new_W
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
 
217
 
218
 
 
 
219
  def forward(self, x):
220
  # x: [Batch, Input length,Channel]
221
  x_shape = x.shape