Update modeling_super_linear.py
Browse files- modeling_super_linear.py +28 -28
modeling_super_linear.py
CHANGED
|
@@ -367,18 +367,18 @@ class superLinear(nn.Module):
|
|
| 367 |
else:
|
| 368 |
self.freq_experts = configs.freq_experts.split('_')
|
| 369 |
|
| 370 |
-
|
| 371 |
|
| 372 |
self.moe_loss = None
|
| 373 |
self.top_k_experts = configs.top_k_experts
|
| 374 |
-
# self.noisy_gating = configs.noisy_gating
|
| 375 |
self.n_experts = configs.moe_n_experts
|
| 376 |
self.freeze_experts = configs.freeze_experts
|
| 377 |
self.layer_type = configs.layer_type
|
| 378 |
self.model_name = "SuperLinear"
|
| 379 |
|
| 380 |
-
|
| 381 |
self.layer_dict = {'DLinear': DLinear, 'Linear': Linear, 'NLinear': NLinear, 'RLinear': RLinear}
|
|
|
|
| 382 |
path = configs.linear_checkpoints_path + configs.linear_checkpoints_dir
|
| 383 |
dirs = os.listdir(path)
|
| 384 |
checkpoints_paths = [path + "/" + d + "/" + "checkpoint.pth" for d in dirs]
|
|
@@ -390,31 +390,31 @@ class superLinear(nn.Module):
|
|
| 390 |
cycle = cp.split("/")
|
| 391 |
|
| 392 |
self.experts = {}
|
| 393 |
-
if self.freq_experts is not None:
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
|
| 419 |
self.n_experts = len(self.experts)
|
| 420 |
else:
|
|
|
|
| 367 |
else:
|
| 368 |
self.freq_experts = configs.freq_experts.split('_')
|
| 369 |
|
| 370 |
+
|
| 371 |
|
| 372 |
self.moe_loss = None
|
| 373 |
self.top_k_experts = configs.top_k_experts
|
|
|
|
| 374 |
self.n_experts = configs.moe_n_experts
|
| 375 |
self.freeze_experts = configs.freeze_experts
|
| 376 |
self.layer_type = configs.layer_type
|
| 377 |
self.model_name = "SuperLinear"
|
| 378 |
|
| 379 |
+
|
| 380 |
self.layer_dict = {'DLinear': DLinear, 'Linear': Linear, 'NLinear': NLinear, 'RLinear': RLinear}
|
| 381 |
+
|
| 382 |
path = configs.linear_checkpoints_path + configs.linear_checkpoints_dir
|
| 383 |
dirs = os.listdir(path)
|
| 384 |
checkpoints_paths = [path + "/" + d + "/" + "checkpoint.pth" for d in dirs]
|
|
|
|
| 390 |
cycle = cp.split("/")
|
| 391 |
|
| 392 |
self.experts = {}
|
| 393 |
+
# if self.freq_experts is not None:
|
| 394 |
+
# for expert_freq in self.freq_experts:
|
| 395 |
+
# if expert_freq == "naive" or expert_freq == "Naive":
|
| 396 |
+
# self.experts[expert_freq] = Naive(self.seq_len, self.pred_len)
|
| 397 |
+
# elif expert_freq == "mean" or expert_freq == "Mean":
|
| 398 |
+
# self.experts[expert_freq] = Mean(self.seq_len, self.pred_len)
|
| 399 |
+
# else:
|
| 400 |
+
# self.experts[expert_freq] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
|
| 401 |
+
# if configs.load_linear:
|
| 402 |
+
# cycle = self.map_to_cycle(expert_freq)
|
| 403 |
+
# cycle_str = f'cycle_{cycle}/'
|
| 404 |
+
# cycle_checkpoint_path = [cp for cp in checkpoints_paths if (cycle_str in cp and self.layer_type in cp)]
|
| 405 |
+
# if len(cycle_checkpoint_path) > 0:
|
| 406 |
+
# print()
|
| 407 |
+
# print(cycle_str)
|
| 408 |
+
# cycle_checkpoint_path = cycle_checkpoint_path[0]
|
| 409 |
+
# #print(f'loading checkpoint with layer type: {self.layer_type} and cycle: {cycle_str}')
|
| 410 |
+
# print(cycle_checkpoint_path)
|
| 411 |
+
# self.experts[expert_freq].load_state_dict(torch.load(cycle_checkpoint_path))
|
| 412 |
+
# else:
|
| 413 |
+
# print(f"Checkpoint for {cycle_str} not found in {path}")
|
| 414 |
+
# raise ValueError(f"Checkpoint for {cycle_str} not found in {path}")
|
| 415 |
+
# if configs.freeze_experts:
|
| 416 |
+
# for param in self.experts[expert_freq].parameters():
|
| 417 |
+
# param.requires_grad = False
|
| 418 |
|
| 419 |
self.n_experts = len(self.experts)
|
| 420 |
else:
|