Update modeling_super_linear.py
Browse files- modeling_super_linear.py +17 -17
modeling_super_linear.py
CHANGED
|
@@ -379,9 +379,9 @@ class superLinear(nn.Module):
|
|
| 379 |
|
| 380 |
self.layer_dict = {'DLinear': DLinear, 'Linear': Linear, 'NLinear': NLinear, 'RLinear': RLinear}
|
| 381 |
|
| 382 |
-
path = configs.linear_checkpoints_path + configs.linear_checkpoints_dir
|
| 383 |
-
dirs = os.listdir(path)
|
| 384 |
-
checkpoints_paths = [path + "/" + d + "/" + "checkpoint.pth" for d in dirs]
|
| 385 |
|
| 386 |
if self.freq_experts == "all":
|
| 387 |
self.freq_experts = []
|
|
@@ -425,11 +425,11 @@ class superLinear(nn.Module):
|
|
| 425 |
|
| 426 |
if configs.misc_moe>0:
|
| 427 |
if configs.misc_moe == 1:
|
| 428 |
-
|
| 429 |
self.experts["misc"] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
|
| 430 |
else:
|
| 431 |
for i in range(configs.misc_moe):
|
| 432 |
-
|
| 433 |
self.experts["misc_"+str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
|
| 434 |
|
| 435 |
|
|
@@ -437,18 +437,18 @@ class superLinear(nn.Module):
|
|
| 437 |
self.moe = SparseNoisyMoE(configs, experts=self.experts.values())
|
| 438 |
self.dropout = nn.Dropout(configs.dropout)
|
| 439 |
|
| 440 |
-
if configs.load_weights:
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
|
| 453 |
|
| 454 |
def map_to_cycle(self, freq):
|
|
|
|
| 379 |
|
| 380 |
self.layer_dict = {'DLinear': DLinear, 'Linear': Linear, 'NLinear': NLinear, 'RLinear': RLinear}
|
| 381 |
|
| 382 |
+
# path = configs.linear_checkpoints_path + configs.linear_checkpoints_dir
|
| 383 |
+
# dirs = os.listdir(path)
|
| 384 |
+
# checkpoints_paths = [path + "/" + d + "/" + "checkpoint.pth" for d in dirs]
|
| 385 |
|
| 386 |
if self.freq_experts == "all":
|
| 387 |
self.freq_experts = []
|
|
|
|
| 425 |
|
| 426 |
if configs.misc_moe>0:
|
| 427 |
if configs.misc_moe == 1:
|
| 428 |
+
print("Creating misc expert")
|
| 429 |
self.experts["misc"] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
|
| 430 |
else:
|
| 431 |
for i in range(configs.misc_moe):
|
| 432 |
+
print(f"Creating misc expert {i}")
|
| 433 |
self.experts["misc_"+str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
|
| 434 |
|
| 435 |
|
|
|
|
| 437 |
self.moe = SparseNoisyMoE(configs, experts=self.experts.values())
|
| 438 |
self.dropout = nn.Dropout(configs.dropout)
|
| 439 |
|
| 440 |
+
# if configs.load_weights:
|
| 441 |
+
# print(f"Loading weights from {path}")
|
| 442 |
+
# path = configs.load_weights_path + "" + configs.load_weights_dir + "/" + "checkpoint.pth"
|
| 443 |
+
# if os.path.exists(path):
|
| 444 |
+
# checkpoint = torch.load(path)
|
| 445 |
+
# print(len(self.experts.keys()))
|
| 446 |
+
# print(self.experts.keys())
|
| 447 |
+
# print(self.state_dict().keys())
|
| 448 |
+
# print(checkpoint.keys())
|
| 449 |
+
# self.load_state_dict(checkpoint)
|
| 450 |
+
# else:
|
| 451 |
+
# print(f"Path {path} does not exist. Skipping loading weights.")
|
| 452 |
|
| 453 |
|
| 454 |
def map_to_cycle(self, freq):
|