import argparse from argparse import Namespace from pathlib import Path import warnings import torch import pytorch_lightning as pl import yaml import numpy as np from lightning_modules import LigandPocketDDPM def merge_args_and_yaml(args, config_dict): arg_dict = args.__dict__ for key, value in config_dict.items(): if key in arg_dict: warnings.warn(f"Command line argument '{key}' (value: " f"{arg_dict[key]}) will be overwritten with value " f"{value} provided in the config file.") if isinstance(value, dict): arg_dict[key] = Namespace(**value) else: arg_dict[key] = value return args def merge_configs(config, resume_config): for key, value in resume_config.items(): if isinstance(value, Namespace): value = value.__dict__ if key in config and config[key] != value: warnings.warn(f"Config parameter '{key}' (value: " f"{config[key]}) will be overwritten with value " f"{value} from the checkpoint.") config[key] = value return config # ------------------------------------------------------------------------------ # Training # ______________________________________________________________________________ if __name__ == "__main__": p = argparse.ArgumentParser() p.add_argument('--config', type=str, required=True) p.add_argument('--resume', type=str, default=None) args = p.parse_args() with open(args.config, 'r') as f: config = yaml.safe_load(f) assert 'resume' not in config # Get main config ckpt_path = None if args.resume is None else Path(args.resume) if args.resume is not None: resume_config = torch.load( ckpt_path, map_location=torch.device('cpu'))['hyper_parameters'] config = merge_configs(config, resume_config) args = merge_args_and_yaml(args, config) out_dir = Path(args.logdir, args.run_name) histogram_file = Path(args.datadir, 'size_distribution.npy') histogram = np.load(histogram_file).tolist() pl_module = LigandPocketDDPM( outdir=out_dir, dataset=args.dataset, datadir=args.datadir, batch_size=args.batch_size, lr=args.lr, egnn_params=args.egnn_params, diffusion_params=args.diffusion_params, num_workers=args.num_workers, augment_noise=args.augment_noise, augment_rotation=args.augment_rotation, clip_grad=args.clip_grad, eval_epochs=args.eval_epochs, eval_params=args.eval_params, visualize_sample_epoch=args.visualize_sample_epoch, visualize_chain_epoch=args.visualize_chain_epoch, auxiliary_loss=args.auxiliary_loss, loss_params=args.loss_params, mode=args.mode, node_histogram=histogram, pocket_representation=args.pocket_representation, virtual_nodes=args.virtual_nodes ) logger = pl.loggers.WandbLogger( save_dir=args.logdir, project='ligand-pocket-ddpm', group=args.wandb_params.group, name=args.run_name, id=args.run_name, resume='must' if args.resume is not None else False, entity=args.wandb_params.entity, mode=args.wandb_params.mode, ) checkpoint_callback = pl.callbacks.ModelCheckpoint( dirpath=Path(out_dir, 'checkpoints'), filename="best-model-epoch={epoch:02d}", monitor="loss/val", save_top_k=1, save_last=True, mode="min", ) trainer = pl.Trainer( max_epochs=args.n_epochs, logger=logger, callbacks=[checkpoint_callback], enable_progress_bar=args.enable_progress_bar, num_sanity_val_steps=args.num_sanity_val_steps, accelerator='gpu', devices=args.gpus, strategy=('ddp' if args.gpus > 1 else None) ) trainer.fit(model=pl_module, ckpt_path=ckpt_path)