|
|
''' |
|
|
|
|
|
''' |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
import logging |
|
|
import argparse |
|
|
import numpy as np |
|
|
|
|
|
from torch import nn |
|
|
from torch.utils import tensorboard |
|
|
|
|
|
from stldm import * |
|
|
|
|
|
from data import dutils |
|
|
import utilspp as utpp |
|
|
from utilspp import SequentialLR, warmup_lambda |
|
|
from data.config import SEVIR_13_12, HKO7_5_20, METEONET_5_20 |
|
|
from data.loader import GET_TrainLoader |
|
|
from data.dutils import resize |
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument('-d', '--dataset', type=str, default='', help='Dataset config to be trained') |
|
|
parser.add_argument('--seq_len', type=int, default=10, help='The input sequence length') |
|
|
parser.add_argument('--out_len', type=int, default=10, help='The output (prediction) sequence length') |
|
|
|
|
|
parser.add_argument('-f', dest='checkpt', type=str, default='', help='model checkpoint to be loaded from (Empty = not loading)') |
|
|
parser.add_argument('-o', '--output', type=str, default='ckpts', help='The output directory') |
|
|
parser.add_argument('-m', '--model', type=str, default='', help='The global configuration to be used (The var name in config.py)') |
|
|
parser.add_argument('--type', type=str, default='3D', help='Determine which kind of model to use, 2D or 3D') |
|
|
|
|
|
parser.add_argument('--ae_ckpt', type=str, default=None, help='Pre-trained AE checkpoint, freeze it during training') |
|
|
parser.add_argument('--ae_eval', action='store_false', help='Set AE to be trainable') |
|
|
parser.add_argument('--back_ckpt', type=str, default=None, help='Pre-trained backbone checkpoint, freeze it during traing') |
|
|
parser.add_argument('--back_eval', action='store_false', help='Set Backbone to be trainable') |
|
|
parser.add_argument('--set_mu_to_0', action='store_false', help='Set the constraint loss to 0') |
|
|
|
|
|
parser.add_argument('--lr', type=float, default=0.0001, help='The initial learning rate') |
|
|
parser.add_argument('-e', '--epoch', type=int, default=50, help='The number of epochs to run') |
|
|
parser.add_argument('-s', "--training_steps", type=int, default=200000, help="number of training steps") |
|
|
parser.add_argument('-b', '--batch_size', type=int, default=4, help='The batch size') |
|
|
parser.add_argument('--micro_batch', type=int, default=1, help='Micro Batch size') |
|
|
|
|
|
parser.add_argument('--print_every', type=int, default=100, help='The number of steps to log the training loss') |
|
|
parser.add_argument('--validate_every', type=int, default=5, help='The number of steps to perform validation once') |
|
|
parser.add_argument('--v_steps', type=int, default=50, help='Validation steps') |
|
|
parser.add_argument('--remarks', type=str, default='', help='This section will affect the model name to be saved') |
|
|
parser.add_argument('--save_every_epoch', action='store_true', help='Save ckpt for every validation epochs, otherwise save the best') |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
assert args.model != '', 'You must specify the model config using -m/--model!' |
|
|
|
|
|
|
|
|
dataset_config = globals()[args.dataset] |
|
|
dataset_type = dataset_config['savedir'] |
|
|
dataset_param, dataset_meta = dataset_config['param'], dataset_config['meta'] |
|
|
|
|
|
model_config = globals()[args.model] |
|
|
model_type = model_config['model'] |
|
|
save_path = utpp.build_model_path(args.output, dataset_type, model_type, timestamp=True) + args.remarks |
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
img_size = model_config['vp_param']['shape_in'][-1] |
|
|
|
|
|
total_seq_len = args.seq_len + args.out_len |
|
|
|
|
|
if dataset_type.startswith('meteo'): |
|
|
train_iter, validate_iter = GET_TrainLoader(dataset_meta, dataset_param, args.batch_size, args.seq_len, args.out_len) |
|
|
train_loader, valid_loader = iter(train_iter), iter(validate_iter) |
|
|
else: |
|
|
train_loader, valid_loader = GET_TrainLoader(dataset_meta, dataset_param, args.batch_size, args.seq_len, args.out_len) |
|
|
|
|
|
if dataset_type.startswith('sevir'): |
|
|
steps_per_epoch = len(train_loader) |
|
|
epochs = args.epoch |
|
|
elif dataset_type.startswith('hko'): |
|
|
steps_per_epoch = 2500 |
|
|
epochs = args.training_steps // steps_per_epoch |
|
|
elif dataset_type.startswith('meteo'): |
|
|
steps_per_epoch = len(train_loader) |
|
|
epochs = args.training_steps // steps_per_epoch |
|
|
else: |
|
|
raise Exception(f'Undefined dataset config name: {dataset_type}') |
|
|
total_training_steps = epochs * steps_per_epoch |
|
|
|
|
|
|
|
|
model_param = model_config['param'] |
|
|
model_pathname = utpp.build_model_name(model_type, model_param) |
|
|
setattr(args, 'step', total_training_steps) |
|
|
|
|
|
|
|
|
logfile_name = os.path.join(save_path, f'_log.log') |
|
|
logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler(logfile_name), logging.StreamHandler()], format='%(message)s') |
|
|
logging.info(f'args: {args}') |
|
|
logging.info('The resulting model will be saved as: {}'.format(os.path.join(save_path, model_pathname))) |
|
|
logging.info(f'Training Epochs: {epochs} and Total Training Steps: {total_training_steps}') |
|
|
|
|
|
log_dir = os.path.join(save_path, 'logs') |
|
|
writer = tensorboard.SummaryWriter(log_dir) |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
setattr(args, 'device', device) |
|
|
assert args.type in ['2D', '3D'], 'Please specify either 2D or 3D' |
|
|
model = n2n_setup[args.type](model_config).to(device) |
|
|
|
|
|
assert args.ae_ckpt!=args.back_ckpt or (args.ae_ckpt is None and args.back_ckpt is None), 'Please specify from End to End (set both to None), LDM only (set args.back_ckpt), or LDM + Meta (set args.ae_ckpt)' |
|
|
|
|
|
if args.ae_ckpt is not None: |
|
|
try: |
|
|
data = torch.load(args.ae_ckpt) |
|
|
model.backbone.vae.load_state_dict(data) |
|
|
model.backbone.vae.requires_grad_(args.ae_eval) |
|
|
logging.info(f'Load pre-trained AE well, Set require grads to be {args.ae_eval}') |
|
|
except: |
|
|
logging.info('Failed to load pre-trained AE') |
|
|
|
|
|
if args.back_ckpt is not None: |
|
|
try: |
|
|
model.backbone.load_state_dict(torch.load(args.back_ckpt, map_location=torch.device(device))) |
|
|
model.backbone.requires_grad_(args.back_eval) |
|
|
logging.info(f'Load pre-trained backbone well, Set require grads to be {args.back_eval}') |
|
|
except: |
|
|
logging.info('Failed to load pre-trained backbone') |
|
|
|
|
|
if args.checkpt != '': |
|
|
try: |
|
|
model.load_state_dict(torch.load(args.checkpt, map_location=torch.device(device))) |
|
|
except: |
|
|
logging.error("Loading weights failed") |
|
|
|
|
|
logging.info(f'Set require grads of VAE to be {args.ae_eval}') |
|
|
logging.info(f'Set require grads of backbone to be {args.back_eval}') |
|
|
|
|
|
|
|
|
trainable_params = list(filter(lambda p: p.requires_grad, model.parameters())) |
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) |
|
|
|
|
|
warmup_iter = 2000 |
|
|
warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lambda(warmup_steps=warmup_iter, min_lr_ratio=0.1)) |
|
|
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(total_training_steps - warmup_iter)//args.micro_batch, eta_min=1e-6) |
|
|
scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_iter]) |
|
|
|
|
|
best_val_loss = 1e10 |
|
|
total_step = 0 |
|
|
for epoch in range(1, epochs+1): |
|
|
if dataset_type.startswith('sevir'): |
|
|
train_loader.reset() |
|
|
elif dataset_type.startswith('meteo'): |
|
|
train_loader = iter(train_iter) |
|
|
|
|
|
for step in range(steps_per_epoch): |
|
|
total_step += 1 |
|
|
model.train() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
if args.ae_eval == False: |
|
|
model.backbone.vae.eval() |
|
|
|
|
|
if args.back_eval == False: |
|
|
model.backbone.eval() |
|
|
|
|
|
if dataset_type == 'sevir': |
|
|
data = train_loader.sample(batch_size=args.batch_size) |
|
|
x, y = data['vil'][:, :args.seq_len], data['vil'][:, args.seq_len:] |
|
|
elif dataset_type.startswith('meteo'): |
|
|
data = next(train_loader) |
|
|
x, y = data |
|
|
elif dataset_type.startswith('hko'): |
|
|
x_seq, x_mask, dt_clip, _ = train_loader.sample(batch_size=args.batch_size) |
|
|
x, y = utpp.hko7_preprocess(x_seq, x_mask, dt_clip, args) |
|
|
|
|
|
x, y = x.to(device), y.to(device) |
|
|
if x.shape[-1] != img_size: |
|
|
x, y = resize(x, img_size), resize(y, img_size) |
|
|
if model_config['pre'] is not None: |
|
|
x = model_config['pre'](x) |
|
|
y = model_config['pre'](y) |
|
|
|
|
|
recon_loss, mu_loss, diff_loss, prior_loss = model.compute_loss(x, y) |
|
|
loss = (recon_loss + mu_loss + diff_loss + prior_loss) |
|
|
loss.backward() |
|
|
|
|
|
if total_step% args.micro_batch == 0: |
|
|
if args.back_ckpt is None: |
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if total_step == 1 or total_step % args.print_every == 0: |
|
|
logging.info(f'[Epoch {epoch}][Step {step}] recon_loss: {float(recon_loss):.4}, mu_loss: {float(mu_loss):.4}, diff_loss: {float(diff_loss):.4}') |
|
|
writer.add_scalar('Training recon_loss', float(recon_loss), global_step=total_step) |
|
|
writer.add_scalar('Training mu_loss', float(mu_loss), global_step=total_step) |
|
|
writer.add_scalar('Training diff_loss', float(diff_loss), global_step=total_step) |
|
|
writer.add_scalar('LR', optimizer.param_groups[0]['lr'], global_step=total_step) |
|
|
|
|
|
|
|
|
if epoch == 1 or epoch % args.validate_every == 0: |
|
|
rand_batch = np.random.randint(min(args.batch_size, 8)) |
|
|
if dataset_type == 'sevir' or dataset_type.startswith('hko'): |
|
|
valid_loader.reset() |
|
|
elif dataset_type.startswith('meteo'): |
|
|
valid_loader = iter(validate_iter) |
|
|
|
|
|
acc_ae, acc_diff, acc_mu = 0, 0, 0 |
|
|
for v_step in range(args.v_steps): |
|
|
model.eval() |
|
|
|
|
|
if dataset_type == 'sevir': |
|
|
data = valid_loader.sample(batch_size=args.batch_size) |
|
|
x, y = data['vil'][:, :args.seq_len], data['vil'][:, args.seq_len:] |
|
|
elif dataset_type.startswith('meteo'): |
|
|
data = next(valid_loader) |
|
|
x, y = data |
|
|
elif dataset_type.startswith('hko'): |
|
|
x_seq, x_mask, dt_clip, _ = valid_loader.sample(batch_size=args.batch_size) |
|
|
x, y = utpp.hko7_preprocess(x_seq, x_mask, dt_clip, args) |
|
|
x, y = x.to(device), y.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
if x.shape[-1] != img_size: |
|
|
x, y = resize(x, img_size), resize(y, img_size) |
|
|
if model_config['pre'] is not None: |
|
|
x = model_config['pre'](x) |
|
|
y = model_config['pre'](y) |
|
|
ae_loss, mu_loss, diff_loss, _ = model.compute_loss(x, y, validate=True) |
|
|
acc_ae += ae_loss |
|
|
acc_diff += diff_loss |
|
|
acc_mu += mu_loss |
|
|
if model_config['post'] is not None: |
|
|
x = model_config['post'](x) |
|
|
y = model_config['post'](y) |
|
|
|
|
|
|
|
|
acc_ae, acc_mu, acc_diff = acc_ae/args.v_steps, acc_mu/args.v_steps, acc_diff/args.v_steps |
|
|
writer.add_scalar('Val AE loss', float(acc_ae), global_step=total_step) |
|
|
writer.add_scalar('Val VP loss', float(acc_mu), global_step=total_step) |
|
|
writer.add_scalar('Val Diff loss', float(acc_diff), global_step=total_step) |
|
|
logging.info(f'[Epoch {epoch}][Validation] AE_loss:{float(acc_ae):.4}, VP_loss:{float(acc_mu):.4}, Diff_loss:{float(acc_diff):.4}') |
|
|
val_loss = (acc_mu+acc_diff)/2 |
|
|
|
|
|
with torch.no_grad(): |
|
|
if model_config['pre'] is not None: |
|
|
x = model_config['pre'](x) |
|
|
y_pred, mu = model(x, include_mu=True) |
|
|
if model_config['post'] is not None: |
|
|
y_pred = model_config['post'](y_pred) |
|
|
mu = model_config['post'](mu) |
|
|
x = model_config['post'](x) |
|
|
|
|
|
out_x, out_y, mu_pred, out_y_pred = x[rand_batch].unsqueeze(0), y[rand_batch].unsqueeze(0), mu[rand_batch].unsqueeze(0), y_pred[rand_batch].unsqueeze(0) |
|
|
utpp.torch_visualize({'input': out_x, 'ground truth': out_y, 'mu_pred': mu_pred, 'predicted': out_y_pred}, savedir=os.path.join(save_path, f'temp-{total_step}.png'), vmin=0, vmax=1) |
|
|
|
|
|
if args.save_every_epoch: |
|
|
torch.save(model.state_dict(), os.path.join(save_path, f'{model_pathname}_epoch={epoch}.pt')) |
|
|
else: |
|
|
if val_loss < best_val_loss: |
|
|
torch.save(model.state_dict(), os.path.join(save_path, f'{model_pathname}_best.pt')) |
|
|
best_val_loss = val_loss |
|
|
|
|
|
torch.save(model.state_dict(), os.path.join(save_path, f'{model_pathname}_final.pt')) |
|
|
|