|
|
import argparse |
|
|
import sys |
|
|
import os |
|
|
import warnings |
|
|
import tempfile |
|
|
import pandas as pd |
|
|
|
|
|
from Bio.PDB import PDBParser |
|
|
from pathlib import Path |
|
|
from rdkit import Chem |
|
|
from torch.utils.data import DataLoader |
|
|
from functools import partial |
|
|
|
|
|
basedir = Path(__file__).resolve().parent.parent |
|
|
sys.path.append(str(basedir)) |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
from src import utils |
|
|
from src.data.dataset import ProcessedLigandPocketDataset |
|
|
from src.data.data_utils import TensorDict, process_raw_pair |
|
|
from src.model.lightning import DrugFlow |
|
|
from src.sbdd_metrics.metrics import FullEvaluator |
|
|
|
|
|
from tqdm import tqdm |
|
|
from pdb import set_trace |
|
|
|
|
|
|
|
|
def aggregate_metrics(table): |
|
|
agg_col = 'posebusters' |
|
|
total = 0 |
|
|
table[agg_col] = 0 |
|
|
for column in table.columns: |
|
|
if column.startswith(agg_col) and column != agg_col: |
|
|
table[agg_col] += table[column].fillna(0).astype(float) |
|
|
total += 1 |
|
|
table[agg_col] = table[agg_col] / total |
|
|
|
|
|
agg_col = 'reos' |
|
|
total = 0 |
|
|
table[agg_col] = 0 |
|
|
for column in table.columns: |
|
|
if column.startswith(agg_col) and column != agg_col: |
|
|
table[agg_col] += table[column].fillna(0).astype(float) |
|
|
total += 1 |
|
|
table[agg_col] = table[agg_col] / total |
|
|
|
|
|
agg_col = 'chembl_ring_systems' |
|
|
total = 0 |
|
|
table[agg_col] = 0 |
|
|
for column in table.columns: |
|
|
if column.startswith(agg_col) and column != agg_col and not column.endswith('smi'): |
|
|
table[agg_col] += table[column].fillna(0).astype(float) |
|
|
total += 1 |
|
|
table[agg_col] = table[agg_col] / total |
|
|
return table |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
p = argparse.ArgumentParser() |
|
|
p.add_argument('--protein', type=str, required=True, help="Input PDB file.") |
|
|
p.add_argument('--ref_ligand', type=str, required=True, help="SDF file with reference ligand used to define the pocket.") |
|
|
p.add_argument('--checkpoint', type=str, required=True, help="Model checkpoint file.") |
|
|
p.add_argument('--molecule_size', type=str, required=False, default=None, help="Maximum number of atoms in the sampled molecules. Can be a single number or a range, e.g. '15,20'. If None, size will be sampled.") |
|
|
p.add_argument('--output', type=str, required=False, default='samples.sdf', help="Output file.") |
|
|
p.add_argument('--n_samples', type=int, required=False, default=10, help="Number of sampled molecules.") |
|
|
p.add_argument('--batch_size', type=int, required=False, default=32, help="Batch size.") |
|
|
p.add_argument('--pocket_distance_cutoff', type=float, required=False, default=8.0, help="Distance cutoff to define the pocket around the reference ligand.") |
|
|
p.add_argument('--n_steps', type=int, required=False, default=None, help="Number of denoising steps.") |
|
|
p.add_argument('--device', type=str, required=False, default='cuda:0', help="Device to use.") |
|
|
p.add_argument('--datadir', type=Path, required=False, default=Path(basedir, 'src', 'default'), help="Needs to be specified to sample molecule sizes.") |
|
|
p.add_argument('--seed', type=int, required=False, default=42, help="Random seed.") |
|
|
p.add_argument('--filter', action='store_true', required=False, default=False, help="Apply basic filters and keep sampling until `n_samples` molecules passing these filters are found.") |
|
|
p.add_argument('--metrics_output', type=str, required=False, default=None, help="If provided, metrics will be computed and saved in csv format at this location.") |
|
|
p.add_argument('--gnina', type=str, required=False, default=None, help="Path to a gnina executable. Required for computing docking scores.") |
|
|
p.add_argument('--reduce', type=str, required=False, default=None, help="Path to a reduce executable. Required for computing interactions.") |
|
|
args = p.parse_args() |
|
|
|
|
|
utils.set_deterministic(seed=args.seed) |
|
|
utils.disable_rdkit_logging() |
|
|
|
|
|
if args.molecule_size is None and (args.datadir is None or not args.datadir.exists()): |
|
|
raise NotImplementedError( |
|
|
"Please provide a path to the processed dataset (using `--datadir`) "\ |
|
|
"to infer the number of nodes. It contains the size distribution histogram." |
|
|
) |
|
|
|
|
|
if not args.filter: |
|
|
args.batch_size = min(args.batch_size, args.n_samples) |
|
|
|
|
|
|
|
|
chkpt_path = Path(args.checkpoint) |
|
|
chkpt_name = chkpt_path.parts[-1].split('.')[0] |
|
|
model = DrugFlow.load_from_checkpoint(args.checkpoint, map_location=args.device, strict=False) |
|
|
if args.datadir is not None: |
|
|
model.datadir = args.datadir |
|
|
|
|
|
model.setup(stage='generation') |
|
|
model.batch_size = model.eval_batch_size = args.batch_size |
|
|
model.eval().to(args.device) |
|
|
if args.n_steps is not None: |
|
|
model.T = args.n_steps |
|
|
|
|
|
|
|
|
size_model = None |
|
|
molecule_size = None |
|
|
molecule_size_boundaries = None |
|
|
if args.molecule_size is not None: |
|
|
if args.molecule_size.isdigit(): |
|
|
molecule_size = int(args.molecule_size) |
|
|
print(f'Will generate molecules of size {molecule_size}') |
|
|
else: |
|
|
boundaries = [x.strip() for x in args.molecule_size.split(',')] |
|
|
assert len(boundaries) == 2 and boundaries[0].isdigit() and boundaries[1].isdigit() |
|
|
left = int(boundaries[0]) |
|
|
right = int(boundaries[1]) |
|
|
molecule_size = f"uniform_{left}_{right}" |
|
|
print(f'Will generate molecules with numbers of atoms sampled from U({left}, {right})') |
|
|
|
|
|
|
|
|
pdb_model = PDBParser(QUIET=True).get_structure('', args.protein)[0] |
|
|
rdmol = Chem.SDMolSupplier(str(args.ref_ligand))[0] |
|
|
|
|
|
ligand, pocket = process_raw_pair( |
|
|
pdb_model, rdmol, |
|
|
dist_cutoff=args.pocket_distance_cutoff, |
|
|
pocket_representation=model.pocket_representation, |
|
|
compute_nerf_params=True, |
|
|
nma_input=args.protein if model.dynamics.add_nma_feat else None |
|
|
) |
|
|
ligand['name'] = 'ligand' |
|
|
dataset = [{'ligand': ligand, 'pocket': pocket} for _ in range(args.batch_size)] |
|
|
dataloader = DataLoader( |
|
|
dataset=dataset, |
|
|
batch_size=args.batch_size, |
|
|
collate_fn=partial(ProcessedLigandPocketDataset.collate_fn, ligand_transform=None), |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
|
|
|
smiles = set() |
|
|
sampled_molecules = [] |
|
|
metrics = [] |
|
|
Path(args.output).parent.absolute().mkdir(parents=True, exist_ok=True) |
|
|
print(f'Will generate {args.n_samples} samples') |
|
|
|
|
|
evaluator = FullEvaluator(gnina=args.gnina, reduce=args.reduce) |
|
|
|
|
|
with tqdm(total=args.n_samples) as pbar: |
|
|
while len(sampled_molecules) < args.n_samples: |
|
|
for i, data in enumerate(dataloader): |
|
|
new_data = { |
|
|
'ligand': TensorDict(**data['ligand']).to(args.device), |
|
|
'pocket': TensorDict(**data['pocket']).to(args.device), |
|
|
} |
|
|
rdmols, rdpockets, _ = model.sample( |
|
|
new_data, |
|
|
n_samples=1, |
|
|
timesteps=args.n_steps, |
|
|
num_nodes=molecule_size, |
|
|
) |
|
|
|
|
|
if args.filter or (args.metrics_output is not None): |
|
|
results = [] |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
for mol, receptor in zip(rdmols, rdpockets): |
|
|
receptor_path = Path(tmpdir, 'receptor.pdb') |
|
|
Chem.MolToPDBFile(receptor, str(receptor_path)) |
|
|
results.append(evaluator(mol, receptor_path)) |
|
|
|
|
|
table = pd.DataFrame(results) |
|
|
table['novel'] = ~table['representation.smiles'].isin(smiles) |
|
|
table = aggregate_metrics(table) |
|
|
|
|
|
added_molecules = 0 |
|
|
if args.filter: |
|
|
table['passed_filters'] = ( |
|
|
(table['posebusters'] == 1) & |
|
|
|
|
|
(table['chembl_ring_systems'] == 1) & |
|
|
(table['novel'] == 1) |
|
|
) |
|
|
for i, (passed, smi) in enumerate(table[['passed_filters', 'representation.smiles']].values): |
|
|
if passed: |
|
|
sampled_molecules.append(rdmols[i]) |
|
|
smiles.add(smi) |
|
|
added_molecules += 1 |
|
|
|
|
|
if args.metrics_output is not None: |
|
|
metrics.append(table[table['passed_filters']]) |
|
|
|
|
|
else: |
|
|
sampled_molecules.extend(rdmols) |
|
|
added_molecules = len(rdmols) |
|
|
if args.metrics_output is not None: |
|
|
metrics.append(table) |
|
|
|
|
|
pbar.update(added_molecules) |
|
|
|
|
|
|
|
|
utils.write_sdf_file(args.output, sampled_molecules) |
|
|
|
|
|
if args.metrics_output is not None: |
|
|
metrics = pd.concat(metrics) |
|
|
metrics.to_csv(args.metrics_output, index=False) |
|
|
|