razmars commited on
Commit
25927a3
·
verified ·
1 Parent(s): 03ea7d0

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +24 -2
modeling_super_linear.py CHANGED
@@ -208,9 +208,31 @@ class RLinear(nn.Module):
208
  final_scaling = original_norm / new_norm if new_norm.item() != 0 else 1.0
209
  #final_scaling = 1
210
  new_W = new_W * final_scaling
 
211
 
212
- self.zero_shot_Linear = new_W
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
 
 
 
 
214
 
215
  def forward(self, x):
216
  # x: [Batch, Input length,Channel]
@@ -218,7 +240,7 @@ class RLinear(nn.Module):
218
  if x.shape[1] < self.seq_len:
219
  if self.zero_shot_Linear is None:
220
  #print(F"new Lookkback : {x.shape[1]}")
221
- self.transform_model(x.shape[1],2)
222
 
223
  x = x.clone()
224
  #x = x * (x.shape[1]/512)
 
208
  final_scaling = original_norm / new_norm if new_norm.item() != 0 else 1.0
209
  #final_scaling = 1
210
  new_W = new_W * final_scaling
211
+ self.zero_shot_Linear = new_W
212
 
213
+ else:
214
+ W = self.Linear.weight.detach()
215
+ target_indices = torch.linspace(0, self.seq_len - 1, steps=new_lookback, device=W.device)
216
+ source_indices = torch.arange(0, self.seq_len, device=W.device).float()
217
+
218
+ # Initialize the new weight matrix
219
+ new_W = torch.zeros((W.size(0), new_lookback), device=W.device)
220
+
221
+ # Linear interpolation for each row
222
+ for i in range(W.size(0)):
223
+ new_W[i] = torch.tensor([torch.sum(W[i] * (1 - torch.abs(idx - source_indices) / self.seq_len).clamp(min=0))
224
+ for idx in target_indices], device=W.device)
225
+
226
+ # Maintain the same norm as the original weights
227
+ original_norm = torch.norm(W, p=2)
228
+ new_norm = torch.norm(new_W, p=2)
229
+ final_scaling = original_norm / new_norm if new_norm.item() != 0 else 1.0
230
+ new_W = new_W * final_scaling
231
 
232
+ self.zero_shot_Linear = new_W
233
+
234
+
235
+
236
 
237
  def forward(self, x):
238
  # x: [Batch, Input length,Channel]
 
240
  if x.shape[1] < self.seq_len:
241
  if self.zero_shot_Linear is None:
242
  #print(F"new Lookkback : {x.shape[1]}")
243
+ self.transform_model(x.shape[1],3)
244
 
245
  x = x.clone()
246
  #x = x * (x.shape[1]/512)