razmars commited on
Commit
fc6971b
·
verified ·
1 Parent(s): 1678415

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +28 -28
modeling_super_linear.py CHANGED
@@ -367,18 +367,18 @@ class superLinear(nn.Module):
367
  else:
368
  self.freq_experts = configs.freq_experts.split('_')
369
 
370
- #print("self.freq_experts:", self.freq_experts)
371
 
372
  self.moe_loss = None
373
  self.top_k_experts = configs.top_k_experts
374
- # self.noisy_gating = configs.noisy_gating
375
  self.n_experts = configs.moe_n_experts
376
  self.freeze_experts = configs.freeze_experts
377
  self.layer_type = configs.layer_type
378
  self.model_name = "SuperLinear"
379
 
380
- #print("self.layer_type", self.layer_type)
381
  self.layer_dict = {'DLinear': DLinear, 'Linear': Linear, 'NLinear': NLinear, 'RLinear': RLinear}
 
382
  path = configs.linear_checkpoints_path + configs.linear_checkpoints_dir
383
  dirs = os.listdir(path)
384
  checkpoints_paths = [path + "/" + d + "/" + "checkpoint.pth" for d in dirs]
@@ -390,31 +390,31 @@ class superLinear(nn.Module):
390
  cycle = cp.split("/")
391
 
392
  self.experts = {}
393
- if self.freq_experts is not None:
394
- for expert_freq in self.freq_experts:
395
- if expert_freq == "naive" or expert_freq == "Naive":
396
- self.experts[expert_freq] = Naive(self.seq_len, self.pred_len)
397
- elif expert_freq == "mean" or expert_freq == "Mean":
398
- self.experts[expert_freq] = Mean(self.seq_len, self.pred_len)
399
- else:
400
- self.experts[expert_freq] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
401
- if configs.load_linear:
402
- cycle = self.map_to_cycle(expert_freq)
403
- cycle_str = f'cycle_{cycle}/'
404
- cycle_checkpoint_path = [cp for cp in checkpoints_paths if (cycle_str in cp and self.layer_type in cp)]
405
- if len(cycle_checkpoint_path) > 0:
406
- print()
407
- print(cycle_str)
408
- cycle_checkpoint_path = cycle_checkpoint_path[0]
409
- #print(f'loading checkpoint with layer type: {self.layer_type} and cycle: {cycle_str}')
410
- print(cycle_checkpoint_path)
411
- self.experts[expert_freq].load_state_dict(torch.load(cycle_checkpoint_path))
412
- else:
413
- print(f"Checkpoint for {cycle_str} not found in {path}")
414
- raise ValueError(f"Checkpoint for {cycle_str} not found in {path}")
415
- if configs.freeze_experts:
416
- for param in self.experts[expert_freq].parameters():
417
- param.requires_grad = False
418
 
419
  self.n_experts = len(self.experts)
420
  else:
 
367
  else:
368
  self.freq_experts = configs.freq_experts.split('_')
369
 
370
+
371
 
372
  self.moe_loss = None
373
  self.top_k_experts = configs.top_k_experts
 
374
  self.n_experts = configs.moe_n_experts
375
  self.freeze_experts = configs.freeze_experts
376
  self.layer_type = configs.layer_type
377
  self.model_name = "SuperLinear"
378
 
379
+
380
  self.layer_dict = {'DLinear': DLinear, 'Linear': Linear, 'NLinear': NLinear, 'RLinear': RLinear}
381
+
382
  path = configs.linear_checkpoints_path + configs.linear_checkpoints_dir
383
  dirs = os.listdir(path)
384
  checkpoints_paths = [path + "/" + d + "/" + "checkpoint.pth" for d in dirs]
 
390
  cycle = cp.split("/")
391
 
392
  self.experts = {}
393
+ # if self.freq_experts is not None:
394
+ # for expert_freq in self.freq_experts:
395
+ # if expert_freq == "naive" or expert_freq == "Naive":
396
+ # self.experts[expert_freq] = Naive(self.seq_len, self.pred_len)
397
+ # elif expert_freq == "mean" or expert_freq == "Mean":
398
+ # self.experts[expert_freq] = Mean(self.seq_len, self.pred_len)
399
+ # else:
400
+ # self.experts[expert_freq] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
401
+ # if configs.load_linear:
402
+ # cycle = self.map_to_cycle(expert_freq)
403
+ # cycle_str = f'cycle_{cycle}/'
404
+ # cycle_checkpoint_path = [cp for cp in checkpoints_paths if (cycle_str in cp and self.layer_type in cp)]
405
+ # if len(cycle_checkpoint_path) > 0:
406
+ # print()
407
+ # print(cycle_str)
408
+ # cycle_checkpoint_path = cycle_checkpoint_path[0]
409
+ # #print(f'loading checkpoint with layer type: {self.layer_type} and cycle: {cycle_str}')
410
+ # print(cycle_checkpoint_path)
411
+ # self.experts[expert_freq].load_state_dict(torch.load(cycle_checkpoint_path))
412
+ # else:
413
+ # print(f"Checkpoint for {cycle_str} not found in {path}")
414
+ # raise ValueError(f"Checkpoint for {cycle_str} not found in {path}")
415
+ # if configs.freeze_experts:
416
+ # for param in self.experts[expert_freq].parameters():
417
+ # param.requires_grad = False
418
 
419
  self.n_experts = len(self.experts)
420
  else: