Update modeling_super_linear.py
Browse files- modeling_super_linear.py +5 -12
modeling_super_linear.py
CHANGED
|
@@ -235,8 +235,6 @@ class SparseNoisyMoE(nn.Module):
|
|
| 235 |
if self.k > self.num_experts:
|
| 236 |
print(f"Warning: k ({self.k}) is greater than the number of experts ({self.num_experts}). Setting k to {self.num_experts}.")
|
| 237 |
self.k = self.num_experts
|
| 238 |
-
# self.ker_len = configs.ker_len
|
| 239 |
-
#self.con = configs.con
|
| 240 |
self.d_model = configs.d_model
|
| 241 |
self.mlp_gating = configs.mlp_gating
|
| 242 |
self.moe_temp = configs.moe_temp
|
|
@@ -367,7 +365,7 @@ class superLinear(nn.Module):
|
|
| 367 |
else:
|
| 368 |
self.freq_experts = configs.freq_experts.split('_')
|
| 369 |
|
| 370 |
-
print("self.freq_experts:", self.freq_experts)
|
| 371 |
|
| 372 |
self.moe_loss = None
|
| 373 |
self.top_k_experts = configs.top_k_experts
|
|
@@ -377,7 +375,7 @@ class superLinear(nn.Module):
|
|
| 377 |
self.layer_type = configs.layer_type
|
| 378 |
self.model_name = "SuperLinear"
|
| 379 |
|
| 380 |
-
print("self.layer_type", self.layer_type)
|
| 381 |
self.layer_dict = {'DLinear': DLinear, 'Linear': Linear, 'NLinear': NLinear, 'RLinear': RLinear}
|
| 382 |
path = configs.linear_checkpoints_path + configs.linear_checkpoints_dir
|
| 383 |
dirs = os.listdir(path)
|
|
@@ -422,20 +420,16 @@ class superLinear(nn.Module):
|
|
| 422 |
print(f"creating expert {i}")
|
| 423 |
self.experts[str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
|
| 424 |
|
| 425 |
-
#self.manual_moe = configs.manual_moe
|
| 426 |
-
|
| 427 |
|
| 428 |
if configs.misc_moe>0:
|
| 429 |
if configs.misc_moe == 1:
|
| 430 |
-
print("Creating misc expert")
|
| 431 |
self.experts["misc"] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
|
| 432 |
else:
|
| 433 |
for i in range(configs.misc_moe):
|
| 434 |
-
print(f"Creating misc expert {i}")
|
| 435 |
self.experts["misc_"+str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
|
| 436 |
-
|
| 437 |
-
print("Creating misc expert")
|
| 438 |
-
self.experts["misc2"] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)'''
|
| 439 |
|
| 440 |
|
| 441 |
self.moe = SparseNoisyMoE(configs, experts=self.experts.values())
|
|
@@ -445,7 +439,6 @@ class superLinear(nn.Module):
|
|
| 445 |
print(f"Loading weights from {path}")
|
| 446 |
path = configs.load_weights_path + "" + configs.load_weights_dir + "/" + "checkpoint.pth"
|
| 447 |
if os.path.exists(path):
|
| 448 |
-
# print(f"Loading weights from {path}")
|
| 449 |
checkpoint = torch.load(path)
|
| 450 |
print(len(self.experts.keys()))
|
| 451 |
print(self.experts.keys())
|
|
|
|
| 235 |
if self.k > self.num_experts:
|
| 236 |
print(f"Warning: k ({self.k}) is greater than the number of experts ({self.num_experts}). Setting k to {self.num_experts}.")
|
| 237 |
self.k = self.num_experts
|
|
|
|
|
|
|
| 238 |
self.d_model = configs.d_model
|
| 239 |
self.mlp_gating = configs.mlp_gating
|
| 240 |
self.moe_temp = configs.moe_temp
|
|
|
|
| 365 |
else:
|
| 366 |
self.freq_experts = configs.freq_experts.split('_')
|
| 367 |
|
| 368 |
+
#print("self.freq_experts:", self.freq_experts)
|
| 369 |
|
| 370 |
self.moe_loss = None
|
| 371 |
self.top_k_experts = configs.top_k_experts
|
|
|
|
| 375 |
self.layer_type = configs.layer_type
|
| 376 |
self.model_name = "SuperLinear"
|
| 377 |
|
| 378 |
+
#print("self.layer_type", self.layer_type)
|
| 379 |
self.layer_dict = {'DLinear': DLinear, 'Linear': Linear, 'NLinear': NLinear, 'RLinear': RLinear}
|
| 380 |
path = configs.linear_checkpoints_path + configs.linear_checkpoints_dir
|
| 381 |
dirs = os.listdir(path)
|
|
|
|
| 420 |
print(f"creating expert {i}")
|
| 421 |
self.experts[str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
|
| 422 |
|
|
|
|
|
|
|
| 423 |
|
| 424 |
if configs.misc_moe>0:
|
| 425 |
if configs.misc_moe == 1:
|
| 426 |
+
#print("Creating misc expert")
|
| 427 |
self.experts["misc"] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
|
| 428 |
else:
|
| 429 |
for i in range(configs.misc_moe):
|
| 430 |
+
#print(f"Creating misc expert {i}")
|
| 431 |
self.experts["misc_"+str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
|
| 432 |
+
|
|
|
|
|
|
|
| 433 |
|
| 434 |
|
| 435 |
self.moe = SparseNoisyMoE(configs, experts=self.experts.values())
|
|
|
|
| 439 |
print(f"Loading weights from {path}")
|
| 440 |
path = configs.load_weights_path + "" + configs.load_weights_dir + "/" + "checkpoint.pth"
|
| 441 |
if os.path.exists(path):
|
|
|
|
| 442 |
checkpoint = torch.load(path)
|
| 443 |
print(len(self.experts.keys()))
|
| 444 |
print(self.experts.keys())
|