razmars commited on
Commit
a3e7047
·
verified ·
1 Parent(s): 9019dcb

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +5 -12
modeling_super_linear.py CHANGED
@@ -235,8 +235,6 @@ class SparseNoisyMoE(nn.Module):
235
  if self.k > self.num_experts:
236
  print(f"Warning: k ({self.k}) is greater than the number of experts ({self.num_experts}). Setting k to {self.num_experts}.")
237
  self.k = self.num_experts
238
- # self.ker_len = configs.ker_len
239
- #self.con = configs.con
240
  self.d_model = configs.d_model
241
  self.mlp_gating = configs.mlp_gating
242
  self.moe_temp = configs.moe_temp
@@ -367,7 +365,7 @@ class superLinear(nn.Module):
367
  else:
368
  self.freq_experts = configs.freq_experts.split('_')
369
 
370
- print("self.freq_experts:", self.freq_experts)
371
 
372
  self.moe_loss = None
373
  self.top_k_experts = configs.top_k_experts
@@ -377,7 +375,7 @@ class superLinear(nn.Module):
377
  self.layer_type = configs.layer_type
378
  self.model_name = "SuperLinear"
379
 
380
- print("self.layer_type", self.layer_type)
381
  self.layer_dict = {'DLinear': DLinear, 'Linear': Linear, 'NLinear': NLinear, 'RLinear': RLinear}
382
  path = configs.linear_checkpoints_path + configs.linear_checkpoints_dir
383
  dirs = os.listdir(path)
@@ -422,20 +420,16 @@ class superLinear(nn.Module):
422
  print(f"creating expert {i}")
423
  self.experts[str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
424
 
425
- #self.manual_moe = configs.manual_moe
426
-
427
 
428
  if configs.misc_moe>0:
429
  if configs.misc_moe == 1:
430
- print("Creating misc expert")
431
  self.experts["misc"] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
432
  else:
433
  for i in range(configs.misc_moe):
434
- print(f"Creating misc expert {i}")
435
  self.experts["misc_"+str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
436
- '''if configs.misc_moe2==1:
437
- print("Creating misc expert")
438
- self.experts["misc2"] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)'''
439
 
440
 
441
  self.moe = SparseNoisyMoE(configs, experts=self.experts.values())
@@ -445,7 +439,6 @@ class superLinear(nn.Module):
445
  print(f"Loading weights from {path}")
446
  path = configs.load_weights_path + "" + configs.load_weights_dir + "/" + "checkpoint.pth"
447
  if os.path.exists(path):
448
- # print(f"Loading weights from {path}")
449
  checkpoint = torch.load(path)
450
  print(len(self.experts.keys()))
451
  print(self.experts.keys())
 
235
  if self.k > self.num_experts:
236
  print(f"Warning: k ({self.k}) is greater than the number of experts ({self.num_experts}). Setting k to {self.num_experts}.")
237
  self.k = self.num_experts
 
 
238
  self.d_model = configs.d_model
239
  self.mlp_gating = configs.mlp_gating
240
  self.moe_temp = configs.moe_temp
 
365
  else:
366
  self.freq_experts = configs.freq_experts.split('_')
367
 
368
+ #print("self.freq_experts:", self.freq_experts)
369
 
370
  self.moe_loss = None
371
  self.top_k_experts = configs.top_k_experts
 
375
  self.layer_type = configs.layer_type
376
  self.model_name = "SuperLinear"
377
 
378
+ #print("self.layer_type", self.layer_type)
379
  self.layer_dict = {'DLinear': DLinear, 'Linear': Linear, 'NLinear': NLinear, 'RLinear': RLinear}
380
  path = configs.linear_checkpoints_path + configs.linear_checkpoints_dir
381
  dirs = os.listdir(path)
 
420
  print(f"creating expert {i}")
421
  self.experts[str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
422
 
 
 
423
 
424
  if configs.misc_moe>0:
425
  if configs.misc_moe == 1:
426
+ #print("Creating misc expert")
427
  self.experts["misc"] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
428
  else:
429
  for i in range(configs.misc_moe):
430
+ #print(f"Creating misc expert {i}")
431
  self.experts["misc_"+str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
432
+
 
 
433
 
434
 
435
  self.moe = SparseNoisyMoE(configs, experts=self.experts.values())
 
439
  print(f"Loading weights from {path}")
440
  path = configs.load_weights_path + "" + configs.load_weights_dir + "/" + "checkpoint.pth"
441
  if os.path.exists(path):
 
442
  checkpoint = torch.load(path)
443
  print(len(self.experts.keys()))
444
  print(self.experts.keys())