|
|
import argparse |
|
|
import sys |
|
|
import yaml |
|
|
import torch |
|
|
import numpy as np |
|
|
import pickle |
|
|
from argparse import Namespace |
|
|
|
|
|
from pathlib import Path |
|
|
|
|
|
basedir = Path(__file__).resolve().parent.parent |
|
|
sys.path.append(str(basedir)) |
|
|
|
|
|
from src import utils |
|
|
from src.utils import dict_to_namespace, namespace_to_dict |
|
|
from src.analysis.visualization_utils import mols_to_pdbfile, mol_as_pdb |
|
|
from src.data.data_utils import TensorDict, Residues |
|
|
from src.data.postprocessing import process_all |
|
|
from src.model.lightning import DrugFlow |
|
|
from src.sbdd_metrics.evaluation import compute_all_metrics_drugflow |
|
|
|
|
|
from tqdm import tqdm |
|
|
from pdb import set_trace |
|
|
|
|
|
|
|
|
def combine(base_args, override_args): |
|
|
assert not isinstance(base_args, dict) |
|
|
assert not isinstance(override_args, dict) |
|
|
|
|
|
arg_dict = base_args.__dict__ |
|
|
for key, value in override_args.__dict__.items(): |
|
|
if key not in arg_dict or arg_dict[key] is None: |
|
|
print(f"Add parameter {key}: {value}") |
|
|
arg_dict[key] = value |
|
|
elif isinstance(value, Namespace): |
|
|
arg_dict[key] = combine(arg_dict[key], value) |
|
|
else: |
|
|
print(f"Replace parameter {key}: {arg_dict[key]} -> {value}") |
|
|
arg_dict[key] = value |
|
|
return base_args |
|
|
|
|
|
|
|
|
def path_to_str(input_dict): |
|
|
for key, value in input_dict.items(): |
|
|
if isinstance(value, dict): |
|
|
input_dict[key] = path_to_str(value) |
|
|
else: |
|
|
input_dict[key] = str(value) if isinstance(value, Path) else value |
|
|
return input_dict |
|
|
|
|
|
|
|
|
def sample(cfg, model_params, samples_dir, job_id=0, n_jobs=1): |
|
|
print('Sampling...') |
|
|
model = DrugFlow.load_from_checkpoint(cfg.checkpoint, map_location=cfg.device, strict=False, |
|
|
**model_params) |
|
|
model.setup(stage='fit' if cfg.set == 'train' else cfg.set) |
|
|
model.eval().to(cfg.device) |
|
|
|
|
|
dataloader = getattr(model, f'{cfg.set}_dataloader')() |
|
|
print(f'Real batch size is {dataloader.batch_size * cfg.n_samples}') |
|
|
|
|
|
name2count = {} |
|
|
for i, data in enumerate(tqdm(dataloader)): |
|
|
if i % n_jobs != job_id: |
|
|
print(f'Skipping batch {i}') |
|
|
continue |
|
|
|
|
|
new_data = { |
|
|
'ligand': TensorDict(**data['ligand']).to(cfg.device), |
|
|
'pocket': Residues(**data['pocket']).to(cfg.device), |
|
|
} |
|
|
try: |
|
|
rdmols, rdpockets, names = model.sample( |
|
|
data=new_data, |
|
|
n_samples=cfg.n_samples, |
|
|
num_nodes=("ground_truth" if cfg.sample_with_ground_truth_size else None) |
|
|
) |
|
|
except Exception as e: |
|
|
if cfg.set == 'train': |
|
|
names = data['ligand']['name'] |
|
|
print(f'Failed to sample for {names}: {e}') |
|
|
continue |
|
|
else: |
|
|
raise e |
|
|
|
|
|
for mol, pocket, name in zip(rdmols, rdpockets, names): |
|
|
name = name.replace('.sdf', '') |
|
|
idx = name2count.setdefault(name, 0) |
|
|
output_dir = Path(samples_dir, name) |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
if cfg.postprocess: |
|
|
mol = process_all(mol, largest_frag=True, adjust_aromatic_Ns=True, relax_iter=0) |
|
|
|
|
|
for prop in mol.GetAtoms()[0].GetPropsAsDict().keys(): |
|
|
|
|
|
mol.SetDoubleProp(prop, np.mean([a.GetDoubleProp(prop) for a in mol.GetAtoms()])) |
|
|
|
|
|
|
|
|
out_pdb_path = Path(output_dir, f'{idx}_ligand_{prop}.pdb') |
|
|
mol_as_pdb(mol, out_pdb_path, bfactor=prop) |
|
|
|
|
|
out_sdf_path = Path(output_dir, f'{idx}_ligand.sdf') |
|
|
out_pdb_path = Path(output_dir, f'{idx}_pocket.pdb') |
|
|
utils.write_sdf_file(out_sdf_path, [mol]) |
|
|
mols_to_pdbfile([pocket], out_pdb_path) |
|
|
|
|
|
name2count[name] += 1 |
|
|
|
|
|
|
|
|
def evaluate(cfg, model_params, samples_dir): |
|
|
print('Evaluation...') |
|
|
data, table_detailed, table_aggregated = compute_all_metrics_drugflow( |
|
|
in_dir=samples_dir, |
|
|
gnina_path=model_params['train_params'].gnina, |
|
|
reduce_path=cfg.reduce, |
|
|
reference_smiles_path=Path(model_params['train_params'].datadir, 'train_smiles.npy'), |
|
|
n_samples=cfg.n_samples, |
|
|
exclude_evaluators=[] if cfg.exclude_evaluators is None else cfg.exclude_evaluators, |
|
|
) |
|
|
with open(Path(samples_dir, 'metrics_data.pkl'), 'wb') as f: |
|
|
pickle.dump(data, f) |
|
|
table_detailed.to_csv(Path(samples_dir, 'metrics_detailed.csv'), index=False) |
|
|
table_aggregated.to_csv(Path(samples_dir, 'metrics_aggregated.csv'), index=False) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
p = argparse.ArgumentParser() |
|
|
p.add_argument('--config', type=str) |
|
|
p.add_argument('--job_id', type=int, default=0, help='Job ID') |
|
|
p.add_argument('--n_jobs', type=int, default=1, help='Number of jobs') |
|
|
args = p.parse_args() |
|
|
|
|
|
with open(args.config, 'r') as f: |
|
|
cfg = yaml.safe_load(f) |
|
|
cfg = dict_to_namespace(cfg) |
|
|
|
|
|
utils.set_deterministic(seed=cfg.seed) |
|
|
utils.disable_rdkit_logging() |
|
|
|
|
|
model_params = torch.load(cfg.checkpoint, map_location=cfg.device)['hyper_parameters'] |
|
|
if 'model_args' in cfg: |
|
|
ckpt_args = dict_to_namespace(model_params) |
|
|
model_params = combine(ckpt_args, cfg.model_args).__dict__ |
|
|
|
|
|
ckpt_path = Path(cfg.checkpoint) |
|
|
ckpt_name = ckpt_path.parts[-1].split('.')[0] |
|
|
n_steps = model_params['simulation_params'].n_steps |
|
|
samples_dir = Path(cfg.sample_outdir, cfg.set, f'{ckpt_name}_T={n_steps}') or \ |
|
|
Path(ckpt_path.parent.parent, 'samples', cfg.set, f'{ckpt_name}_T={n_steps}') |
|
|
assert cfg.set in {'val', 'test', 'train'} |
|
|
samples_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
with open(Path(samples_dir, 'model_params.yaml'), 'w') as f: |
|
|
yaml.dump(path_to_str(namespace_to_dict(model_params)), f) |
|
|
with open(Path(samples_dir, 'sampling_params.yaml'), 'w') as f: |
|
|
yaml.dump(path_to_str(namespace_to_dict(cfg)), f) |
|
|
|
|
|
if cfg.sample: |
|
|
sample(cfg, model_params, samples_dir, job_id=args.job_id, n_jobs=args.n_jobs) |
|
|
|
|
|
if cfg.evaluate: |
|
|
assert args.job_id == 0 and args.n_jobs == 1, 'Evaluation is not parallelised on GPU machines' |
|
|
evaluate(cfg, model_params, samples_dir) |