STLDM_official / data /loader.py
sqfoo's picture
Upload 99 files
6021dd1 verified
import torch.nn.functional as F
from data import dutils
from nowcasting.hko_iterator import HKOIterator
def GET_TrainLoader(meta, param, batch_size, in_len, out_len):
if meta['dataset'] == 'SEVIR':
total_seq_len = in_len + out_len
train_config = {
'data_types': ['vil'],
'layout': 'NTCHW',
'seq_len': total_seq_len,
'raw_seq_len': total_seq_len,
'end_date': dutils.SEVIR_TRAIN_TEST_SPLIT_DATE,
'start_date': None
}
test_config = {
'data_types': ['vil'],
'layout': 'NTCHW',
'seq_len': total_seq_len,
'raw_seq_len': total_seq_len,
'end_date': None,
'start_date': dutils.SEVIR_TRAIN_TEST_SPLIT_DATE
}
train_loader = dutils.SEVIRDataIterator(**train_config, batch_size=batch_size)
test_loader = dutils.SEVIRDataIterator(**test_config, batch_size=8 if batch_size > 8 else batch_size)
return train_loader, test_loader
elif meta['dataset'].startswith('HKO'):
total_seq_len = in_len + out_len
pkl_path = param['pd_path']
train_loader = HKOIterator(pd_path=pkl_path.replace('test', 'train'), sample_mode="random", seq_len=total_seq_len, stride=1)
test_loader = HKOIterator(pd_path=pkl_path, sample_mode="sequent", seq_len=total_seq_len, stride=in_len)
return train_loader, test_loader
elif meta['dataset'] == 'meteonet':
train_loader, test_loader = dutils.load_meteonet(batch_size=batch_size, val_batch_size=8 if batch_size > 8 else batch_size, train=True, **param)
return train_loader, test_loader
else:
raise Exception(f'Undefined dataset config name: {dataset_config["dataset"]}')
def GET_TestLoader(meta, param, batch_size):
if meta['dataset'] == 'SEVIR':
return dutils.SEVIRDataIterator(**param, batch_size=batch_size)
elif meta['dataset'].startswith('HKO'):
return HKOIterator(**param)
elif meta['dataset'] == 'meteonet':
_, test_iter = dutils.load_meteonet(batch_size=batch_size, val_batch_size=8, train=False, **param)
return iter(test_iter)
else:
raise Exception(f'Undefined dataset config name: {dataset_config["dataset"]}')