File size: 6,326 Bytes
6e7d4ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import argparse
import sys
import yaml
import torch
import numpy as np
import pickle
from argparse import Namespace
from pathlib import Path
basedir = Path(__file__).resolve().parent.parent
sys.path.append(str(basedir))
from src import utils
from src.utils import dict_to_namespace, namespace_to_dict
from src.analysis.visualization_utils import mols_to_pdbfile, mol_as_pdb
from src.data.data_utils import TensorDict, Residues
from src.data.postprocessing import process_all
from src.model.lightning import DrugFlow
from src.sbdd_metrics.evaluation import compute_all_metrics_drugflow
from tqdm import tqdm
from pdb import set_trace
def combine(base_args, override_args):
assert not isinstance(base_args, dict)
assert not isinstance(override_args, dict)
arg_dict = base_args.__dict__
for key, value in override_args.__dict__.items():
if key not in arg_dict or arg_dict[key] is None: # parameter not provided previously
print(f"Add parameter {key}: {value}")
arg_dict[key] = value
elif isinstance(value, Namespace):
arg_dict[key] = combine(arg_dict[key], value)
else:
print(f"Replace parameter {key}: {arg_dict[key]} -> {value}")
arg_dict[key] = value
return base_args
def path_to_str(input_dict):
for key, value in input_dict.items():
if isinstance(value, dict):
input_dict[key] = path_to_str(value)
else:
input_dict[key] = str(value) if isinstance(value, Path) else value
return input_dict
def sample(cfg, model_params, samples_dir, job_id=0, n_jobs=1):
print('Sampling...')
model = DrugFlow.load_from_checkpoint(cfg.checkpoint, map_location=cfg.device, strict=False,
**model_params)
model.setup(stage='fit' if cfg.set == 'train' else cfg.set)
model.eval().to(cfg.device)
dataloader = getattr(model, f'{cfg.set}_dataloader')()
print(f'Real batch size is {dataloader.batch_size * cfg.n_samples}')
name2count = {}
for i, data in enumerate(tqdm(dataloader)):
if i % n_jobs != job_id:
print(f'Skipping batch {i}')
continue
new_data = {
'ligand': TensorDict(**data['ligand']).to(cfg.device),
'pocket': Residues(**data['pocket']).to(cfg.device),
}
try:
rdmols, rdpockets, names = model.sample(
data=new_data,
n_samples=cfg.n_samples,
num_nodes=("ground_truth" if cfg.sample_with_ground_truth_size else None)
)
except Exception as e:
if cfg.set == 'train':
names = data['ligand']['name']
print(f'Failed to sample for {names}: {e}')
continue
else:
raise e
for mol, pocket, name in zip(rdmols, rdpockets, names):
name = name.replace('.sdf', '')
idx = name2count.setdefault(name, 0)
output_dir = Path(samples_dir, name)
output_dir.mkdir(parents=True, exist_ok=True)
if cfg.postprocess:
mol = process_all(mol, largest_frag=True, adjust_aromatic_Ns=True, relax_iter=0)
for prop in mol.GetAtoms()[0].GetPropsAsDict().keys():
# compute avg uncertainty
mol.SetDoubleProp(prop, np.mean([a.GetDoubleProp(prop) for a in mol.GetAtoms()]))
# visualise local differences
out_pdb_path = Path(output_dir, f'{idx}_ligand_{prop}.pdb')
mol_as_pdb(mol, out_pdb_path, bfactor=prop)
out_sdf_path = Path(output_dir, f'{idx}_ligand.sdf')
out_pdb_path = Path(output_dir, f'{idx}_pocket.pdb')
utils.write_sdf_file(out_sdf_path, [mol])
mols_to_pdbfile([pocket], out_pdb_path)
name2count[name] += 1
def evaluate(cfg, model_params, samples_dir):
print('Evaluation...')
data, table_detailed, table_aggregated = compute_all_metrics_drugflow(
in_dir=samples_dir,
gnina_path=model_params['train_params'].gnina,
reduce_path=cfg.reduce,
reference_smiles_path=Path(model_params['train_params'].datadir, 'train_smiles.npy'),
n_samples=cfg.n_samples,
exclude_evaluators=[] if cfg.exclude_evaluators is None else cfg.exclude_evaluators,
)
with open(Path(samples_dir, 'metrics_data.pkl'), 'wb') as f:
pickle.dump(data, f)
table_detailed.to_csv(Path(samples_dir, 'metrics_detailed.csv'), index=False)
table_aggregated.to_csv(Path(samples_dir, 'metrics_aggregated.csv'), index=False)
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument('--config', type=str)
p.add_argument('--job_id', type=int, default=0, help='Job ID')
p.add_argument('--n_jobs', type=int, default=1, help='Number of jobs')
args = p.parse_args()
with open(args.config, 'r') as f:
cfg = yaml.safe_load(f)
cfg = dict_to_namespace(cfg)
utils.set_deterministic(seed=cfg.seed)
utils.disable_rdkit_logging()
model_params = torch.load(cfg.checkpoint, map_location=cfg.device)['hyper_parameters']
if 'model_args' in cfg:
ckpt_args = dict_to_namespace(model_params)
model_params = combine(ckpt_args, cfg.model_args).__dict__
ckpt_path = Path(cfg.checkpoint)
ckpt_name = ckpt_path.parts[-1].split('.')[0]
n_steps = model_params['simulation_params'].n_steps
samples_dir = Path(cfg.sample_outdir, cfg.set, f'{ckpt_name}_T={n_steps}') or \
Path(ckpt_path.parent.parent, 'samples', cfg.set, f'{ckpt_name}_T={n_steps}')
assert cfg.set in {'val', 'test', 'train'}
samples_dir.mkdir(parents=True, exist_ok=True)
# save configs
with open(Path(samples_dir, 'model_params.yaml'), 'w') as f:
yaml.dump(path_to_str(namespace_to_dict(model_params)), f)
with open(Path(samples_dir, 'sampling_params.yaml'), 'w') as f:
yaml.dump(path_to_str(namespace_to_dict(cfg)), f)
if cfg.sample:
sample(cfg, model_params, samples_dir, job_id=args.job_id, n_jobs=args.n_jobs)
if cfg.evaluate:
assert args.job_id == 0 and args.n_jobs == 1, 'Evaluation is not parallelised on GPU machines'
evaluate(cfg, model_params, samples_dir) |