razmars commited on
Commit
3c7220d
·
verified ·
1 Parent(s): 29025e3

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +32 -32
modeling_super_linear.py CHANGED
@@ -451,38 +451,38 @@ class superLinear(nn.Module):
451
  # print(f"Path {path} does not exist. Skipping loading weights.")
452
 
453
 
454
- def map_to_cycle(self, freq):
455
- if "/" in freq:
456
- cycle = int(freq.split("/")[1])
457
- elif "h" in freq:
458
- cycle = 24
459
- elif "2h":
460
- cycle = 12
461
- elif "3h" in freq:
462
- cycle = 8
463
- elif "4h" in freq:
464
- cycle = 6
465
- elif "D" in freq:
466
- cycle = 7
467
- elif "DM" in freq:
468
- cycle = 30
469
- elif "W" in freq:
470
- cycle = 52
471
- elif "M" in freq:
472
- cycle = 12
473
- elif "min" in freq:
474
- cycle = 1440
475
- elif "5min" in freq:
476
- cycle = 288
477
- elif "10min" in freq:
478
- cycle = 144
479
- elif "15min" in freq:
480
- cycle = 96
481
- elif "30min" in freq:
482
- cycle = 48
483
- else:
484
- cycle = int(freq)
485
- return cycle
486
 
487
 
488
  def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None, freq=[None], get_prob=False, inf_pred_len=None):
 
451
  # print(f"Path {path} does not exist. Skipping loading weights.")
452
 
453
 
454
+ # def map_to_cycle(self, freq):
455
+ # if "/" in freq:
456
+ # cycle = int(freq.split("/")[1])
457
+ # elif "h" in freq:
458
+ # cycle = 24
459
+ # elif "2h":
460
+ # cycle = 12
461
+ # elif "3h" in freq:
462
+ # cycle = 8
463
+ # elif "4h" in freq:
464
+ # cycle = 6
465
+ # elif "D" in freq:
466
+ # cycle = 7
467
+ # elif "DM" in freq:
468
+ # cycle = 30
469
+ # elif "W" in freq:
470
+ # cycle = 52
471
+ # elif "M" in freq:
472
+ # cycle = 12
473
+ # elif "min" in freq:
474
+ # cycle = 1440
475
+ # elif "5min" in freq:
476
+ # cycle = 288
477
+ # elif "10min" in freq:
478
+ # cycle = 144
479
+ # elif "15min" in freq:
480
+ # cycle = 96
481
+ # elif "30min" in freq:
482
+ # cycle = 48
483
+ # else:
484
+ # cycle = int(freq)
485
+ # return cycle
486
 
487
 
488
  def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None, freq=[None], get_prob=False, inf_pred_len=None):