DrugFlow / src /train.py
mority's picture
Upload 53 files
6e7d4ba verified
import argparse
from argparse import Namespace
from pathlib import Path
import warnings
import torch
import pytorch_lightning as pl
import yaml
import sys
basedir = Path(__file__).resolve().parent.parent
sys.path.append(str(basedir))
from src.model.lightning import DrugFlow
from src.model.dpo import DPO
from src.utils import set_deterministic, disable_rdkit_logging, dict_to_namespace, namespace_to_dict
def merge_args_and_yaml(args, config_dict):
arg_dict = args.__dict__
for key, value in config_dict.items():
if key in arg_dict:
warnings.warn(f"Command line argument '{key}' (value: "
f"{arg_dict[key]}) will be overwritten with value "
f"{value} provided in the config file.")
# if isinstance(value, dict):
# arg_dict[key] = Namespace(**value)
# else:
# arg_dict[key] = value
arg_dict[key] = dict_to_namespace(value)
return args
def merge_configs(config, resume_config):
for key, value in resume_config.items():
if isinstance(value, Namespace):
value = value.__dict__
if isinstance(value, dict):
# update dictionaries recursively
value = merge_configs(config[key], value)
if key in config and config[key] != value:
print(f'[CONFIG UPDATE] {key}: {value} -> {config[key]}')
return config
# ------------------------------------------------------------------------------
# Training
# ______________________________________________________________________________
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument('--config', type=str, required=True)
p.add_argument('--resume', type=str, default=None)
p.add_argument('--backoff', action='store_true')
p.add_argument('--finetune', action='store_true')
p.add_argument('--debug', action='store_true')
p.add_argument('--overfit', action='store_true')
args = p.parse_args()
set_deterministic(seed=42)
disable_rdkit_logging()
with open(args.config, 'r') as f:
config = yaml.safe_load(f)
assert 'resume' not in config
assert not (args.resume is not None and args.backoff)
config['dpo_mode'] = config.get('dpo_mode', None)
assert not (config['dpo_mode'] and 'checkpoint' not in config), 'DPO mode requires a reference checkpoint'
if args.debug:
config['run_name'] = 'debug'
out_dir = Path(config['train_params']['logdir'], config['run_name'])
checkpoints_root_dir = Path(out_dir, 'checkpoints')
if args.backoff:
last_checkpoint = Path(checkpoints_root_dir, 'last.ckpt')
print(f'Checking if there is a checkpoint at: {last_checkpoint}')
if last_checkpoint.exists():
print(f'Found existing checkpoint: {last_checkpoint}')
args.resume = str(last_checkpoint)
else:
print(f'Did not find {last_checkpoint}')
# Get main config
ckpt_path = None if args.resume is None else Path(args.resume)
if args.resume is not None and not args.finetune:
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
print(f'Resuming from epoch {ckpt["epoch"]}')
resume_config = ckpt['hyper_parameters']
config = merge_configs(config, resume_config)
args = merge_args_and_yaml(args, config)
if args.debug:
print('DEBUG MODE')
args.wandb_params.mode = 'disabled'
args.train_params.enable_progress_bar = True
args.train_params.num_workers = 0
if args.overfit:
print('OVERFITTING MODE')
args.eval_params.outdir = out_dir
model_class = DPO if args.dpo_mode else DrugFlow
model_args = {
'pocket_representation': args.pocket_representation,
'train_params': args.train_params,
'loss_params': args.loss_params,
'eval_params': args.eval_params,
'predictor_params': args.predictor_params,
'simulation_params': args.simulation_params,
'virtual_nodes': args.virtual_nodes,
'flexible': args.flexible,
'flexible_bb': args.flexible_bb,
'debug': args.debug,
'overfit': args.overfit,
}
if args.dpo_mode:
print('DPO MODE')
model_args.update({
'dpo_mode': args.dpo_mode,
'ref_checkpoint_p': args.checkpoint,
})
pl_module = model_class(**model_args)
resume_logging = False
if args.finetune:
resume_logging = 'allow'
elif args.resume is not None:
resume_logging = 'must'
logger = pl.loggers.WandbLogger(
save_dir=args.train_params.logdir,
project='FlexFlow',
group=args.wandb_params.group,
name=args.run_name,
id=args.run_name,
resume=resume_logging,
entity=args.wandb_params.entity,
mode=args.wandb_params.mode,
)
checkpoint_callbacks = [
pl.callbacks.ModelCheckpoint(
dirpath=checkpoints_root_dir,
save_last=True,
save_on_train_epoch_end=True,
),
pl.callbacks.ModelCheckpoint(
dirpath=Path(checkpoints_root_dir, 'val_loss'),
filename="epoch_{epoch:04d}_loss_{loss/val:.3f}",
monitor="loss/val",
save_top_k=5,
mode="min",
auto_insert_metric_name=False,
),
]
# For learning rate logging
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')
default_strategy = 'auto' if pl.__version__ >= '2.0.0' else None
trainer = pl.Trainer(
max_epochs=args.train_params.n_epochs,
logger=logger,
callbacks=checkpoint_callbacks + [lr_monitor],
enable_progress_bar=args.train_params.enable_progress_bar,
check_val_every_n_epoch=args.eval_params.eval_epochs,
num_sanity_val_steps=args.train_params.num_sanity_val_steps,
accumulate_grad_batches=args.train_params.accumulate_grad_batches,
accelerator='gpu' if args.train_params.gpus > 0 else 'cpu',
devices=args.train_params.gpus if args.train_params.gpus > 0 else 'auto',
strategy=('ddp_find_unused_parameters_true' if args.train_params.gpus > 1 else default_strategy),
use_distributed_sampler=False,
)
# add all arguments as dictionaries because WandB does not display
# nested Namespace objects correctly
logger.experiment.config.update({'as_dict': namespace_to_dict(args)}, allow_val_change=True)
trainer.fit(model=pl_module, ckpt_path=ckpt_path)
# # run test set
# result = trainer.test(ckpt_path='best')