Update modeling_super_linear.py
Browse files- modeling_super_linear.py +30 -30
modeling_super_linear.py
CHANGED
|
@@ -390,37 +390,37 @@ class superLinear(nn.Module):
|
|
| 390 |
cycle = cp.split("/")
|
| 391 |
|
| 392 |
self.experts = {}
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
|
| 425 |
|
| 426 |
if configs.misc_moe>0:
|
|
|
|
| 390 |
cycle = cp.split("/")
|
| 391 |
|
| 392 |
self.experts = {}
|
| 393 |
+
if self.freq_experts is not None:
|
| 394 |
+
for expert_freq in self.freq_experts:
|
| 395 |
+
if expert_freq == "naive" or expert_freq == "Naive":
|
| 396 |
+
self.experts[expert_freq] = Naive(self.seq_len, self.pred_len)
|
| 397 |
+
elif expert_freq == "mean" or expert_freq == "Mean":
|
| 398 |
+
self.experts[expert_freq] = Mean(self.seq_len, self.pred_len)
|
| 399 |
+
else:
|
| 400 |
+
self.experts[expert_freq] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
|
| 401 |
+
# if configs.load_linear:
|
| 402 |
+
# cycle = self.map_to_cycle(expert_freq)
|
| 403 |
+
# cycle_str = f'cycle_{cycle}/'
|
| 404 |
+
# cycle_checkpoint_path = [cp for cp in checkpoints_paths if (cycle_str in cp and self.layer_type in cp)]
|
| 405 |
+
# if len(cycle_checkpoint_path) > 0:
|
| 406 |
+
# print()
|
| 407 |
+
# print(cycle_str)
|
| 408 |
+
# cycle_checkpoint_path = cycle_checkpoint_path[0]
|
| 409 |
+
# #print(f'loading checkpoint with layer type: {self.layer_type} and cycle: {cycle_str}')
|
| 410 |
+
# print(cycle_checkpoint_path)
|
| 411 |
+
# self.experts[expert_freq].load_state_dict(torch.load(cycle_checkpoint_path))
|
| 412 |
+
# else:
|
| 413 |
+
# print(f"Checkpoint for {cycle_str} not found in {path}")
|
| 414 |
+
# raise ValueError(f"Checkpoint for {cycle_str} not found in {path}")
|
| 415 |
+
# if configs.freeze_experts:
|
| 416 |
+
# for param in self.experts[expert_freq].parameters():
|
| 417 |
+
# param.requires_grad = False
|
| 418 |
|
| 419 |
+
self.n_experts = len(self.experts)
|
| 420 |
+
else:
|
| 421 |
+
for i in range(self.n_experts):
|
| 422 |
+
print(f"creating expert {i}")
|
| 423 |
+
self.experts[str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
|
| 424 |
|
| 425 |
|
| 426 |
if configs.misc_moe>0:
|