|
|
import argparse |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
import random |
|
|
import shutil |
|
|
from time import time |
|
|
from collections import defaultdict |
|
|
from Bio.PDB import PDBParser |
|
|
from rdkit import Chem |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
import pandas as pd |
|
|
from itertools import combinations |
|
|
|
|
|
import sys |
|
|
basedir = Path(__file__).resolve().parent.parent.parent |
|
|
sys.path.append(str(basedir)) |
|
|
|
|
|
from src.sbdd_metrics.metrics import REOSEvaluator, MedChemEvaluator, PoseBustersEvaluator, GninaEvalulator |
|
|
from src.data.data_utils import process_raw_pair, rdmol_to_smiles |
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--smplsdir', type=Path, required=True) |
|
|
parser.add_argument('--metrics-detailed', type=Path, required=False) |
|
|
parser.add_argument('--ignore-missing-scores', action='store_true') |
|
|
parser.add_argument('--datadir', type=Path, required=True) |
|
|
parser.add_argument('--dpo-criterion', type=str, default='reos.all', |
|
|
choices=['reos.all', 'medchem.sa', 'medchem.qed', 'gnina.vina_efficiency','combined']) |
|
|
parser.add_argument('--basedir', type=Path, default=None) |
|
|
parser.add_argument('--pocket', type=str, default='CA+', |
|
|
choices=['side_chain_bead', 'CA+']) |
|
|
parser.add_argument('--gnina', type=Path, default='gnina') |
|
|
parser.add_argument('--random_seed', type=int, default=42) |
|
|
parser.add_argument('--normal_modes', action='store_true') |
|
|
parser.add_argument('--flex', action='store_true') |
|
|
parser.add_argument('--toy', action='store_true') |
|
|
parser.add_argument('--toy_size', type=int, default=100) |
|
|
parser.add_argument('--n_pairs', type=int, default=5) |
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
def scan_smpl_dir(samples_dir): |
|
|
samples_dir = Path(samples_dir) |
|
|
subdirs = [] |
|
|
for subdir in tqdm(samples_dir.iterdir(), desc='Scanning samples'): |
|
|
if not subdir.is_dir(): |
|
|
continue |
|
|
if not sample_dir_valid(subdir): |
|
|
continue |
|
|
subdirs.append(subdir) |
|
|
return subdirs |
|
|
|
|
|
def sample_dir_valid(samples_dir): |
|
|
pocket = samples_dir / '0_pocket.pdb' |
|
|
if not pocket.exists(): |
|
|
return False |
|
|
ligands = list(samples_dir.glob('*_ligand.sdf')) |
|
|
if len(ligands) < 2: |
|
|
return False |
|
|
for ligand in ligands: |
|
|
if ligand.stat().st_size == 0: |
|
|
return False |
|
|
return True |
|
|
|
|
|
def return_winning_losing_smpl(score_1, score_2, criterion): |
|
|
if criterion == 'reos.all': |
|
|
if score_1 == score_2: |
|
|
return None |
|
|
return score_1 > score_2 |
|
|
elif criterion == 'medchem.sa': |
|
|
if np.abs(score_1 - score_2) < 0.5: |
|
|
return None |
|
|
return score_1 < score_2 |
|
|
elif criterion == 'medchem.qed': |
|
|
if np.abs(score_1 - score_2) < 0.1: |
|
|
return None |
|
|
return score_1 > score_2 |
|
|
elif criterion == 'gnina.vina_efficiency': |
|
|
if np.abs(score_1 - score_2) < 0.1: |
|
|
return None |
|
|
return score_1 < score_2 |
|
|
elif criterion == 'combined': |
|
|
score_reos_1, score_reos_2 = score_1['reos.all'], score_2['reos.all'] |
|
|
score_sa_1, score_sa_2 = score_1['medchem.sa'], score_2['medchem.sa'] |
|
|
score_qed_1, score_qed_2 = score_1['medchem.qed'], score_2['medchem.qed'] |
|
|
score_vina_1, score_vina_2 = score_1['gnina.vina_efficiency'], score_2['gnina.vina_efficiency'] |
|
|
if score_reos_1 == score_reos_2: return None |
|
|
|
|
|
reos_sign = score_reos_1 > score_reos_2 |
|
|
sa_sign = score_sa_1 < score_sa_2 |
|
|
qed_sign = score_qed_1 > score_qed_2 |
|
|
vina_sign = score_vina_1 < score_vina_2 |
|
|
signs = [reos_sign, sa_sign, qed_sign, vina_sign] |
|
|
if all(signs) or not any(signs): return signs[0] |
|
|
return None |
|
|
|
|
|
def compute_scores(sample_dirs, evaluator, criterion, n_pairs=5, toy=False, toy_size=100, |
|
|
precomp_scores=None, ignore_missing_scores=False): |
|
|
samples = [] |
|
|
pose_evaluator = PoseBustersEvaluator() |
|
|
pbar = tqdm(sample_dirs, desc='Computing scores for samples') |
|
|
|
|
|
for dir in pbar: |
|
|
pocket = dir / '0_pocket.pdb' |
|
|
ligands = list(dir.glob('*_ligand.sdf')) |
|
|
|
|
|
target_samples = [] |
|
|
for lig_path in ligands: |
|
|
try: |
|
|
mol = Chem.SDMolSupplier(str(lig_path))[0] |
|
|
if mol is None: |
|
|
continue |
|
|
smiles = rdmol_to_smiles(mol) |
|
|
except Exception as e: |
|
|
print('Failed to read ligand:', lig_path) |
|
|
continue |
|
|
|
|
|
if precomp_scores is not None and str(lig_path) in precomp_scores.index: |
|
|
mol_props = precomp_scores.loc[str(lig_path)].to_dict() |
|
|
if criterion == 'combined': |
|
|
if not 'reos.all' in mol_props or not 'medchem.sa' in mol_props or not 'medchem.qed' in mol_props or not 'gnina.vina_efficiency' in mol_props: |
|
|
print(f'Missing combined scores for ligand:', lig_path) |
|
|
continue |
|
|
mol_props['combined'] = { |
|
|
'reos.all': mol_props['reos.all'], |
|
|
'medchem.sa': mol_props['medchem.sa'], |
|
|
'medchem.qed': mol_props['medchem.qed'], |
|
|
'gnina.vina_efficiency': mol_props['gnina.vina_efficiency'], |
|
|
'combined': mol_props['gnina.vina_efficiency'] |
|
|
} |
|
|
else: |
|
|
mol_props = {} |
|
|
if criterion not in mol_props: |
|
|
if ignore_missing_scores: |
|
|
print(f'Missing {criterion} for ligand:', lig_path) |
|
|
continue |
|
|
print(f'Recomputing {criterion} for ligand:', lig_path) |
|
|
try: |
|
|
eval_res = evaluator.evaluate(mol) |
|
|
criterion_cat = criterion.split('.')[0] |
|
|
eval_res = {f'{criterion_cat}.{k}': v for k, v in eval_res.items()} |
|
|
score = eval_res[criterion] |
|
|
except: |
|
|
continue |
|
|
else: |
|
|
score = mol_props[criterion] |
|
|
|
|
|
if 'posebusters.all' not in mol_props: |
|
|
if ignore_missing_scores: |
|
|
print('Missing PoseBusters for ligand:', lig_path) |
|
|
continue |
|
|
print('Recomputing PoseBusters for ligand:', lig_path) |
|
|
try: |
|
|
pose_eval_res = pose_evaluator.evaluate(lig_path, pocket) |
|
|
except: |
|
|
continue |
|
|
if 'all' not in pose_eval_res or not pose_eval_res['all']: |
|
|
continue |
|
|
else: |
|
|
pose_eval_res = mol_props['posebusters.all'] |
|
|
if not pose_eval_res: |
|
|
continue |
|
|
|
|
|
target_samples.append({ |
|
|
'smiles': smiles, |
|
|
'score': score, |
|
|
'ligand_path': lig_path, |
|
|
'pocket_path': pocket |
|
|
}) |
|
|
|
|
|
|
|
|
unique_samples = {} |
|
|
for sample in target_samples: |
|
|
if sample['smiles'] not in unique_samples: |
|
|
unique_samples[sample['smiles']] = sample |
|
|
unique_samples = list(unique_samples.values()) |
|
|
if len(unique_samples) < 2: |
|
|
continue |
|
|
|
|
|
|
|
|
all_pairs = list(combinations(unique_samples, 2)) |
|
|
|
|
|
|
|
|
valid_pairs = [] |
|
|
for s1, s2 in all_pairs: |
|
|
sign = return_winning_losing_smpl(s1['score'], s2['score'], criterion) |
|
|
if sign is None: |
|
|
continue |
|
|
score_diff = abs(s1['score'] - s2['score']) if not criterion == 'combined' else \ |
|
|
abs(s1['score']['combined'] - s2['score']['combined']) |
|
|
if sign: |
|
|
valid_pairs.append((s1, s2, score_diff)) |
|
|
elif sign is False: |
|
|
valid_pairs.append((s2, s1, score_diff)) |
|
|
|
|
|
|
|
|
valid_pairs.sort(key=lambda x: x[2], reverse=True) |
|
|
used_ligand_paths = set() |
|
|
selected_pairs = [] |
|
|
for winning, losing, score_diff in valid_pairs: |
|
|
if winning['ligand_path'] in used_ligand_paths or losing['ligand_path'] in used_ligand_paths: |
|
|
continue |
|
|
|
|
|
selected_pairs.append((winning, losing, score_diff)) |
|
|
used_ligand_paths.add(winning['ligand_path']) |
|
|
used_ligand_paths.add(losing['ligand_path']) |
|
|
|
|
|
if len(selected_pairs) == n_pairs: |
|
|
break |
|
|
for winning, losing, _ in selected_pairs: |
|
|
d = { |
|
|
'score_w': winning['score'], |
|
|
'score_l': losing['score'], |
|
|
'pocket_p': winning['pocket_path'], |
|
|
'ligand_p_w': winning['ligand_path'], |
|
|
'ligand_p_l': losing['ligand_path'] |
|
|
} |
|
|
if isinstance(winning['score'], dict): |
|
|
for k, v in winning['score'].items(): |
|
|
d[f'{k}_w'] = v |
|
|
d['score_w'] = winning['score']['combined'] |
|
|
if isinstance(losing['score'], dict): |
|
|
for k, v in losing['score'].items(): |
|
|
d[f'{k}_l'] = v |
|
|
d['score_l'] = losing['score']['combined'] |
|
|
samples.append(d) |
|
|
|
|
|
pbar.set_postfix({'samples': len(samples)}) |
|
|
|
|
|
if toy and len(samples) >= toy_size: |
|
|
break |
|
|
|
|
|
return samples |
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
if 'reos' in args.dpo_criterion: |
|
|
evaluator = REOSEvaluator() |
|
|
elif 'medchem' in args.dpo_criterion: |
|
|
evaluator = MedChemEvaluator() |
|
|
elif 'gnina' in args.dpo_criterion: |
|
|
evaluator = GninaEvalulator(gnina=args.gnina) |
|
|
elif 'combined' in args.dpo_criterion: |
|
|
evaluator = None |
|
|
if args.metrics_detailed is None: |
|
|
raise ValueError('For combined criterion, detailed metrics file has to be provided') |
|
|
if not args.ignore_missing_scores: |
|
|
raise ValueError('For combined criterion, --ignore-missing-scores flag has to be set') |
|
|
else: |
|
|
raise ValueError(f"Unknown DPO criterion: {args.dpo_criterion}") |
|
|
|
|
|
|
|
|
dirname = f"dpo_{args.dpo_criterion.replace('.','_')}_{args.pocket}" |
|
|
if args.flex: |
|
|
dirname += '_flex' |
|
|
if args.normal_modes: |
|
|
dirname += '_nma' |
|
|
if args.toy: |
|
|
dirname += '_toy' |
|
|
processed_dir = Path(args.basedir, dirname) |
|
|
processed_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if (processed_dir / f'samples_{args.dpo_criterion}.csv').exists(): |
|
|
print(f"Samples already computed for criterion {args.dpo_criterion}, loading from file") |
|
|
samples = pd.read_csv(processed_dir / f'samples_{args.dpo_criterion}.csv') |
|
|
samples = [dict(row) for _, row in samples.iterrows()] |
|
|
print(f"Found {len(samples)} winning/losing samples") |
|
|
else: |
|
|
print('Scanning sample directory...') |
|
|
samples_dir = Path(args.smplsdir) |
|
|
|
|
|
sample_dirs = scan_smpl_dir(samples_dir) |
|
|
if args.metrics_detailed: |
|
|
print(f'Loading precomputed scores from {args.metrics_detailed}') |
|
|
precomp_scores = pd.read_csv(args.metrics_detailed) |
|
|
precomp_scores = precomp_scores.set_index('sdf_file') |
|
|
else: |
|
|
precomp_scores = None |
|
|
print(f'Found {len(sample_dirs)} valid sample directories') |
|
|
print('Computing scores...') |
|
|
samples = compute_scores(sample_dirs, evaluator, args.dpo_criterion, |
|
|
n_pairs=args.n_pairs, toy=args.toy, toy_size=args.toy_size, |
|
|
precomp_scores=precomp_scores, |
|
|
ignore_missing_scores=args.ignore_missing_scores) |
|
|
print(f'Found {len(samples)} winning/losing samples, saving to file') |
|
|
pd.DataFrame(samples).to_csv(Path(processed_dir, f'samples_{args.dpo_criterion}.csv'), index=False) |
|
|
|
|
|
data_split = {} |
|
|
data_split['train'] = samples |
|
|
if args.toy: |
|
|
data_split['train'] = random.sample(samples, min(args.toy_size, len(data_split['train']))) |
|
|
|
|
|
failed = {} |
|
|
train_smiles = [] |
|
|
|
|
|
for split in data_split.keys(): |
|
|
|
|
|
print(f"Processing {split} dataset...") |
|
|
|
|
|
ligands_w = defaultdict(list) |
|
|
ligands_l = defaultdict(list) |
|
|
pockets = defaultdict(list) |
|
|
|
|
|
tic = time() |
|
|
pbar = tqdm(data_split[split]) |
|
|
for entry in pbar: |
|
|
|
|
|
pbar.set_description(f'#failed: {len(failed)}') |
|
|
|
|
|
pdbfile = Path(entry['pocket_p']) |
|
|
entry['ligand_p_w'] = Path(entry['ligand_p_w']) |
|
|
entry['ligand_p_l'] = Path(entry['ligand_p_l']) |
|
|
entry['ligand_w'] = Chem.SDMolSupplier(str(entry['ligand_p_w']))[0] |
|
|
entry['ligand_l'] = Chem.SDMolSupplier(str(entry['ligand_p_l']))[0] |
|
|
|
|
|
try: |
|
|
pdb_model = PDBParser(QUIET=True).get_structure('', pdbfile)[0] |
|
|
|
|
|
ligand_w, pocket = process_raw_pair( |
|
|
pdb_model, entry['ligand_w'], pocket_representation=args.pocket, |
|
|
compute_nerf_params=args.flex, compute_bb_frames=args.flex, |
|
|
nma_input=pdbfile if args.normal_modes else None) |
|
|
ligand_l, _ = process_raw_pair( |
|
|
pdb_model, entry['ligand_l'], pocket_representation=args.pocket, |
|
|
compute_nerf_params=args.flex, compute_bb_frames=args.flex, |
|
|
nma_input=pdbfile if args.normal_modes else None) |
|
|
|
|
|
except (KeyError, AssertionError, FileNotFoundError, IndexError, |
|
|
ValueError, AttributeError) as e: |
|
|
failed[(split, entry['ligand_p_w'], entry['ligand_p_l'], pdbfile)] \ |
|
|
= (type(e).__name__, str(e)) |
|
|
continue |
|
|
|
|
|
nerf_keys = ['fixed_coord', 'atom_mask', 'nerf_indices', 'length', 'theta', 'chi', 'ddihedral', 'chi_indices'] |
|
|
for k in ['x', 'one_hot', 'bonds', 'bond_one_hot', 'v', 'nma_vec'] + nerf_keys + ['axis_angle']: |
|
|
if k in ligand_w: |
|
|
ligands_w[k].append(ligand_w[k]) |
|
|
ligands_l[k].append(ligand_l[k]) |
|
|
if k in pocket: |
|
|
pockets[k].append(pocket[k]) |
|
|
|
|
|
smpl_n = pdbfile.parent.name |
|
|
pocket_file = f'{smpl_n}__{pdbfile.stem}.pdb' |
|
|
ligand_file_w = f'{smpl_n}__{entry["ligand_p_w"].stem}.sdf' |
|
|
ligand_file_l = f'{smpl_n}__{entry["ligand_p_l"].stem}.sdf' |
|
|
ligands_w['name'].append(ligand_file_w) |
|
|
ligands_l['name'].append(ligand_file_l) |
|
|
pockets['name'].append(pocket_file) |
|
|
train_smiles.append(rdmol_to_smiles(entry['ligand_w'])) |
|
|
train_smiles.append(rdmol_to_smiles(entry['ligand_l'])) |
|
|
|
|
|
data = {'ligands_w': ligands_w, |
|
|
'ligands_l': ligands_l, |
|
|
'pockets': pockets} |
|
|
torch.save(data, Path(processed_dir, f'{split}.pt')) |
|
|
|
|
|
if split == 'train': |
|
|
np.save(Path(processed_dir, 'train_smiles.npy'), train_smiles) |
|
|
|
|
|
print(f"Processing {split} set took {(time() - tic) / 60.0:.2f} minutes") |
|
|
|
|
|
|
|
|
size_distr_p = Path(args.datadir, 'size_distribution.npy') |
|
|
type_histo_p = Path(args.datadir, 'ligand_type_histogram.npy') |
|
|
bond_histo_p = Path(args.datadir, 'ligand_bond_type_histogram.npy') |
|
|
metadata_p = Path(args.datadir, 'metadata.yml') |
|
|
shutil.copy(size_distr_p, processed_dir) |
|
|
shutil.copy(type_histo_p, processed_dir) |
|
|
shutil.copy(bond_histo_p, processed_dir) |
|
|
shutil.copy(metadata_p, processed_dir) |
|
|
|
|
|
|
|
|
val_dir = Path(args.datadir, 'val') |
|
|
test_dir = Path(args.datadir, 'test') |
|
|
val_pt = Path(args.datadir, 'val.pt') |
|
|
test_pt = Path(args.datadir, 'test.pt') |
|
|
assert val_dir.exists() and test_dir.exists() and val_pt.exists() and test_pt.exists() |
|
|
if (processed_dir / 'val').exists(): |
|
|
shutil.rmtree(processed_dir / 'val') |
|
|
if (processed_dir / 'test').exists(): |
|
|
shutil.rmtree(processed_dir / 'test') |
|
|
shutil.copytree(val_dir, processed_dir / 'val') |
|
|
shutil.copytree(test_dir, processed_dir / 'test') |
|
|
shutil.copy(val_pt, processed_dir) |
|
|
shutil.copy(test_pt, processed_dir) |
|
|
|
|
|
|
|
|
error_str = "" |
|
|
for k, v in failed.items(): |
|
|
error_str += f"{'Split':<15}: {k[0]}\n" |
|
|
error_str += f"{'Ligand W':<15}: {k[1]}\n" |
|
|
error_str += f"{'Ligand L':<15}: {k[2]}\n" |
|
|
error_str += f"{'Pocket':<15}: {k[3]}\n" |
|
|
error_str += f"{'Error type':<15}: {v[0]}\n" |
|
|
error_str += f"{'Error msg':<15}: {v[1]}\n\n" |
|
|
|
|
|
with open(Path(processed_dir, 'errors.txt'), 'w') as f: |
|
|
f.write(error_str) |
|
|
|
|
|
with open(Path(processed_dir, 'dataset_config.txt'), 'w') as f: |
|
|
f.write(str(args)) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |