yxc97's picture
Upload folder using huggingface_hub
62a2f1c verified
import os
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch_geometric.nn import to_hetero
import torch.optim as optim
from torch.utils.data.distributed import DistributedSampler
from loguru import logger
import numpy as np
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, ExponentialLR, ReduceLROnPlateau
import torch.distributed as dist
from src.utils import seed_worker
from torch_geometric.loader import DataLoader
from torch.utils.data import ConcatDataset
from torch.cuda.amp import autocast
from .utils import count_parameters, AverageMeter, AVGMeter, Reporter, Timer
# torch.autograd.set_detect_anomaly(True)
class Oven(object):
def __init__(self, cfg):
self.cfg = cfg
self.ngpus = cfg.get('ngpus', 1)
def _init_training_wt_checkpoint(self, filepath_ckp):
if not os.path.exists(filepath_ckp):
return np.Infinity, -1, 0
checkpoint_resum = torch.load(filepath_ckp)
self.model.load_state_dict(checkpoint_resum['model_state'])
epoch = checkpoint_resum['epoch']
previous_best = checkpoint_resum['best_performance']
previous_best_epoch = checkpoint_resum["best_epoch"]
previous_best_metrics = checkpoint_resum["local_best_metrics"]
return previous_best, previous_best_epoch, epoch, previous_best_metrics
def _init_optim(self):
if self.cfg['train'].get("optimizer_type", "Adam").lower() in "adam":
optimizer = optim.Adam(self.model.parameters(),
lr=float(self.cfg['train']['learning_rate']),
weight_decay=self.cfg['train'].get("weight_decay", 1e-5)
)
else: # SGD by defalut
optimizer = optim.SGD(self.model.parameters(),
lr=self.cfg['train']['learning_rate'],
momentum=self.cfg['train'].get("momentum", 0.9),
weight_decay=self.cfg['train'].get("weight_decay", 1e-5))
# scheduler = StepLR(optimizer, step_size=int(self.cfg['train']['epochs']*2/3), gamma=0.1)
if self.cfg['scheduler']['type'] == 'Cosine':
scheduler = CosineAnnealingLR(optimizer,
T_max=self.cfg['train']['epochs'],
eta_min=float(self.cfg['scheduler']['eta_min']))
elif self.cfg['scheduler']['type'] == 'Exponential':
scheduler = ExponentialLR(optimizer, gamma=self.cfg['scheduler']['gamma'], last_epoch=-1, verbose=False)
elif self.cfg['scheduler']['type'] == 'ReduceLROnPlateau':
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=5, min_lr=1e-5)
else: # otherwise: Fixed lr
scheduler = None
return optimizer, scheduler
def _init_data(self):
train_dataset = self.get_dataset(**self.cfg['data']['train'])
val_dataset = self.get_dataset(**self.cfg['data']['val'])
if not self.cfg['distributed']:
train_loader = DataLoader(
train_dataset,
batch_size=self.cfg['data']['batch_size'],
num_workers=self.cfg['data']['num_workers'],
shuffle=True,
worker_init_fn=seed_worker,
drop_last=True
)
val_loader = DataLoader(
val_dataset,
batch_size=self.cfg['data'].get("batch_size_test", self.cfg['data']['batch_size']),
num_workers=self.cfg['data']['num_workers'],
shuffle=False,
drop_last=True,
worker_init_fn=seed_worker
)
else:
train_sampler = DistributedSampler(train_dataset, shuffle=True)
train_loader = DataLoader(train_dataset,
batch_size=self.cfg['data']['batch_size'],
num_workers=self.cfg['data']['num_workers'],
sampler=train_sampler,
drop_last=True,
worker_init_fn=seed_worker)
valid_sampler = DistributedSampler(val_dataset, shuffle=False)
val_loader = DataLoader(val_dataset,
batch_size=self.cfg['data'].get("batch_size_test", self.cfg['data']['batch_size']),
num_workers=self.cfg['data']['num_workers'],
sampler=valid_sampler,
drop_last=True,
worker_init_fn=seed_worker)
return train_loader, val_loader
def get_dataset(self, dataset_type, **kwargs):
if dataset_type == 'PowerFlowDataset':
from src.dataset.powerflow_dataset import PowerFlowDataset
return PowerFlowDataset(
data_root=kwargs['data_root'],
split_txt=kwargs['split_txt'],
pq_len=kwargs['pq_len'],
pv_len=kwargs['pv_len'],
slack_len=kwargs['slack_len'],
mask_num=kwargs['mask_num']
)
def summary_epoch(self,
epoch,
train_loss, train_matrix,
valid_loss, val_matrix,
timer, local_best,
local_best_ep=-1,
local_best_metrics={},
local_best_ema=100,
local_best_ep_ema=-1,
local_best_metrics_ema = {},
valid_loss_ema=None, val_matrix_ema=None):
if self.cfg['distributed']:
if dist.get_rank() == 0:
cur_lr = self.optim.param_groups[0]["lr"]
# self.reporter.record({'epoch': epoch+1, 'train_loss': train_loss, 'valid_loss': valid_loss, 'lr': cur_lr})
self.reporter.record({'loss/train_loss': train_loss}, epoch=epoch)
self.reporter.record({'loss/val_loss': valid_loss}, epoch=epoch)
self.reporter.record({'lr': cur_lr}, epoch=epoch)
self.reporter.record(train_matrix, epoch=epoch)
self.reporter.record(val_matrix, epoch=epoch)
# logger.info(f"Epoch {str(epoch+1).zfill(3)}/{self.cfg['train']['epochs']}, lr: {cur_lr: .8f}, eta: {timer.eta}h, train_loss: {train_loss: .5f}, valid_loss: {valid_loss: .5f}")
logger.info(f"Epoch {str(epoch+1).zfill(3)}/{self.cfg['train']['epochs']},"
+ f" lr: {cur_lr: .8f}, eta: {timer.eta}h, "
+ f"train_loss: {train_loss.agg(): .5f}, "
+ f"valid_loss: {valid_loss.agg(): .5f}")
train_matrix_info = "Train: "
for key in train_matrix.keys():
tkey = str(key).split("/")[-1]
train_matrix_info += f"{tkey}:{train_matrix[key].agg(): .6f} "
logger.info(f"\t{train_matrix_info}")
val_matrix_info = "ZTest: "
performance_record = dict()
for key in val_matrix.keys():
tkey = str(key).split("/")[-1]
val_matrix_info += f"{tkey}:{val_matrix[key].agg(): .6f} "
performance_record[key] = val_matrix[key].agg()
logger.info(f"\t{val_matrix_info}")
if val_matrix_ema is not None:
val_matrix_info_ema = "ZTest-ema: "
performance_record_ema = dict()
for key in val_matrix_ema.keys():
tkey = str(key).split("/")[-1]
val_matrix_info_ema += f"{tkey}:{val_matrix_ema[key].agg(): .6f} "
performance_record_ema[key] = val_matrix_ema[key].agg()
logger.info(f"\t{val_matrix_info_ema}")
checked_performance_ema = {x:y for x,y in performance_record_ema.items() if "rmse" in x}
best_performance_ema = max(checked_performance_ema.values())
if best_performance_ema < local_best_ema:
local_best_ema = best_performance_ema
local_best_ep_ema = epoch
local_best_metrics_ema = checked_performance_ema
logger.info(f"\t ValOfEMA:{best_performance_ema:.6f}/{local_best_ema:.6f}, Epoch:{epoch+1}/{local_best_ep_ema+1}")
# best_performance = max(performance_record.values())
checked_performance = {x:y for x,y in performance_record.items() if "rmse" in x}
best_performance = max(checked_performance.values())
if best_performance < local_best:
local_best = best_performance
local_best_metrics = checked_performance
local_best_ep = epoch
# torch.save(self.model.module, os.path.join(self.cfg['log_path'], 'ckpt_{}_{}.pt'.format(epoch, round(local_best,4))))
torch.save(self.model.module, os.path.join(self.cfg['log_path'], 'ckpt_best.pt'))
state = {
"epoch": epoch + 1,
# "model_state": self.model.module.state_dict(),
"model_state": self.model.state_dict(),
"optimizer_state": self.optim.state_dict(),
"scheduler_state": self.scheduler.state_dict(),
"best_performance": local_best,
"best_epoch":local_best_ep,
"local_best_metrics": local_best_metrics,
}
torch.save(state, os.path.join(self.cfg['log_path'], 'ckpt_latest.pt'))
logger.info(f"\tTime(ep):{int(timer.elapsed_time)}s, Val(curr/best):{best_performance:.6f}/{local_best:.6f}, Epoch(curr/best):{epoch+1}/{local_best_ep+1}")
# else:
# return local_best, local_best_ep
else:
cur_lr = self.optim.param_groups[0]["lr"]
self.reporter.record({'loss/train_loss': train_loss}, epoch=epoch)
self.reporter.record({'loss/val_loss': valid_loss}, epoch=epoch)
self.reporter.record({'lr': cur_lr}, epoch=epoch)
self.reporter.record(train_matrix, epoch=epoch)
self.reporter.record(val_matrix, epoch=epoch)
logger.info(f"Epoch {epoch}/{self.cfg['train']['epochs']},"
+ f" lr: {cur_lr: .8f}, eta: {timer.eta}h, "
+ f"train_loss: {train_loss.agg(): .5f}, "
+ f"valid_loss: {valid_loss.agg(): .5f}")
train_matrix_info = "Train: "
for key in train_matrix.keys():
tkey = str(key).split("/")[-1]
train_matrix_info += f"{tkey}:{train_matrix[key].agg(): .8f} "
logger.info(f"\t{train_matrix_info}")
val_matrix_info = "ZTest: "
performance_record = dict()
for key in val_matrix.keys():
tkey = str(key).split("/")[-1]
val_matrix_info += f"{tkey}:{val_matrix[key].agg(): .8f} "
performance_record[key] = val_matrix[key].agg()
logger.info(f"\t{val_matrix_info}")
if val_matrix_ema is not None:
val_matrix_info_ema = "ZTest-ema: "
performance_record_ema = dict()
for key in val_matrix_ema.keys():
tkey = str(key).split("/")[-1]
val_matrix_info_ema += f"{tkey}:{val_matrix_ema[key].agg(): .6f} "
performance_record_ema[key] = val_matrix_ema[key].agg()
logger.info(f"\t{val_matrix_info_ema}")
checked_performance_ema = {x:y for x,y in performance_record_ema.items() if "rmse" in x}
best_performance_ema = max(checked_performance_ema.values())
if best_performance_ema < local_best_ema:
local_best_ema = best_performance_ema
local_best_metrics_ema = checked_performance_ema
local_best_ep_ema = epoch
logger.info(f"\t ValOfEMA:{best_performance_ema:.6f}/{local_best_ema:.6f}, Epoch:{epoch+1}/{local_best_ep_ema+1}")
# best_performance = max(performance_record)
checked_performance = {x:y for x,y in performance_record.items() if "rmse" in x}
best_performance = max(checked_performance.values())
if best_performance < local_best: # save best
local_best = best_performance
local_best_ep = epoch
local_best_metrics = checked_performance
# torch.save(self.model, os.path.join(self.cfg['log_path'], 'ckpt_{}_{}.pt'.format(epoch, round(local_best,4))))
torch.save(self.model, os.path.join(self.cfg['log_path'], 'ckpt_best.pt'))
state = {
"epoch": epoch + 1,
"model_state": self.model.state_dict(),
"optimizer_state": self.optim.state_dict(),
"scheduler_state": self.scheduler.state_dict(),
"best_performance": local_best,
"best_epoch":local_best_ep,
"local_best_metrics": local_best_metrics,
}
torch.save(state, os.path.join(self.cfg['log_path'], 'ckpt_latest.pt'))
logger.info(f"\tTime(ep):{int(timer.elapsed_time)}s, Val(curr/best):{best_performance:.6f}/{local_best:.6f}, Epoch(curr/best):{epoch+1}/{local_best_ep+1}")
if val_matrix_ema is not None:
return local_best, local_best_ep, local_best_ema, local_best_ep_ema, local_best_metrics_ema
else:
return local_best, local_best_ep, local_best_metrics