STLDM_official / train.py
sqfoo's picture
Upload 99 files
6021dd1 verified
'''
'''
import os
import sys
import torch
import logging
import argparse
import numpy as np
from torch import nn
from torch.utils import tensorboard
from stldm import *
# Library Issue
from data import dutils
import utilspp as utpp
from utilspp import SequentialLR, warmup_lambda
from data.config import SEVIR_13_12, HKO7_5_20, METEONET_5_20
from data.loader import GET_TrainLoader
from data.dutils import resize
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# dataset related
parser.add_argument('-d', '--dataset', type=str, default='', help='Dataset config to be trained')
parser.add_argument('--seq_len', type=int, default=10, help='The input sequence length')
parser.add_argument('--out_len', type=int, default=10, help='The output (prediction) sequence length')
# model related
parser.add_argument('-f', dest='checkpt', type=str, default='', help='model checkpoint to be loaded from (Empty = not loading)')
parser.add_argument('-o', '--output', type=str, default='ckpts', help='The output directory')
parser.add_argument('-m', '--model', type=str, default='', help='The global configuration to be used (The var name in config.py)')
parser.add_argument('--type', type=str, default='3D', help='Determine which kind of model to use, 2D or 3D')
# Training Components Related
parser.add_argument('--ae_ckpt', type=str, default=None, help='Pre-trained AE checkpoint, freeze it during training')
parser.add_argument('--ae_eval', action='store_false', help='Set AE to be trainable')
parser.add_argument('--back_ckpt', type=str, default=None, help='Pre-trained backbone checkpoint, freeze it during traing')
parser.add_argument('--back_eval', action='store_false', help='Set Backbone to be trainable')
parser.add_argument('--set_mu_to_0', action='store_false', help='Set the constraint loss to 0')
# hyperparams
parser.add_argument('--lr', type=float, default=0.0001, help='The initial learning rate')
parser.add_argument('-e', '--epoch', type=int, default=50, help='The number of epochs to run')
parser.add_argument('-s', "--training_steps", type=int, default=200000, help="number of training steps")
parser.add_argument('-b', '--batch_size', type=int, default=4, help='The batch size')
parser.add_argument('--micro_batch', type=int, default=1, help='Micro Batch size')
# logging related
parser.add_argument('--print_every', type=int, default=100, help='The number of steps to log the training loss')
parser.add_argument('--validate_every', type=int, default=5, help='The number of steps to perform validation once')
parser.add_argument('--v_steps', type=int, default=50, help='Validation steps')
parser.add_argument('--remarks', type=str, default='', help='This section will affect the model name to be saved')
parser.add_argument('--save_every_epoch', action='store_true', help='Save ckpt for every validation epochs, otherwise save the best')
args = parser.parse_args()
# args validation
assert args.model != '', 'You must specify the model config using -m/--model!'
# read the model config
dataset_config = globals()[args.dataset]
dataset_type = dataset_config['savedir']
dataset_param, dataset_meta = dataset_config['param'], dataset_config['meta']
model_config = globals()[args.model]
model_type = model_config['model']
save_path = utpp.build_model_path(args.output, dataset_type, model_type, timestamp=True) + args.remarks
os.makedirs(save_path, exist_ok=True)
img_size = model_config['vp_param']['shape_in'][-1]
# prepare dataloader
total_seq_len = args.seq_len + args.out_len
if dataset_type.startswith('meteo'):
train_iter, validate_iter = GET_TrainLoader(dataset_meta, dataset_param, args.batch_size, args.seq_len, args.out_len)
train_loader, valid_loader = iter(train_iter), iter(validate_iter)
else:
train_loader, valid_loader = GET_TrainLoader(dataset_meta, dataset_param, args.batch_size, args.seq_len, args.out_len)
if dataset_type.startswith('sevir'):
steps_per_epoch = len(train_loader)
epochs = args.epoch
elif dataset_type.startswith('hko'):
steps_per_epoch = 2500
epochs = args.training_steps // steps_per_epoch
elif dataset_type.startswith('meteo'):
steps_per_epoch = len(train_loader)
epochs = args.training_steps // steps_per_epoch
else:
raise Exception(f'Undefined dataset config name: {dataset_type}')
total_training_steps = epochs * steps_per_epoch
# define the model
model_param = model_config['param']
model_pathname = utpp.build_model_name(model_type, model_param)
setattr(args, 'step', total_training_steps)
# prepare logger
logfile_name = os.path.join(save_path, f'_log.log')
logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler(logfile_name), logging.StreamHandler()], format='%(message)s')
logging.info(f'args: {args}')
logging.info('The resulting model will be saved as: {}'.format(os.path.join(save_path, model_pathname)))
logging.info(f'Training Epochs: {epochs} and Total Training Steps: {total_training_steps}')
# Writing logs for tensorboard
log_dir = os.path.join(save_path, 'logs')
writer = tensorboard.SummaryWriter(log_dir)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
setattr(args, 'device', device)
assert args.type in ['2D', '3D'], 'Please specify either 2D or 3D'
model = n2n_setup[args.type](model_config).to(device)
assert args.ae_ckpt!=args.back_ckpt or (args.ae_ckpt is None and args.back_ckpt is None), 'Please specify from End to End (set both to None), LDM only (set args.back_ckpt), or LDM + Meta (set args.ae_ckpt)'
# Load Pre-trained AutoEncoder
if args.ae_ckpt is not None:
try:
data = torch.load(args.ae_ckpt)
model.backbone.vae.load_state_dict(data)
model.backbone.vae.requires_grad_(args.ae_eval)
logging.info(f'Load pre-trained AE well, Set require grads to be {args.ae_eval}')
except:
logging.info('Failed to load pre-trained AE')
if args.back_ckpt is not None:
try:
model.backbone.load_state_dict(torch.load(args.back_ckpt, map_location=torch.device(device)))
model.backbone.requires_grad_(args.back_eval)
logging.info(f'Load pre-trained backbone well, Set require grads to be {args.back_eval}')
except:
logging.info('Failed to load pre-trained backbone')
if args.checkpt != '':
try:
model.load_state_dict(torch.load(args.checkpt, map_location=torch.device(device)))
except:
logging.error("Loading weights failed")
logging.info(f'Set require grads of VAE to be {args.ae_eval}')
logging.info(f'Set require grads of backbone to be {args.back_eval}')
# The original methods in the NeurIPS 2015 paper
trainable_params = list(filter(lambda p: p.requires_grad, model.parameters()))
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
warmup_iter = 2000
warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lambda(warmup_steps=warmup_iter, min_lr_ratio=0.1))
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(total_training_steps - warmup_iter)//args.micro_batch, eta_min=1e-6)
scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_iter])
best_val_loss = 1e10
total_step = 0
for epoch in range(1, epochs+1):
if dataset_type.startswith('sevir'):
train_loader.reset()
elif dataset_type.startswith('meteo'):
train_loader = iter(train_iter)
for step in range(steps_per_epoch):
total_step += 1
model.train()
optimizer.zero_grad()
if args.ae_eval == False:
model.backbone.vae.eval()
if args.back_eval == False:
model.backbone.eval()
if dataset_type == 'sevir':
data = train_loader.sample(batch_size=args.batch_size)
x, y = data['vil'][:, :args.seq_len], data['vil'][:, args.seq_len:]
elif dataset_type.startswith('meteo'):
data = next(train_loader)
x, y = data
elif dataset_type.startswith('hko'):
x_seq, x_mask, dt_clip, _ = train_loader.sample(batch_size=args.batch_size)
x, y = utpp.hko7_preprocess(x_seq, x_mask, dt_clip, args)
x, y = x.to(device), y.to(device)
if x.shape[-1] != img_size:
x, y = resize(x, img_size), resize(y, img_size)
if model_config['pre'] is not None:
x = model_config['pre'](x)
y = model_config['pre'](y)
recon_loss, mu_loss, diff_loss, prior_loss = model.compute_loss(x, y)
loss = (recon_loss + mu_loss + diff_loss + prior_loss)
loss.backward()
if total_step% args.micro_batch == 0:
if args.back_ckpt is None:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
# -----------------------------------------------------
# On Step End
# -----------------------------------------------------
# terminal log every {print_every} steps.
if total_step == 1 or total_step % args.print_every == 0:
logging.info(f'[Epoch {epoch}][Step {step}] recon_loss: {float(recon_loss):.4}, mu_loss: {float(mu_loss):.4}, diff_loss: {float(diff_loss):.4}')
writer.add_scalar('Training recon_loss', float(recon_loss), global_step=total_step)
writer.add_scalar('Training mu_loss', float(mu_loss), global_step=total_step)
writer.add_scalar('Training diff_loss', float(diff_loss), global_step=total_step)
writer.add_scalar('LR', optimizer.param_groups[0]['lr'], global_step=total_step)
# validate every {validate_every} epochs
if epoch == 1 or epoch % args.validate_every == 0:
rand_batch = np.random.randint(min(args.batch_size, 8))
if dataset_type == 'sevir' or dataset_type.startswith('hko'):
valid_loader.reset()
elif dataset_type.startswith('meteo'):
valid_loader = iter(validate_iter)
acc_ae, acc_diff, acc_mu = 0, 0, 0
for v_step in range(args.v_steps):
model.eval()
if dataset_type == 'sevir':
data = valid_loader.sample(batch_size=args.batch_size)
x, y = data['vil'][:, :args.seq_len], data['vil'][:, args.seq_len:]
elif dataset_type.startswith('meteo'):
data = next(valid_loader)
x, y = data
elif dataset_type.startswith('hko'):
x_seq, x_mask, dt_clip, _ = valid_loader.sample(batch_size=args.batch_size)
x, y = utpp.hko7_preprocess(x_seq, x_mask, dt_clip, args)
x, y = x.to(device), y.to(device)
with torch.no_grad():
if x.shape[-1] != img_size:
x, y = resize(x, img_size), resize(y, img_size)
if model_config['pre'] is not None:
x = model_config['pre'](x)
y = model_config['pre'](y)
ae_loss, mu_loss, diff_loss, _ = model.compute_loss(x, y, validate=True)
acc_ae += ae_loss
acc_diff += diff_loss
acc_mu += mu_loss
if model_config['post'] is not None:
x = model_config['post'](x)
y = model_config['post'](y)
acc_ae, acc_mu, acc_diff = acc_ae/args.v_steps, acc_mu/args.v_steps, acc_diff/args.v_steps
writer.add_scalar('Val AE loss', float(acc_ae), global_step=total_step)
writer.add_scalar('Val VP loss', float(acc_mu), global_step=total_step)
writer.add_scalar('Val Diff loss', float(acc_diff), global_step=total_step)
logging.info(f'[Epoch {epoch}][Validation] AE_loss:{float(acc_ae):.4}, VP_loss:{float(acc_mu):.4}, Diff_loss:{float(acc_diff):.4}')
val_loss = (acc_mu+acc_diff)/2
with torch.no_grad():
if model_config['pre'] is not None:
x = model_config['pre'](x)
y_pred, mu = model(x, include_mu=True)
if model_config['post'] is not None:
y_pred = model_config['post'](y_pred)
mu = model_config['post'](mu)
x = model_config['post'](x)
out_x, out_y, mu_pred, out_y_pred = x[rand_batch].unsqueeze(0), y[rand_batch].unsqueeze(0), mu[rand_batch].unsqueeze(0), y_pred[rand_batch].unsqueeze(0)
utpp.torch_visualize({'input': out_x, 'ground truth': out_y, 'mu_pred': mu_pred, 'predicted': out_y_pred}, savedir=os.path.join(save_path, f'temp-{total_step}.png'), vmin=0, vmax=1)
if args.save_every_epoch:
torch.save(model.state_dict(), os.path.join(save_path, f'{model_pathname}_epoch={epoch}.pt'))
else:
if val_loss < best_val_loss:
torch.save(model.state_dict(), os.path.join(save_path, f'{model_pathname}_best.pt'))
best_val_loss = val_loss
torch.save(model.state_dict(), os.path.join(save_path, f'{model_pathname}_final.pt'))