File size: 7,308 Bytes
4742cab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import argparse
import warnings
from pathlib import Path
from time import time
import torch
from rdkit import Chem
from tqdm import tqdm
from lightning_modules import LigandPocketDDPM
from analysis.molecule_builder import process_molecule
import utils
MAXITER = 10
MAXNTRIES = 10
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint', type=Path)
parser.add_argument('--test_dir', type=Path)
parser.add_argument('--test_list', type=Path, default=None)
parser.add_argument('--outdir', type=Path)
parser.add_argument('--n_samples', type=int, default=100)
parser.add_argument('--all_frags', action='store_true')
parser.add_argument('--sanitize', action='store_true')
parser.add_argument('--relax', action='store_true')
parser.add_argument('--batch_size', type=int, default=120)
parser.add_argument('--resamplings', type=int, default=10)
parser.add_argument('--jump_length', type=int, default=1)
parser.add_argument('--timesteps', type=int, default=None)
parser.add_argument('--fix_n_nodes', action='store_true')
parser.add_argument('--n_nodes_bias', type=int, default=0)
parser.add_argument('--n_nodes_min', type=int, default=0)
parser.add_argument('--skip_existing', action='store_true')
args = parser.parse_args()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.outdir.mkdir(exist_ok=args.skip_existing)
raw_sdf_dir = Path(args.outdir, 'raw')
raw_sdf_dir.mkdir(exist_ok=args.skip_existing)
processed_sdf_dir = Path(args.outdir, 'processed')
processed_sdf_dir.mkdir(exist_ok=args.skip_existing)
times_dir = Path(args.outdir, 'pocket_times')
times_dir.mkdir(exist_ok=args.skip_existing)
# Load model
model = LigandPocketDDPM.load_from_checkpoint(
args.checkpoint, map_location=device)
model = model.to(device)
test_files = list(args.test_dir.glob('[!.]*.sdf'))
if args.test_list is not None:
with open(args.test_list, 'r') as f:
test_list = set(f.read().split(','))
test_files = [x for x in test_files if x.stem in test_list]
pbar = tqdm(test_files)
time_per_pocket = {}
for sdf_file in pbar:
ligand_name = sdf_file.stem
pdb_name, pocket_id, *suffix = ligand_name.split('_')
pdb_file = Path(sdf_file.parent, f"{pdb_name}.pdb")
txt_file = Path(sdf_file.parent, f"{ligand_name}.txt")
sdf_out_file_raw = Path(raw_sdf_dir, f'{ligand_name}_gen.sdf')
sdf_out_file_processed = Path(processed_sdf_dir,
f'{ligand_name}_gen.sdf')
time_file = Path(times_dir, f'{ligand_name}.txt')
if args.skip_existing and time_file.exists() \
and sdf_out_file_processed.exists() \
and sdf_out_file_raw.exists():
with open(time_file, 'r') as f:
time_per_pocket[str(sdf_file)] = float(f.read().split()[1])
continue
for n_try in range(MAXNTRIES):
try:
t_pocket_start = time()
with open(txt_file, 'r') as f:
resi_list = f.read().split()
if args.fix_n_nodes:
# some ligands (e.g. 6JWS_bio1_PT1:A:801) could not be read with sanitize=True
suppl = Chem.SDMolSupplier(str(sdf_file), sanitize=False)
num_nodes_lig = suppl[0].GetNumAtoms()
else:
num_nodes_lig = None
all_molecules = []
valid_molecules = []
processed_molecules = [] # only used as temporary variable
iter = 0
n_generated = 0
n_valid = 0
while len(valid_molecules) < args.n_samples:
iter += 1
if iter > MAXITER:
raise RuntimeError('Maximum number of iterations has been exceeded.')
num_nodes_lig_inflated = None if num_nodes_lig is None else \
torch.ones(args.batch_size, dtype=int) * num_nodes_lig
# Turn all filters off first
mols_batch = model.generate_ligands(
pdb_file, args.batch_size, resi_list,
num_nodes_lig=num_nodes_lig_inflated,
timesteps=args.timesteps, sanitize=False,
largest_frag=False, relax_iter=0,
n_nodes_bias=args.n_nodes_bias,
n_nodes_min=args.n_nodes_min,
resamplings=args.resamplings,
jump_length=args.jump_length)
all_molecules.extend(mols_batch)
# Filter to find valid molecules
mols_batch_processed = [
process_molecule(m, sanitize=args.sanitize,
relax_iter=(200 if args.relax else 0),
largest_frag=not args.all_frags)
for m in mols_batch
]
processed_molecules.extend(mols_batch_processed)
valid_mols_batch = [m for m in mols_batch_processed if m is not None]
n_generated += args.batch_size
n_valid += len(valid_mols_batch)
valid_molecules.extend(valid_mols_batch)
# Remove excess molecules from list
valid_molecules = valid_molecules[:args.n_samples]
# Reorder raw files
all_molecules = \
[all_molecules[i] for i, m in enumerate(processed_molecules)
if m is not None] + \
[all_molecules[i] for i, m in enumerate(processed_molecules)
if m is None]
# Write SDF files
utils.write_sdf_file(sdf_out_file_raw, all_molecules)
utils.write_sdf_file(sdf_out_file_processed, valid_molecules)
# Time the sampling process
time_per_pocket[str(sdf_file)] = time() - t_pocket_start
with open(time_file, 'w') as f:
f.write(f"{str(sdf_file)} {time_per_pocket[str(sdf_file)]}")
pbar.set_description(
f'Last processed: {ligand_name}. '
f'Validity: {n_valid / n_generated * 100:.2f}%. '
f'{(time() - t_pocket_start) / len(valid_molecules):.2f} '
f'sec/mol.')
break # no more tries needed
except (RuntimeError, ValueError) as e:
if n_try >= MAXNTRIES - 1:
raise RuntimeError("Maximum number of retries exceeded")
warnings.warn(f"Attempt {n_try + 1}/{MAXNTRIES} failed with "
f"error: '{e}'. Trying again...")
with open(Path(args.outdir, 'pocket_times.txt'), 'w') as f:
for k, v in time_per_pocket.items():
f.write(f"{k} {v}\n")
times_arr = torch.tensor([x for x in time_per_pocket.values()])
print(f"Time per pocket: {times_arr.mean():.3f} \pm "
f"{times_arr.std(unbiased=False):.2f}")
|