DrugFlow / src /sample_and_evaluate.py
mority's picture
Upload 53 files
6e7d4ba verified
import argparse
import sys
import yaml
import torch
import numpy as np
import pickle
from argparse import Namespace
from pathlib import Path
basedir = Path(__file__).resolve().parent.parent
sys.path.append(str(basedir))
from src import utils
from src.utils import dict_to_namespace, namespace_to_dict
from src.analysis.visualization_utils import mols_to_pdbfile, mol_as_pdb
from src.data.data_utils import TensorDict, Residues
from src.data.postprocessing import process_all
from src.model.lightning import DrugFlow
from src.sbdd_metrics.evaluation import compute_all_metrics_drugflow
from tqdm import tqdm
from pdb import set_trace
def combine(base_args, override_args):
assert not isinstance(base_args, dict)
assert not isinstance(override_args, dict)
arg_dict = base_args.__dict__
for key, value in override_args.__dict__.items():
if key not in arg_dict or arg_dict[key] is None: # parameter not provided previously
print(f"Add parameter {key}: {value}")
arg_dict[key] = value
elif isinstance(value, Namespace):
arg_dict[key] = combine(arg_dict[key], value)
else:
print(f"Replace parameter {key}: {arg_dict[key]} -> {value}")
arg_dict[key] = value
return base_args
def path_to_str(input_dict):
for key, value in input_dict.items():
if isinstance(value, dict):
input_dict[key] = path_to_str(value)
else:
input_dict[key] = str(value) if isinstance(value, Path) else value
return input_dict
def sample(cfg, model_params, samples_dir, job_id=0, n_jobs=1):
print('Sampling...')
model = DrugFlow.load_from_checkpoint(cfg.checkpoint, map_location=cfg.device, strict=False,
**model_params)
model.setup(stage='fit' if cfg.set == 'train' else cfg.set)
model.eval().to(cfg.device)
dataloader = getattr(model, f'{cfg.set}_dataloader')()
print(f'Real batch size is {dataloader.batch_size * cfg.n_samples}')
name2count = {}
for i, data in enumerate(tqdm(dataloader)):
if i % n_jobs != job_id:
print(f'Skipping batch {i}')
continue
new_data = {
'ligand': TensorDict(**data['ligand']).to(cfg.device),
'pocket': Residues(**data['pocket']).to(cfg.device),
}
try:
rdmols, rdpockets, names = model.sample(
data=new_data,
n_samples=cfg.n_samples,
num_nodes=("ground_truth" if cfg.sample_with_ground_truth_size else None)
)
except Exception as e:
if cfg.set == 'train':
names = data['ligand']['name']
print(f'Failed to sample for {names}: {e}')
continue
else:
raise e
for mol, pocket, name in zip(rdmols, rdpockets, names):
name = name.replace('.sdf', '')
idx = name2count.setdefault(name, 0)
output_dir = Path(samples_dir, name)
output_dir.mkdir(parents=True, exist_ok=True)
if cfg.postprocess:
mol = process_all(mol, largest_frag=True, adjust_aromatic_Ns=True, relax_iter=0)
for prop in mol.GetAtoms()[0].GetPropsAsDict().keys():
# compute avg uncertainty
mol.SetDoubleProp(prop, np.mean([a.GetDoubleProp(prop) for a in mol.GetAtoms()]))
# visualise local differences
out_pdb_path = Path(output_dir, f'{idx}_ligand_{prop}.pdb')
mol_as_pdb(mol, out_pdb_path, bfactor=prop)
out_sdf_path = Path(output_dir, f'{idx}_ligand.sdf')
out_pdb_path = Path(output_dir, f'{idx}_pocket.pdb')
utils.write_sdf_file(out_sdf_path, [mol])
mols_to_pdbfile([pocket], out_pdb_path)
name2count[name] += 1
def evaluate(cfg, model_params, samples_dir):
print('Evaluation...')
data, table_detailed, table_aggregated = compute_all_metrics_drugflow(
in_dir=samples_dir,
gnina_path=model_params['train_params'].gnina,
reduce_path=cfg.reduce,
reference_smiles_path=Path(model_params['train_params'].datadir, 'train_smiles.npy'),
n_samples=cfg.n_samples,
exclude_evaluators=[] if cfg.exclude_evaluators is None else cfg.exclude_evaluators,
)
with open(Path(samples_dir, 'metrics_data.pkl'), 'wb') as f:
pickle.dump(data, f)
table_detailed.to_csv(Path(samples_dir, 'metrics_detailed.csv'), index=False)
table_aggregated.to_csv(Path(samples_dir, 'metrics_aggregated.csv'), index=False)
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument('--config', type=str)
p.add_argument('--job_id', type=int, default=0, help='Job ID')
p.add_argument('--n_jobs', type=int, default=1, help='Number of jobs')
args = p.parse_args()
with open(args.config, 'r') as f:
cfg = yaml.safe_load(f)
cfg = dict_to_namespace(cfg)
utils.set_deterministic(seed=cfg.seed)
utils.disable_rdkit_logging()
model_params = torch.load(cfg.checkpoint, map_location=cfg.device)['hyper_parameters']
if 'model_args' in cfg:
ckpt_args = dict_to_namespace(model_params)
model_params = combine(ckpt_args, cfg.model_args).__dict__
ckpt_path = Path(cfg.checkpoint)
ckpt_name = ckpt_path.parts[-1].split('.')[0]
n_steps = model_params['simulation_params'].n_steps
samples_dir = Path(cfg.sample_outdir, cfg.set, f'{ckpt_name}_T={n_steps}') or \
Path(ckpt_path.parent.parent, 'samples', cfg.set, f'{ckpt_name}_T={n_steps}')
assert cfg.set in {'val', 'test', 'train'}
samples_dir.mkdir(parents=True, exist_ok=True)
# save configs
with open(Path(samples_dir, 'model_params.yaml'), 'w') as f:
yaml.dump(path_to_str(namespace_to_dict(model_params)), f)
with open(Path(samples_dir, 'sampling_params.yaml'), 'w') as f:
yaml.dump(path_to_str(namespace_to_dict(cfg)), f)
if cfg.sample:
sample(cfg, model_params, samples_dir, job_id=args.job_id, n_jobs=args.n_jobs)
if cfg.evaluate:
assert args.job_id == 0 and args.n_jobs == 1, 'Evaluation is not parallelised on GPU machines'
evaluate(cfg, model_params, samples_dir)