File size: 2,272 Bytes
6021dd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch.nn.functional as F

from data import dutils
from nowcasting.hko_iterator import HKOIterator

def GET_TrainLoader(meta, param, batch_size, in_len, out_len):
    if meta['dataset'] == 'SEVIR':
        total_seq_len = in_len + out_len
        train_config = {
            'data_types': ['vil'],
            'layout': 'NTCHW',
            'seq_len': total_seq_len,
            'raw_seq_len': total_seq_len,
            'end_date': dutils.SEVIR_TRAIN_TEST_SPLIT_DATE,
            'start_date': None
        }
        test_config = {
            'data_types': ['vil'],
            'layout': 'NTCHW',
            'seq_len': total_seq_len,
            'raw_seq_len': total_seq_len,
            'end_date': None,
            'start_date': dutils.SEVIR_TRAIN_TEST_SPLIT_DATE
        }
        train_loader = dutils.SEVIRDataIterator(**train_config, batch_size=batch_size)
        test_loader = dutils.SEVIRDataIterator(**test_config, batch_size=8 if batch_size > 8 else batch_size)
        return train_loader, test_loader
    elif meta['dataset'].startswith('HKO'):
        total_seq_len = in_len + out_len
        pkl_path = param['pd_path']
        train_loader = HKOIterator(pd_path=pkl_path.replace('test', 'train'), sample_mode="random", seq_len=total_seq_len, stride=1)
        test_loader = HKOIterator(pd_path=pkl_path, sample_mode="sequent", seq_len=total_seq_len, stride=in_len)
        return train_loader, test_loader
    elif meta['dataset'] == 'meteonet':
        train_loader, test_loader = dutils.load_meteonet(batch_size=batch_size, val_batch_size=8 if batch_size > 8 else batch_size, train=True, **param)
        return train_loader, test_loader
    else:
        raise Exception(f'Undefined dataset config name: {dataset_config["dataset"]}')

def GET_TestLoader(meta, param, batch_size):
    if meta['dataset'] == 'SEVIR':
        return dutils.SEVIRDataIterator(**param, batch_size=batch_size)
    elif meta['dataset'].startswith('HKO'):
        return HKOIterator(**param)
    elif meta['dataset'] == 'meteonet':
        _, test_iter = dutils.load_meteonet(batch_size=batch_size, val_batch_size=8, train=False, **param)
        return iter(test_iter)
    else:
        raise Exception(f'Undefined dataset config name: {dataset_config["dataset"]}')