Upload 53 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- LICENSE +9 -0
- configs/sampling/sample_and_maybe_evaluate.yml +25 -0
- configs/sampling/sample_train_split.yml +25 -0
- configs/training/drugflow.yml +82 -0
- configs/training/drugflow_no_virtual_nodes.yml +82 -0
- configs/training/drugflow_ood.yml +83 -0
- configs/training/flexflow.yml +90 -0
- configs/training/preference_alignment.yml +93 -0
- docs/drugflow.jpg +3 -0
- environment.yaml +30 -0
- examples/kras.pdb +0 -0
- examples/kras_ref_ligand.sdf +74 -0
- scripts/python/evaluate_baselines.py +53 -0
- scripts/python/postprocess_metrics.py +271 -0
- src/analysis/SA_Score/README.md +1 -0
- src/analysis/SA_Score/fpscores.pkl.gz +3 -0
- src/analysis/SA_Score/sascorer.py +173 -0
- src/analysis/metrics.py +544 -0
- src/analysis/visualization_utils.py +192 -0
- src/constants.py +256 -0
- src/data/data_utils.py +901 -0
- src/data/dataset.py +208 -0
- src/data/misc.py +19 -0
- src/data/molecule_builder.py +107 -0
- src/data/nerf.py +250 -0
- src/data/normal_modes.py +69 -0
- src/data/postprocessing.py +93 -0
- src/data/process_crossdocked.py +176 -0
- src/data/process_dpo_dataset.py +406 -0
- src/data/sanifix.py +159 -0
- src/data/so3_utils.py +450 -0
- src/default/size_distribution.npy +3 -0
- src/generate.py +204 -0
- src/model/diffusion_utils.py +206 -0
- src/model/dpo.py +252 -0
- src/model/dynamics.py +791 -0
- src/model/dynamics_hetero.py +1008 -0
- src/model/flows.py +448 -0
- src/model/gvp.py +650 -0
- src/model/gvp_transformer.py +471 -0
- src/model/lightning.py +1426 -0
- src/model/loss_utils.py +79 -0
- src/model/markov_bridge.py +163 -0
- src/sample_and_evaluate.py +164 -0
- src/sbdd_metrics/evaluation.py +239 -0
- src/sbdd_metrics/fpscores.pkl.gz +3 -0
- src/sbdd_metrics/interactions.py +231 -0
- src/sbdd_metrics/metrics.py +929 -0
- src/sbdd_metrics/sascorer.py +173 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
docs/drugflow.jpg filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Arne Schneuing, Ilia Igashov, Adrian Dobbelstein
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
| 6 |
+
|
| 7 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
| 8 |
+
|
| 9 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
configs/sampling/sample_and_maybe_evaluate.yml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
checkpoint: <TODO>
|
| 2 |
+
set: test
|
| 3 |
+
sample_outdir: ./samples
|
| 4 |
+
n_samples: 100
|
| 5 |
+
sample_with_ground_truth_size: False
|
| 6 |
+
device: cuda
|
| 7 |
+
seed: 42
|
| 8 |
+
sample: True
|
| 9 |
+
postprocess: False
|
| 10 |
+
evaluate: False
|
| 11 |
+
reduce: reduce
|
| 12 |
+
|
| 13 |
+
# Override training config parameters if necessary
|
| 14 |
+
model_args:
|
| 15 |
+
|
| 16 |
+
virtual_nodes: [0, 5]
|
| 17 |
+
|
| 18 |
+
train_params:
|
| 19 |
+
datadir: ./processed_crossdocked
|
| 20 |
+
gnina: gnina
|
| 21 |
+
|
| 22 |
+
eval_params:
|
| 23 |
+
n_sampling_steps: 500
|
| 24 |
+
eval_batch_size: 1
|
| 25 |
+
|
configs/sampling/sample_train_split.yml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
checkpoint: <TODO>
|
| 2 |
+
set: train
|
| 3 |
+
sample_outdir: ./samples
|
| 4 |
+
n_samples: 50
|
| 5 |
+
sample_with_ground_truth_size: False
|
| 6 |
+
device: cuda
|
| 7 |
+
seed: 42
|
| 8 |
+
sample: True
|
| 9 |
+
postprocess: False
|
| 10 |
+
evaluate: False
|
| 11 |
+
reduce: reduce
|
| 12 |
+
|
| 13 |
+
# Override training config parameters if necessary
|
| 14 |
+
model_args:
|
| 15 |
+
|
| 16 |
+
virtual_nodes: [0, 10]
|
| 17 |
+
|
| 18 |
+
train_params:
|
| 19 |
+
datadir: ./processed_crossdocked
|
| 20 |
+
gnina: gnina
|
| 21 |
+
batch_size: 2
|
| 22 |
+
|
| 23 |
+
eval_params:
|
| 24 |
+
n_sampling_steps: 100
|
| 25 |
+
|
configs/training/drugflow.yml
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run_name: drugflow # iclr_drugflow_T5000
|
| 2 |
+
pocket_representation: CA+
|
| 3 |
+
virtual_nodes: [0, 10]
|
| 4 |
+
flexible: False
|
| 5 |
+
flexible_bb: False
|
| 6 |
+
|
| 7 |
+
train_params:
|
| 8 |
+
logdir: ./runs # symlink to any location you like
|
| 9 |
+
datadir: ./processed_crossdocked # symlink to the dataset location
|
| 10 |
+
enable_progress_bar: True
|
| 11 |
+
num_sanity_val_steps: 0
|
| 12 |
+
batch_size: 64
|
| 13 |
+
accumulate_grad_batches: 2
|
| 14 |
+
lr: 5.0e-4
|
| 15 |
+
n_epochs: 1000
|
| 16 |
+
num_workers: 0
|
| 17 |
+
gpus: 1
|
| 18 |
+
clip_grad: True
|
| 19 |
+
gnina: gnina
|
| 20 |
+
sample_from_clusters: False
|
| 21 |
+
sharded_dataset: False
|
| 22 |
+
|
| 23 |
+
wandb_params:
|
| 24 |
+
mode: online # disabled, offline, online
|
| 25 |
+
entity:
|
| 26 |
+
group: crossdocked
|
| 27 |
+
|
| 28 |
+
loss_params:
|
| 29 |
+
discrete_loss: VLB # VLB or CE
|
| 30 |
+
lambda_x: 1.0
|
| 31 |
+
lambda_h: 50.0
|
| 32 |
+
lambda_e: 50.0
|
| 33 |
+
lambda_chi: null
|
| 34 |
+
lambda_trans: null
|
| 35 |
+
lambda_rot: null
|
| 36 |
+
lambda_clash: null
|
| 37 |
+
timestep_weights: null
|
| 38 |
+
|
| 39 |
+
simulation_params:
|
| 40 |
+
n_steps: 5000
|
| 41 |
+
prior_h: marginal # uniform, marginal
|
| 42 |
+
prior_e: uniform # uniform, marginal
|
| 43 |
+
predict_final: False
|
| 44 |
+
predict_confidence: False
|
| 45 |
+
|
| 46 |
+
eval_params:
|
| 47 |
+
eval_epochs: 100
|
| 48 |
+
n_eval_samples: 4
|
| 49 |
+
n_sampling_steps: 500
|
| 50 |
+
eval_batch_size: 16
|
| 51 |
+
visualize_sample_epoch: 1
|
| 52 |
+
n_visualize_samples: 100
|
| 53 |
+
visualize_chain_epoch: 1
|
| 54 |
+
keep_frames: 100
|
| 55 |
+
sample_with_ground_truth_size: True
|
| 56 |
+
|
| 57 |
+
predictor_params:
|
| 58 |
+
heterogeneous_graph: True
|
| 59 |
+
backbone: gvp
|
| 60 |
+
num_rbf_time: 16
|
| 61 |
+
edge_cutoff_ligand: null
|
| 62 |
+
edge_cutoff_pocket: 10.0
|
| 63 |
+
edge_cutoff_interaction: 10.0
|
| 64 |
+
cycle_counts: True
|
| 65 |
+
spectral_feat: False
|
| 66 |
+
reflection_equivariant: False
|
| 67 |
+
num_rbf: 16
|
| 68 |
+
d_max: 15.0
|
| 69 |
+
self_conditioning: True
|
| 70 |
+
augment_residue_sc: False
|
| 71 |
+
augment_ligand_sc: False
|
| 72 |
+
normal_modes: False
|
| 73 |
+
add_chi_as_feature: False
|
| 74 |
+
angle_act_fn: null
|
| 75 |
+
add_all_atom_diff: False
|
| 76 |
+
|
| 77 |
+
gvp_params:
|
| 78 |
+
n_layers: 5
|
| 79 |
+
node_h_dim: [ 128, 32 ] # (s, V)
|
| 80 |
+
edge_h_dim: [ 128, 32 ]
|
| 81 |
+
dropout: 0.0
|
| 82 |
+
vector_gate: True
|
configs/training/drugflow_no_virtual_nodes.yml
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run_name: drugflow_no_virtual_nodes # iclr_drugflow_T5000_no_virtual_nodes
|
| 2 |
+
pocket_representation: CA+
|
| 3 |
+
virtual_nodes: null
|
| 4 |
+
flexible: False
|
| 5 |
+
flexible_bb: False
|
| 6 |
+
|
| 7 |
+
train_params:
|
| 8 |
+
logdir: ./runs # symlink to any location you like
|
| 9 |
+
datadir: ./processed_crossdocked # symlink to the dataset location
|
| 10 |
+
enable_progress_bar: True
|
| 11 |
+
num_sanity_val_steps: 0
|
| 12 |
+
batch_size: 64
|
| 13 |
+
accumulate_grad_batches: 2
|
| 14 |
+
lr: 5.0e-4
|
| 15 |
+
n_epochs: 1000
|
| 16 |
+
num_workers: 0
|
| 17 |
+
gpus: 1
|
| 18 |
+
clip_grad: True
|
| 19 |
+
gnina: gnina
|
| 20 |
+
sample_from_clusters: False
|
| 21 |
+
sharded_dataset: False
|
| 22 |
+
|
| 23 |
+
wandb_params:
|
| 24 |
+
mode: online # disabled, offline, online
|
| 25 |
+
entity: lpdi
|
| 26 |
+
group: crossdocked
|
| 27 |
+
|
| 28 |
+
loss_params:
|
| 29 |
+
discrete_loss: VLB # VLB or CE
|
| 30 |
+
lambda_x: 1.0
|
| 31 |
+
lambda_h: 50.0
|
| 32 |
+
lambda_e: 50.0
|
| 33 |
+
lambda_chi: null
|
| 34 |
+
lambda_trans: null
|
| 35 |
+
lambda_rot: null
|
| 36 |
+
lambda_clash: null
|
| 37 |
+
timestep_weights: null
|
| 38 |
+
|
| 39 |
+
simulation_params:
|
| 40 |
+
n_steps: 5000
|
| 41 |
+
prior_h: marginal # uniform, marginal
|
| 42 |
+
prior_e: uniform # uniform, marginal
|
| 43 |
+
predict_final: False
|
| 44 |
+
predict_confidence: False
|
| 45 |
+
|
| 46 |
+
eval_params:
|
| 47 |
+
eval_epochs: 100
|
| 48 |
+
n_eval_samples: 4
|
| 49 |
+
n_sampling_steps: 500
|
| 50 |
+
eval_batch_size: 16
|
| 51 |
+
visualize_sample_epoch: 1
|
| 52 |
+
n_visualize_samples: 100
|
| 53 |
+
visualize_chain_epoch: 1
|
| 54 |
+
keep_frames: 100
|
| 55 |
+
sample_with_ground_truth_size: True
|
| 56 |
+
|
| 57 |
+
predictor_params:
|
| 58 |
+
heterogeneous_graph: True
|
| 59 |
+
backbone: gvp
|
| 60 |
+
num_rbf_time: 16
|
| 61 |
+
edge_cutoff_ligand: null
|
| 62 |
+
edge_cutoff_pocket: 10.0
|
| 63 |
+
edge_cutoff_interaction: 10.0
|
| 64 |
+
cycle_counts: True
|
| 65 |
+
spectral_feat: False
|
| 66 |
+
reflection_equivariant: False
|
| 67 |
+
num_rbf: 16
|
| 68 |
+
d_max: 15.0
|
| 69 |
+
self_conditioning: True
|
| 70 |
+
augment_residue_sc: False
|
| 71 |
+
augment_ligand_sc: False
|
| 72 |
+
normal_modes: False
|
| 73 |
+
add_chi_as_feature: False
|
| 74 |
+
angle_act_fn: null
|
| 75 |
+
add_all_atom_diff: False
|
| 76 |
+
|
| 77 |
+
gvp_params:
|
| 78 |
+
n_layers: 5
|
| 79 |
+
node_h_dim: [ 128, 32 ] # (s, V)
|
| 80 |
+
edge_h_dim: [ 128, 32 ]
|
| 81 |
+
dropout: 0.0
|
| 82 |
+
vector_gate: True
|
configs/training/drugflow_ood.yml
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run_name: drugflow_ood # iclr_drugflow_T5000_confidence_ru10
|
| 2 |
+
pocket_representation: CA+
|
| 3 |
+
virtual_nodes: [0, 10]
|
| 4 |
+
flexible: False
|
| 5 |
+
flexible_bb: False
|
| 6 |
+
|
| 7 |
+
train_params:
|
| 8 |
+
logdir: ./runs # symlink to any location you like
|
| 9 |
+
datadir: ./processed_crossdocked # symlink to the dataset location
|
| 10 |
+
enable_progress_bar: True
|
| 11 |
+
num_sanity_val_steps: 0
|
| 12 |
+
batch_size: 64
|
| 13 |
+
accumulate_grad_batches: 2
|
| 14 |
+
lr: 5.0e-4
|
| 15 |
+
n_epochs: 1000
|
| 16 |
+
num_workers: 0
|
| 17 |
+
gpus: 1
|
| 18 |
+
clip_grad: True
|
| 19 |
+
gnina: gnina
|
| 20 |
+
sample_from_clusters: False
|
| 21 |
+
sharded_dataset: False
|
| 22 |
+
|
| 23 |
+
wandb_params:
|
| 24 |
+
mode: online # disabled, offline, online
|
| 25 |
+
entity: lpdi
|
| 26 |
+
group: crossdocked
|
| 27 |
+
|
| 28 |
+
loss_params:
|
| 29 |
+
discrete_loss: VLB # VLB or CE
|
| 30 |
+
lambda_x: 1.0
|
| 31 |
+
lambda_h: 50.0
|
| 32 |
+
lambda_e: 50.0
|
| 33 |
+
lambda_chi: null
|
| 34 |
+
lambda_trans: null
|
| 35 |
+
lambda_rot: null
|
| 36 |
+
lambda_clash: null
|
| 37 |
+
timestep_weights: null
|
| 38 |
+
regularize_uncertainty: 10.0
|
| 39 |
+
|
| 40 |
+
simulation_params:
|
| 41 |
+
n_steps: 5000
|
| 42 |
+
prior_h: marginal # uniform, marginal
|
| 43 |
+
prior_e: uniform # uniform, marginal
|
| 44 |
+
predict_final: False
|
| 45 |
+
predict_confidence: True
|
| 46 |
+
|
| 47 |
+
eval_params:
|
| 48 |
+
eval_epochs: 100
|
| 49 |
+
n_eval_samples: 4
|
| 50 |
+
n_sampling_steps: 500
|
| 51 |
+
eval_batch_size: 16
|
| 52 |
+
visualize_sample_epoch: 1
|
| 53 |
+
n_visualize_samples: 100
|
| 54 |
+
visualize_chain_epoch: 1
|
| 55 |
+
keep_frames: 100
|
| 56 |
+
sample_with_ground_truth_size: True
|
| 57 |
+
|
| 58 |
+
predictor_params:
|
| 59 |
+
heterogeneous_graph: True
|
| 60 |
+
backbone: gvp
|
| 61 |
+
num_rbf_time: 16
|
| 62 |
+
edge_cutoff_ligand: null
|
| 63 |
+
edge_cutoff_pocket: 10.0
|
| 64 |
+
edge_cutoff_interaction: 10.0
|
| 65 |
+
cycle_counts: True
|
| 66 |
+
spectral_feat: False
|
| 67 |
+
reflection_equivariant: False
|
| 68 |
+
num_rbf: 16
|
| 69 |
+
d_max: 15.0
|
| 70 |
+
self_conditioning: True
|
| 71 |
+
augment_residue_sc: False
|
| 72 |
+
augment_ligand_sc: False
|
| 73 |
+
normal_modes: False
|
| 74 |
+
add_chi_as_feature: False
|
| 75 |
+
angle_act_fn: null
|
| 76 |
+
add_all_atom_diff: False
|
| 77 |
+
|
| 78 |
+
gvp_params:
|
| 79 |
+
n_layers: 5
|
| 80 |
+
node_h_dim: [ 128, 32 ] # (s, V)
|
| 81 |
+
edge_h_dim: [ 128, 32 ]
|
| 82 |
+
dropout: 0.0
|
| 83 |
+
vector_gate: True
|
configs/training/flexflow.yml
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run_name: flexflow
|
| 2 |
+
pocket_representation: CA+
|
| 3 |
+
virtual_nodes: [0, 10]
|
| 4 |
+
flexible: True
|
| 5 |
+
flexible_bb: False
|
| 6 |
+
|
| 7 |
+
train_params:
|
| 8 |
+
logdir: ./runs # symlink to any location you like
|
| 9 |
+
datadir: ./processed_crossdocked # symlink to the dataset location
|
| 10 |
+
enable_progress_bar: False
|
| 11 |
+
num_sanity_val_steps: 0
|
| 12 |
+
batch_size: 64
|
| 13 |
+
accumulate_grad_batches: 2
|
| 14 |
+
lr: 5.0e-4
|
| 15 |
+
lr_step_size: null
|
| 16 |
+
lr_gamma: null
|
| 17 |
+
n_epochs: 700
|
| 18 |
+
num_workers: 4
|
| 19 |
+
gpus: 1
|
| 20 |
+
clip_grad: True
|
| 21 |
+
gnina: gnina # add Gnina location to path
|
| 22 |
+
sample_from_clusters: False
|
| 23 |
+
sharded_dataset: False
|
| 24 |
+
|
| 25 |
+
wandb_params:
|
| 26 |
+
mode: online # disabled, offline, online
|
| 27 |
+
entity:
|
| 28 |
+
group: crossdocked
|
| 29 |
+
|
| 30 |
+
loss_params:
|
| 31 |
+
discrete_loss: VLB # VLB or CE
|
| 32 |
+
reduce: sum # 'mean' or 'sum'
|
| 33 |
+
lambda_x: 0.015
|
| 34 |
+
lambda_h: 2.5
|
| 35 |
+
lambda_e: 0.25
|
| 36 |
+
lambda_chi: 0.002
|
| 37 |
+
lambda_trans: null
|
| 38 |
+
lambda_rot: null
|
| 39 |
+
lambda_clash: null
|
| 40 |
+
regularize_uncertainty: null
|
| 41 |
+
timestep_weights: null
|
| 42 |
+
|
| 43 |
+
simulation_params:
|
| 44 |
+
n_steps: 5000
|
| 45 |
+
prior_h: marginal # uniform, marginal
|
| 46 |
+
prior_e: uniform # uniform, marginal
|
| 47 |
+
predict_final: False
|
| 48 |
+
predict_confidence: False
|
| 49 |
+
scheduler_chi:
|
| 50 |
+
type: polynomial
|
| 51 |
+
k: 3 # constant for exponential scheduler kappa(t)=(1-t)^k
|
| 52 |
+
|
| 53 |
+
eval_params:
|
| 54 |
+
eval_epochs: 100
|
| 55 |
+
n_loss_per_sample: 100
|
| 56 |
+
n_eval_samples: 4
|
| 57 |
+
n_sampling_steps: 500
|
| 58 |
+
eval_batch_size: 16
|
| 59 |
+
visualize_sample_epoch: 1
|
| 60 |
+
n_visualize_samples: 100
|
| 61 |
+
visualize_chain_epoch: 1
|
| 62 |
+
keep_frames: 100
|
| 63 |
+
sample_with_ground_truth_size: True
|
| 64 |
+
|
| 65 |
+
predictor_params:
|
| 66 |
+
heterogeneous_graph: True
|
| 67 |
+
backbone: gvp
|
| 68 |
+
num_rbf_time: 16
|
| 69 |
+
edge_cutoff_ligand: null
|
| 70 |
+
edge_cutoff_pocket: 10.0
|
| 71 |
+
edge_cutoff_interaction: 10.0
|
| 72 |
+
cycle_counts: True
|
| 73 |
+
spectral_feat: False
|
| 74 |
+
reflection_equivariant: False
|
| 75 |
+
num_rbf: 16
|
| 76 |
+
d_max: 15.0
|
| 77 |
+
self_conditioning: True
|
| 78 |
+
augment_residue_sc: False
|
| 79 |
+
augment_ligand_sc: False
|
| 80 |
+
normal_modes: False
|
| 81 |
+
add_chi_as_feature: False
|
| 82 |
+
angle_act_fn: null
|
| 83 |
+
add_all_atom_diff: True
|
| 84 |
+
|
| 85 |
+
gvp_params:
|
| 86 |
+
n_layers: 5
|
| 87 |
+
node_h_dim: [ 128, 32 ] # (s, V)
|
| 88 |
+
edge_h_dim: [ 128, 32 ]
|
| 89 |
+
dropout: 0.0
|
| 90 |
+
vector_gate: True
|
configs/training/preference_alignment.yml
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run_name: drugflow_preference_alignment
|
| 2 |
+
|
| 3 |
+
checkpoint: ./reference.ckpt # TODO: specify reference checkpoint
|
| 4 |
+
dpo_mode: single_dpo_comp_v3
|
| 5 |
+
|
| 6 |
+
pocket_representation: CA+
|
| 7 |
+
virtual_nodes: [0, 10]
|
| 8 |
+
flexible: False
|
| 9 |
+
flexible_bb: False
|
| 10 |
+
|
| 11 |
+
train_params:
|
| 12 |
+
logdir: ./runs # symlink to any location you like
|
| 13 |
+
datadir: ./processed_crossdocked # symlink to the dataset location
|
| 14 |
+
enable_progress_bar: True
|
| 15 |
+
num_sanity_val_steps: 0
|
| 16 |
+
batch_size: 64
|
| 17 |
+
accumulate_grad_batches: 2
|
| 18 |
+
lr: 5.0e-5
|
| 19 |
+
n_epochs: 500
|
| 20 |
+
num_workers: 0
|
| 21 |
+
gpus: 1
|
| 22 |
+
clip_grad: True
|
| 23 |
+
gnina: gnina # path to gnina binary
|
| 24 |
+
sample_from_clusters: False
|
| 25 |
+
sharded_dataset: False
|
| 26 |
+
|
| 27 |
+
wandb_params:
|
| 28 |
+
mode: online # disabled, offline, online
|
| 29 |
+
entity:
|
| 30 |
+
group: crossdocked
|
| 31 |
+
|
| 32 |
+
loss_params:
|
| 33 |
+
discrete_loss: VLB # VLB or CE
|
| 34 |
+
lambda_x: 1.0
|
| 35 |
+
lambda_h: 500
|
| 36 |
+
dpo_lambda_h: 2500
|
| 37 |
+
lambda_e: 500
|
| 38 |
+
dpo_lambda_e: 2500
|
| 39 |
+
lambda_chi: 0.5 # only effective if flexible=True
|
| 40 |
+
lambda_trans: 1.0 # only effective if flexible_bb=True
|
| 41 |
+
lambda_rot: 0.1 # only effective if flexible_bb=True
|
| 42 |
+
lambda_clash: null
|
| 43 |
+
timestep_weights: null # sigmoid_a=1_b=10 # null, sigmoid_a=?_b=?
|
| 44 |
+
dpo_beta: 100.0
|
| 45 |
+
dpo_beta_schedule: 't'
|
| 46 |
+
dpo_lambda_w: 1.0
|
| 47 |
+
dpo_lambda_l: 0.2
|
| 48 |
+
clamp_dpo: False
|
| 49 |
+
|
| 50 |
+
simulation_params:
|
| 51 |
+
n_steps: 5000
|
| 52 |
+
prior_h: marginal # uniform, marginal
|
| 53 |
+
prior_e: uniform # uniform, marginal
|
| 54 |
+
predict_final: False
|
| 55 |
+
predict_confidence: False
|
| 56 |
+
|
| 57 |
+
eval_params:
|
| 58 |
+
eval_epochs: 4
|
| 59 |
+
n_eval_samples: 1
|
| 60 |
+
n_sampling_steps: 500
|
| 61 |
+
eval_batch_size: 16
|
| 62 |
+
visualize_sample_epoch: 1
|
| 63 |
+
n_visualize_samples: 10
|
| 64 |
+
visualize_chain_epoch: 1
|
| 65 |
+
keep_frames: 100
|
| 66 |
+
sample_with_ground_truth_size: True
|
| 67 |
+
|
| 68 |
+
predictor_params:
|
| 69 |
+
heterogeneous_graph: True
|
| 70 |
+
backbone: gvp
|
| 71 |
+
num_rbf_time: 16
|
| 72 |
+
edge_cutoff_ligand: null
|
| 73 |
+
edge_cutoff_pocket: 10.0
|
| 74 |
+
edge_cutoff_interaction: 10.0
|
| 75 |
+
cycle_counts: True
|
| 76 |
+
spectral_feat: False
|
| 77 |
+
reflection_equivariant: False
|
| 78 |
+
num_rbf: 16
|
| 79 |
+
d_max: 15.0
|
| 80 |
+
self_conditioning: True
|
| 81 |
+
augment_residue_sc: False
|
| 82 |
+
augment_ligand_sc: False
|
| 83 |
+
normal_modes: False
|
| 84 |
+
add_chi_as_feature: False
|
| 85 |
+
angle_act_fn: null
|
| 86 |
+
add_all_atom_diff: False
|
| 87 |
+
|
| 88 |
+
gvp_params:
|
| 89 |
+
n_layers: 5
|
| 90 |
+
node_h_dim: [ 128, 32 ] # (s, V)
|
| 91 |
+
edge_h_dim: [ 128, 32 ]
|
| 92 |
+
dropout: 0.0
|
| 93 |
+
vector_gate: True
|
docs/drugflow.jpg
ADDED
|
Git LFS Details
|
environment.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: sbdd
|
| 2 |
+
|
| 3 |
+
channels:
|
| 4 |
+
- pytorch
|
| 5 |
+
- conda-forge
|
| 6 |
+
- anaconda
|
| 7 |
+
- pyg
|
| 8 |
+
- nvidia
|
| 9 |
+
|
| 10 |
+
dependencies:
|
| 11 |
+
- python=3.11.8
|
| 12 |
+
- pytorch=2.2.1=*cuda12.1*
|
| 13 |
+
- pytorch-cuda=12.1
|
| 14 |
+
- pytorch-lightning=2.2.1
|
| 15 |
+
- rdkit=2023.09.6
|
| 16 |
+
- openbabel=3.1.1
|
| 17 |
+
- biopython=1.83
|
| 18 |
+
- scipy=1.12.0
|
| 19 |
+
- pyg=2.5.1
|
| 20 |
+
- pytorch-scatter=2.1.2
|
| 21 |
+
- ProDy=2.4.0
|
| 22 |
+
- wandb=0.16.3
|
| 23 |
+
- pandas=2.2.2
|
| 24 |
+
- pip=24.0
|
| 25 |
+
- pip:
|
| 26 |
+
- posebusters==0.3.1
|
| 27 |
+
- useful_rdkit_utils==0.65
|
| 28 |
+
- fcd==1.2.2
|
| 29 |
+
- webdataset==0.2.86
|
| 30 |
+
- prolif==2.0.3
|
examples/kras.pdb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/kras_ref_ligand.sdf
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
8AZR
|
| 2 |
+
PyMOL2.5 3D 0
|
| 3 |
+
|
| 4 |
+
32 36 0 0 0 0 0 0 0 0999 V2000
|
| 5 |
+
15.7084 1.6569 4.9428 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 6 |
+
16.2939 1.9182 6.3219 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 7 |
+
17.7757 1.5677 6.3468 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 8 |
+
18.0388 0.0580 6.1328 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 9 |
+
16.1458 0.3026 4.4709 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 10 |
+
17.1748 -0.4207 4.9854 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 11 |
+
17.2894 -1.6945 4.3617 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 12 |
+
16.3332 -1.9132 3.3763 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 13 |
+
15.2948 -0.5437 3.2188 S 0 0 0 0 0 0 0 0 0 0 0 0
|
| 14 |
+
17.6856 -0.7371 7.4005 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 15 |
+
19.5008 -0.1084 5.7694 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 16 |
+
19.9420 0.4778 4.6523 O 0 0 0 0 0 0 0 0 0 0 0 0
|
| 17 |
+
21.3366 0.1893 4.6052 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 18 |
+
21.5306 -0.5212 5.6843 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 19 |
+
20.3929 -0.7319 6.4483 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 20 |
+
16.1651 -3.0052 2.6033 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 21 |
+
22.8349 -1.0932 6.0768 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 22 |
+
23.9207 -0.6312 5.4365 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 23 |
+
25.1129 -1.1528 5.7755 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 24 |
+
25.2639 -2.1387 6.7500 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 25 |
+
24.1280 -2.5941 7.3940 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 26 |
+
22.8945 -2.0709 7.0591 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 27 |
+
18.2816 -2.6789 4.6625 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 28 |
+
19.0589 -3.4973 4.8688 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 29 |
+
26.1982 -0.6750 5.0820 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 30 |
+
26.0358 0.4071 4.0954 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 31 |
+
26.8978 0.1468 2.8491 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 32 |
+
28.2989 -0.1678 3.2648 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 33 |
+
28.3171 -1.4142 4.0851 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 34 |
+
27.5312 -1.2091 5.3777 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 35 |
+
29.1988 -0.2741 2.0804 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 36 |
+
26.3415 1.7618 4.7132 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 37 |
+
1 2 1 0 0 0 0
|
| 38 |
+
2 3 1 0 0 0 0
|
| 39 |
+
3 4 1 0 0 0 0
|
| 40 |
+
4 6 1 0 0 0 0
|
| 41 |
+
4 10 1 0 0 0 0
|
| 42 |
+
4 11 1 0 0 0 0
|
| 43 |
+
1 5 1 0 0 0 0
|
| 44 |
+
5 6 4 0 0 0 0
|
| 45 |
+
5 9 4 0 0 0 0
|
| 46 |
+
6 7 4 0 0 0 0
|
| 47 |
+
7 8 4 0 0 0 0
|
| 48 |
+
7 23 1 0 0 0 0
|
| 49 |
+
8 9 4 0 0 0 0
|
| 50 |
+
8 16 1 0 0 0 0
|
| 51 |
+
11 12 4 0 0 0 0
|
| 52 |
+
11 15 4 0 0 0 0
|
| 53 |
+
12 13 4 0 0 0 0
|
| 54 |
+
13 14 4 0 0 0 0
|
| 55 |
+
14 15 4 0 0 0 0
|
| 56 |
+
14 17 1 0 0 0 0
|
| 57 |
+
17 18 4 0 0 0 0
|
| 58 |
+
17 22 4 0 0 0 0
|
| 59 |
+
18 19 4 0 0 0 0
|
| 60 |
+
19 25 1 0 0 0 0
|
| 61 |
+
19 20 4 0 0 0 0
|
| 62 |
+
20 21 4 0 0 0 0
|
| 63 |
+
21 22 4 0 0 0 0
|
| 64 |
+
23 24 3 0 0 0 0
|
| 65 |
+
25 26 1 0 0 0 0
|
| 66 |
+
26 27 1 0 0 0 0
|
| 67 |
+
26 32 1 0 0 0 0
|
| 68 |
+
27 28 1 0 0 0 0
|
| 69 |
+
28 29 1 0 0 0 0
|
| 70 |
+
29 30 1 0 0 0 0
|
| 71 |
+
25 30 1 0 0 0 0
|
| 72 |
+
28 31 1 0 0 0 0
|
| 73 |
+
M END
|
| 74 |
+
$$$$
|
scripts/python/evaluate_baselines.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import pickle
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
basedir = Path(__file__).resolve().parent.parent.parent
|
| 7 |
+
sys.path.append(str(basedir))
|
| 8 |
+
|
| 9 |
+
from src.sbdd_metrics.evaluation import compute_all_metrics_drugflow
|
| 10 |
+
|
| 11 |
+
if __name__ == '__main__':
|
| 12 |
+
p = argparse.ArgumentParser()
|
| 13 |
+
p.add_argument('--in_dir', type=Path, required=True, help='Directory with samples')
|
| 14 |
+
p.add_argument('--out_dir', type=str, required=True, help='Output directory')
|
| 15 |
+
p.add_argument('--reference_smiles', type=str, default=None, help='Path to the .npy file with reference SMILES (optional)')
|
| 16 |
+
p.add_argument('--gnina', type=str, default=None, help='Path to the gnina binary file (optional)')
|
| 17 |
+
p.add_argument('--reduce', type=str, default=None, help='Path to the reduce binary file (optional)')
|
| 18 |
+
p.add_argument('--n_samples', type=int, default=None, help='Top-N sampels to evaluate (optional)')
|
| 19 |
+
p.add_argument('--exclude', type=str, nargs='+', default=[], help='Evaluator IDs to exclude')
|
| 20 |
+
p.add_argument('--job_id', type=int, default=0, help='Job ID')
|
| 21 |
+
p.add_argument('--n_jobs', type=int, default=1, help='Number of jobs')
|
| 22 |
+
args = p.parse_args()
|
| 23 |
+
|
| 24 |
+
Path(args.out_dir).mkdir(exist_ok=True, parents=True)
|
| 25 |
+
if args.job_id == 0 and args.n_jobs == 1:
|
| 26 |
+
out_detailed_table = Path(args.out_dir, 'metrics_detailed.csv')
|
| 27 |
+
out_aggregated_table = Path(args.out_dir, 'metrics_aggregated.csv')
|
| 28 |
+
out_distributions_file = Path(args.out_dir, 'metrics_data.pkl')
|
| 29 |
+
else:
|
| 30 |
+
out_detailed_table = Path(args.out_dir, f'metrics_detailed_{args.job_id}.csv')
|
| 31 |
+
out_aggregated_table = Path(args.out_dir, f'metrics_aggregated_{args.job_id}.csv')
|
| 32 |
+
out_distributions_file = Path(args.out_dir, f'metrics_data_{args.job_id}.pkl')
|
| 33 |
+
|
| 34 |
+
if out_detailed_table.exists() and out_aggregated_table.exists() and out_distributions_file.exists():
|
| 35 |
+
print(f'Data already exist. Terminating')
|
| 36 |
+
sys.exit(0)
|
| 37 |
+
|
| 38 |
+
print(f'Evaluating: {args.in_dir}')
|
| 39 |
+
data, detailed, aggregated = compute_all_metrics_drugflow(
|
| 40 |
+
in_dir=args.in_dir,
|
| 41 |
+
gnina_path=args.gnina,
|
| 42 |
+
reduce_path=args.reduce,
|
| 43 |
+
reference_smiles_path=args.reference_smiles,
|
| 44 |
+
n_samples=args.n_samples,
|
| 45 |
+
exclude_evaluators=args.exclude,
|
| 46 |
+
job_id=args.job_id,
|
| 47 |
+
n_jobs=args.n_jobs,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
detailed.to_csv(out_detailed_table, index=False)
|
| 51 |
+
aggregated.to_csv(out_aggregated_table, index=False)
|
| 52 |
+
with open(Path(out_distributions_file), 'wb') as f:
|
| 53 |
+
pickle.dump(data, f)
|
scripts/python/postprocess_metrics.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
import sys
|
| 5 |
+
from collections import Counter, defaultdict
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from rdkit import Chem
|
| 11 |
+
from scipy.stats import wasserstein_distance
|
| 12 |
+
from scipy.spatial.distance import jensenshannon
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
basedir = Path(__file__).resolve().parent.parent.parent
|
| 16 |
+
sys.path.append(str(basedir))
|
| 17 |
+
|
| 18 |
+
from src.data.data_utils import atom_encoder, bond_encoder, encode_atom
|
| 19 |
+
from src.sbdd_metrics.evaluation import VALIDITY_METRIC_NAME, aggregated_metrics, collection_metrics, get_data_type
|
| 20 |
+
from src.sbdd_metrics.metrics import FullEvaluator
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
DATA_TYPES = data_types = FullEvaluator().dtypes
|
| 24 |
+
|
| 25 |
+
MEDCHEM_PROPS = [
|
| 26 |
+
'medchem.qed',
|
| 27 |
+
'medchem.sa',
|
| 28 |
+
'medchem.logp',
|
| 29 |
+
'medchem.lipinski',
|
| 30 |
+
'medchem.size',
|
| 31 |
+
'medchem.n_rotatable_bonds',
|
| 32 |
+
'energy.energy',
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
DOCKING_PROPS = [
|
| 36 |
+
'gnina.vina_score',
|
| 37 |
+
'gnina.gnina_score',
|
| 38 |
+
'gnina.vina_efficiency',
|
| 39 |
+
'gnina.gnina_efficiency',
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
RELEVANT_INTERACTIONS = [
|
| 43 |
+
'interactions.HBAcceptor',
|
| 44 |
+
'interactions.HBDonor',
|
| 45 |
+
'interactions.HB',
|
| 46 |
+
'interactions.PiStacking',
|
| 47 |
+
'interactions.Hydrophobic',
|
| 48 |
+
#
|
| 49 |
+
'interactions.HBAcceptor.normalized',
|
| 50 |
+
'interactions.HBDonor.normalized',
|
| 51 |
+
'interactions.HB.normalized',
|
| 52 |
+
'interactions.PiStacking.normalized',
|
| 53 |
+
'interactions.Hydrophobic.normalized'
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def compute_discrete_distributions(smiles, name):
|
| 58 |
+
atom_counter = Counter()
|
| 59 |
+
bond_counter = Counter()
|
| 60 |
+
|
| 61 |
+
for smi in tqdm(smiles, desc=name):
|
| 62 |
+
mol = Chem.MolFromSmiles(smi)
|
| 63 |
+
mol = Chem.RemoveAllHs(mol, sanitize=False)
|
| 64 |
+
for atom in mol.GetAtoms():
|
| 65 |
+
try:
|
| 66 |
+
encoded_atom = encode_atom(atom, atom_encoder=atom_encoder)
|
| 67 |
+
except KeyError:
|
| 68 |
+
continue
|
| 69 |
+
atom_counter[encoded_atom] += 1
|
| 70 |
+
for bond in mol.GetBonds():
|
| 71 |
+
bond_counter[bond_encoder[str(bond.GetBondType())]] += 1
|
| 72 |
+
|
| 73 |
+
atom_distribution = np.zeros(len(atom_encoder))
|
| 74 |
+
bond_distribution = np.zeros(len(bond_encoder))
|
| 75 |
+
|
| 76 |
+
for k, v in atom_counter.items():
|
| 77 |
+
atom_distribution[k] = v
|
| 78 |
+
for k, v in bond_counter.items():
|
| 79 |
+
bond_distribution[k] = v
|
| 80 |
+
|
| 81 |
+
atom_distribution = atom_distribution / atom_distribution.sum()
|
| 82 |
+
bond_distribution = bond_distribution / bond_distribution.sum()
|
| 83 |
+
|
| 84 |
+
return atom_distribution, bond_distribution
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def flatten_distribution(data, name, table):
|
| 88 |
+
aux = ['sample', 'sdf_file', 'pdb_file']
|
| 89 |
+
method_distributions = defaultdict(list)
|
| 90 |
+
|
| 91 |
+
sdf2sample2size = defaultdict(dict)
|
| 92 |
+
for _, row in table.iterrows():
|
| 93 |
+
sdf2sample2size[row['sdf_file']][int(row['sample'])] = row['medchem.size']
|
| 94 |
+
|
| 95 |
+
for item in tqdm(data, desc=name):
|
| 96 |
+
if item['medchem.valid'] is not True:
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
+
if 'interactions.HBAcceptor' in item and 'interactions.HBDonor' in item:
|
| 100 |
+
item['interactions.HB'] = item['interactions.HBAcceptor'] + item['interactions.HBDonor']
|
| 101 |
+
|
| 102 |
+
new_entries = {}
|
| 103 |
+
for key, value in item.items():
|
| 104 |
+
if key.startswith('interactions'):
|
| 105 |
+
size = sdf2sample2size.get(item['sdf_file'], dict()).get(int(item['sample']))
|
| 106 |
+
if size is not None:
|
| 107 |
+
new_entries[key + '.normalized'] = value / size
|
| 108 |
+
item.update(new_entries)
|
| 109 |
+
|
| 110 |
+
for key, value in item.items():
|
| 111 |
+
if value is None:
|
| 112 |
+
continue
|
| 113 |
+
if key in aux:
|
| 114 |
+
continue
|
| 115 |
+
if key == 'energy.energy' and abs(value) > 1000:
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
if get_data_type(key, DATA_TYPES, default=type(value)) == list:
|
| 119 |
+
method_distributions[key] += value
|
| 120 |
+
else:
|
| 121 |
+
method_distributions[key].append(value)
|
| 122 |
+
|
| 123 |
+
return method_distributions
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def prepare_baseline_data(root_path, baseline_name):
|
| 127 |
+
metrics_detailed = pd.read_csv(f'{root_path}/metrics_detailed.csv')
|
| 128 |
+
metrics_detailed = metrics_detailed[metrics_detailed['medchem.valid']]
|
| 129 |
+
distributions = pickle.load(open(f'{root_path}/metrics_data.pkl', 'rb'))
|
| 130 |
+
distributions = flatten_distribution(distributions, name=baseline_name, table=metrics_detailed)
|
| 131 |
+
distributions['energy.energy'] = [v for v in distributions['energy.energy'] if -1000 <= v <= 1000]
|
| 132 |
+
for prop in MEDCHEM_PROPS + DOCKING_PROPS:
|
| 133 |
+
distributions[prop] = metrics_detailed[prop].dropna().values.tolist()
|
| 134 |
+
|
| 135 |
+
smiles = metrics_detailed['representation.smiles']
|
| 136 |
+
atom_distribution, bond_distribution = compute_discrete_distributions(smiles, name=baseline_name)
|
| 137 |
+
discrete_distributions = {
|
| 138 |
+
'atom_types': atom_distribution,
|
| 139 |
+
'bond_types': bond_distribution,
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
return distributions, discrete_distributions
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == '__main__':
|
| 146 |
+
p = argparse.ArgumentParser()
|
| 147 |
+
p.add_argument('--in_dir', type=Path, required=True, help='Directory with samples')
|
| 148 |
+
p.add_argument('--out_dir', type=str, required=True, help='Output directory')
|
| 149 |
+
p.add_argument('--n_samples', type=int, required=False, default=None, help='N samples per target')
|
| 150 |
+
p.add_argument('--reference_smiles', type=str, default=None, help='Path to the .npy file with reference SMILES (optional)')
|
| 151 |
+
p.add_argument('--crossdocked_dir', type=str, required=False, default=None, help='Crossdocked data dir for computing distances between distributions')
|
| 152 |
+
args = p.parse_args()
|
| 153 |
+
|
| 154 |
+
Path(args.out_dir).mkdir(parents=True, exist_ok=True)
|
| 155 |
+
|
| 156 |
+
print('Combining data')
|
| 157 |
+
data = []
|
| 158 |
+
for file_path in tqdm(Path(args.in_dir).glob('metrics_data_*.pkl')):
|
| 159 |
+
with open(file_path, 'rb') as f:
|
| 160 |
+
d = pickle.load(f)
|
| 161 |
+
if args.n_samples is not None:
|
| 162 |
+
d = d[:args.n_samples]
|
| 163 |
+
data += d
|
| 164 |
+
with open(Path(args.out_dir, 'metrics_data.pkl'), 'wb') as f:
|
| 165 |
+
pickle.dump(data, f)
|
| 166 |
+
|
| 167 |
+
print('Combining detailed metrics')
|
| 168 |
+
tables = []
|
| 169 |
+
for file_path in tqdm(Path(args.in_dir).glob('metrics_detailed_*.csv')):
|
| 170 |
+
table = pd.read_csv(file_path)
|
| 171 |
+
if args.n_samples is not None:
|
| 172 |
+
table = table.head(args.n_samples)
|
| 173 |
+
tables.append(table)
|
| 174 |
+
|
| 175 |
+
table_detailed = pd.concat(tables)
|
| 176 |
+
table_detailed.to_csv(Path(args.out_dir, 'metrics_detailed.csv'), index=False)
|
| 177 |
+
|
| 178 |
+
print('Computing aggregated metrics')
|
| 179 |
+
evaluator = FullEvaluator(gnina='gnina', reduce='reduce')
|
| 180 |
+
table_aggregated = aggregated_metrics(
|
| 181 |
+
table_detailed,
|
| 182 |
+
data_types=evaluator.dtypes,
|
| 183 |
+
validity_metric_name=VALIDITY_METRIC_NAME
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
if args.reference_smiles is not None:
|
| 187 |
+
reference_smiles = np.load(args.reference_smiles)
|
| 188 |
+
col_metrics = collection_metrics(
|
| 189 |
+
table=table_detailed,
|
| 190 |
+
reference_smiles=reference_smiles,
|
| 191 |
+
validity_metric_name=VALIDITY_METRIC_NAME,
|
| 192 |
+
exclude_evaluators=[],
|
| 193 |
+
)
|
| 194 |
+
table_aggregated = pd.concat([table_aggregated, col_metrics])
|
| 195 |
+
|
| 196 |
+
table_aggregated.to_csv(Path(args.out_dir, 'metrics_aggregated.csv'), index=False)
|
| 197 |
+
|
| 198 |
+
# Computing distributions
|
| 199 |
+
if args.crossdocked_dir is not None:
|
| 200 |
+
|
| 201 |
+
# Loading training data distributions
|
| 202 |
+
crossdocked_distributions = None
|
| 203 |
+
crossdocked_discrete_distributions = None
|
| 204 |
+
precomputed_distr_path = f'{args.crossdocked_dir}/crossdocked_distributions.pkl'
|
| 205 |
+
precomputed_discrete_distr_path = f'{args.crossdocked_dir}/crossdocked_discrete_distributions.pkl'
|
| 206 |
+
if os.path.exists(precomputed_distr_path) and os.path.exists(precomputed_discrete_distr_path):
|
| 207 |
+
# Use precomputed distributions in case they exist
|
| 208 |
+
with open(precomputed_distr_path, 'rb') as f:
|
| 209 |
+
crossdocked_distributions = pickle.load(f)
|
| 210 |
+
with open(precomputed_discrete_distr_path, 'rb') as f:
|
| 211 |
+
crossdocked_discrete_distributions = pickle.load(f)
|
| 212 |
+
else:
|
| 213 |
+
assert os.path.exists(f'{args.crossdocked_dir}/metrics_detailed.csv')
|
| 214 |
+
assert os.path.exists(f'{args.crossdocked_dir}/metrics_data.pkl')
|
| 215 |
+
crossdocked_distributions, crossdocked_discrete_distributions = prepare_baseline_data(
|
| 216 |
+
root_path=args.crossdocked_dir,
|
| 217 |
+
baseline_name='crossdocked'
|
| 218 |
+
)
|
| 219 |
+
# Save precomputed distributions for faster next runs
|
| 220 |
+
with open(precomputed_distr_path, 'wb') as f:
|
| 221 |
+
pickle.dump(crossdocked_distributions, f)
|
| 222 |
+
with open(precomputed_discrete_distr_path, 'wb') as f:
|
| 223 |
+
pickle.dump(crossdocked_discrete_distributions, f)
|
| 224 |
+
|
| 225 |
+
# Selecting top-5 most frequent atom types, bond types, angles and torsions
|
| 226 |
+
bonds = sorted([
|
| 227 |
+
(k, len(v)) for k, v in crossdocked_distributions.items()
|
| 228 |
+
if k.startswith('geometry.') and sum(s.isalpha() for s in k.split('.')[1]) == 2
|
| 229 |
+
], key=lambda t: t[1], reverse=True)[:5]
|
| 230 |
+
top_5_bonds = [t[0] for t in bonds]
|
| 231 |
+
|
| 232 |
+
angles = sorted([
|
| 233 |
+
(k, len(v)) for k, v in crossdocked_distributions.items()
|
| 234 |
+
if k.startswith('geometry.') and sum(s.isalpha() for s in k.split('.')[1]) == 3
|
| 235 |
+
], key=lambda t: t[1], reverse=True)[:5]
|
| 236 |
+
top_5_angles = [t[0] for t in angles]
|
| 237 |
+
|
| 238 |
+
# Loading distributions of samples
|
| 239 |
+
distributions, discrete_distributions = prepare_baseline_data(args.out_dir, 'samples')
|
| 240 |
+
|
| 241 |
+
# Computing distances between distributions
|
| 242 |
+
distances = {'method': 'method',}
|
| 243 |
+
relevant_columns = MEDCHEM_PROPS + DOCKING_PROPS + RELEVANT_INTERACTIONS + top_5_bonds + top_5_angles
|
| 244 |
+
for metric in distributions.keys():
|
| 245 |
+
if metric not in relevant_columns:
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
ref = crossdocked_distributions.get(metric)
|
| 249 |
+
# cur = distributions.get(metric)
|
| 250 |
+
cur = [x for x in distributions.get(metric) if not pd.isna(x)]
|
| 251 |
+
|
| 252 |
+
if ref is not None and cur is not None and len(cur) > 0:
|
| 253 |
+
try:
|
| 254 |
+
distance = wasserstein_distance(ref, cur)
|
| 255 |
+
except:
|
| 256 |
+
from pdb import set_trace; set_trace()
|
| 257 |
+
num_ref = len(ref)
|
| 258 |
+
num_cur = len(cur)
|
| 259 |
+
distances[f'WD.{metric}'] = distance
|
| 260 |
+
|
| 261 |
+
for metric in crossdocked_discrete_distributions.keys():
|
| 262 |
+
ref = crossdocked_discrete_distributions.get(metric)
|
| 263 |
+
cur = discrete_distributions.get(metric)
|
| 264 |
+
if ref is not None and cur is not None:
|
| 265 |
+
distance = jensenshannon(p=ref, q=cur)
|
| 266 |
+
num_ref = len(ref)
|
| 267 |
+
num_cur = len(cur)
|
| 268 |
+
distances[f'JS.{metric}'] = distance
|
| 269 |
+
|
| 270 |
+
dist_table = pd.DataFrame([distances])
|
| 271 |
+
dist_table.to_csv(Path(args.out_dir, 'metrics_distances.csv'), index=False)
|
src/analysis/SA_Score/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Files taken from: https://github.com/rdkit/rdkit/tree/master/Contrib/SA_Score
|
src/analysis/SA_Score/fpscores.pkl.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:10dcef9340c873e7b987924461b0af5365eb8dd96be607203debe8ddf80c1e73
|
| 3 |
+
size 3848394
|
src/analysis/SA_Score/sascorer.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# calculation of synthetic accessibility score as described in:
|
| 3 |
+
#
|
| 4 |
+
# Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
|
| 5 |
+
# Peter Ertl and Ansgar Schuffenhauer
|
| 6 |
+
# Journal of Cheminformatics 1:8 (2009)
|
| 7 |
+
# http://www.jcheminf.com/content/1/1/8
|
| 8 |
+
#
|
| 9 |
+
# several small modifications to the original paper are included
|
| 10 |
+
# particularly slightly different formula for marocyclic penalty
|
| 11 |
+
# and taking into account also molecule symmetry (fingerprint density)
|
| 12 |
+
#
|
| 13 |
+
# for a set of 10k diverse molecules the agreement between the original method
|
| 14 |
+
# as implemented in PipelinePilot and this implementation is r2 = 0.97
|
| 15 |
+
#
|
| 16 |
+
# peter ertl & greg landrum, september 2013
|
| 17 |
+
#
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
from rdkit import Chem
|
| 21 |
+
from rdkit.Chem import rdMolDescriptors
|
| 22 |
+
import pickle
|
| 23 |
+
|
| 24 |
+
import math
|
| 25 |
+
from collections import defaultdict
|
| 26 |
+
|
| 27 |
+
import os.path as op
|
| 28 |
+
|
| 29 |
+
_fscores = None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def readFragmentScores(name='fpscores'):
|
| 33 |
+
import gzip
|
| 34 |
+
global _fscores
|
| 35 |
+
# generate the full path filename:
|
| 36 |
+
if name == "fpscores":
|
| 37 |
+
name = op.join(op.dirname(__file__), name)
|
| 38 |
+
data = pickle.load(gzip.open('%s.pkl.gz' % name))
|
| 39 |
+
outDict = {}
|
| 40 |
+
for i in data:
|
| 41 |
+
for j in range(1, len(i)):
|
| 42 |
+
outDict[i[j]] = float(i[0])
|
| 43 |
+
_fscores = outDict
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def numBridgeheadsAndSpiro(mol, ri=None):
|
| 47 |
+
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
|
| 48 |
+
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
|
| 49 |
+
return nBridgehead, nSpiro
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def calculateScore(m):
|
| 53 |
+
if _fscores is None:
|
| 54 |
+
readFragmentScores()
|
| 55 |
+
|
| 56 |
+
# fragment score
|
| 57 |
+
fp = rdMolDescriptors.GetMorganFingerprint(m,
|
| 58 |
+
2) # <- 2 is the *radius* of the circular fingerprint
|
| 59 |
+
fps = fp.GetNonzeroElements()
|
| 60 |
+
score1 = 0.
|
| 61 |
+
nf = 0
|
| 62 |
+
for bitId, v in fps.items():
|
| 63 |
+
nf += v
|
| 64 |
+
sfp = bitId
|
| 65 |
+
score1 += _fscores.get(sfp, -4) * v
|
| 66 |
+
score1 /= nf
|
| 67 |
+
|
| 68 |
+
# features score
|
| 69 |
+
nAtoms = m.GetNumAtoms()
|
| 70 |
+
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
|
| 71 |
+
ri = m.GetRingInfo()
|
| 72 |
+
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
|
| 73 |
+
nMacrocycles = 0
|
| 74 |
+
for x in ri.AtomRings():
|
| 75 |
+
if len(x) > 8:
|
| 76 |
+
nMacrocycles += 1
|
| 77 |
+
|
| 78 |
+
sizePenalty = nAtoms**1.005 - nAtoms
|
| 79 |
+
stereoPenalty = math.log10(nChiralCenters + 1)
|
| 80 |
+
spiroPenalty = math.log10(nSpiro + 1)
|
| 81 |
+
bridgePenalty = math.log10(nBridgeheads + 1)
|
| 82 |
+
macrocyclePenalty = 0.
|
| 83 |
+
# ---------------------------------------
|
| 84 |
+
# This differs from the paper, which defines:
|
| 85 |
+
# macrocyclePenalty = math.log10(nMacrocycles+1)
|
| 86 |
+
# This form generates better results when 2 or more macrocycles are present
|
| 87 |
+
if nMacrocycles > 0:
|
| 88 |
+
macrocyclePenalty = math.log10(2)
|
| 89 |
+
|
| 90 |
+
score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
|
| 91 |
+
|
| 92 |
+
# correction for the fingerprint density
|
| 93 |
+
# not in the original publication, added in version 1.1
|
| 94 |
+
# to make highly symmetrical molecules easier to synthetise
|
| 95 |
+
score3 = 0.
|
| 96 |
+
if nAtoms > len(fps):
|
| 97 |
+
score3 = math.log(float(nAtoms) / len(fps)) * .5
|
| 98 |
+
|
| 99 |
+
sascore = score1 + score2 + score3
|
| 100 |
+
|
| 101 |
+
# need to transform "raw" value into scale between 1 and 10
|
| 102 |
+
min = -4.0
|
| 103 |
+
max = 2.5
|
| 104 |
+
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
|
| 105 |
+
# smooth the 10-end
|
| 106 |
+
if sascore > 8.:
|
| 107 |
+
sascore = 8. + math.log(sascore + 1. - 9.)
|
| 108 |
+
if sascore > 10.:
|
| 109 |
+
sascore = 10.0
|
| 110 |
+
elif sascore < 1.:
|
| 111 |
+
sascore = 1.0
|
| 112 |
+
|
| 113 |
+
return sascore
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def processMols(mols):
|
| 117 |
+
print('smiles\tName\tsa_score')
|
| 118 |
+
for i, m in enumerate(mols):
|
| 119 |
+
if m is None:
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
s = calculateScore(m)
|
| 123 |
+
|
| 124 |
+
smiles = Chem.MolToSmiles(m)
|
| 125 |
+
print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == '__main__':
|
| 129 |
+
import sys
|
| 130 |
+
import time
|
| 131 |
+
|
| 132 |
+
t1 = time.time()
|
| 133 |
+
readFragmentScores("fpscores")
|
| 134 |
+
t2 = time.time()
|
| 135 |
+
|
| 136 |
+
suppl = Chem.SmilesMolSupplier(sys.argv[1])
|
| 137 |
+
t3 = time.time()
|
| 138 |
+
processMols(suppl)
|
| 139 |
+
t4 = time.time()
|
| 140 |
+
|
| 141 |
+
print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
|
| 142 |
+
file=sys.stderr)
|
| 143 |
+
|
| 144 |
+
#
|
| 145 |
+
# Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
|
| 146 |
+
# All rights reserved.
|
| 147 |
+
#
|
| 148 |
+
# Redistribution and use in source and binary forms, with or without
|
| 149 |
+
# modification, are permitted provided that the following conditions are
|
| 150 |
+
# met:
|
| 151 |
+
#
|
| 152 |
+
# * Redistributions of source code must retain the above copyright
|
| 153 |
+
# notice, this list of conditions and the following disclaimer.
|
| 154 |
+
# * Redistributions in binary form must reproduce the above
|
| 155 |
+
# copyright notice, this list of conditions and the following
|
| 156 |
+
# disclaimer in the documentation and/or other materials provided
|
| 157 |
+
# with the distribution.
|
| 158 |
+
# * Neither the name of Novartis Institutes for BioMedical Research Inc.
|
| 159 |
+
# nor the names of its contributors may be used to endorse or promote
|
| 160 |
+
# products derived from this software without specific prior written permission.
|
| 161 |
+
#
|
| 162 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
| 163 |
+
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
| 164 |
+
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
| 165 |
+
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
| 166 |
+
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
| 167 |
+
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
| 168 |
+
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
| 169 |
+
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
| 170 |
+
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 171 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 172 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 173 |
+
#
|
src/analysis/metrics.py
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tempfile
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from rdkit import Chem, DataStructs
|
| 8 |
+
from rdkit.Chem import AllChem
|
| 9 |
+
from rdkit.Chem import Descriptors, Crippen, Lipinski, QED
|
| 10 |
+
from rdkit.Chem import AtomKekulizeException, AtomValenceException, \
|
| 11 |
+
KekulizeException, MolSanitizeException
|
| 12 |
+
from src.analysis.SA_Score.sascorer import calculateScore
|
| 13 |
+
from src.utils import write_sdf_file
|
| 14 |
+
|
| 15 |
+
from copy import deepcopy
|
| 16 |
+
|
| 17 |
+
from pdb import set_trace
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CategoricalDistribution:
|
| 21 |
+
EPS = 1e-10
|
| 22 |
+
|
| 23 |
+
def __init__(self, histogram_dict, mapping):
|
| 24 |
+
histogram = np.zeros(len(mapping))
|
| 25 |
+
for k, v in histogram_dict.items():
|
| 26 |
+
histogram[mapping[k]] = v
|
| 27 |
+
|
| 28 |
+
# Normalize histogram
|
| 29 |
+
self.p = histogram / histogram.sum()
|
| 30 |
+
self.mapping = deepcopy(mapping)
|
| 31 |
+
|
| 32 |
+
def kl_divergence(self, other_sample):
|
| 33 |
+
sample_histogram = np.zeros(len(self.mapping))
|
| 34 |
+
for x in other_sample:
|
| 35 |
+
# sample_histogram[self.mapping[x]] += 1
|
| 36 |
+
sample_histogram[x] += 1
|
| 37 |
+
|
| 38 |
+
# Normalize
|
| 39 |
+
q = sample_histogram / sample_histogram.sum()
|
| 40 |
+
|
| 41 |
+
return -np.sum(self.p * np.log(q / (self.p + self.EPS) + self.EPS))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def check_mol(rdmol):
|
| 45 |
+
"""
|
| 46 |
+
See also: https://www.rdkit.org/docs/RDKit_Book.html#molecular-sanitization
|
| 47 |
+
"""
|
| 48 |
+
if rdmol is None:
|
| 49 |
+
return 'is_none'
|
| 50 |
+
|
| 51 |
+
_rdmol = Chem.Mol(rdmol)
|
| 52 |
+
try:
|
| 53 |
+
Chem.SanitizeMol(_rdmol)
|
| 54 |
+
return 'valid'
|
| 55 |
+
except ValueError as e:
|
| 56 |
+
assert isinstance(e, MolSanitizeException)
|
| 57 |
+
return type(e).__name__
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def validity_analysis(rdmol_list):
|
| 61 |
+
"""
|
| 62 |
+
For explanations, see: https://www.rdkit.org/docs/RDKit_Book.html#molecular-sanitization
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
result = {
|
| 66 |
+
'AtomValenceException': 0, # atoms in higher-than-allowed valence states
|
| 67 |
+
'AtomKekulizeException': 0,
|
| 68 |
+
'KekulizeException': 0, # ring cannot be kekulized or aromatic bonds found outside of rings
|
| 69 |
+
'other': 0,
|
| 70 |
+
'valid': 0
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
for rdmol in rdmol_list:
|
| 74 |
+
flag = check_mol(rdmol)
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
result[flag] += 1
|
| 78 |
+
except KeyError:
|
| 79 |
+
result['other'] += 1
|
| 80 |
+
|
| 81 |
+
assert sum(result.values()) == len(rdmol_list)
|
| 82 |
+
|
| 83 |
+
return result
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class MoleculeValidity:
|
| 87 |
+
def __init__(self, connectivity_thresh=1.0):
|
| 88 |
+
self.connectivity_thresh = connectivity_thresh
|
| 89 |
+
|
| 90 |
+
def compute_validity(self, generated):
|
| 91 |
+
""" generated: list of RDKit molecules. """
|
| 92 |
+
if len(generated) < 1:
|
| 93 |
+
return [], 0.0
|
| 94 |
+
|
| 95 |
+
# Return copies of the valid molecules
|
| 96 |
+
valid = [Chem.Mol(mol) for mol in generated if check_mol(mol) == 'valid']
|
| 97 |
+
return valid, len(valid) / len(generated)
|
| 98 |
+
|
| 99 |
+
def compute_connectivity(self, valid):
|
| 100 |
+
"""
|
| 101 |
+
Consider molecule connected if its largest fragment contains at
|
| 102 |
+
least <self.connectivity_thresh * 100>% of all atoms.
|
| 103 |
+
:param valid: list of valid RDKit molecules
|
| 104 |
+
"""
|
| 105 |
+
if len(valid) < 1:
|
| 106 |
+
return [], 0.0
|
| 107 |
+
|
| 108 |
+
for mol in valid:
|
| 109 |
+
Chem.SanitizeMol(mol) # all molecules should be valid
|
| 110 |
+
|
| 111 |
+
connected = []
|
| 112 |
+
for mol in valid:
|
| 113 |
+
|
| 114 |
+
if mol.GetNumAtoms() < 1:
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True)
|
| 119 |
+
except MolSanitizeException as e:
|
| 120 |
+
print('Error while computing connectivity:', e)
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
largest_frag = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
|
| 124 |
+
if largest_frag.GetNumAtoms() / mol.GetNumAtoms() >= self.connectivity_thresh:
|
| 125 |
+
connected.append(largest_frag)
|
| 126 |
+
|
| 127 |
+
return connected, len(connected) / len(valid)
|
| 128 |
+
|
| 129 |
+
def __call__(self, rdmols, verbose=False):
|
| 130 |
+
"""
|
| 131 |
+
:param rdmols: list of RDKit molecules
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
results = {}
|
| 135 |
+
results['n_total'] = len(rdmols)
|
| 136 |
+
|
| 137 |
+
valid, validity = self.compute_validity(rdmols)
|
| 138 |
+
results['n_valid'] = len(valid)
|
| 139 |
+
results['validity'] = validity
|
| 140 |
+
|
| 141 |
+
connected, connectivity = self.compute_connectivity(valid)
|
| 142 |
+
results['n_connected'] = len(connected)
|
| 143 |
+
results['connectivity'] = connectivity
|
| 144 |
+
results['valid_and_connected'] = results['n_connected'] / results['n_total']
|
| 145 |
+
|
| 146 |
+
if verbose:
|
| 147 |
+
print(f"Validity over {results['n_total']} molecules: {validity * 100 :.2f}%")
|
| 148 |
+
print(f"Connectivity over {results['n_valid']} valid molecules: {connectivity * 100 :.2f}%")
|
| 149 |
+
|
| 150 |
+
return results
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class MolecularMetrics:
|
| 154 |
+
def __init__(self, connectivity_thresh=1.0):
|
| 155 |
+
self.connectivity_thresh = connectivity_thresh
|
| 156 |
+
|
| 157 |
+
@staticmethod
|
| 158 |
+
def is_valid(rdmol):
|
| 159 |
+
if rdmol.GetNumAtoms() < 1:
|
| 160 |
+
return False
|
| 161 |
+
|
| 162 |
+
_mol = Chem.Mol(rdmol)
|
| 163 |
+
try:
|
| 164 |
+
Chem.SanitizeMol(_mol)
|
| 165 |
+
except ValueError:
|
| 166 |
+
return False
|
| 167 |
+
|
| 168 |
+
return True
|
| 169 |
+
|
| 170 |
+
def is_connected(self, rdmol):
|
| 171 |
+
|
| 172 |
+
if rdmol.GetNumAtoms() < 1:
|
| 173 |
+
return False
|
| 174 |
+
|
| 175 |
+
mol_frags = Chem.rdmolops.GetMolFrags(rdmol, asMols=True)
|
| 176 |
+
|
| 177 |
+
largest_frag = max(mol_frags, default=rdmol, key=lambda m: m.GetNumAtoms())
|
| 178 |
+
if largest_frag.GetNumAtoms() / rdmol.GetNumAtoms() >= self.connectivity_thresh:
|
| 179 |
+
return True
|
| 180 |
+
else:
|
| 181 |
+
return False
|
| 182 |
+
|
| 183 |
+
@staticmethod
|
| 184 |
+
def calculate_qed(rdmol):
|
| 185 |
+
return QED.qed(rdmol)
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
def calculate_sa(rdmol):
|
| 189 |
+
sa = calculateScore(rdmol)
|
| 190 |
+
return sa
|
| 191 |
+
|
| 192 |
+
@staticmethod
|
| 193 |
+
def calculate_logp(rdmol):
|
| 194 |
+
return Crippen.MolLogP(rdmol)
|
| 195 |
+
|
| 196 |
+
@staticmethod
|
| 197 |
+
def calculate_lipinski(rdmol):
|
| 198 |
+
rule_1 = Descriptors.ExactMolWt(rdmol) < 500
|
| 199 |
+
rule_2 = Lipinski.NumHDonors(rdmol) <= 5
|
| 200 |
+
rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10
|
| 201 |
+
rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5)
|
| 202 |
+
rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10
|
| 203 |
+
return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])
|
| 204 |
+
|
| 205 |
+
def __call__(self, rdmol):
|
| 206 |
+
valid = self.is_valid(rdmol)
|
| 207 |
+
|
| 208 |
+
if valid:
|
| 209 |
+
Chem.SanitizeMol(rdmol)
|
| 210 |
+
|
| 211 |
+
connected = None if not valid else self.is_connected(rdmol)
|
| 212 |
+
qed = None if not valid else self.calculate_qed(rdmol)
|
| 213 |
+
sa = None if not valid else self.calculate_sa(rdmol)
|
| 214 |
+
logp = None if not valid else self.calculate_logp(rdmol)
|
| 215 |
+
lipinski = None if not valid else self.calculate_lipinski(rdmol)
|
| 216 |
+
|
| 217 |
+
return {
|
| 218 |
+
'valid': valid,
|
| 219 |
+
'connected': connected,
|
| 220 |
+
'qed': qed,
|
| 221 |
+
'sa': sa,
|
| 222 |
+
'logp': logp,
|
| 223 |
+
'lipinski': lipinski
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class Diversity:
|
| 228 |
+
@staticmethod
|
| 229 |
+
def similarity(fp1, fp2):
|
| 230 |
+
return DataStructs.TanimotoSimilarity(fp1, fp2)
|
| 231 |
+
|
| 232 |
+
def get_fingerprint(self, mol):
|
| 233 |
+
# fp = AllChem.GetMorganFingerprintAsBitVect(
|
| 234 |
+
# mol, 2, nBits=2048, useChirality=False)
|
| 235 |
+
fp = Chem.RDKFingerprint(mol)
|
| 236 |
+
return fp
|
| 237 |
+
|
| 238 |
+
def __call__(self, pocket_mols):
|
| 239 |
+
|
| 240 |
+
if len(pocket_mols) < 2:
|
| 241 |
+
return 0.0
|
| 242 |
+
|
| 243 |
+
pocket_fps = [self.get_fingerprint(m) for m in pocket_mols]
|
| 244 |
+
|
| 245 |
+
div = 0
|
| 246 |
+
total = 0
|
| 247 |
+
for i in range(len(pocket_fps)):
|
| 248 |
+
for j in range(i + 1, len(pocket_fps)):
|
| 249 |
+
div += 1 - self.similarity(pocket_fps[i], pocket_fps[j])
|
| 250 |
+
total += 1
|
| 251 |
+
|
| 252 |
+
return div / total
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class MoleculeUniqueness:
|
| 256 |
+
def __call__(self, smiles_list):
|
| 257 |
+
""" smiles_list: list of SMILES strings. """
|
| 258 |
+
if len(smiles_list) < 1:
|
| 259 |
+
return 0.0
|
| 260 |
+
|
| 261 |
+
return len(set(smiles_list)) / len(smiles_list)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class MoleculeNovelty:
|
| 265 |
+
def __init__(self, reference_smiles):
|
| 266 |
+
"""
|
| 267 |
+
:param reference_smiles: list of SMILES strings
|
| 268 |
+
"""
|
| 269 |
+
self.reference_smiles = set(reference_smiles)
|
| 270 |
+
|
| 271 |
+
def __call__(self, smiles_list):
|
| 272 |
+
if len(smiles_list) < 1:
|
| 273 |
+
return 0.0
|
| 274 |
+
|
| 275 |
+
novel = [smi for smi in smiles_list if smi not in self.reference_smiles]
|
| 276 |
+
return len(novel) / len(smiles_list)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class MolecularProperties:
|
| 280 |
+
|
| 281 |
+
@staticmethod
|
| 282 |
+
def calculate_qed(rdmol):
|
| 283 |
+
return QED.qed(rdmol)
|
| 284 |
+
|
| 285 |
+
@staticmethod
|
| 286 |
+
def calculate_sa(rdmol):
|
| 287 |
+
sa = calculateScore(rdmol)
|
| 288 |
+
# return round((10 - sa) / 9, 2) # from pocket2mol
|
| 289 |
+
return sa
|
| 290 |
+
|
| 291 |
+
@staticmethod
|
| 292 |
+
def calculate_logp(rdmol):
|
| 293 |
+
return Crippen.MolLogP(rdmol)
|
| 294 |
+
|
| 295 |
+
@staticmethod
|
| 296 |
+
def calculate_lipinski(rdmol):
|
| 297 |
+
rule_1 = Descriptors.ExactMolWt(rdmol) < 500
|
| 298 |
+
rule_2 = Lipinski.NumHDonors(rdmol) <= 5
|
| 299 |
+
rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10
|
| 300 |
+
rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5)
|
| 301 |
+
rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10
|
| 302 |
+
return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])
|
| 303 |
+
|
| 304 |
+
@classmethod
|
| 305 |
+
def calculate_diversity(cls, pocket_mols):
|
| 306 |
+
if len(pocket_mols) < 2:
|
| 307 |
+
return 0.0
|
| 308 |
+
|
| 309 |
+
div = 0
|
| 310 |
+
total = 0
|
| 311 |
+
for i in range(len(pocket_mols)):
|
| 312 |
+
for j in range(i + 1, len(pocket_mols)):
|
| 313 |
+
div += 1 - cls.similarity(pocket_mols[i], pocket_mols[j])
|
| 314 |
+
total += 1
|
| 315 |
+
return div / total
|
| 316 |
+
|
| 317 |
+
@staticmethod
|
| 318 |
+
def similarity(mol_a, mol_b):
|
| 319 |
+
# fp1 = AllChem.GetMorganFingerprintAsBitVect(
|
| 320 |
+
# mol_a, 2, nBits=2048, useChirality=False)
|
| 321 |
+
# fp2 = AllChem.GetMorganFingerprintAsBitVect(
|
| 322 |
+
# mol_b, 2, nBits=2048, useChirality=False)
|
| 323 |
+
fp1 = Chem.RDKFingerprint(mol_a)
|
| 324 |
+
fp2 = Chem.RDKFingerprint(mol_b)
|
| 325 |
+
return DataStructs.TanimotoSimilarity(fp1, fp2)
|
| 326 |
+
|
| 327 |
+
def evaluate_pockets(self, pocket_rdmols, verbose=False):
|
| 328 |
+
"""
|
| 329 |
+
Run full evaluation
|
| 330 |
+
Args:
|
| 331 |
+
pocket_rdmols: list of lists, the inner list contains all RDKit
|
| 332 |
+
molecules generated for a pocket
|
| 333 |
+
Returns:
|
| 334 |
+
QED, SA, LogP, Lipinski (per molecule), and Diversity (per pocket)
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
for pocket in pocket_rdmols:
|
| 338 |
+
for mol in pocket:
|
| 339 |
+
Chem.SanitizeMol(mol) # only evaluate valid molecules
|
| 340 |
+
|
| 341 |
+
all_qed = []
|
| 342 |
+
all_sa = []
|
| 343 |
+
all_logp = []
|
| 344 |
+
all_lipinski = []
|
| 345 |
+
per_pocket_diversity = []
|
| 346 |
+
for pocket in tqdm(pocket_rdmols):
|
| 347 |
+
all_qed.append([self.calculate_qed(mol) for mol in pocket])
|
| 348 |
+
all_sa.append([self.calculate_sa(mol) for mol in pocket])
|
| 349 |
+
all_logp.append([self.calculate_logp(mol) for mol in pocket])
|
| 350 |
+
all_lipinski.append([self.calculate_lipinski(mol) for mol in pocket])
|
| 351 |
+
per_pocket_diversity.append(self.calculate_diversity(pocket))
|
| 352 |
+
|
| 353 |
+
qed_flattened = [x for px in all_qed for x in px]
|
| 354 |
+
sa_flattened = [x for px in all_sa for x in px]
|
| 355 |
+
logp_flattened = [x for px in all_logp for x in px]
|
| 356 |
+
lipinski_flattened = [x for px in all_lipinski for x in px]
|
| 357 |
+
|
| 358 |
+
if verbose:
|
| 359 |
+
print(f"{sum([len(p) for p in pocket_rdmols])} molecules from "
|
| 360 |
+
f"{len(pocket_rdmols)} pockets evaluated.")
|
| 361 |
+
print(f"QED: {np.mean(qed_flattened):.3f} \pm {np.std(qed_flattened):.2f}")
|
| 362 |
+
print(f"SA: {np.mean(sa_flattened):.3f} \pm {np.std(sa_flattened):.2f}")
|
| 363 |
+
print(f"LogP: {np.mean(logp_flattened):.3f} \pm {np.std(logp_flattened):.2f}")
|
| 364 |
+
print(f"Lipinski: {np.mean(lipinski_flattened):.3f} \pm {np.std(lipinski_flattened):.2f}")
|
| 365 |
+
print(f"Diversity: {np.mean(per_pocket_diversity):.3f} \pm {np.std(per_pocket_diversity):.2f}")
|
| 366 |
+
|
| 367 |
+
return all_qed, all_sa, all_logp, all_lipinski, per_pocket_diversity
|
| 368 |
+
|
| 369 |
+
def __call__(self, rdmols):
|
| 370 |
+
"""
|
| 371 |
+
Run full evaluation and return mean of each property
|
| 372 |
+
Args:
|
| 373 |
+
rdmols: list of RDKit molecules
|
| 374 |
+
Returns:
|
| 375 |
+
Dictionary with mean QED, SA, LogP, Lipinski, and Diversity values
|
| 376 |
+
"""
|
| 377 |
+
|
| 378 |
+
if len(rdmols) < 1:
|
| 379 |
+
return {'QED': 0.0, 'SA': 0.0, 'LogP': 0.0, 'Lipinski': 0.0,
|
| 380 |
+
'Diversity': 0.0}
|
| 381 |
+
|
| 382 |
+
_rdmols = []
|
| 383 |
+
for mol in rdmols:
|
| 384 |
+
try:
|
| 385 |
+
Chem.SanitizeMol(mol) # only evaluate valid molecules
|
| 386 |
+
_rdmols.append(mol)
|
| 387 |
+
except ValueError as e:
|
| 388 |
+
print("Tried to analyze invalid molecule")
|
| 389 |
+
rdmols = _rdmols
|
| 390 |
+
|
| 391 |
+
qed = np.mean([self.calculate_qed(mol) for mol in rdmols])
|
| 392 |
+
sa = np.mean([self.calculate_sa(mol) for mol in rdmols])
|
| 393 |
+
logp = np.mean([self.calculate_logp(mol) for mol in rdmols])
|
| 394 |
+
lipinski = np.mean([self.calculate_lipinski(mol) for mol in rdmols])
|
| 395 |
+
diversity = self.calculate_diversity(rdmols)
|
| 396 |
+
|
| 397 |
+
return {'QED': qed, 'SA': sa, 'LogP': logp, 'Lipinski': lipinski,
|
| 398 |
+
'Diversity': diversity}
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def compute_gnina_scores(ligands, receptors, gnina):
|
| 402 |
+
metrics = ['minimizedAffinity', 'minimizedRMSD', 'CNNscore', 'CNNaffinity', 'CNN_VS', 'CNNaffinity_variance']
|
| 403 |
+
out = {m: [] for m in metrics}
|
| 404 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 405 |
+
for ligand, receptor in zip(tqdm(ligands, desc='Docking'), receptors):
|
| 406 |
+
in_ligand_path = Path(tmpdir, 'in_ligand.sdf')
|
| 407 |
+
out_ligand_path = Path(tmpdir, 'out_ligand.sdf')
|
| 408 |
+
receptor_path = Path(tmpdir, 'receptor.pdb')
|
| 409 |
+
write_sdf_file(in_ligand_path, [ligand], catch_errors=True)
|
| 410 |
+
Chem.MolToPDBFile(receptor, str(receptor_path))
|
| 411 |
+
if (
|
| 412 |
+
(not in_ligand_path.exists()) or
|
| 413 |
+
(not receptor_path.exists()) or
|
| 414 |
+
in_ligand_path.read_text() == '' or
|
| 415 |
+
receptor_path.read_text() == ''
|
| 416 |
+
):
|
| 417 |
+
continue
|
| 418 |
+
|
| 419 |
+
cmd = (
|
| 420 |
+
f'{gnina} -r {receptor_path} -l {in_ligand_path} '
|
| 421 |
+
f'--minimize --seed 42 -o {out_ligand_path} --no_gpu 1> /dev/null'
|
| 422 |
+
)
|
| 423 |
+
subprocess.run(cmd, shell=True)
|
| 424 |
+
if not out_ligand_path.exists() or out_ligand_path.read_text() == '':
|
| 425 |
+
continue
|
| 426 |
+
|
| 427 |
+
mol = Chem.SDMolSupplier(str(out_ligand_path), sanitize=False)[0]
|
| 428 |
+
for metric in metrics:
|
| 429 |
+
out[metric].append(float(mol.GetProp(metric)))
|
| 430 |
+
|
| 431 |
+
for metric in metrics:
|
| 432 |
+
out[metric] = sum(out[metric]) / len(out[metric]) if len(out[metric]) > 0 else 0
|
| 433 |
+
|
| 434 |
+
return out
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def legacy_clash_score(rdmol1, rdmol2=None, margin=0.75):
|
| 438 |
+
"""
|
| 439 |
+
Computes a clash score as the number of atoms that have at least one
|
| 440 |
+
clash divided by the number of atoms in the molecule.
|
| 441 |
+
|
| 442 |
+
INTERMOLECULAR CLASH SCORE
|
| 443 |
+
If rdmol2 is provided, the score is the percentage of atoms in rdmol1
|
| 444 |
+
that have at least one clash with rdmol2.
|
| 445 |
+
We define a clash if two atoms are closer than "margin times the sum of
|
| 446 |
+
their van der Waals radii".
|
| 447 |
+
|
| 448 |
+
INTRAMOLECULAR CLASH SCORE
|
| 449 |
+
If rdmol2 is not provided, the score is the percentage of atoms in rdmol1
|
| 450 |
+
that have at least one clash with other atoms in rdmol1.
|
| 451 |
+
In this case, a clash is defined by margin times the atoms' smallest
|
| 452 |
+
covalent radii (among single, double and triple bond radii). This is done
|
| 453 |
+
so that this function is applicable even if no connectivity information is
|
| 454 |
+
available.
|
| 455 |
+
"""
|
| 456 |
+
# source: https://en.wikipedia.org/wiki/Van_der_Waals_radius
|
| 457 |
+
vdw_radii = {'N': 1.55, 'O': 1.52, 'C': 1.70, 'H': 1.10, 'S': 1.80, 'P': 1.80,
|
| 458 |
+
'Se': 1.90, 'K': 2.75, 'Na': 2.27, 'Mg': 1.73, 'Zn': 1.39, 'B': 1.92,
|
| 459 |
+
'Br': 1.85, 'Cl': 1.75, 'I': 1.98, 'F': 1.47}
|
| 460 |
+
|
| 461 |
+
# https://en.wikipedia.org/wiki/Covalent_radius#Radii_for_multiple_bonds
|
| 462 |
+
covalent_radii = {'H': 0.32, 'C': 0.60, 'N': 0.54, 'O': 0.53, 'F': 0.53, 'B': 0.73,
|
| 463 |
+
'Al': 1.11, 'Si': 1.02, 'P': 0.94, 'S': 0.94, 'Cl': 0.93, 'As': 1.06,
|
| 464 |
+
'Br': 1.09, 'I': 1.25, 'Hg': 1.33, 'Bi': 1.35}
|
| 465 |
+
|
| 466 |
+
coord1 = rdmol1.GetConformer().GetPositions()
|
| 467 |
+
|
| 468 |
+
if rdmol2 is None:
|
| 469 |
+
radii1 = np.array([covalent_radii[a.GetSymbol()] for a in rdmol1.GetAtoms()])
|
| 470 |
+
assert coord1.shape[0] == radii1.shape[0]
|
| 471 |
+
|
| 472 |
+
dist = np.sqrt(np.sum((coord1[:, None, :] - coord1[None, :, :]) ** 2, axis=-1))
|
| 473 |
+
np.fill_diagonal(dist, np.inf)
|
| 474 |
+
clashes = dist < margin * (radii1[:, None] + radii1[None, :])
|
| 475 |
+
|
| 476 |
+
else:
|
| 477 |
+
coord2 = rdmol2.GetConformer().GetPositions()
|
| 478 |
+
|
| 479 |
+
radii1 = np.array([vdw_radii[a.GetSymbol()] for a in rdmol1.GetAtoms()])
|
| 480 |
+
assert coord1.shape[0] == radii1.shape[0]
|
| 481 |
+
radii2 = np.array([vdw_radii[a.GetSymbol()] for a in rdmol2.GetAtoms()])
|
| 482 |
+
assert coord2.shape[0] == radii2.shape[0]
|
| 483 |
+
|
| 484 |
+
dist = np.sqrt(np.sum((coord1[:, None, :] - coord2[None, :, :]) ** 2, axis=-1))
|
| 485 |
+
clashes = dist < margin * (radii1[:, None] + radii2[None, :])
|
| 486 |
+
|
| 487 |
+
clashes = np.any(clashes, axis=1)
|
| 488 |
+
return np.mean(clashes)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def clash_score(rdmol1, rdmol2=None, margin=0.75, ignore={'H'}):
|
| 492 |
+
"""
|
| 493 |
+
Computes a clash score as the number of atoms that have at least one
|
| 494 |
+
clash divided by the number of atoms in the molecule.
|
| 495 |
+
|
| 496 |
+
INTERMOLECULAR CLASH SCORE
|
| 497 |
+
If rdmol2 is provided, the score is the percentage of atoms in rdmol1
|
| 498 |
+
that have at least one clash with rdmol2.
|
| 499 |
+
We define a clash if two atoms are closer than "margin times the sum of
|
| 500 |
+
their van der Waals radii".
|
| 501 |
+
|
| 502 |
+
INTRAMOLECULAR CLASH SCORE
|
| 503 |
+
If rdmol2 is not provided, the score is the percentage of atoms in rdmol1
|
| 504 |
+
that have at least one clash with other atoms in rdmol1.
|
| 505 |
+
In this case, a clash is defined by margin times the atoms' smallest
|
| 506 |
+
covalent radii (among single, double and triple bond radii). This is done
|
| 507 |
+
so that this function is applicable even if no connectivity information is
|
| 508 |
+
available.
|
| 509 |
+
"""
|
| 510 |
+
|
| 511 |
+
intramolecular = rdmol2 is None
|
| 512 |
+
|
| 513 |
+
_periodic_table = AllChem.GetPeriodicTable()
|
| 514 |
+
|
| 515 |
+
def _coord_and_radii(rdmol):
|
| 516 |
+
coord = rdmol.GetConformer().GetPositions()
|
| 517 |
+
radii = np.array([_get_radius(a.GetSymbol()) for a in rdmol.GetAtoms()])
|
| 518 |
+
|
| 519 |
+
mask = np.array([a.GetSymbol() not in ignore for a in rdmol.GetAtoms()])
|
| 520 |
+
coord = coord[mask]
|
| 521 |
+
radii = radii[mask]
|
| 522 |
+
|
| 523 |
+
assert coord.shape[0] == radii.shape[0]
|
| 524 |
+
return coord, radii
|
| 525 |
+
|
| 526 |
+
# INTRAMOLECULAR CLASH SCORE
|
| 527 |
+
if intramolecular:
|
| 528 |
+
rdmol2 = rdmol1
|
| 529 |
+
_get_radius = _periodic_table.GetRcovalent # covalent radii
|
| 530 |
+
|
| 531 |
+
# INTERMOLECULAR CLASH SCORE
|
| 532 |
+
else:
|
| 533 |
+
_get_radius = _periodic_table.GetRvdw # vdW radii
|
| 534 |
+
|
| 535 |
+
coord1, radii1 = _coord_and_radii(rdmol1)
|
| 536 |
+
coord2, radii2 = _coord_and_radii(rdmol2)
|
| 537 |
+
|
| 538 |
+
dist = np.sqrt(np.sum((coord1[:, None, :] - coord2[None, :, :]) ** 2, axis=-1))
|
| 539 |
+
if intramolecular:
|
| 540 |
+
np.fill_diagonal(dist, np.inf)
|
| 541 |
+
|
| 542 |
+
clashes = dist < margin * (radii1[:, None] + radii2[None, :])
|
| 543 |
+
clashes = np.any(clashes, axis=1)
|
| 544 |
+
return np.mean(clashes)
|
src/analysis/visualization_utils.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from rdkit import Chem
|
| 5 |
+
from rdkit.Chem import Draw, AllChem
|
| 6 |
+
from rdkit.Chem import SanitizeFlags
|
| 7 |
+
from src.analysis.metrics import check_mol
|
| 8 |
+
from src import utils
|
| 9 |
+
from src.data.molecule_builder import build_molecule
|
| 10 |
+
from src.data.misc import protein_letters_1to3
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# def pocket_to_rdkit(pocket, pocket_representation, atom_encoder=None,
|
| 14 |
+
# atom_decoder=None, aa_decoder=None, residue_decoder=None,
|
| 15 |
+
# aa_atom_index=None):
|
| 16 |
+
#
|
| 17 |
+
# rdpockets = []
|
| 18 |
+
# for i in torch.unique(pocket['mask']):
|
| 19 |
+
#
|
| 20 |
+
# node_coord = pocket['x'][pocket['mask'] == i]
|
| 21 |
+
# h = pocket['one_hot'][pocket['mask'] == i]
|
| 22 |
+
#
|
| 23 |
+
# if pocket_representation == 'side_chain_bead':
|
| 24 |
+
# coord = node_coord
|
| 25 |
+
#
|
| 26 |
+
# node_types = [residue_decoder[b] for b in h[:, -len(residue_decoder):].argmax(-1)]
|
| 27 |
+
# atom_types = ['C' if r == 'CA' else 'F' for r in node_types]
|
| 28 |
+
#
|
| 29 |
+
# elif pocket_representation == 'CA+':
|
| 30 |
+
# aa_types = [aa_decoder[b] for b in h.argmax(-1)]
|
| 31 |
+
# side_chain_vec = pocket['v'][pocket['mask'] == i]
|
| 32 |
+
#
|
| 33 |
+
# coord = []
|
| 34 |
+
# atom_types = []
|
| 35 |
+
# for xyz, aa, vec in zip(node_coord, aa_types, side_chain_vec):
|
| 36 |
+
# # C_alpha
|
| 37 |
+
# coord.append(xyz)
|
| 38 |
+
# atom_types.append('C')
|
| 39 |
+
#
|
| 40 |
+
# # all other atoms
|
| 41 |
+
# for atom_name, idx in aa_atom_index[aa].items():
|
| 42 |
+
# coord.append(xyz + vec[idx])
|
| 43 |
+
# atom_types.append(atom_name[0])
|
| 44 |
+
#
|
| 45 |
+
# coord = torch.stack(coord, dim=0)
|
| 46 |
+
#
|
| 47 |
+
# else:
|
| 48 |
+
# raise NotImplementedError(f"{pocket_representation} residue representation not supported")
|
| 49 |
+
#
|
| 50 |
+
# atom_types = torch.tensor([atom_encoder[a] for a in atom_types])
|
| 51 |
+
# rdpockets.append(build_molecule(coord, atom_types, atom_decoder=atom_decoder))
|
| 52 |
+
#
|
| 53 |
+
# return rdpockets
|
| 54 |
+
def pocket_to_rdkit(pocket, pocket_representation, atom_encoder=None,
|
| 55 |
+
atom_decoder=None, aa_decoder=None, residue_decoder=None,
|
| 56 |
+
aa_atom_index=None):
|
| 57 |
+
|
| 58 |
+
rdpockets = []
|
| 59 |
+
for i in torch.unique(pocket['mask']):
|
| 60 |
+
|
| 61 |
+
node_coord = pocket['x'][pocket['mask'] == i]
|
| 62 |
+
h = pocket['one_hot'][pocket['mask'] == i]
|
| 63 |
+
atom_mask = pocket['atom_mask'][pocket['mask'] == i]
|
| 64 |
+
|
| 65 |
+
pdb_infos = []
|
| 66 |
+
|
| 67 |
+
if pocket_representation == 'side_chain_bead':
|
| 68 |
+
coord = node_coord
|
| 69 |
+
|
| 70 |
+
node_types = [residue_decoder[b] for b in h[:, -len(residue_decoder):].argmax(-1)]
|
| 71 |
+
atom_types = ['C' if r == 'CA' else 'F' for r in node_types]
|
| 72 |
+
|
| 73 |
+
elif pocket_representation == 'CA+':
|
| 74 |
+
aa_types = [aa_decoder[b] for b in h.argmax(-1)]
|
| 75 |
+
side_chain_vec = pocket['v'][pocket['mask'] == i]
|
| 76 |
+
|
| 77 |
+
coord = []
|
| 78 |
+
atom_types = []
|
| 79 |
+
for resi, (xyz, aa, vec, am) in enumerate(zip(node_coord, aa_types, side_chain_vec, atom_mask)):
|
| 80 |
+
|
| 81 |
+
# CA not treated differently with updated atom dictionary
|
| 82 |
+
for atom_name, idx in aa_atom_index[aa].items():
|
| 83 |
+
|
| 84 |
+
if ~am[idx]:
|
| 85 |
+
warnings.warn(f"Missing atom {atom_name} in {aa}:{resi}")
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
coord.append(xyz + vec[idx])
|
| 89 |
+
atom_types.append(atom_name[0])
|
| 90 |
+
|
| 91 |
+
info = Chem.AtomPDBResidueInfo()
|
| 92 |
+
# info.SetChainId('A')
|
| 93 |
+
info.SetResidueName(protein_letters_1to3[aa])
|
| 94 |
+
info.SetResidueNumber(resi + 1)
|
| 95 |
+
info.SetOccupancy(1.0)
|
| 96 |
+
info.SetTempFactor(0.0)
|
| 97 |
+
info.SetName(f' {atom_name:<3}')
|
| 98 |
+
pdb_infos.append(info)
|
| 99 |
+
|
| 100 |
+
coord = torch.stack(coord, dim=0)
|
| 101 |
+
|
| 102 |
+
else:
|
| 103 |
+
raise NotImplementedError(f"{pocket_representation} residue representation not supported")
|
| 104 |
+
|
| 105 |
+
atom_types = torch.tensor([atom_encoder[a] for a in atom_types])
|
| 106 |
+
rdmol = build_molecule(coord, atom_types, atom_decoder=atom_decoder)
|
| 107 |
+
|
| 108 |
+
if len(pdb_infos) == len(rdmol.GetAtoms()):
|
| 109 |
+
for a, info in zip(rdmol.GetAtoms(), pdb_infos):
|
| 110 |
+
a.SetPDBResidueInfo(info)
|
| 111 |
+
|
| 112 |
+
rdpockets.append(rdmol)
|
| 113 |
+
|
| 114 |
+
return rdpockets
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def mols_to_pdbfile(rdmols, filename, flavor=0):
|
| 118 |
+
pdb_str = ""
|
| 119 |
+
for i, mol in enumerate(rdmols):
|
| 120 |
+
pdb_str += f"MODEL{i + 1:>9}\n"
|
| 121 |
+
block = Chem.MolToPDBBlock(mol, flavor=flavor)
|
| 122 |
+
block = "\n".join(block.split("\n")[:-2]) # remove END
|
| 123 |
+
pdb_str += block + "\n"
|
| 124 |
+
pdb_str += f"ENDMDL\n"
|
| 125 |
+
pdb_str += f"END\n"
|
| 126 |
+
|
| 127 |
+
with open(filename, 'w') as f:
|
| 128 |
+
f.write(pdb_str)
|
| 129 |
+
|
| 130 |
+
return pdb_str
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def mol_as_pdb(rdmol, filename=None, bfactor=None):
|
| 134 |
+
|
| 135 |
+
_rdmol = Chem.Mol(rdmol) # copy
|
| 136 |
+
for a in _rdmol.GetAtoms():
|
| 137 |
+
a.SetIsAromatic(False)
|
| 138 |
+
for b in _rdmol.GetBonds():
|
| 139 |
+
b.SetIsAromatic(False)
|
| 140 |
+
|
| 141 |
+
if bfactor is not None:
|
| 142 |
+
for a in _rdmol.GetAtoms():
|
| 143 |
+
val = a.GetPropsAsDict()[bfactor]
|
| 144 |
+
|
| 145 |
+
info = Chem.AtomPDBResidueInfo()
|
| 146 |
+
info.SetResidueName('UNL')
|
| 147 |
+
info.SetResidueNumber(1)
|
| 148 |
+
info.SetName(f' {a.GetSymbol():<3}')
|
| 149 |
+
info.SetIsHeteroAtom(True)
|
| 150 |
+
info.SetOccupancy(1.0)
|
| 151 |
+
info.SetTempFactor(val)
|
| 152 |
+
a.SetPDBResidueInfo(info)
|
| 153 |
+
|
| 154 |
+
pdb_str = Chem.MolToPDBBlock(_rdmol)
|
| 155 |
+
|
| 156 |
+
if filename is not None:
|
| 157 |
+
with open(filename, 'w') as f:
|
| 158 |
+
f.write(pdb_str)
|
| 159 |
+
|
| 160 |
+
return pdb_str
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def draw_grid(molecules, mols_per_row=5, fig_size=(200, 200),
|
| 164 |
+
label=check_mol,
|
| 165 |
+
highlight_atom=lambda atom: False,
|
| 166 |
+
highlight_bond=lambda bond: False):
|
| 167 |
+
|
| 168 |
+
draw_mols = []
|
| 169 |
+
marked_atoms = []
|
| 170 |
+
marked_bonds = []
|
| 171 |
+
for mol in molecules:
|
| 172 |
+
draw_mol = Chem.Mol(mol) # copy
|
| 173 |
+
Chem.SanitizeMol(draw_mol, sanitizeOps=SanitizeFlags.SANITIZE_NONE)
|
| 174 |
+
AllChem.Compute2DCoords(draw_mol)
|
| 175 |
+
draw_mol = Draw.rdMolDraw2D.PrepareMolForDrawing(draw_mol,
|
| 176 |
+
kekulize=False)
|
| 177 |
+
draw_mols.append(draw_mol)
|
| 178 |
+
marked_atoms.append([a.GetIdx() for a in draw_mol.GetAtoms() if highlight_atom(a)])
|
| 179 |
+
marked_bonds.append([b.GetIdx() for b in draw_mol.GetBonds() if highlight_bond(b)])
|
| 180 |
+
|
| 181 |
+
drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
|
| 182 |
+
drawOptions.prepareMolsBeforeDrawing = False
|
| 183 |
+
drawOptions.highlightBondWidthMultiplier = 20
|
| 184 |
+
|
| 185 |
+
return Draw.MolsToGridImage(draw_mols,
|
| 186 |
+
molsPerRow=mols_per_row,
|
| 187 |
+
subImgSize=fig_size,
|
| 188 |
+
drawOptions=drawOptions,
|
| 189 |
+
highlightAtomLists=marked_atoms,
|
| 190 |
+
highlightBondLists=marked_bonds,
|
| 191 |
+
legends=[f'[{i}] {label(mol)}' for
|
| 192 |
+
i, mol in enumerate(draw_mols)])
|
src/constants.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from rdkit import Chem
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
# ------------------------------------------------------------------------------
|
| 7 |
+
# Computational
|
| 8 |
+
# ------------------------------------------------------------------------------
|
| 9 |
+
FLOAT_TYPE = torch.float32
|
| 10 |
+
INT_TYPE = torch.int64
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ------------------------------------------------------------------------------
|
| 14 |
+
# Type encoding/decoding
|
| 15 |
+
# ------------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
atom_dict = os.environ.get('ATOM_DICT')
|
| 18 |
+
if atom_dict == 'simple':
|
| 19 |
+
atom_encoder = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9, 'NOATOM': 10}
|
| 20 |
+
atom_decoder = ['C', 'N', 'O', 'S', 'B', 'Br', 'Cl', 'P', 'I', 'F', 'NOATOM']
|
| 21 |
+
|
| 22 |
+
else:
|
| 23 |
+
atom_encoder = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9, 'NH': 10, 'N+': 11, 'O-': 12, 'NOATOM': 13}
|
| 24 |
+
atom_decoder = ['C', 'N', 'O', 'S', 'B', 'Br', 'Cl', 'P', 'I', 'F', 'NH', 'N+', 'O-', 'NOATOM']
|
| 25 |
+
|
| 26 |
+
bond_encoder = {"NOBOND": 0, "SINGLE": 1, "DOUBLE": 2, "TRIPLE": 3, 'AROMATIC': 4}
|
| 27 |
+
bond_decoder = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
|
| 28 |
+
|
| 29 |
+
aa_encoder = {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19}
|
| 30 |
+
aa_decoder = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
|
| 31 |
+
|
| 32 |
+
residue_encoder = {'CA': 0, 'SS': 1}
|
| 33 |
+
residue_decoder = ['CA', 'SS']
|
| 34 |
+
|
| 35 |
+
residue_bond_encoder = {'CA-CA': 0, 'CA-SS': 1, 'NOBOND': 2}
|
| 36 |
+
residue_bond_decoder = ['CA-CA', 'CA-SS', None]
|
| 37 |
+
|
| 38 |
+
# aa_atom_index = {
|
| 39 |
+
# 'A': {'N': 0, 'C': 1, 'O': 2, 'CB': 3},
|
| 40 |
+
# 'C': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'SG': 4},
|
| 41 |
+
# 'D': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'OD1': 5, 'OD2': 6},
|
| 42 |
+
# 'E': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD': 5, 'OE1': 6, 'OE2': 7},
|
| 43 |
+
# 'F': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD1': 5, 'CD2': 6, 'CE1': 7, 'CE2': 8, 'CZ': 9},
|
| 44 |
+
# 'G': {'N': 0, 'C': 1, 'O': 2},
|
| 45 |
+
# 'H': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'ND1': 5, 'CD2': 6, 'CE1': 7, 'NE2': 8},
|
| 46 |
+
# 'I': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG1': 4, 'CG2': 5, 'CD1': 6},
|
| 47 |
+
# 'K': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD': 5, 'CE': 6, 'NZ': 7},
|
| 48 |
+
# 'L': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD1': 5, 'CD2': 6},
|
| 49 |
+
# 'M': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'SD': 5, 'CE': 6},
|
| 50 |
+
# 'N': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'OD1': 5, 'ND2': 6},
|
| 51 |
+
# 'P': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD': 5},
|
| 52 |
+
# 'Q': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD': 5, 'OE1': 6, 'NE2': 7},
|
| 53 |
+
# 'R': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD': 5, 'NE': 6, 'CZ': 7, 'NH1': 8, 'NH2': 9},
|
| 54 |
+
# 'S': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'OG': 4},
|
| 55 |
+
# 'T': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'OG1': 4, 'CG2': 5},
|
| 56 |
+
# 'V': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG1': 4, 'CG2': 5},
|
| 57 |
+
# 'W': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD1': 5, 'CD2': 6, 'NE1': 7, 'CE2': 8, 'CE3': 9, 'CZ2': 10, 'CZ3': 11, 'CH2': 12},
|
| 58 |
+
# 'Y': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD1': 5, 'CD2': 6, 'CE1': 7, 'CE2': 8, 'CZ': 9, 'OH': 10},
|
| 59 |
+
# }
|
| 60 |
+
aa_atom_index = {
|
| 61 |
+
'A': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4},
|
| 62 |
+
'C': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'SG': 5},
|
| 63 |
+
'D': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'OD1': 6, 'OD2': 7},
|
| 64 |
+
'E': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD': 6, 'OE1': 7, 'OE2': 8},
|
| 65 |
+
'F': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD1': 6, 'CD2': 7, 'CE1': 8, 'CE2': 9, 'CZ': 10},
|
| 66 |
+
'G': {'N': 0, 'CA': 1, 'C': 2, 'O': 3},
|
| 67 |
+
'H': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'ND1': 6, 'CD2': 7, 'CE1': 8, 'NE2': 9},
|
| 68 |
+
'I': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG1': 5, 'CG2': 6, 'CD1': 7},
|
| 69 |
+
'K': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD': 6, 'CE': 7, 'NZ': 8},
|
| 70 |
+
'L': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD1': 6, 'CD2': 7},
|
| 71 |
+
'M': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'SD': 6, 'CE': 7},
|
| 72 |
+
'N': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'OD1': 6, 'ND2': 7},
|
| 73 |
+
'P': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD': 6},
|
| 74 |
+
'Q': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD': 6, 'OE1': 7, 'NE2': 8},
|
| 75 |
+
'R': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD': 6, 'NE': 7, 'CZ': 8, 'NH1': 9, 'NH2': 10},
|
| 76 |
+
'S': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'OG': 5},
|
| 77 |
+
'T': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'OG1': 5, 'CG2': 6},
|
| 78 |
+
'V': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG1': 5, 'CG2': 6},
|
| 79 |
+
'W': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD1': 6, 'CD2': 7, 'NE1': 8, 'CE2': 9, 'CE3': 10, 'CZ2': 11, 'CZ3': 12, 'CH2': 13},
|
| 80 |
+
'Y': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD1': 6, 'CD2': 7, 'CE1': 8, 'CE2': 9, 'CZ': 10, 'OH': 11},
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
# ------------------------------------------------------------------------------
|
| 84 |
+
# NERF
|
| 85 |
+
# ------------------------------------------------------------------------------
|
| 86 |
+
|
| 87 |
+
# indicates whether atom exists
|
| 88 |
+
aa_atom_mask = {
|
| 89 |
+
'A': [True, True, True, True, True, False, False, False, False, False, False, False, False, False],
|
| 90 |
+
'C': [True, True, True, True, True, True, False, False, False, False, False, False, False, False],
|
| 91 |
+
'D': [True, True, True, True, True, True, True, True, False, False, False, False, False, False],
|
| 92 |
+
'E': [True, True, True, True, True, True, True, True, True, False, False, False, False, False],
|
| 93 |
+
'F': [True, True, True, True, True, True, True, True, True, True, True, False, False, False],
|
| 94 |
+
'G': [True, True, True, True, False, False, False, False, False, False, False, False, False, False],
|
| 95 |
+
'H': [True, True, True, True, True, True, True, True, True, True, False, False, False, False],
|
| 96 |
+
'I': [True, True, True, True, True, True, True, True, False, False, False, False, False, False],
|
| 97 |
+
'K': [True, True, True, True, True, True, True, True, True, False, False, False, False, False],
|
| 98 |
+
'L': [True, True, True, True, True, True, True, True, False, False, False, False, False, False],
|
| 99 |
+
'M': [True, True, True, True, True, True, True, True, False, False, False, False, False, False],
|
| 100 |
+
'N': [True, True, True, True, True, True, True, True, False, False, False, False, False, False],
|
| 101 |
+
'P': [True, True, True, True, True, True, True, False, False, False, False, False, False, False],
|
| 102 |
+
'Q': [True, True, True, True, True, True, True, True, True, False, False, False, False, False],
|
| 103 |
+
'R': [True, True, True, True, True, True, True, True, True, True, True, False, False, False],
|
| 104 |
+
'S': [True, True, True, True, True, True, False, False, False, False, False, False, False, False],
|
| 105 |
+
'T': [True, True, True, True, True, True, True, False, False, False, False, False, False, False],
|
| 106 |
+
'V': [True, True, True, True, True, True, True, False, False, False, False, False, False, False],
|
| 107 |
+
'W': [True, True, True, True, True, True, True, True, True, True, True, True, True, True],
|
| 108 |
+
'Y': [True, True, True, True, True, True, True, True, True, True, True, True, False, False],
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
# (14, 3) index tensor with atom indices of atoms a, b and c for NERF reconstruction
|
| 112 |
+
# in principle, columns 1 and 2 can be inferred from column one (immediate predecessor) alone
|
| 113 |
+
aa_nerf_indices = {
|
| 114 |
+
'A': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 115 |
+
'C': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 116 |
+
'D': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 117 |
+
'E': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [6, 5, 4], [6, 5, 4], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 118 |
+
'F': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [6, 5, 4], [7, 5, 4], [8, 6, 5], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 119 |
+
'G': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 120 |
+
'H': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [6, 5, 4], [7, 5, 4], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 121 |
+
'I': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [4, 1, 0], [5, 4, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 122 |
+
'K': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [6, 5, 4], [7, 6, 5], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 123 |
+
'L': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 124 |
+
'M': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [6, 5, 4], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 125 |
+
'N': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 126 |
+
'P': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 127 |
+
'Q': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [6, 5, 4], [6, 5, 4], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 128 |
+
'R': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [6, 5, 4], [7, 6, 5], [8, 7, 6], [8, 7, 6], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 129 |
+
'S': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 130 |
+
'T': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [4, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 131 |
+
'V': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [4, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
| 132 |
+
'W': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [6, 5, 4], [7, 5, 4], [7, 5, 4], [9, 7, 5], [10, 7, 5], [11, 9, 7]],
|
| 133 |
+
'Y': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [6, 5, 4], [7, 5, 4], [8, 6, 5], [10, 8, 6], [0, 0, 0], [0, 0, 0]],
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
# unique id for each rotatable bond (0=chi1, 1=chi, ...)
|
| 137 |
+
aa_bond_to_chi = {
|
| 138 |
+
'A': {},
|
| 139 |
+
'C': {('CA', 'CB'): 0},
|
| 140 |
+
'D': {('CA', 'CB'): 0, ('CB', 'CG'): 1},
|
| 141 |
+
'E': {('CA', 'CB'): 0, ('CB', 'CG'): 1, ('CG', 'CD'): 2},
|
| 142 |
+
'F': {('CA', 'CB'): 0, ('CB', 'CG'): 1},
|
| 143 |
+
'G': {},
|
| 144 |
+
'H': {('CA', 'CB'): 0, ('CB', 'CG'): 1},
|
| 145 |
+
'I': {('CA', 'CB'): 0, ('CB', 'CG2'): 1},
|
| 146 |
+
'K': {('CA', 'CB'): 0, ('CB', 'CG'): 1, ('CG', 'CD'): 2, ('CD', 'CE'): 3},
|
| 147 |
+
'L': {('CA', 'CB'): 0, ('CB', 'CG'): 1},
|
| 148 |
+
'M': {('CA', 'CB'): 0, ('CB', 'CG'): 1, ('CG', 'SD'): 2},
|
| 149 |
+
'N': {('CA', 'CB'): 0, ('CB', 'CG'): 1},
|
| 150 |
+
'P': {},
|
| 151 |
+
'Q': {('CA', 'CB'): 0, ('CB', 'CG'): 1, ('CG', 'CD'): 2},
|
| 152 |
+
'R': {('CA', 'CB'): 0, ('CB', 'CG'): 1, ('CG', 'CD'): 2, ('CD', 'NE'): 3, ('NE', 'CZ'): 4},
|
| 153 |
+
'S': {('CA', 'CB'): 0},
|
| 154 |
+
'T': {('CA', 'CB'): 0},
|
| 155 |
+
'V': {('CA', 'CB'): 0},
|
| 156 |
+
'W': {('CA', 'CB'): 0, ('CB', 'CG'): 1},
|
| 157 |
+
'Y': {('CA', 'CB'): 0, ('CB', 'CG'): 1},
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
# index between 0 and 4 to retrieve chi angles, -1 means not a rotatable bond
|
| 161 |
+
aa_chi_indices = {
|
| 162 |
+
'A': [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
| 163 |
+
'C': [-1, -1, -1, -1, -1, 0, -1, -1, -1, -1, -1, -1, -1, -1],
|
| 164 |
+
'D': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1],
|
| 165 |
+
'E': [-1, -1, -1, -1, -1, 0, 1, 2, 2, -1, -1, -1, -1, -1],
|
| 166 |
+
'F': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1],
|
| 167 |
+
'G': [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
| 168 |
+
'H': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1],
|
| 169 |
+
'I': [-1, -1, -1, -1, -1, 0, 0, 1, -1, -1, -1, -1, -1, -1],
|
| 170 |
+
'K': [-1, -1, -1, -1, -1, 0, 1, 2, 3, -1, -1, -1, -1, -1],
|
| 171 |
+
'L': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1],
|
| 172 |
+
'M': [-1, -1, -1, -1, -1, 0, 1, 2, -1, -1, -1, -1, -1, -1],
|
| 173 |
+
'N': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1],
|
| 174 |
+
'P': [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
| 175 |
+
'Q': [-1, -1, -1, -1, -1, 0, 1, 2, 2, -1, -1, -1, -1, -1],
|
| 176 |
+
'R': [-1, -1, -1, -1, -1, 0, 1, 2, 3, 4, 4, -1, -1, -1],
|
| 177 |
+
'S': [-1, -1, -1, -1, -1, 0, -1, -1, -1, -1, -1, -1, -1, -1],
|
| 178 |
+
'T': [-1, -1, -1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1],
|
| 179 |
+
'V': [-1, -1, -1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1],
|
| 180 |
+
'W': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1],
|
| 181 |
+
'Y': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1],
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
# key: chi index (0=chi1, 1=chi, ...); value: index of atom that defines the chi angle (together with its three predecessors)
|
| 185 |
+
aa_chi_anchor_atom = {
|
| 186 |
+
'A': {},
|
| 187 |
+
'C': {0: 5},
|
| 188 |
+
'D': {0: 5, 1: 6},
|
| 189 |
+
'E': {0: 5, 1: 6, 2: 7},
|
| 190 |
+
'F': {0: 5, 1: 6},
|
| 191 |
+
'G': {},
|
| 192 |
+
'H': {0: 5, 1: 6},
|
| 193 |
+
'I': {0: 5, 1: 7},
|
| 194 |
+
'K': {0: 5, 1: 6, 2: 7, 3: 8},
|
| 195 |
+
'L': {0: 5, 1: 6},
|
| 196 |
+
'M': {0: 5, 1: 6, 2: 7},
|
| 197 |
+
'N': {0: 5, 1: 6},
|
| 198 |
+
'P': {},
|
| 199 |
+
'Q': {0: 5, 1: 6, 2: 7},
|
| 200 |
+
'R': {0: 5, 1: 6, 2: 7, 3: 8, 4: 9},
|
| 201 |
+
'S': {0: 5},
|
| 202 |
+
'T': {0: 5},
|
| 203 |
+
'V': {0: 5},
|
| 204 |
+
'W': {0: 5, 1: 6},
|
| 205 |
+
'Y': {0: 5, 1: 6},
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
# ------------------------------------------------------------------------------
|
| 209 |
+
# Visualization
|
| 210 |
+
# ------------------------------------------------------------------------------
|
| 211 |
+
# PyMOL colors, see: https://pymolwiki.org/index.php/Color_Values#Chemical_element_colours
|
| 212 |
+
colors_dic = ['#33ff33', '#3333ff', '#ff4d4d', '#e6c540', '#ffb5b5', '#A62929', '#1FF01F', '#ff8000', '#940094', '#B3FFFF', '#b3e3f5']
|
| 213 |
+
radius_dic = [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# ------------------------------------------------------------------------------
|
| 217 |
+
# Backbone geometry
|
| 218 |
+
# Taken from: Bhagavan, N. V., and C. E. Ha.
|
| 219 |
+
# "Chapter 4-Three-dimensional structure of proteins and disorders of protein misfolding."
|
| 220 |
+
# Essentials of Medical Biochemistry (2015): 31-51.
|
| 221 |
+
# https://www.sciencedirect.com/science/article/pii/B978012416687500004X
|
| 222 |
+
# ------------------------------------------------------------------------------
|
| 223 |
+
N_CA_DIST = 1.47
|
| 224 |
+
CA_C_DIST = 1.53
|
| 225 |
+
N_CA_C_ANGLE = 110 * np.pi / 180
|
| 226 |
+
|
| 227 |
+
# ------------------------------------------------------------------------------
|
| 228 |
+
# Atom radii
|
| 229 |
+
# ------------------------------------------------------------------------------
|
| 230 |
+
# # https://en.wikipedia.org/wiki/Covalent_radius#Radii_for_multiple_bonds
|
| 231 |
+
# # (2023/04/14)
|
| 232 |
+
# covalent_radii = {'H': [32, None, None],
|
| 233 |
+
# 'C': [75, 67, 60],
|
| 234 |
+
# 'N': [71, 60, 54],
|
| 235 |
+
# 'O': [63, 57, 53],
|
| 236 |
+
# 'F': [64, 59, 53],
|
| 237 |
+
# 'B': [85, 78, 73],
|
| 238 |
+
# 'Al': [126, 113, 111],
|
| 239 |
+
# 'Si': [116, 107, 102],
|
| 240 |
+
# 'P': [111, 102, 94],
|
| 241 |
+
# 'S': [103, 94, 95],
|
| 242 |
+
# 'Cl': [99, 95, 93],
|
| 243 |
+
# 'As': [121, 114, 106],
|
| 244 |
+
# 'Br': [114, 109, 110],
|
| 245 |
+
# 'I': [133, 129, 125],
|
| 246 |
+
# 'Hg': [133, 142, None],
|
| 247 |
+
# 'Bi': [151, 141, 135]}
|
| 248 |
+
|
| 249 |
+
# source: https://en.wikipedia.org/wiki/Van_der_Waals_radius
|
| 250 |
+
vdw_radii = {'N': 1.55, 'O': 1.52, 'C': 1.70, 'H': 1.10, 'S': 1.80, 'P': 1.80,
|
| 251 |
+
'Se': 1.90, 'K': 2.75, 'Na': 2.27, 'Mg': 1.73, 'Zn': 1.39, 'B': 1.92,
|
| 252 |
+
'Br': 1.85, 'Cl': 1.75, 'I': 1.98, 'F': 1.47}
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
WEBDATASET_SHARD_SIZE = 50000
|
| 256 |
+
WEBDATASET_VAL_SIZE = 100
|
src/data/data_utils.py
ADDED
|
@@ -0,0 +1,901 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
from itertools import accumulate, chain
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
from rdkit import Chem
|
| 9 |
+
from torch_scatter import scatter_mean
|
| 10 |
+
from Bio.PDB import StructureBuilder, Chain, Model, Structure
|
| 11 |
+
from Bio.PDB.PICIO import read_PIC, write_PIC
|
| 12 |
+
from scipy.ndimage import gaussian_filter
|
| 13 |
+
from pdb import set_trace
|
| 14 |
+
|
| 15 |
+
from src.constants import FLOAT_TYPE, INT_TYPE
|
| 16 |
+
from src.constants import atom_encoder, bond_encoder, aa_encoder, residue_encoder, residue_bond_encoder, aa_atom_index
|
| 17 |
+
from src import utils
|
| 18 |
+
from src.data.misc import protein_letters_3to1, is_aa
|
| 19 |
+
from src.data.normal_modes import pdb_to_normal_modes
|
| 20 |
+
from src.data.nerf import get_nerf_params, ic_to_coords
|
| 21 |
+
import src.data.so3_utils as so3
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TensorDict(dict):
|
| 25 |
+
def __init__(self, **kwargs):
|
| 26 |
+
super(TensorDict, self).__init__(**kwargs)
|
| 27 |
+
|
| 28 |
+
def _apply(self, func: str, *args, **kwargs):
|
| 29 |
+
""" Apply function to all tensors. """
|
| 30 |
+
for k, v in self.items():
|
| 31 |
+
if torch.is_tensor(v):
|
| 32 |
+
self[k] = getattr(v, func)(*args, **kwargs)
|
| 33 |
+
return self
|
| 34 |
+
|
| 35 |
+
# def to(self, device):
|
| 36 |
+
# for k, v in self.items():
|
| 37 |
+
# if torch.is_tensor(v):
|
| 38 |
+
# self[k] = v.to(device)
|
| 39 |
+
# return self
|
| 40 |
+
|
| 41 |
+
def cuda(self):
|
| 42 |
+
return self.to('cuda')
|
| 43 |
+
|
| 44 |
+
def cpu(self):
|
| 45 |
+
return self.to('cpu')
|
| 46 |
+
|
| 47 |
+
def to(self, device):
|
| 48 |
+
return self._apply("to", device)
|
| 49 |
+
|
| 50 |
+
def detach(self):
|
| 51 |
+
return self._apply("detach")
|
| 52 |
+
|
| 53 |
+
def __repr__(self):
|
| 54 |
+
def val_to_str(val):
|
| 55 |
+
if isinstance(val, torch.Tensor):
|
| 56 |
+
# if val.isnan().any():
|
| 57 |
+
# return "(!nan)"
|
| 58 |
+
return "%r" % list(val.size())
|
| 59 |
+
if isinstance(val, list):
|
| 60 |
+
return "[%r,]" % len(val)
|
| 61 |
+
else:
|
| 62 |
+
return "?"
|
| 63 |
+
|
| 64 |
+
return f"{type(self).__name__}({', '.join(f'{k}={val_to_str(v)}' for k, v in self.items())})"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def collate_entity(batch):
|
| 68 |
+
|
| 69 |
+
out = {}
|
| 70 |
+
for prop in batch[0].keys():
|
| 71 |
+
|
| 72 |
+
if prop == 'name':
|
| 73 |
+
out[prop] = [x[prop] for x in batch]
|
| 74 |
+
|
| 75 |
+
elif prop == 'size' or prop == 'n_bonds':
|
| 76 |
+
out[prop] = torch.tensor([x[prop] for x in batch])
|
| 77 |
+
|
| 78 |
+
elif prop == 'bonds':
|
| 79 |
+
# index offset
|
| 80 |
+
offset = list(accumulate([x['size'] for x in batch], initial=0))
|
| 81 |
+
out[prop] = torch.cat([x[prop] + offset[i] for i, x in enumerate(batch)], dim=1)
|
| 82 |
+
|
| 83 |
+
elif prop == 'residues':
|
| 84 |
+
out[prop] = list(chain.from_iterable(x[prop] for x in batch))
|
| 85 |
+
|
| 86 |
+
elif prop in {'mask', 'bond_mask'}:
|
| 87 |
+
pass # batch masks will be written later
|
| 88 |
+
|
| 89 |
+
else:
|
| 90 |
+
out[prop] = torch.cat([x[prop] for x in batch], dim=0)
|
| 91 |
+
|
| 92 |
+
# Create batch masks
|
| 93 |
+
# make sure indices in batch start at zero (needed for torch_scatter)
|
| 94 |
+
if prop == 'x':
|
| 95 |
+
out['mask'] = torch.cat([i * torch.ones(len(x[prop]), dtype=torch.int64, device=x[prop].device)
|
| 96 |
+
for i, x in enumerate(batch)], dim=0)
|
| 97 |
+
if prop == 'bond_one_hot':
|
| 98 |
+
# TODO: this is not necessary as it can be computed on-the-fly as bond_mask = mask[bonds[0]] or bond_mask = mask[bonds[1]]
|
| 99 |
+
out['bond_mask'] = torch.cat([i * torch.ones(len(x[prop]), dtype=torch.int64, device=x[prop].device)
|
| 100 |
+
for i, x in enumerate(batch)], dim=0)
|
| 101 |
+
|
| 102 |
+
return out
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def split_entity(
|
| 106 |
+
batch,
|
| 107 |
+
*,
|
| 108 |
+
index_types={'bonds'},
|
| 109 |
+
edge_types={'bond_one_hot', 'bond_mask'},
|
| 110 |
+
no_split={'name', 'size', 'n_bonds'},
|
| 111 |
+
skip={'fragments'},
|
| 112 |
+
batch_mask=None,
|
| 113 |
+
edge_mask=None
|
| 114 |
+
):
|
| 115 |
+
""" Splits a batch into items and returns a list. """
|
| 116 |
+
|
| 117 |
+
batch_mask = batch["mask"] if batch_mask is None else batch_mask
|
| 118 |
+
edge_mask = batch["bond_mask"] if edge_mask is None else edge_mask
|
| 119 |
+
sizes = batch['size'] if 'size' in batch else torch.unique(batch_mask, return_counts=True)[1].tolist()
|
| 120 |
+
|
| 121 |
+
batch_size = len(torch.unique(batch['mask']))
|
| 122 |
+
out = {}
|
| 123 |
+
for prop in batch.keys():
|
| 124 |
+
if prop in skip:
|
| 125 |
+
continue
|
| 126 |
+
if prop in no_split:
|
| 127 |
+
out[prop] = batch[prop] # already a list
|
| 128 |
+
|
| 129 |
+
elif prop in index_types:
|
| 130 |
+
offsets = list(accumulate(sizes[:-1], initial=0))
|
| 131 |
+
out[prop] = utils.batch_to_list_for_indices(batch[prop], edge_mask, offsets)
|
| 132 |
+
|
| 133 |
+
elif prop in edge_types:
|
| 134 |
+
out[prop] = utils.batch_to_list(batch[prop], edge_mask)
|
| 135 |
+
|
| 136 |
+
else:
|
| 137 |
+
out[prop] = utils.batch_to_list(batch[prop], batch_mask)
|
| 138 |
+
|
| 139 |
+
out = [{k: v[i] for k, v in out.items()} for i in range(batch_size)]
|
| 140 |
+
return out
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def repeat_items(batch, repeats):
|
| 144 |
+
batch_list = split_entity(batch)
|
| 145 |
+
out = collate_entity([x for _ in range(repeats) for x in batch_list])
|
| 146 |
+
return type(batch)(**out)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def get_side_chain_bead_coord(biopython_residue):
|
| 150 |
+
"""
|
| 151 |
+
Places side chain bead at the location of the farthest side chain atom.
|
| 152 |
+
"""
|
| 153 |
+
if biopython_residue.get_resname() == 'GLY':
|
| 154 |
+
return None
|
| 155 |
+
if biopython_residue.get_resname() == 'ALA':
|
| 156 |
+
return biopython_residue['CB'].get_coord()
|
| 157 |
+
|
| 158 |
+
ca_coord = biopython_residue['CA'].get_coord()
|
| 159 |
+
side_chain_atoms = [a for a in biopython_residue.get_atoms() if
|
| 160 |
+
a.id not in {'N', 'CA', 'C', 'O'} and a.element != 'H']
|
| 161 |
+
side_chain_coords = np.stack([a.get_coord() for a in side_chain_atoms])
|
| 162 |
+
|
| 163 |
+
atom_idx = np.argmax(np.sum((side_chain_coords - ca_coord[None, :]) ** 2, axis=-1))
|
| 164 |
+
|
| 165 |
+
return side_chain_coords[atom_idx, :]
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def get_side_chain_vectors(res, index_dict, size=None):
|
| 169 |
+
if size is None:
|
| 170 |
+
size = max([x for aa in index_dict.values() for x in aa.values()]) + 1
|
| 171 |
+
|
| 172 |
+
resname = protein_letters_3to1[res.get_resname()]
|
| 173 |
+
|
| 174 |
+
out = np.zeros((size, 3))
|
| 175 |
+
for atom in res.get_atoms():
|
| 176 |
+
if atom.get_name() in index_dict[resname]:
|
| 177 |
+
idx = index_dict[resname][atom.get_name()]
|
| 178 |
+
out[idx] = atom.get_coord() - res['CA'].get_coord()
|
| 179 |
+
# else:
|
| 180 |
+
# if atom.get_name() != 'CA' and not atom.get_name().startswith('H'):
|
| 181 |
+
# print(resname, atom.get_name())
|
| 182 |
+
|
| 183 |
+
return out
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def get_normal_modes(res, normal_mode_dict):
|
| 187 |
+
nm = normal_mode_dict[(res.get_parent().id, res.id[1], 'CA')] # (n_modes, 3)
|
| 188 |
+
return nm
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def get_torsion_angles(res, device=None):
|
| 192 |
+
"""
|
| 193 |
+
Return the five chi angles. Missing angles are filled with zeros.
|
| 194 |
+
"""
|
| 195 |
+
ANGLES = ['chi1', 'chi2', 'chi3', 'chi4', 'chi5']
|
| 196 |
+
|
| 197 |
+
ic_res = res.internal_coord
|
| 198 |
+
chi_angles = [ic_res.get_angle(chi) for chi in ANGLES]
|
| 199 |
+
chi_angles = [chi if chi is not None else float('nan') for chi in chi_angles]
|
| 200 |
+
|
| 201 |
+
return torch.tensor(chi_angles, device=device) * np.pi / 180
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def apply_torsion_angles(res, chi_angles):
|
| 205 |
+
"""
|
| 206 |
+
Set side chain torsion angles of a biopython residue object with
|
| 207 |
+
internal coordinates.
|
| 208 |
+
"""
|
| 209 |
+
ANGLES = ['chi1', 'chi2', 'chi3', 'chi4', 'chi5']
|
| 210 |
+
|
| 211 |
+
chi_angles = chi_angles * 180 / np.pi
|
| 212 |
+
|
| 213 |
+
# res.parent.internal_coord.build_atomArray() # rebuild atom pointers
|
| 214 |
+
|
| 215 |
+
ic_res = res.internal_coord
|
| 216 |
+
for chi, angle in zip(ANGLES, chi_angles):
|
| 217 |
+
if ic_res.pick_angle(chi) is None:
|
| 218 |
+
continue
|
| 219 |
+
ic_res.bond_set(chi, angle)
|
| 220 |
+
|
| 221 |
+
res.parent.internal_to_atom_coordinates(verbose=False)
|
| 222 |
+
# res.parent.internal_coord.init_atom_coords()
|
| 223 |
+
# res.internal_coord.assemble()
|
| 224 |
+
|
| 225 |
+
return res
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def prepare_internal_coord(res):
|
| 229 |
+
|
| 230 |
+
# Make new structure with a single residue
|
| 231 |
+
new_struct = Structure.Structure('X')
|
| 232 |
+
new_struct.header = {}
|
| 233 |
+
new_model = Model.Model(0)
|
| 234 |
+
new_struct.add(new_model)
|
| 235 |
+
new_chain = Chain.Chain('X')
|
| 236 |
+
new_model.add(new_chain)
|
| 237 |
+
new_chain.add(res)
|
| 238 |
+
res.set_parent(new_chain) # update pointer
|
| 239 |
+
|
| 240 |
+
# Compute internal coordinates
|
| 241 |
+
new_chain.atom_to_internal_coordinates()
|
| 242 |
+
|
| 243 |
+
pic_io = io.StringIO()
|
| 244 |
+
write_PIC(new_struct, pic_io)
|
| 245 |
+
return pic_io.getvalue()
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def residue_from_internal_coord(ic_string):
|
| 249 |
+
pic_io = io.StringIO(ic_string)
|
| 250 |
+
struct = read_PIC(pic_io, quick=True)
|
| 251 |
+
res = struct.child_list[0].child_list[0].child_list[0]
|
| 252 |
+
res.parent.internal_to_atom_coordinates(verbose=False)
|
| 253 |
+
return res
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def prepare_pocket(biopython_residues, amino_acid_encoder, residue_encoder,
|
| 257 |
+
residue_bond_encoder, pocket_representation='side_chain_bead',
|
| 258 |
+
compute_nerf_params=False, compute_bb_frames=False,
|
| 259 |
+
nma_input=None):
|
| 260 |
+
|
| 261 |
+
assert nma_input is None or pocket_representation == 'CA+', \
|
| 262 |
+
"vector features are only supported for CA+ pockets"
|
| 263 |
+
|
| 264 |
+
# sort residues
|
| 265 |
+
biopython_residues = sorted(biopython_residues, key=lambda x: (x.parent.id, x.id[1]))
|
| 266 |
+
|
| 267 |
+
if nma_input is not None:
|
| 268 |
+
# preprocessed normal mode eigenvectors
|
| 269 |
+
if isinstance(nma_input, dict):
|
| 270 |
+
nma_dict = nma_input
|
| 271 |
+
|
| 272 |
+
# PDB file
|
| 273 |
+
else:
|
| 274 |
+
nma_dict = pdb_to_normal_modes(str(nma_input))
|
| 275 |
+
|
| 276 |
+
if pocket_representation == 'side_chain_bead':
|
| 277 |
+
ca_coords = np.zeros((len(biopython_residues), 3))
|
| 278 |
+
ca_types = np.zeros(len(biopython_residues), dtype='int64')
|
| 279 |
+
side_chain_coords = []
|
| 280 |
+
side_chain_aa_types = []
|
| 281 |
+
edges = [] # CA-CA and CA-side_chain
|
| 282 |
+
edge_types = []
|
| 283 |
+
last_res_id = None
|
| 284 |
+
for i, res in enumerate(biopython_residues):
|
| 285 |
+
aa = amino_acid_encoder[protein_letters_3to1[res.get_resname()]]
|
| 286 |
+
ca_coords[i, :] = res['CA'].get_coord()
|
| 287 |
+
ca_types[i] = aa
|
| 288 |
+
side_chain_coord = get_side_chain_bead_coord(res)
|
| 289 |
+
if side_chain_coord is not None:
|
| 290 |
+
side_chain_coords.append(side_chain_coord)
|
| 291 |
+
side_chain_aa_types.append(aa)
|
| 292 |
+
edges.append((i, len(ca_coords) + len(side_chain_coords) - 1))
|
| 293 |
+
edge_types.append(residue_bond_encoder['CA-SS'])
|
| 294 |
+
|
| 295 |
+
# add edges between contiguous CA atoms
|
| 296 |
+
if i > 0 and res.id[1] == last_res_id + 1:
|
| 297 |
+
edges.append((i - 1, i))
|
| 298 |
+
edge_types.append(residue_bond_encoder['CA-CA'])
|
| 299 |
+
|
| 300 |
+
last_res_id = res.id[1]
|
| 301 |
+
|
| 302 |
+
# Coordinates
|
| 303 |
+
side_chain_coords = np.stack(side_chain_coords)
|
| 304 |
+
pocket_coords = np.concatenate([ca_coords, side_chain_coords], axis=0)
|
| 305 |
+
pocket_coords = torch.from_numpy(pocket_coords)
|
| 306 |
+
|
| 307 |
+
# Features
|
| 308 |
+
amino_acid_onehot = F.one_hot(
|
| 309 |
+
torch.cat([torch.from_numpy(ca_types), torch.tensor(side_chain_aa_types, dtype=torch.int64)], dim=0),
|
| 310 |
+
num_classes=len(amino_acid_encoder)
|
| 311 |
+
)
|
| 312 |
+
side_chain_onehot = np.concatenate([
|
| 313 |
+
np.tile(np.eye(1, len(residue_encoder), residue_encoder['CA']),
|
| 314 |
+
[len(ca_coords), 1]),
|
| 315 |
+
np.tile(np.eye(1, len(residue_encoder), residue_encoder['SS']),
|
| 316 |
+
[len(side_chain_coords), 1])
|
| 317 |
+
], axis=0)
|
| 318 |
+
side_chain_onehot = torch.from_numpy(side_chain_onehot)
|
| 319 |
+
pocket_onehot = torch.cat([amino_acid_onehot, side_chain_onehot], dim=1)
|
| 320 |
+
|
| 321 |
+
vector_features = None
|
| 322 |
+
nma_features = None
|
| 323 |
+
|
| 324 |
+
# Bonds
|
| 325 |
+
edges = torch.tensor(edges).T
|
| 326 |
+
edge_types = F.one_hot(torch.tensor(edge_types), num_classes=len(residue_bond_encoder))
|
| 327 |
+
|
| 328 |
+
elif pocket_representation == 'CA+':
|
| 329 |
+
ca_coords = np.zeros((len(biopython_residues), 3))
|
| 330 |
+
ca_types = np.zeros(len(biopython_residues), dtype='int64')
|
| 331 |
+
|
| 332 |
+
v_dim = max([x for aa in aa_atom_index.values() for x in aa.values()]) + 1
|
| 333 |
+
vec_feats = np.zeros((len(biopython_residues), v_dim, 3), dtype='float32')
|
| 334 |
+
nf_nma = 5
|
| 335 |
+
nma_feats = np.zeros((len(biopython_residues), nf_nma, 3), dtype='float32')
|
| 336 |
+
|
| 337 |
+
edges = [] # CA-CA and CA-side_chain
|
| 338 |
+
edge_types = []
|
| 339 |
+
last_res_id = None
|
| 340 |
+
for i, res in enumerate(biopython_residues):
|
| 341 |
+
aa = amino_acid_encoder[protein_letters_3to1[res.get_resname()]]
|
| 342 |
+
ca_coords[i, :] = res['CA'].get_coord()
|
| 343 |
+
ca_types[i] = aa
|
| 344 |
+
|
| 345 |
+
vec_feats[i] = get_side_chain_vectors(res, aa_atom_index, v_dim)
|
| 346 |
+
if nma_input is not None:
|
| 347 |
+
nma_feats[i] = get_normal_modes(res, nma_dict)
|
| 348 |
+
|
| 349 |
+
# add edges between contiguous CA atoms
|
| 350 |
+
if i > 0 and res.id[1] == last_res_id + 1:
|
| 351 |
+
edges.append((i - 1, i))
|
| 352 |
+
edge_types.append(residue_bond_encoder['CA-CA'])
|
| 353 |
+
|
| 354 |
+
last_res_id = res.id[1]
|
| 355 |
+
|
| 356 |
+
# Coordinates
|
| 357 |
+
pocket_coords = torch.from_numpy(ca_coords)
|
| 358 |
+
|
| 359 |
+
# Features
|
| 360 |
+
pocket_onehot = F.one_hot(torch.from_numpy(ca_types),
|
| 361 |
+
num_classes=len(amino_acid_encoder))
|
| 362 |
+
|
| 363 |
+
vector_features = torch.from_numpy(vec_feats)
|
| 364 |
+
nma_features = torch.from_numpy(nma_feats)
|
| 365 |
+
|
| 366 |
+
# Bonds
|
| 367 |
+
if len(edges) < 1:
|
| 368 |
+
edges = torch.empty(2, 0)
|
| 369 |
+
edge_types = torch.empty(0, len(residue_bond_encoder))
|
| 370 |
+
else:
|
| 371 |
+
edges = torch.tensor(edges).T
|
| 372 |
+
edge_types = F.one_hot(torch.tensor(edge_types),
|
| 373 |
+
num_classes=len(residue_bond_encoder))
|
| 374 |
+
|
| 375 |
+
else:
|
| 376 |
+
raise NotImplementedError(
|
| 377 |
+
f"Pocket representation '{pocket_representation}' not implemented")
|
| 378 |
+
|
| 379 |
+
# pocket_ids = [f'{res.parent.id}:{res.id[1]}' for res in biopython_residues]
|
| 380 |
+
|
| 381 |
+
pocket = {
|
| 382 |
+
'x': pocket_coords.to(dtype=FLOAT_TYPE),
|
| 383 |
+
'one_hot': pocket_onehot.to(dtype=FLOAT_TYPE),
|
| 384 |
+
# 'ids': pocket_ids,
|
| 385 |
+
'size': torch.tensor([len(pocket_coords)], dtype=INT_TYPE),
|
| 386 |
+
'mask': torch.zeros(len(pocket_coords), dtype=INT_TYPE),
|
| 387 |
+
'bonds': edges.to(INT_TYPE),
|
| 388 |
+
'bond_one_hot': edge_types.to(FLOAT_TYPE),
|
| 389 |
+
'bond_mask': torch.zeros(edges.size(1), dtype=INT_TYPE),
|
| 390 |
+
'n_bonds': torch.tensor([len(edge_types)], dtype=INT_TYPE),
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
if vector_features is not None:
|
| 394 |
+
pocket['v'] = vector_features.to(dtype=FLOAT_TYPE)
|
| 395 |
+
|
| 396 |
+
if nma_input is not None:
|
| 397 |
+
pocket['nma_vec'] = nma_features.to(dtype=FLOAT_TYPE)
|
| 398 |
+
|
| 399 |
+
if compute_nerf_params:
|
| 400 |
+
nerf_params = [get_nerf_params(r) for r in biopython_residues]
|
| 401 |
+
nerf_params = {k: torch.stack([x[k] for x in nerf_params], dim=0)
|
| 402 |
+
for k in nerf_params[0].keys()}
|
| 403 |
+
pocket.update(nerf_params)
|
| 404 |
+
|
| 405 |
+
if compute_bb_frames:
|
| 406 |
+
n_xyz = torch.from_numpy(np.stack([r['N'].get_coord() for r in biopython_residues]))
|
| 407 |
+
ca_xyz = torch.from_numpy(np.stack([r['CA'].get_coord() for r in biopython_residues]))
|
| 408 |
+
c_xyz = torch.from_numpy(np.stack([r['C'].get_coord() for r in biopython_residues]))
|
| 409 |
+
pocket['axis_angle'], _ = get_bb_transform(n_xyz, ca_xyz, c_xyz)
|
| 410 |
+
|
| 411 |
+
return pocket, biopython_residues
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def encode_atom(rd_atom, atom_encoder):
|
| 415 |
+
element = rd_atom.GetSymbol().capitalize()
|
| 416 |
+
|
| 417 |
+
explicitHs = rd_atom.GetNumExplicitHs()
|
| 418 |
+
if explicitHs == 1 and f'{element}H' in atom_encoder:
|
| 419 |
+
return atom_encoder[f'{element}H']
|
| 420 |
+
|
| 421 |
+
charge = rd_atom.GetFormalCharge()
|
| 422 |
+
if charge == 1 and f'{element}+' in atom_encoder:
|
| 423 |
+
return atom_encoder[f'{element}+']
|
| 424 |
+
if charge == -1 and f'{element}-' in atom_encoder:
|
| 425 |
+
return atom_encoder[f'{element}-']
|
| 426 |
+
|
| 427 |
+
return atom_encoder[element]
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def prepare_ligand(rdmol, atom_encoder, bond_encoder):
|
| 431 |
+
|
| 432 |
+
# remove H atoms if not in atom_encoder
|
| 433 |
+
if 'H' not in atom_encoder:
|
| 434 |
+
rdmol = Chem.RemoveAllHs(rdmol, sanitize=False)
|
| 435 |
+
|
| 436 |
+
# Coordinates
|
| 437 |
+
ligand_coord = rdmol.GetConformer().GetPositions()
|
| 438 |
+
ligand_coord = torch.from_numpy(ligand_coord)
|
| 439 |
+
|
| 440 |
+
# Features
|
| 441 |
+
ligand_onehot = F.one_hot(
|
| 442 |
+
torch.tensor([encode_atom(a, atom_encoder) for a in rdmol.GetAtoms()]),
|
| 443 |
+
num_classes=len(atom_encoder)
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# Bonds
|
| 447 |
+
adj = np.ones((rdmol.GetNumAtoms(), rdmol.GetNumAtoms())) * bond_encoder['NOBOND']
|
| 448 |
+
for b in rdmol.GetBonds():
|
| 449 |
+
i = b.GetBeginAtomIdx()
|
| 450 |
+
j = b.GetEndAtomIdx()
|
| 451 |
+
adj[i, j] = bond_encoder[str(b.GetBondType())]
|
| 452 |
+
adj[j, i] = adj[i, j] # undirected graph
|
| 453 |
+
|
| 454 |
+
# molecular graph is undirected -> don't save redundant information
|
| 455 |
+
bonds = np.stack(np.triu_indices(len(ligand_coord), k=1), axis=0)
|
| 456 |
+
# bonds = np.stack(np.ones_like(adj).nonzero(), axis=0)
|
| 457 |
+
bond_types = adj[bonds[0], bonds[1]].astype('int64')
|
| 458 |
+
bonds = torch.from_numpy(bonds)
|
| 459 |
+
bond_types = F.one_hot(torch.from_numpy(bond_types), num_classes=len(bond_encoder))
|
| 460 |
+
|
| 461 |
+
ligand = {
|
| 462 |
+
'x': ligand_coord.to(dtype=FLOAT_TYPE),
|
| 463 |
+
'one_hot': ligand_onehot.to(dtype=FLOAT_TYPE),
|
| 464 |
+
'mask': torch.zeros(len(ligand_coord), dtype=INT_TYPE),
|
| 465 |
+
'bonds': bonds.to(INT_TYPE),
|
| 466 |
+
'bond_one_hot': bond_types.to(FLOAT_TYPE),
|
| 467 |
+
'bond_mask': torch.zeros(bonds.size(1), dtype=INT_TYPE),
|
| 468 |
+
'size': torch.tensor([len(ligand_coord)], dtype=INT_TYPE),
|
| 469 |
+
'n_bonds': torch.tensor([len(bond_types)], dtype=INT_TYPE),
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
return ligand
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def process_raw_molecule_with_empty_pocket(rdmol):
|
| 476 |
+
ligand = prepare_ligand(rdmol, atom_encoder, bond_encoder)
|
| 477 |
+
pocket = {
|
| 478 |
+
'x': torch.tensor([], dtype=FLOAT_TYPE),
|
| 479 |
+
'one_hot': torch.tensor([], dtype=FLOAT_TYPE),
|
| 480 |
+
'size': torch.tensor([], dtype=INT_TYPE),
|
| 481 |
+
'mask': torch.tensor([], dtype=INT_TYPE),
|
| 482 |
+
'bonds': torch.tensor([], dtype=INT_TYPE),
|
| 483 |
+
'bond_one_hot': torch.tensor([], dtype=FLOAT_TYPE),
|
| 484 |
+
'bond_mask': torch.tensor([], dtype=INT_TYPE),
|
| 485 |
+
'n_bonds': torch.tensor([], dtype=INT_TYPE),
|
| 486 |
+
}
|
| 487 |
+
return ligand, pocket
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def process_raw_pair(biopython_model, rdmol, dist_cutoff=None,
|
| 491 |
+
pocket_representation='side_chain_bead',
|
| 492 |
+
compute_nerf_params=False, compute_bb_frames=False,
|
| 493 |
+
nma_input=None, return_pocket_pdb=False):
|
| 494 |
+
|
| 495 |
+
# Process ligand
|
| 496 |
+
ligand = prepare_ligand(rdmol, atom_encoder, bond_encoder)
|
| 497 |
+
|
| 498 |
+
# Find interacting pocket residues based on distance cutoff
|
| 499 |
+
pocket_residues = []
|
| 500 |
+
for residue in biopython_model.get_residues():
|
| 501 |
+
|
| 502 |
+
# Remove non-standard amino acids and HETATMs
|
| 503 |
+
if not is_aa(residue.get_resname(), standard=True):
|
| 504 |
+
continue
|
| 505 |
+
|
| 506 |
+
res_coords = torch.from_numpy(np.array([a.get_coord() for a in residue.get_atoms()]))
|
| 507 |
+
if dist_cutoff is None or (((res_coords[:, None, :] - ligand['x'][None, :, :]) ** 2).sum(-1) ** 0.5).min() < dist_cutoff:
|
| 508 |
+
pocket_residues.append(residue)
|
| 509 |
+
|
| 510 |
+
pocket, pocket_residues = prepare_pocket(
|
| 511 |
+
pocket_residues, aa_encoder, residue_encoder, residue_bond_encoder,
|
| 512 |
+
pocket_representation, compute_nerf_params, compute_bb_frames, nma_input
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if return_pocket_pdb:
|
| 516 |
+
builder = StructureBuilder.StructureBuilder()
|
| 517 |
+
builder.init_structure("")
|
| 518 |
+
builder.init_model(0)
|
| 519 |
+
pocket_struct = builder.get_structure()
|
| 520 |
+
for residue in pocket_residues:
|
| 521 |
+
chain = residue.get_parent().get_id()
|
| 522 |
+
|
| 523 |
+
# init chain if necessary
|
| 524 |
+
if not pocket_struct[0].has_id(chain):
|
| 525 |
+
builder.init_chain(chain)
|
| 526 |
+
|
| 527 |
+
# add residue
|
| 528 |
+
pocket_struct[0][chain].add(residue)
|
| 529 |
+
|
| 530 |
+
pocket['pocket_pdb'] = pocket_struct
|
| 531 |
+
# if return_pocket_pdb:
|
| 532 |
+
# pocket['residues'] = [prepare_internal_coord(res) for res in pocket_residues]
|
| 533 |
+
|
| 534 |
+
return ligand, pocket
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
class AppendVirtualNodes:
|
| 538 |
+
def __init__(self, atom_encoder, bond_encoder, max_ligand_size, scale=1.0):
|
| 539 |
+
self.max_size = max_ligand_size
|
| 540 |
+
self.atom_encoder = atom_encoder
|
| 541 |
+
self.bond_encoder = bond_encoder
|
| 542 |
+
self.vidx = atom_encoder['NOATOM']
|
| 543 |
+
self.bidx = bond_encoder['NOBOND']
|
| 544 |
+
self.scale = scale
|
| 545 |
+
|
| 546 |
+
def __call__(self, ligand, max_size=None, eps=1e-6):
|
| 547 |
+
if max_size is None:
|
| 548 |
+
max_size = self.max_size
|
| 549 |
+
|
| 550 |
+
n_virt = max_size - ligand['size']
|
| 551 |
+
|
| 552 |
+
C = torch.cov(ligand['x'].T)
|
| 553 |
+
L = torch.linalg.cholesky(C + torch.eye(3) * eps)
|
| 554 |
+
mu = ligand['x'].mean(0, keepdim=True)
|
| 555 |
+
virt_coords = mu + torch.randn(n_virt, 3) @ L.T * self.scale
|
| 556 |
+
|
| 557 |
+
# insert virtual atom column
|
| 558 |
+
virt_one_hot = F.one_hot(torch.ones(n_virt, dtype=torch.int64) * self.vidx, num_classes=len(self.atom_encoder))
|
| 559 |
+
virt_mask = torch.cat([torch.zeros(ligand['size'], dtype=bool), torch.ones(n_virt, dtype=bool)])
|
| 560 |
+
|
| 561 |
+
ligand['x'] = torch.cat([ligand['x'], virt_coords])
|
| 562 |
+
ligand['one_hot'] = torch.cat(([ligand['one_hot'], virt_one_hot]))
|
| 563 |
+
ligand['virtual_mask'] = virt_mask
|
| 564 |
+
ligand['size'] = max_size
|
| 565 |
+
|
| 566 |
+
# Bonds
|
| 567 |
+
new_bonds = torch.triu_indices(max_size, max_size, offset=1)
|
| 568 |
+
|
| 569 |
+
bond_types = torch.ones(max_size, max_size, dtype=INT_TYPE) * self.bidx
|
| 570 |
+
row, col = ligand['bonds']
|
| 571 |
+
bond_types[row, col] = ligand['bond_one_hot'].argmax(dim=1)
|
| 572 |
+
new_row, new_col = new_bonds
|
| 573 |
+
bond_types = bond_types[new_row, new_col]
|
| 574 |
+
|
| 575 |
+
ligand['bonds'] = new_bonds
|
| 576 |
+
ligand['bond_one_hot'] = F.one_hot(bond_types, num_classes=len(self.bond_encoder)).to(ligand['bond_one_hot'].dtype)
|
| 577 |
+
ligand['n_bonds'] = len(ligand['bond_one_hot'])
|
| 578 |
+
|
| 579 |
+
return ligand
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
class AppendVirtualNodesInCoM:
|
| 583 |
+
def __init__(self, atom_encoder, bond_encoder, add_min=0, add_max=10):
|
| 584 |
+
self.atom_encoder = atom_encoder
|
| 585 |
+
self.bond_encoder = bond_encoder
|
| 586 |
+
self.vidx = atom_encoder['NOATOM']
|
| 587 |
+
self.bidx = bond_encoder['NOBOND']
|
| 588 |
+
self.add_min = add_min
|
| 589 |
+
self.add_max = add_max
|
| 590 |
+
|
| 591 |
+
def __call__(self, ligand):
|
| 592 |
+
|
| 593 |
+
n_virt = random.randint(self.add_min, self.add_max)
|
| 594 |
+
|
| 595 |
+
# all virtual coordinates in the CoM
|
| 596 |
+
virt_coords = ligand['x'].mean(0, keepdim=True).repeat(n_virt, 1)
|
| 597 |
+
|
| 598 |
+
# insert virtual atom column
|
| 599 |
+
virt_one_hot = F.one_hot(torch.ones(n_virt, dtype=torch.int64) * self.vidx, num_classes=len(self.atom_encoder))
|
| 600 |
+
virt_mask = torch.cat([torch.zeros(ligand['size'], dtype=bool), torch.ones(n_virt, dtype=bool)])
|
| 601 |
+
|
| 602 |
+
ligand['x'] = torch.cat([ligand['x'], virt_coords])
|
| 603 |
+
ligand['one_hot'] = torch.cat(([ligand['one_hot'], virt_one_hot]))
|
| 604 |
+
ligand['virtual_mask'] = virt_mask
|
| 605 |
+
ligand['size'] = len(ligand['x'])
|
| 606 |
+
|
| 607 |
+
# Bonds
|
| 608 |
+
new_bonds = torch.triu_indices(ligand['size'], ligand['size'], offset=1)
|
| 609 |
+
|
| 610 |
+
bond_types = torch.ones(ligand['size'], ligand['size'], dtype=INT_TYPE) * self.bidx
|
| 611 |
+
row, col = ligand['bonds']
|
| 612 |
+
bond_types[row, col] = ligand['bond_one_hot'].argmax(dim=1)
|
| 613 |
+
new_row, new_col = new_bonds
|
| 614 |
+
bond_types = bond_types[new_row, new_col]
|
| 615 |
+
|
| 616 |
+
ligand['bonds'] = new_bonds
|
| 617 |
+
ligand['bond_one_hot'] = F.one_hot(bond_types, num_classes=len(self.bond_encoder)).to(ligand['bond_one_hot'].dtype)
|
| 618 |
+
ligand['n_bonds'] = len(ligand['bond_one_hot'])
|
| 619 |
+
|
| 620 |
+
return ligand
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def rdmol_to_smiles(rdmol):
|
| 624 |
+
mol = Chem.Mol(rdmol)
|
| 625 |
+
Chem.RemoveStereochemistry(mol)
|
| 626 |
+
mol = Chem.RemoveHs(mol)
|
| 627 |
+
return Chem.MolToSmiles(mol)
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
def get_n_nodes(lig_positions, pocket_positions, smooth_sigma=None):
|
| 631 |
+
# Joint distribution of ligand's and pocket's number of nodes
|
| 632 |
+
n_nodes_lig = [len(x) for x in lig_positions]
|
| 633 |
+
n_nodes_pocket = [len(x) for x in pocket_positions]
|
| 634 |
+
|
| 635 |
+
joint_histogram = np.zeros((np.max(n_nodes_lig) + 1,
|
| 636 |
+
np.max(n_nodes_pocket) + 1))
|
| 637 |
+
|
| 638 |
+
for nlig, npocket in zip(n_nodes_lig, n_nodes_pocket):
|
| 639 |
+
joint_histogram[nlig, npocket] += 1
|
| 640 |
+
|
| 641 |
+
print(f'Original histogram: {np.count_nonzero(joint_histogram)}/'
|
| 642 |
+
f'{joint_histogram.shape[0] * joint_histogram.shape[1]} bins filled')
|
| 643 |
+
|
| 644 |
+
# Smooth the histogram
|
| 645 |
+
if smooth_sigma is not None:
|
| 646 |
+
filtered_histogram = gaussian_filter(
|
| 647 |
+
joint_histogram, sigma=smooth_sigma, order=0, mode='constant',
|
| 648 |
+
cval=0.0, truncate=4.0)
|
| 649 |
+
|
| 650 |
+
print(f'Smoothed histogram: {np.count_nonzero(filtered_histogram)}/'
|
| 651 |
+
f'{filtered_histogram.shape[0] * filtered_histogram.shape[1]} bins filled')
|
| 652 |
+
|
| 653 |
+
joint_histogram = filtered_histogram
|
| 654 |
+
|
| 655 |
+
return joint_histogram
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
# def get_type_histograms(lig_one_hot, pocket_one_hot, lig_encoder, pocket_encoder):
|
| 659 |
+
#
|
| 660 |
+
# lig_one_hot = np.concatenate(lig_one_hot, axis=0)
|
| 661 |
+
# pocket_one_hot = np.concatenate(pocket_one_hot, axis=0)
|
| 662 |
+
#
|
| 663 |
+
# atom_decoder = list(lig_encoder.keys())
|
| 664 |
+
# lig_counts = {k: 0 for k in lig_encoder.keys()}
|
| 665 |
+
# for a in [atom_decoder[x] for x in lig_one_hot.argmax(1)]:
|
| 666 |
+
# lig_counts[a] += 1
|
| 667 |
+
#
|
| 668 |
+
# aa_decoder = list(pocket_encoder.keys())
|
| 669 |
+
# pocket_counts = {k: 0 for k in pocket_encoder.keys()}
|
| 670 |
+
# for r in [aa_decoder[x] for x in pocket_one_hot.argmax(1)]:
|
| 671 |
+
# pocket_counts[r] += 1
|
| 672 |
+
#
|
| 673 |
+
# return lig_counts, pocket_counts
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def get_type_histogram(one_hot, type_encoder):
|
| 677 |
+
|
| 678 |
+
one_hot = np.concatenate(one_hot, axis=0)
|
| 679 |
+
|
| 680 |
+
decoder = list(type_encoder.keys())
|
| 681 |
+
counts = {k: 0 for k in type_encoder.keys()}
|
| 682 |
+
for a in [decoder[x] for x in one_hot.argmax(1)]:
|
| 683 |
+
counts[a] += 1
|
| 684 |
+
|
| 685 |
+
return counts
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
def get_residue_with_resi(pdb_chain, resi):
|
| 689 |
+
res = [x for x in pdb_chain.get_residues() if x.id[1] == resi]
|
| 690 |
+
assert len(res) == 1
|
| 691 |
+
return res[0]
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
def get_pocket_from_ligand(pdb_model, ligand, dist_cutoff=8.0):
|
| 695 |
+
|
| 696 |
+
if ligand.endswith(".sdf"):
|
| 697 |
+
# ligand as sdf file
|
| 698 |
+
rdmol = Chem.SDMolSupplier(str(ligand))[0]
|
| 699 |
+
ligand_coords = torch.from_numpy(rdmol.GetConformer().GetPositions()).float()
|
| 700 |
+
resi = None
|
| 701 |
+
else:
|
| 702 |
+
# ligand contained in PDB; given in <chain>:<resi> format
|
| 703 |
+
chain, resi = ligand.split(':')
|
| 704 |
+
ligand = get_residue_with_resi(pdb_model[chain], int(resi))
|
| 705 |
+
ligand_coords = torch.from_numpy(
|
| 706 |
+
np.array([a.get_coord() for a in ligand.get_atoms()]))
|
| 707 |
+
|
| 708 |
+
pocket_residues = []
|
| 709 |
+
for residue in pdb_model.get_residues():
|
| 710 |
+
if residue.id[1] == resi:
|
| 711 |
+
continue # skip ligand itself
|
| 712 |
+
|
| 713 |
+
res_coords = torch.from_numpy(
|
| 714 |
+
np.array([a.get_coord() for a in residue.get_atoms()]))
|
| 715 |
+
if is_aa(residue.get_resname(), standard=True) \
|
| 716 |
+
and torch.cdist(res_coords, ligand_coords).min() < dist_cutoff:
|
| 717 |
+
pocket_residues.append(residue)
|
| 718 |
+
|
| 719 |
+
return pocket_residues
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
def encode_residues(biopython_residues, type_encoder, level='atom',
|
| 723 |
+
remove_H=True):
|
| 724 |
+
assert level in {'atom', 'residue'}
|
| 725 |
+
|
| 726 |
+
if level == 'atom':
|
| 727 |
+
entities = [a for res in biopython_residues for a in res.get_atoms()
|
| 728 |
+
if (a.element != 'H' or not remove_H)]
|
| 729 |
+
types = [a.element.capitalize() for a in entities]
|
| 730 |
+
else:
|
| 731 |
+
entities = [res['CA'] for res in biopython_residues]
|
| 732 |
+
types = [protein_letters_3to1[res.get_resname()] for res in biopython_residues]
|
| 733 |
+
|
| 734 |
+
coord = torch.tensor(np.stack([e.get_coord() for e in entities]))
|
| 735 |
+
one_hot = F.one_hot(torch.tensor([type_encoder[t] for t in types]),
|
| 736 |
+
num_classes=len(type_encoder))
|
| 737 |
+
|
| 738 |
+
return coord, one_hot
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
def center_data(ligand, pocket):
|
| 742 |
+
if pocket['x'].numel() > 0:
|
| 743 |
+
pocket_com = pocket.center()
|
| 744 |
+
else:
|
| 745 |
+
pocket_com = scatter_mean(ligand['x'], ligand['mask'], dim=0)
|
| 746 |
+
|
| 747 |
+
ligand['x'] = ligand['x'] - pocket_com[ligand['mask']]
|
| 748 |
+
return ligand, pocket
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
def get_bb_transform(n_xyz, ca_xyz, c_xyz):
|
| 752 |
+
"""
|
| 753 |
+
Compute translation and rotation of the canoncical backbone frame (triangle N-Ca-C) from a position with
|
| 754 |
+
Ca at the origin, N on the x-axis and C in the xy-plane to the global position of the backbone frame
|
| 755 |
+
|
| 756 |
+
Args:
|
| 757 |
+
n_xyz: (n, 3)
|
| 758 |
+
ca_xyz: (n, 3)
|
| 759 |
+
c_xyz: (n, 3)
|
| 760 |
+
|
| 761 |
+
Returns:
|
| 762 |
+
axis-angle representation of the rotation, shape (n, 3) # rotation matrix of shape (n, 3, 3)
|
| 763 |
+
translation vector of shape (n, 3)
|
| 764 |
+
"""
|
| 765 |
+
|
| 766 |
+
def rotation_matrix(angle, axis):
|
| 767 |
+
axis_mapping = {'x': 0, 'y': 1, 'z': 2}
|
| 768 |
+
axis = axis_mapping[axis]
|
| 769 |
+
vector = torch.zeros(len(angle), 3)
|
| 770 |
+
vector[:, axis] = 1
|
| 771 |
+
# return axis_angle_to_matrix(angle * vector)
|
| 772 |
+
return so3.matrix_from_rotation_vector(angle.view(-1, 1) * vector)
|
| 773 |
+
|
| 774 |
+
translation = ca_xyz
|
| 775 |
+
n_xyz = n_xyz - translation
|
| 776 |
+
c_xyz = c_xyz - translation
|
| 777 |
+
|
| 778 |
+
# Find rotation matrix that aligns the coordinate systems
|
| 779 |
+
|
| 780 |
+
# rotate around y-axis to move N into the xy-plane
|
| 781 |
+
theta_y = torch.arctan2(n_xyz[:, 2], -n_xyz[:, 0])
|
| 782 |
+
Ry = rotation_matrix(theta_y, 'y')
|
| 783 |
+
Ry = Ry.transpose(2, 1)
|
| 784 |
+
n_xyz = torch.einsum('noi,ni->no', Ry, n_xyz)
|
| 785 |
+
|
| 786 |
+
# rotate around z-axis to move N onto the x-axis
|
| 787 |
+
theta_z = torch.arctan2(n_xyz[:, 1], n_xyz[:, 0])
|
| 788 |
+
Rz = rotation_matrix(theta_z, 'z')
|
| 789 |
+
Rz = Rz.transpose(2, 1)
|
| 790 |
+
# print(torch.einsum('noi,ni->no', Rz, n_xyz))
|
| 791 |
+
|
| 792 |
+
# n_xyz = torch.einsum('noi,ni->no', Rz.transpose(0, 2, 1), n_xyz)
|
| 793 |
+
|
| 794 |
+
# rotate around x-axis to move C into the xy-plane
|
| 795 |
+
c_xyz = torch.einsum('noj,nji,ni->no', Rz, Ry, c_xyz)
|
| 796 |
+
theta_x = torch.arctan2(c_xyz[:, 2], c_xyz[:, 1])
|
| 797 |
+
Rx = rotation_matrix(theta_x, 'x')
|
| 798 |
+
Rx = Rx.transpose(2, 1)
|
| 799 |
+
# print(torch.einsum('noi,ni->no', Rx, c_xyz))
|
| 800 |
+
|
| 801 |
+
# Final rotation matrix
|
| 802 |
+
Ry = Ry.transpose(2, 1)
|
| 803 |
+
Rz = Rz.transpose(2, 1)
|
| 804 |
+
Rx = Rx.transpose(2, 1)
|
| 805 |
+
R = torch.einsum('nok,nkj,nji->noi', Ry, Rz, Rx)
|
| 806 |
+
|
| 807 |
+
# return R, translation
|
| 808 |
+
# return matrix_to_axis_angle(R), translation
|
| 809 |
+
return so3.rotation_vector_from_matrix(R), translation
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
class Residues(TensorDict):
|
| 813 |
+
"""
|
| 814 |
+
Dictionary-like container for residues that supports some basic transformations.
|
| 815 |
+
"""
|
| 816 |
+
|
| 817 |
+
# all keys
|
| 818 |
+
KEYS = {'x', 'one_hot', 'bonds', 'bond_one_hot', 'v', 'nma_vec', 'fixed_coord',
|
| 819 |
+
'atom_mask', 'nerf_indices', 'length', 'theta', 'chi', 'ddihedral',
|
| 820 |
+
'chi_indices', 'axis_angle', 'mask', 'bond_mask'}
|
| 821 |
+
|
| 822 |
+
# coordinate-type values, shape (..., 3)
|
| 823 |
+
COORD_KEYS = {'x', 'fixed_coord'}
|
| 824 |
+
|
| 825 |
+
# vector-type values, shape (n_residues, n_feat, 3)
|
| 826 |
+
VECTOR_KEYS = {'v', 'nma_vec'}
|
| 827 |
+
|
| 828 |
+
# properties that change if the side chains and/or backbones are updated
|
| 829 |
+
MUTABLE_PROPS_SS_AND_BB = {'v'}
|
| 830 |
+
|
| 831 |
+
# properties that only change if the side chains are updated
|
| 832 |
+
MUTABLE_PROPS_SS = {'chi'}
|
| 833 |
+
|
| 834 |
+
# properties that only change if the backbones are updated
|
| 835 |
+
MUTABLE_PROPS_BB = {'x', 'fixed_coord', 'axis_angle', 'nma_vec'}
|
| 836 |
+
|
| 837 |
+
# properties that remain fixed in all cases
|
| 838 |
+
IMMUTABLE_PROPS = {'mask', 'one_hot', 'bonds', 'bond_one_hot', 'bond_mask',
|
| 839 |
+
'atom_mask', 'nerf_indices', 'length', 'theta',
|
| 840 |
+
'ddihedral', 'chi_indices', 'name', 'size', 'n_bonds'}
|
| 841 |
+
|
| 842 |
+
def copy(self):
|
| 843 |
+
data = super().copy()
|
| 844 |
+
return Residues(**data)
|
| 845 |
+
|
| 846 |
+
def deepcopy(self):
|
| 847 |
+
data = {k: v.clone() if torch.is_tensor(v) else deepcopy(v)
|
| 848 |
+
for k, v in self.items()}
|
| 849 |
+
return Residues(**data)
|
| 850 |
+
|
| 851 |
+
def center(self):
|
| 852 |
+
com = scatter_mean(self['x'], self['mask'], dim=0)
|
| 853 |
+
self['x'] = self['x'] - com[self['mask']]
|
| 854 |
+
self['fixed_coord'] = self['fixed_coord'] - com[self['mask']].unsqueeze(1)
|
| 855 |
+
return com
|
| 856 |
+
|
| 857 |
+
def set_empty_v(self):
|
| 858 |
+
self['v'] = torch.tensor([], device=self['x'].device)
|
| 859 |
+
|
| 860 |
+
@torch.no_grad()
|
| 861 |
+
def set_chi(self, chi_angles):
|
| 862 |
+
self['chi'][:, :5] = chi_angles
|
| 863 |
+
nerf_params = {k: self[k] for k in ['fixed_coord', 'atom_mask',
|
| 864 |
+
'nerf_indices', 'length', 'theta',
|
| 865 |
+
'chi', 'ddihedral', 'chi_indices']}
|
| 866 |
+
self['v'] = ic_to_coords(**nerf_params) - self['x'].unsqueeze(1)
|
| 867 |
+
|
| 868 |
+
@torch.no_grad()
|
| 869 |
+
def set_frame(self, new_ca_coord, new_axis_angle):
|
| 870 |
+
bb_coord = self['fixed_coord']
|
| 871 |
+
bb_coord = bb_coord - self['x'].unsqueeze(1)
|
| 872 |
+
rotmat_before = so3.matrix_from_rotation_vector(self['axis_angle'])
|
| 873 |
+
rotmat_after = so3.matrix_from_rotation_vector(new_axis_angle)
|
| 874 |
+
rotmat_diff = rotmat_after @ rotmat_before.transpose(-1, -2)
|
| 875 |
+
bb_coord = torch.einsum('boi,bai->bao', rotmat_diff, bb_coord)
|
| 876 |
+
bb_coord = bb_coord + new_ca_coord.unsqueeze(1)
|
| 877 |
+
|
| 878 |
+
self['x'] = new_ca_coord
|
| 879 |
+
self['axis_angle'] = new_axis_angle
|
| 880 |
+
self['fixed_coord'] = bb_coord
|
| 881 |
+
self['v'] = torch.einsum('boi,bai->bao', rotmat_diff, self['v'])
|
| 882 |
+
|
| 883 |
+
@staticmethod
|
| 884 |
+
def empty(device):
|
| 885 |
+
return Residues(
|
| 886 |
+
x=torch.zeros(1, 3, device=device).float(),
|
| 887 |
+
mask=torch.zeros(1, 1, device=device).long(),
|
| 888 |
+
size=torch.zeros(1, device=device).long(),
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
def randomize_tensors(tensor_dict, exclude_keys=None):
|
| 893 |
+
"""Replace tensors with random tensors with the same shape."""
|
| 894 |
+
exclude_keys = set() if exclude_keys is None else set(exclude_keys)
|
| 895 |
+
for k, v in tensor_dict.items():
|
| 896 |
+
if isinstance(v, torch.Tensor) and k not in exclude_keys:
|
| 897 |
+
if torch.is_floating_point(v):
|
| 898 |
+
tensor_dict[k] = torch.randn_like(v)
|
| 899 |
+
else:
|
| 900 |
+
tensor_dict[k] = torch.randint_like(v, low=-42, high=42)
|
| 901 |
+
return tensor_dict
|
src/data/dataset.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import random
|
| 3 |
+
import warnings
|
| 4 |
+
import torch
|
| 5 |
+
import webdataset as wds
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
|
| 10 |
+
from src.data.data_utils import TensorDict, collate_entity
|
| 11 |
+
from src.constants import WEBDATASET_SHARD_SIZE, WEBDATASET_VAL_SIZE
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ProcessedLigandPocketDataset(Dataset):
|
| 15 |
+
def __init__(self, pt_path, ligand_transform=None, pocket_transform=None,
|
| 16 |
+
catch_errors=False):
|
| 17 |
+
|
| 18 |
+
self.ligand_transform = ligand_transform
|
| 19 |
+
self.pocket_transform = pocket_transform
|
| 20 |
+
self.catch_errors = catch_errors
|
| 21 |
+
self.pt_path = pt_path
|
| 22 |
+
|
| 23 |
+
self.data = torch.load(pt_path)
|
| 24 |
+
|
| 25 |
+
# add number of nodes for convenience
|
| 26 |
+
for entity in ['ligands', 'pockets']:
|
| 27 |
+
self.data[entity]['size'] = torch.tensor([len(x) for x in self.data[entity]['x']])
|
| 28 |
+
self.data[entity]['n_bonds'] = torch.tensor([len(x) for x in self.data[entity]['bond_one_hot']])
|
| 29 |
+
|
| 30 |
+
def __len__(self):
|
| 31 |
+
return len(self.data['ligands']['name'])
|
| 32 |
+
|
| 33 |
+
def __getitem__(self, idx):
|
| 34 |
+
data = {}
|
| 35 |
+
data['ligand'] = {key: val[idx] for key, val in self.data['ligands'].items()}
|
| 36 |
+
data['pocket'] = {key: val[idx] for key, val in self.data['pockets'].items()}
|
| 37 |
+
try:
|
| 38 |
+
if self.ligand_transform is not None:
|
| 39 |
+
data['ligand'] = self.ligand_transform(data['ligand'])
|
| 40 |
+
if self.pocket_transform is not None:
|
| 41 |
+
data['pocket'] = self.pocket_transform(data['pocket'])
|
| 42 |
+
except (RuntimeError, ValueError) as e:
|
| 43 |
+
if self.catch_errors:
|
| 44 |
+
warnings.warn(f"{type(e).__name__}('{e}') in data transform. "
|
| 45 |
+
f"Returning random item instead")
|
| 46 |
+
# replace bad item with a random one
|
| 47 |
+
rand_idx = random.randint(0, len(self) - 1)
|
| 48 |
+
return self[rand_idx]
|
| 49 |
+
else:
|
| 50 |
+
raise e
|
| 51 |
+
return data
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def collate_fn(batch_pairs, ligand_transform=None):
|
| 55 |
+
|
| 56 |
+
out = {}
|
| 57 |
+
for entity in ['ligand', 'pocket']:
|
| 58 |
+
batch = [x[entity] for x in batch_pairs]
|
| 59 |
+
|
| 60 |
+
if entity == 'ligand' and ligand_transform is not None:
|
| 61 |
+
max_size = max(x['size'].item() for x in batch)
|
| 62 |
+
# TODO: might have to remove elements from batch if processing fails, warn user in that case
|
| 63 |
+
batch = [ligand_transform(x, max_size=max_size) for x in batch]
|
| 64 |
+
|
| 65 |
+
out[entity] = TensorDict(**collate_entity(batch))
|
| 66 |
+
|
| 67 |
+
return out
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class ClusteredDataset(ProcessedLigandPocketDataset):
|
| 71 |
+
def __init__(self, pt_path, ligand_transform=None, pocket_transform=None,
|
| 72 |
+
catch_errors=False):
|
| 73 |
+
super().__init__(pt_path, ligand_transform, pocket_transform, catch_errors)
|
| 74 |
+
self.clusters = list(self.data['clusters'].values())
|
| 75 |
+
|
| 76 |
+
def __len__(self):
|
| 77 |
+
return len(self.clusters)
|
| 78 |
+
|
| 79 |
+
def __getitem__(self, cidx):
|
| 80 |
+
cluster_inds = self.clusters[cidx]
|
| 81 |
+
# idx = cluster_inds[random.randint(0, len(cluster_inds) - 1)]
|
| 82 |
+
idx = random.choice(cluster_inds)
|
| 83 |
+
return super().__getitem__(idx)
|
| 84 |
+
|
| 85 |
+
class DPODataset(ProcessedLigandPocketDataset):
|
| 86 |
+
def __init__(self, pt_path, ligand_transform=None, pocket_transform=None,
|
| 87 |
+
catch_errors=False):
|
| 88 |
+
self.ligand_transform = ligand_transform
|
| 89 |
+
self.pocket_transform = pocket_transform
|
| 90 |
+
self.catch_errors = catch_errors
|
| 91 |
+
self.pt_path = pt_path
|
| 92 |
+
|
| 93 |
+
self.data = torch.load(pt_path)
|
| 94 |
+
|
| 95 |
+
if not 'pockets' in self.data:
|
| 96 |
+
self.data['pockets'] = self.data['pockets_w']
|
| 97 |
+
if not 'ligands' in self.data:
|
| 98 |
+
self.data['ligands'] = self.data['ligands_w']
|
| 99 |
+
|
| 100 |
+
if (
|
| 101 |
+
len(self.data["ligands"]["name"])
|
| 102 |
+
!= len(self.data["ligands_l"]["name"])
|
| 103 |
+
!= len(self.data["pockets"]["name"])
|
| 104 |
+
):
|
| 105 |
+
raise ValueError(
|
| 106 |
+
"Error while importing DPO Dataset: Number of ligands winning, ligands losing and pockets must be the same"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# add number of nodes for convenience
|
| 110 |
+
for entity in ['ligands', 'ligands_l', 'pockets']:
|
| 111 |
+
self.data[entity]['size'] = torch.tensor([len(x) for x in self.data[entity]['x']])
|
| 112 |
+
self.data[entity]['n_bonds'] = torch.tensor([len(x) for x in self.data[entity]['bond_one_hot']])
|
| 113 |
+
|
| 114 |
+
def __len__(self):
|
| 115 |
+
return len(self.data["ligands"]["name"])
|
| 116 |
+
|
| 117 |
+
def __getitem__(self, idx):
|
| 118 |
+
data = {}
|
| 119 |
+
data['ligand'] = {key: val[idx] for key, val in self.data['ligands'].items()}
|
| 120 |
+
data['ligand_l'] = {key: val[idx] for key, val in self.data['ligands_l'].items()}
|
| 121 |
+
data['pocket'] = {key: val[idx] for key, val in self.data['pockets'].items()}
|
| 122 |
+
try:
|
| 123 |
+
if self.ligand_transform is not None:
|
| 124 |
+
data['ligand'] = self.ligand_transform(data['ligand'])
|
| 125 |
+
data['ligand_l'] = self.ligand_transform(data['ligand_l'])
|
| 126 |
+
if self.pocket_transform is not None:
|
| 127 |
+
data['pocket'] = self.pocket_transform(data['pocket'])
|
| 128 |
+
except (RuntimeError, ValueError) as e:
|
| 129 |
+
if self.catch_errors:
|
| 130 |
+
warnings.warn(f"{type(e).__name__}('{e}') in data transform. "
|
| 131 |
+
f"Returning random item instead")
|
| 132 |
+
# replace bad item with a random one
|
| 133 |
+
rand_idx = random.randint(0, len(self) - 1)
|
| 134 |
+
return self[rand_idx]
|
| 135 |
+
else:
|
| 136 |
+
raise e
|
| 137 |
+
return data
|
| 138 |
+
|
| 139 |
+
@staticmethod
|
| 140 |
+
def collate_fn(batch_pairs, ligand_transform=None):
|
| 141 |
+
|
| 142 |
+
out = {}
|
| 143 |
+
for entity in ['ligand', 'ligand_l', 'pocket']:
|
| 144 |
+
batch = [x[entity] for x in batch_pairs]
|
| 145 |
+
|
| 146 |
+
if entity in ['ligand', 'ligand_l'] and ligand_transform is not None:
|
| 147 |
+
max_size = max(x['size'].item() for x in batch)
|
| 148 |
+
batch = [ligand_transform(x, max_size=max_size) for x in batch]
|
| 149 |
+
|
| 150 |
+
out[entity] = TensorDict(**collate_entity(batch))
|
| 151 |
+
|
| 152 |
+
return out
|
| 153 |
+
|
| 154 |
+
##########################################
|
| 155 |
+
############### WebDatasets ##############
|
| 156 |
+
##########################################
|
| 157 |
+
|
| 158 |
+
class ProteinLigandWebDataset(wds.WebDataset):
|
| 159 |
+
@staticmethod
|
| 160 |
+
def collate_fn(batch_pairs, ligand_transform=None):
|
| 161 |
+
return ProcessedLigandPocketDataset.collate_fn(batch_pairs, ligand_transform)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def wds_decoder(key, value):
|
| 165 |
+
return torch.load(io.BytesIO(value))
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def preprocess_wds_item(data):
|
| 169 |
+
out = {}
|
| 170 |
+
for entity in ['ligand', 'pocket']:
|
| 171 |
+
out[entity] = data['pt'][entity]
|
| 172 |
+
for attr in ['size', 'n_bonds']:
|
| 173 |
+
if torch.is_tensor(out[entity][attr]):
|
| 174 |
+
assert len(out[entity][attr]) == 0
|
| 175 |
+
out[entity][attr] = 0
|
| 176 |
+
|
| 177 |
+
return out
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def get_wds(data_path, stage, ligand_transform=None, pocket_transform=None):
|
| 181 |
+
current_data_dir = Path(data_path, stage)
|
| 182 |
+
shards = sorted(current_data_dir.glob('shard-?????.tar'), key=lambda s: int(s.name.split('-')[-1].split('.')[0]))
|
| 183 |
+
min_shard = min(shards).name.split('-')[-1].split('.')[0]
|
| 184 |
+
max_shard = max(shards).name.split('-')[-1].split('.')[0]
|
| 185 |
+
total_size = (int(max_shard) - int(min_shard) + 1) * WEBDATASET_SHARD_SIZE if stage == 'train' else WEBDATASET_VAL_SIZE
|
| 186 |
+
|
| 187 |
+
url = f'{data_path}/{stage}/shard-{{{min_shard}..{max_shard}}}.tar'
|
| 188 |
+
ligand_transform_wrapper = lambda _data: _data
|
| 189 |
+
pocket_transform_wrapper = lambda _data: _data
|
| 190 |
+
|
| 191 |
+
if ligand_transform is not None:
|
| 192 |
+
def ligand_transform_wrapper(_data):
|
| 193 |
+
_data['pt']['ligand'] = ligand_transform(_data['pt']['ligand'])
|
| 194 |
+
return _data
|
| 195 |
+
|
| 196 |
+
if pocket_transform is not None:
|
| 197 |
+
def pocket_transform_wrapper(_data):
|
| 198 |
+
_data['pt']['pocket'] = pocket_transform(_data['pt']['pocket'])
|
| 199 |
+
return _data
|
| 200 |
+
|
| 201 |
+
return (
|
| 202 |
+
ProteinLigandWebDataset(url, nodesplitter=wds.split_by_node)
|
| 203 |
+
.decode(wds_decoder)
|
| 204 |
+
.map(ligand_transform_wrapper)
|
| 205 |
+
.map(pocket_transform_wrapper)
|
| 206 |
+
.map(preprocess_wds_item)
|
| 207 |
+
.with_length(total_size)
|
| 208 |
+
)
|
src/data/misc.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# From: https://github.com/biopython/biopython/blob/master/Bio/PDB/Polypeptide.py#L128
|
| 2 |
+
|
| 3 |
+
protein_letters_1to3 = {'A': 'ALA', 'C': 'CYS', 'D': 'ASP', 'E': 'GLU', 'F': 'PHE', 'G': 'GLY', 'H': 'HIS', 'I': 'ILE', 'K': 'LYS', 'L': 'LEU', 'M': 'MET', 'N': 'ASN', 'P': 'PRO', 'Q': 'GLN', 'R': 'ARG', 'S': 'SER', 'T': 'THR', 'V': 'VAL', 'W': 'TRP', 'Y': 'TYR'}
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
protein_letters_3to1 = {'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'}
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
protein_letters_3to1_extended = {'A5N': 'N', 'A8E': 'V', 'A9D': 'S', 'AA3': 'A', 'AA4': 'A', 'AAR': 'R', 'ABA': 'A', 'ACL': 'R', 'AEA': 'C', 'AEI': 'D', 'AFA': 'N', 'AGM': 'R', 'AGQ': 'Y', 'AGT': 'C', 'AHB': 'N', 'AHL': 'R', 'AHO': 'A', 'AHP': 'A', 'AIB': 'A', 'AKL': 'D', 'AKZ': 'D', 'ALA': 'A', 'ALC': 'A', 'ALM': 'A', 'ALN': 'A', 'ALO': 'T', 'ALS': 'A', 'ALT': 'A', 'ALV': 'A', 'ALY': 'K', 'AME': 'M', 'AN6': 'L', 'AN8': 'A', 'API': 'K', 'APK': 'K', 'AR2': 'R', 'AR4': 'E', 'AR7': 'R', 'ARG': 'R', 'ARM': 'R', 'ARO': 'R', 'AS7': 'N', 'ASA': 'D', 'ASB': 'D', 'ASI': 'D', 'ASK': 'D', 'ASL': 'D', 'ASN': 'N', 'ASP': 'D', 'ASQ': 'D', 'AYA': 'A', 'AZH': 'A', 'AZK': 'K', 'AZS': 'S', 'AZY': 'Y', 'AVJ': 'H', 'A30': 'Y', 'A3U': 'F', 'ECC': 'Q', 'ECX': 'C', 'EFC': 'C', 'EHP': 'F', 'ELY': 'K', 'EME': 'E', 'EPM': 'M', 'EPQ': 'Q', 'ESB': 'Y', 'ESC': 'M', 'EXY': 'L', 'EXA': 'K', 'E0Y': 'P', 'E9V': 'H', 'E9M': 'W', 'EJA': 'C', 'EUP': 'T', 'EZY': 'G', 'E9C': 'Y', 'EW6': 'S', 'EXL': 'W', 'I2M': 'I', 'I4G': 'G', 'I58': 'K', 'IAM': 'A', 'IAR': 'R', 'ICY': 'C', 'IEL': 'K', 'IGL': 'G', 'IIL': 'I', 'ILE': 'I', 'ILG': 'E', 'ILM': 'I', 'ILX': 'I', 'ILY': 'K', 'IML': 'I', 'IOR': 'R', 'IPG': 'G', 'IT1': 'K', 'IYR': 'Y', 'IZO': 'M', 'IC0': 'G', 'M0H': 'C', 'M2L': 'K', 'M2S': 'M', 'M30': 'G', 'M3L': 'K', 'M3R': 'K', 'MA ': 'A', 'MAA': 'A', 'MAI': 'R', 'MBQ': 'Y', 'MC1': 'S', 'MCL': 'K', 'MCS': 'C', 'MD3': 'C', 'MD5': 'C', 'MD6': 'G', 'MDF': 'Y', 'ME0': 'M', 'MEA': 'F', 'MEG': 'E', 'MEN': 'N', 'MEQ': 'Q', 'MET': 'M', 'MEU': 'G', 'MFN': 'E', 'MGG': 'R', 'MGN': 'Q', 'MGY': 'G', 'MH1': 'H', 'MH6': 'S', 'MHL': 'L', 'MHO': 'M', 'MHS': 'H', 'MHU': 'F', 'MIR': 'S', 'MIS': 'S', 'MK8': 'L', 'ML3': 'K', 'MLE': 'L', 'MLL': 'L', 'MLY': 'K', 'MLZ': 'K', 'MME': 'M', 'MMO': 'R', 'MNL': 'L', 'MNV': 'V', 'MP8': 'P', 'MPQ': 'G', 'MSA': 'G', 'MSE': 'M', 'MSL': 'M', 'MSO': 'M', 'MT2': 'M', 'MTY': 'Y', 'MVA': 'V', 'MYK': 'K', 'MYN': 'R', 'QCS': 'C', 'QIL': 'I', 'QMM': 'Q', 'QPA': 'C', 'QPH': 'F', 'Q3P': 'K', 'QVA': 'C', 'QX7': 'A', 'Q2E': 'W', 'Q75': 'M', 'Q78': 'F', 'QM8': 'L', 'QMB': 'A', 'QNQ': 'C', 'QNT': 'C', 'QNW': 'C', 'QO2': 'C', 'QO5': 'C', 'QO8': 'C', 'QQ8': 'Q', 'U2X': 'Y', 'U3X': 'F', 'UF0': 'S', 'UGY': 'G', 'UM1': 'A', 'UM2': 'A', 'UMA': 'A', 'UQK': 'A', 'UX8': 'W', 'UXQ': 'F', 'YCM': 'C', 'YOF': 'Y', 'YPR': 'P', 'YPZ': 'Y', 'YTH': 'T', 'Y1V': 'L', 'Y57': 'K', 'YHA': 'K', '200': 'F', '23F': 'F', '23P': 'A', '26B': 'T', '28X': 'T', '2AG': 'A', '2CO': 'C', '2FM': 'M', '2GX': 'F', '2HF': 'H', '2JG': 'S', '2KK': 'K', '2KP': 'K', '2LT': 'Y', '2LU': 'L', '2ML': 'L', '2MR': 'R', '2MT': 'P', '2OR': 'R', '2P0': 'P', '2QZ': 'T', '2R3': 'Y', '2RA': 'A', '2RX': 'S', '2SO': 'H', '2TY': 'Y', '2VA': 'V', '2XA': 'C', '2ZC': 'S', '6CL': 'K', '6CW': 'W', '6GL': 'A', '6HN': 'K', '60F': 'C', '66D': 'I', '6CV': 'A', '6M6': 'C', '6V1': 'C', '6WK': 'C', '6Y9': 'P', '6DN': 'K', 'DA2': 'R', 'DAB': 'A', 'DAH': 'F', 'DBS': 'S', 'DBU': 'T', 'DBY': 'Y', 'DBZ': 'A', 'DC2': 'C', 'DDE': 'H', 'DDZ': 'A', 'DI7': 'Y', 'DHA': 'S', 'DHN': 'V', 'DIR': 'R', 'DLS': 'K', 'DM0': 'K', 'DMH': 'N', 'DMK': 'D', 'DNL': 'K', 'DNP': 'A', 'DNS': 'K', 'DNW': 'A', 'DOH': 'D', 'DON': 'L', 'DP1': 'R', 'DPL': 'P', 'DPP': 'A', 'DPQ': 'Y', 'DYS': 'C', 'D2T': 'D', 'DYA': 'D', 'DJD': 'F', 'DYJ': 'P', 'DV9': 'E', 'H14': 'F', 'H1D': 'M', 'H5M': 'P', 'HAC': 'A', 'HAR': 'R', 'HBN': 'H', 'HCM': 'C', 'HGY': 'G', 'HHI': 'H', 'HIA': 'H', 'HIC': 'H', 'HIP': 'H', 'HIQ': 'H', 'HIS': 'H', 'HL2': 'L', 'HLU': 'L', 'HMR': 'R', 'HNC': 'C', 'HOX': 'F', 'HPC': 'F', 'HPE': 'F', 'HPH': 'F', 'HPQ': 'F', 'HQA': 'A', 'HR7': 'R', 'HRG': 'R', 'HRP': 'W', 'HS8': 'H', 'HS9': 'H', 'HSE': 'S', 'HSK': 'H', 'HSL': 'S', 'HSO': 'H', 'HT7': 'W', 'HTI': 'C', 'HTR': 'W', 'HV5': 'A', 'HVA': 'V', 'HY3': 'P', 'HYI': 'M', 'HYP': 'P', 'HZP': 'P', 'HIX': 'A', 'HSV': 'H', 'HLY': 'K', 'HOO': 'H', 'H7V': 'A', 'L5P': 'K', 'LRK': 'K', 'L3O': 'L', 'LA2': 'K', 'LAA': 'D', 'LAL': 'A', 'LBY': 'K', 'LCK': 'K', 'LCX': 'K', 'LDH': 'K', 'LE1': 'V', 'LED': 'L', 'LEF': 'L', 'LEH': 'L', 'LEM': 'L', 'LEN': 'L', 'LET': 'K', 'LEU': 'L', 'LEX': 'L', 'LGY': 'K', 'LLO': 'K', 'LLP': 'K', 'LLY': 'K', 'LLZ': 'K', 'LME': 'E', 'LMF': 'K', 'LMQ': 'Q', 'LNE': 'L', 'LNM': 'L', 'LP6': 'K', 'LPD': 'P', 'LPG': 'G', 'LPS': 'S', 'LSO': 'K', 'LTR': 'W', 'LVG': 'G', 'LVN': 'V', 'LWY': 'P', 'LYF': 'K', 'LYK': 'K', 'LYM': 'K', 'LYN': 'K', 'LYO': 'K', 'LYP': 'K', 'LYR': 'K', 'LYS': 'K', 'LYU': 'K', 'LYX': 'K', 'LYZ': 'K', 'LAY': 'L', 'LWI': 'F', 'LBZ': 'K', 'P1L': 'C', 'P2Q': 'Y', 'P2Y': 'P', 'P3Q': 'Y', 'PAQ': 'Y', 'PAS': 'D', 'PAT': 'W', 'PBB': 'C', 'PBF': 'F', 'PCA': 'Q', 'PCC': 'P', 'PCS': 'F', 'PE1': 'K', 'PEC': 'C', 'PF5': 'F', 'PFF': 'F', 'PG1': 'S', 'PGY': 'G', 'PHA': 'F', 'PHD': 'D', 'PHE': 'F', 'PHI': 'F', 'PHL': 'F', 'PHM': 'F', 'PKR': 'P', 'PLJ': 'P', 'PM3': 'F', 'POM': 'P', 'PPN': 'F', 'PR3': 'C', 'PR4': 'P', 'PR7': 'P', 'PR9': 'P', 'PRJ': 'P', 'PRK': 'K', 'PRO': 'P', 'PRS': 'P', 'PRV': 'G', 'PSA': 'F', 'PSH': 'H', 'PTH': 'Y', 'PTM': 'Y', 'PTR': 'Y', 'PVH': 'H', 'PXU': 'P', 'PYA': 'A', 'PYH': 'K', 'PYX': 'C', 'PH6': 'P', 'P9S': 'C', 'P5U': 'S', 'POK': 'R', 'T0I': 'Y', 'T11': 'F', 'TAV': 'D', 'TBG': 'V', 'TBM': 'T', 'TCQ': 'Y', 'TCR': 'W', 'TEF': 'F', 'TFQ': 'F', 'TH5': 'T', 'TH6': 'T', 'THC': 'T', 'THR': 'T', 'THZ': 'R', 'TIH': 'A', 'TIS': 'S', 'TLY': 'K', 'TMB': 'T', 'TMD': 'T', 'TNB': 'C', 'TNR': 'S', 'TNY': 'T', 'TOQ': 'W', 'TOX': 'W', 'TPJ': 'P', 'TPK': 'P', 'TPL': 'W', 'TPO': 'T', 'TPQ': 'Y', 'TQI': 'W', 'TQQ': 'W', 'TQZ': 'C', 'TRF': 'W', 'TRG': 'K', 'TRN': 'W', 'TRO': 'W', 'TRP': 'W', 'TRQ': 'W', 'TRW': 'W', 'TRX': 'W', 'TRY': 'W', 'TS9': 'I', 'TSY': 'C', 'TTQ': 'W', 'TTS': 'Y', 'TXY': 'Y', 'TY1': 'Y', 'TY2': 'Y', 'TY3': 'Y', 'TY5': 'Y', 'TY8': 'Y', 'TY9': 'Y', 'TYB': 'Y', 'TYC': 'Y', 'TYE': 'Y', 'TYI': 'Y', 'TYJ': 'Y', 'TYN': 'Y', 'TYO': 'Y', 'TYQ': 'Y', 'TYR': 'Y', 'TYS': 'Y', 'TYT': 'Y', 'TYW': 'Y', 'TYY': 'Y', 'T8L': 'T', 'T9E': 'T', 'TNQ': 'W', 'TSQ': 'F', 'TGH': 'W', 'X2W': 'E', 'XCN': 'C', 'XPR': 'P', 'XSN': 'N', 'XW1': 'A', 'XX1': 'K', 'XYC': 'A', 'XA6': 'F', '11Q': 'P', '11W': 'E', '12L': 'P', '12X': 'P', '12Y': 'P', '143': 'C', '1AC': 'A', '1L1': 'A', '1OP': 'Y', '1PA': 'F', '1PI': 'A', '1TQ': 'W', '1TY': 'Y', '1X6': 'S', '56A': 'H', '5AB': 'A', '5CS': 'C', '5CW': 'W', '5HP': 'E', '5OH': 'A', '5PG': 'G', '51T': 'Y', '54C': 'W', '5CR': 'F', '5CT': 'K', '5FQ': 'A', '5GM': 'I', '5JP': 'S', '5T3': 'K', '5MW': 'K', '5OW': 'K', '5R5': 'S', '5VV': 'N', '5XU': 'A', '55I': 'F', '999': 'D', '9DN': 'N', '9NE': 'E', '9NF': 'F', '9NR': 'R', '9NV': 'V', '9E7': 'K', '9KP': 'K', '9WV': 'A', '9TR': 'K', '9TU': 'K', '9TX': 'K', '9U0': 'K', '9IJ': 'F', 'B1F': 'F', 'B27': 'T', 'B2A': 'A', 'B2F': 'F', 'B2I': 'I', 'B2V': 'V', 'B3A': 'A', 'B3D': 'D', 'B3E': 'E', 'B3K': 'K', 'B3U': 'H', 'B3X': 'N', 'B3Y': 'Y', 'BB6': 'C', 'BB7': 'C', 'BB8': 'F', 'BB9': 'C', 'BBC': 'C', 'BCS': 'C', 'BCX': 'C', 'BFD': 'D', 'BG1': 'S', 'BH2': 'D', 'BHD': 'D', 'BIF': 'F', 'BIU': 'I', 'BL2': 'L', 'BLE': 'L', 'BLY': 'K', 'BMT': 'T', 'BNN': 'F', 'BOR': 'R', 'BP5': 'A', 'BPE': 'C', 'BSE': 'S', 'BTA': 'L', 'BTC': 'C', 'BTK': 'K', 'BTR': 'W', 'BUC': 'C', 'BUG': 'V', 'BYR': 'Y', 'BWV': 'R', 'BWB': 'S', 'BXT': 'S', 'F2F': 'F', 'F2Y': 'Y', 'FAK': 'K', 'FB5': 'A', 'FB6': 'A', 'FC0': 'F', 'FCL': 'F', 'FDL': 'K', 'FFM': 'C', 'FGL': 'G', 'FGP': 'S', 'FH7': 'K', 'FHL': 'K', 'FHO': 'K', 'FIO': 'R', 'FLA': 'A', 'FLE': 'L', 'FLT': 'Y', 'FME': 'M', 'FOE': 'C', 'FP9': 'P', 'FPK': 'P', 'FT6': 'W', 'FTR': 'W', 'FTY': 'Y', 'FVA': 'V', 'FZN': 'K', 'FY3': 'Y', 'F7W': 'W', 'FY2': 'Y', 'FQA': 'K', 'F7Q': 'Y', 'FF9': 'K', 'FL6': 'D', 'JJJ': 'C', 'JJK': 'C', 'JJL': 'C', 'JLP': 'K', 'J3D': 'C', 'J9Y': 'R', 'J8W': 'S', 'JKH': 'P', 'N10': 'S', 'N7P': 'P', 'NA8': 'A', 'NAL': 'A', 'NAM': 'A', 'NBQ': 'Y', 'NC1': 'S', 'NCB': 'A', 'NEM': 'H', 'NEP': 'H', 'NFA': 'F', 'NIY': 'Y', 'NLB': 'L', 'NLE': 'L', 'NLN': 'L', 'NLO': 'L', 'NLP': 'L', 'NLQ': 'Q', 'NLY': 'G', 'NMC': 'G', 'NMM': 'R', 'NNH': 'R', 'NOT': 'L', 'NPH': 'C', 'NPI': 'A', 'NTR': 'Y', 'NTY': 'Y', 'NVA': 'V', 'NWD': 'A', 'NYB': 'C', 'NYS': 'C', 'NZH': 'H', 'N80': 'P', 'NZC': 'T', 'NLW': 'L', 'N0A': 'F', 'N9P': 'A', 'N65': 'K', 'R1A': 'C', 'R4K': 'W', 'RE0': 'W', 'RE3': 'W', 'RGL': 'R', 'RGP': 'E', 'RT0': 'P', 'RVX': 'S', 'RZ4': 'S', 'RPI': 'R', 'RVJ': 'A', 'VAD': 'V', 'VAF': 'V', 'VAH': 'V', 'VAI': 'V', 'VAL': 'V', 'VB1': 'K', 'VH0': 'P', 'VR0': 'R', 'V44': 'C', 'V61': 'F', 'VPV': 'K', 'V5N': 'H', 'V7T': 'K', 'Z01': 'A', 'Z3E': 'T', 'Z70': 'H', 'ZBZ': 'C', 'ZCL': 'F', 'ZU0': 'T', 'ZYJ': 'P', 'ZYK': 'P', 'ZZD': 'C', 'ZZJ': 'A', 'ZIQ': 'W', 'ZPO': 'P', 'ZDJ': 'Y', 'ZT1': 'K', '30V': 'C', '31Q': 'C', '33S': 'F', '33W': 'A', '34E': 'V', '3AH': 'H', '3BY': 'P', '3CF': 'F', '3CT': 'Y', '3GA': 'A', '3GL': 'E', '3MD': 'D', '3MY': 'Y', '3NF': 'Y', '3O3': 'E', '3PX': 'P', '3QN': 'K', '3TT': 'P', '3XH': 'G', '3YM': 'Y', '3WS': 'A', '3WX': 'P', '3X9': 'C', '3ZH': 'H', '7JA': 'I', '73C': 'S', '73N': 'R', '73O': 'Y', '73P': 'K', '74P': 'K', '7N8': 'F', '7O5': 'A', '7XC': 'F', '7ID': 'D', '7OZ': 'A', 'C1S': 'C', 'C1T': 'C', 'C1X': 'K', 'C22': 'A', 'C3Y': 'C', 'C4R': 'C', 'C5C': 'C', 'C6C': 'C', 'CAF': 'C', 'CAS': 'C', 'CAY': 'C', 'CCS': 'C', 'CEA': 'C', 'CGA': 'E', 'CGU': 'E', 'CGV': 'C', 'CHP': 'G', 'CIR': 'R', 'CLE': 'L', 'CLG': 'K', 'CLH': 'K', 'CME': 'C', 'CMH': 'C', 'CML': 'C', 'CMT': 'C', 'CR5': 'G', 'CS0': 'C', 'CS1': 'C', 'CS3': 'C', 'CS4': 'C', 'CSA': 'C', 'CSB': 'C', 'CSD': 'C', 'CSE': 'C', 'CSJ': 'C', 'CSO': 'C', 'CSP': 'C', 'CSR': 'C', 'CSS': 'C', 'CSU': 'C', 'CSW': 'C', 'CSX': 'C', 'CSZ': 'C', 'CTE': 'W', 'CTH': 'T', 'CWD': 'A', 'CWR': 'S', 'CXM': 'M', 'CY0': 'C', 'CY1': 'C', 'CY3': 'C', 'CY4': 'C', 'CYA': 'C', 'CYD': 'C', 'CYF': 'C', 'CYG': 'C', 'CYJ': 'K', 'CYM': 'C', 'CYQ': 'C', 'CYR': 'C', 'CYS': 'C', 'CYW': 'C', 'CZ2': 'C', 'CZZ': 'C', 'CG6': 'C', 'C1J': 'R', 'C4G': 'R', 'C67': 'R', 'C6D': 'R', 'CE7': 'N', 'CZS': 'A', 'G01': 'E', 'G8M': 'E', 'GAU': 'E', 'GEE': 'G', 'GFT': 'S', 'GHC': 'E', 'GHG': 'Q', 'GHW': 'E', 'GL3': 'G', 'GLH': 'Q', 'GLJ': 'E', 'GLK': 'E', 'GLN': 'Q', 'GLQ': 'E', 'GLU': 'E', 'GLY': 'G', 'GLZ': 'G', 'GMA': 'E', 'GME': 'E', 'GNC': 'Q', 'GPL': 'K', 'GSC': 'G', 'GSU': 'E', 'GT9': 'C', 'GVL': 'S', 'G3M': 'R', 'G5G': 'L', 'G1X': 'Y', 'G8X': 'P', 'K1R': 'C', 'KBE': 'K', 'KCX': 'K', 'KFP': 'K', 'KGC': 'K', 'KNB': 'A', 'KOR': 'M', 'KPI': 'K', 'KPY': 'K', 'KST': 'K', 'KYN': 'W', 'KYQ': 'K', 'KCR': 'K', 'KPF': 'K', 'K5L': 'S', 'KEO': 'K', 'KHB': 'K', 'KKD': 'D', 'K5H': 'C', 'K7K': 'S', 'OAR': 'R', 'OAS': 'S', 'OBS': 'K', 'OCS': 'C', 'OCY': 'C', 'OHI': 'H', 'OHS': 'D', 'OLD': 'H', 'OLT': 'T', 'OLZ': 'S', 'OMH': 'S', 'OMT': 'M', 'OMX': 'Y', 'OMY': 'Y', 'ONH': 'A', 'ORN': 'A', 'ORQ': 'R', 'OSE': 'S', 'OTH': 'T', 'OXX': 'D', 'OYL': 'H', 'O7A': 'T', 'O7D': 'W', 'O7G': 'V', 'O2E': 'S', 'O6H': 'W', 'OZW': 'F', 'S12': 'S', 'S1H': 'S', 'S2C': 'C', 'S2P': 'A', 'SAC': 'S', 'SAH': 'C', 'SAR': 'G', 'SBG': 'S', 'SBL': 'S', 'SCH': 'C', 'SCS': 'C', 'SCY': 'C', 'SD4': 'N', 'SDB': 'S', 'SDP': 'S', 'SEB': 'S', 'SEE': 'S', 'SEG': 'A', 'SEL': 'S', 'SEM': 'S', 'SEN': 'S', 'SEP': 'S', 'SER': 'S', 'SET': 'S', 'SGB': 'S', 'SHC': 'C', 'SHP': 'G', 'SHR': 'K', 'SIB': 'C', 'SLL': 'K', 'SLZ': 'K', 'SMC': 'C', 'SME': 'M', 'SMF': 'F', 'SNC': 'C', 'SNN': 'N', 'SOY': 'S', 'SRZ': 'S', 'STY': 'Y', 'SUN': 'S', 'SVA': 'S', 'SVV': 'S', 'SVW': 'S', 'SVX': 'S', 'SVY': 'S', 'SVZ': 'S', 'SXE': 'S', 'SKH': 'K', 'SNM': 'S', 'SNK': 'H', 'SWW': 'S', 'WFP': 'F', 'WLU': 'L', 'WPA': 'F', 'WRP': 'W', 'WVL': 'V', '02K': 'A', '02L': 'N', '02O': 'A', '02Y': 'A', '033': 'V', '037': 'P', '03Y': 'C', '04U': 'P', '04V': 'P', '05N': 'P', '07O': 'C', '0A0': 'D', '0A1': 'Y', '0A2': 'K', '0A8': 'C', '0A9': 'F', '0AA': 'V', '0AB': 'V', '0AC': 'G', '0AF': 'W', '0AG': 'L', '0AH': 'S', '0AK': 'D', '0AR': 'R', '0BN': 'F', '0CS': 'A', '0E5': 'T', '0EA': 'Y', '0FL': 'A', '0LF': 'P', '0NC': 'A', '0PR': 'Y', '0QL': 'C', '0TD': 'D', '0UO': 'W', '0WZ': 'Y', '0X9': 'R', '0Y8': 'P', '4AF': 'F', '4AR': 'R', '4AW': 'W', '4BF': 'F', '4CF': 'F', '4CY': 'M', '4DP': 'W', '4FB': 'P', '4FW': 'W', '4HL': 'Y', '4HT': 'W', '4IN': 'W', '4MM': 'M', '4PH': 'F', '4U7': 'A', '41H': 'F', '41Q': 'N', '42Y': 'S', '432': 'S', '45F': 'P', '4AK': 'K', '4D4': 'R', '4GJ': 'C', '4KY': 'P', '4L0': 'P', '4LZ': 'Y', '4N7': 'P', '4N8': 'P', '4N9': 'P', '4OG': 'W', '4OU': 'F', '4OV': 'S', '4OZ': 'S', '4PQ': 'W', '4SJ': 'F', '4WQ': 'A', '4HH': 'S', '4HJ': 'S', '4J4': 'C', '4J5': 'R', '4II': 'F', '4VI': 'R', '823': 'N', '8SP': 'S', '8AY': 'A'}
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def is_aa(residue, standard=False):
|
| 13 |
+
if not isinstance(residue, str):
|
| 14 |
+
residue = f"{residue.get_resname():<3s}"
|
| 15 |
+
residue = residue.upper()
|
| 16 |
+
if standard:
|
| 17 |
+
return residue in protein_letters_3to1
|
| 18 |
+
else:
|
| 19 |
+
return residue in protein_letters_3to1_extended
|
src/data/molecule_builder.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rdkit import Chem
|
| 2 |
+
|
| 3 |
+
from src import constants
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def remove_dummy_atoms(rdmol, sanitize=False):
|
| 7 |
+
# find exit atoms to be removed
|
| 8 |
+
dummy_inds = []
|
| 9 |
+
for a in rdmol.GetAtoms():
|
| 10 |
+
if a.GetSymbol() == '*':
|
| 11 |
+
dummy_inds.append(a.GetIdx())
|
| 12 |
+
|
| 13 |
+
dummy_inds = sorted(dummy_inds, reverse=True)
|
| 14 |
+
new_mol = Chem.EditableMol(rdmol)
|
| 15 |
+
for idx in dummy_inds:
|
| 16 |
+
new_mol.RemoveAtom(idx)
|
| 17 |
+
new_mol = new_mol.GetMol()
|
| 18 |
+
if sanitize:
|
| 19 |
+
Chem.SanitizeMol(new_mol)
|
| 20 |
+
return new_mol
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def build_molecule(coords, atom_types, bonds=None, bond_types=None,
|
| 24 |
+
atom_props=None, atom_decoder=None, bond_decoder=None):
|
| 25 |
+
"""
|
| 26 |
+
Build RDKit molecule with given bonds
|
| 27 |
+
:param coords: N x 3
|
| 28 |
+
:param atom_types: N
|
| 29 |
+
:param bonds: 2 x N_bonds
|
| 30 |
+
:param bond_types: N_bonds
|
| 31 |
+
:param atom_props: Dict, key: property name, value: list of float values (N,)
|
| 32 |
+
:param atom_decoder: list
|
| 33 |
+
:param bond_decoder: list
|
| 34 |
+
:return: RDKit molecule
|
| 35 |
+
"""
|
| 36 |
+
if atom_decoder is None:
|
| 37 |
+
atom_decoder = constants.atom_decoder
|
| 38 |
+
if bond_decoder is None:
|
| 39 |
+
bond_decoder = constants.bond_decoder
|
| 40 |
+
assert len(coords) == len(atom_types)
|
| 41 |
+
assert bonds is None or bonds.size(1) == len(bond_types)
|
| 42 |
+
|
| 43 |
+
mol = Chem.RWMol()
|
| 44 |
+
for i, atom in enumerate(atom_types):
|
| 45 |
+
element = atom_decoder[atom.item()]
|
| 46 |
+
charge = None
|
| 47 |
+
explicitHs = None
|
| 48 |
+
|
| 49 |
+
if len(element) > 1 and element.endswith('H'):
|
| 50 |
+
explicitHs = 1
|
| 51 |
+
element = element[:-1]
|
| 52 |
+
elif element.endswith('+'):
|
| 53 |
+
charge = 1
|
| 54 |
+
element = element[:-1]
|
| 55 |
+
elif element.endswith('-'):
|
| 56 |
+
charge = -1
|
| 57 |
+
element = element[:-1]
|
| 58 |
+
|
| 59 |
+
if element == 'NOATOM':
|
| 60 |
+
# element = 'Xe' # debug
|
| 61 |
+
element = '*'
|
| 62 |
+
|
| 63 |
+
a = Chem.Atom(element)
|
| 64 |
+
|
| 65 |
+
if explicitHs is not None:
|
| 66 |
+
a.SetNumExplicitHs(explicitHs)
|
| 67 |
+
if charge is not None:
|
| 68 |
+
a.SetFormalCharge(charge)
|
| 69 |
+
|
| 70 |
+
if atom_props is not None:
|
| 71 |
+
for k, vals in atom_props.items():
|
| 72 |
+
a.SetDoubleProp(k, vals[i].item())
|
| 73 |
+
|
| 74 |
+
mol.AddAtom(a)
|
| 75 |
+
|
| 76 |
+
# add coordinates
|
| 77 |
+
conf = Chem.Conformer(mol.GetNumAtoms())
|
| 78 |
+
for i in range(mol.GetNumAtoms()):
|
| 79 |
+
conf.SetAtomPosition(i, (coords[i, 0].item(),
|
| 80 |
+
coords[i, 1].item(),
|
| 81 |
+
coords[i, 2].item()))
|
| 82 |
+
mol.AddConformer(conf)
|
| 83 |
+
|
| 84 |
+
# add bonds
|
| 85 |
+
if bonds is not None:
|
| 86 |
+
for bond, bond_type in zip(bonds.T, bond_types):
|
| 87 |
+
bond_type = bond_decoder[bond_type]
|
| 88 |
+
src = bond[0].item()
|
| 89 |
+
dst = bond[1].item()
|
| 90 |
+
|
| 91 |
+
# try:
|
| 92 |
+
if bond_type == 'NOBOND' or mol.GetAtomWithIdx(src).GetSymbol() == '*' or mol.GetAtomWithIdx(dst).GetSymbol() == '*':
|
| 93 |
+
continue
|
| 94 |
+
# except RuntimeError:
|
| 95 |
+
# from pdb import set_trace; set_trace()
|
| 96 |
+
|
| 97 |
+
if mol.GetBondBetweenAtoms(src, dst) is not None:
|
| 98 |
+
assert mol.GetBondBetweenAtoms(src, dst).GetBondType() == bond_type, \
|
| 99 |
+
"Trying to assign two different types to the same bond."
|
| 100 |
+
continue
|
| 101 |
+
|
| 102 |
+
if bond_type is None or src == dst:
|
| 103 |
+
continue
|
| 104 |
+
mol.AddBond(src, dst, bond_type)
|
| 105 |
+
|
| 106 |
+
mol = remove_dummy_atoms(mol, sanitize=False)
|
| 107 |
+
return mol
|
src/data/nerf.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Natural Extension Reference Frame (NERF)
|
| 3 |
+
|
| 4 |
+
Inspiration for parallel reconstruction:
|
| 5 |
+
https://github.com/EleutherAI/mp_nerf and references therein
|
| 6 |
+
|
| 7 |
+
For atom names, see also:
|
| 8 |
+
https://www.ccpn.ac.uk/manual/v3/NEFAtomNames.html
|
| 9 |
+
|
| 10 |
+
References:
|
| 11 |
+
- https://onlinelibrary.wiley.com/doi/10.1002/jcc.20237 (NERF)
|
| 12 |
+
- https://onlinelibrary.wiley.com/doi/10.1002/jcc.26768 (for code)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import warnings
|
| 16 |
+
import torch
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
from src.data.misc import protein_letters_3to1
|
| 20 |
+
from src.constants import aa_atom_index, aa_atom_mask, aa_nerf_indices, aa_chi_indices, aa_chi_anchor_atom
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# https://github.com/EleutherAI/mp_nerf/blob/master/mp_nerf/utils.py
|
| 24 |
+
def get_dihedral(c1, c2, c3, c4):
|
| 25 |
+
""" Returns the dihedral angle in radians.
|
| 26 |
+
Will use atan2 formula from:
|
| 27 |
+
https://en.wikipedia.org/wiki/Dihedral_angle#In_polymer_physics
|
| 28 |
+
Inputs:
|
| 29 |
+
* c1: (batch, 3) or (3,)
|
| 30 |
+
* c2: (batch, 3) or (3,)
|
| 31 |
+
* c3: (batch, 3) or (3,)
|
| 32 |
+
* c4: (batch, 3) or (3,)
|
| 33 |
+
"""
|
| 34 |
+
u1 = c2 - c1
|
| 35 |
+
u2 = c3 - c2
|
| 36 |
+
u3 = c4 - c3
|
| 37 |
+
|
| 38 |
+
return torch.atan2( ( (torch.norm(u2, dim=-1, keepdim=True) * u1) * torch.cross(u2,u3, dim=-1) ).sum(dim=-1) ,
|
| 39 |
+
( torch.cross(u1,u2, dim=-1) * torch.cross(u2, u3, dim=-1) ).sum(dim=-1) )
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# https://github.com/EleutherAI/mp_nerf/blob/master/mp_nerf/utils.py
|
| 43 |
+
def get_angle(c1, c2, c3):
|
| 44 |
+
""" Returns the angle in radians.
|
| 45 |
+
Inputs:
|
| 46 |
+
* c1: (batch, 3) or (3,)
|
| 47 |
+
* c2: (batch, 3) or (3,)
|
| 48 |
+
* c3: (batch, 3) or (3,)
|
| 49 |
+
"""
|
| 50 |
+
u1 = c2 - c1
|
| 51 |
+
u2 = c3 - c2
|
| 52 |
+
|
| 53 |
+
# dont use acos since norms involved.
|
| 54 |
+
# better use atan2 formula: atan2(cross, dot) from here:
|
| 55 |
+
# https://johnblackburne.blogspot.com/2012/05/angle-between-two-3d-vectors.html
|
| 56 |
+
|
| 57 |
+
# add a minus since we want the angle in reversed order - sidechainnet issues
|
| 58 |
+
return torch.atan2( torch.norm(torch.cross(u1,u2, dim=-1), dim=-1),
|
| 59 |
+
-(u1*u2).sum(dim=-1) )
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_nerf_params(biopython_residue):
|
| 63 |
+
aa = protein_letters_3to1[biopython_residue.get_resname()]
|
| 64 |
+
|
| 65 |
+
# Basic mask and index tensors
|
| 66 |
+
atom_mask = torch.tensor(aa_atom_mask[aa], dtype=bool)
|
| 67 |
+
nerf_indices = torch.tensor(aa_nerf_indices[aa], dtype=int)
|
| 68 |
+
chi_indices = torch.tensor(aa_chi_indices[aa], dtype=int)
|
| 69 |
+
|
| 70 |
+
fixed_coord = torch.zeros((5, 3))
|
| 71 |
+
residue_coords = torch.zeros((14, 3)) # only required to compute internal coordinates during pre-processing
|
| 72 |
+
atom_found = torch.zeros_like(atom_mask)
|
| 73 |
+
for atom in biopython_residue.get_atoms():
|
| 74 |
+
try:
|
| 75 |
+
idx = aa_atom_index[aa][atom.get_name()]
|
| 76 |
+
atom_found[idx] = True
|
| 77 |
+
except KeyError:
|
| 78 |
+
warnings.warn(f"{atom.get_name()} not found")
|
| 79 |
+
continue
|
| 80 |
+
|
| 81 |
+
residue_coords[idx, :] = torch.from_numpy(atom.get_coord())
|
| 82 |
+
|
| 83 |
+
if atom.get_name() in ['N', 'CA', 'C', 'O', 'CB']:
|
| 84 |
+
fixed_coord[idx, :] = torch.from_numpy(atom.get_coord())
|
| 85 |
+
|
| 86 |
+
# Determine chi angles
|
| 87 |
+
chi = torch.zeros(6) # the last chi angle is a dummy and should always be zero
|
| 88 |
+
for chi_idx, anchor in aa_chi_anchor_atom[aa].items():
|
| 89 |
+
idx_a = nerf_indices[anchor, 2]
|
| 90 |
+
idx_b = nerf_indices[anchor, 1]
|
| 91 |
+
idx_c = nerf_indices[anchor, 0]
|
| 92 |
+
|
| 93 |
+
coords_a = residue_coords[idx_a, :]
|
| 94 |
+
coords_b = residue_coords[idx_b, :]
|
| 95 |
+
coords_c = residue_coords[idx_c, :]
|
| 96 |
+
coords_d = residue_coords[anchor, :]
|
| 97 |
+
|
| 98 |
+
chi[chi_idx] = get_dihedral(coords_a, coords_b, coords_c, coords_d)
|
| 99 |
+
|
| 100 |
+
# Compute remaining internal coordinates
|
| 101 |
+
# (parallel version)
|
| 102 |
+
idx_a = nerf_indices[:, 2]
|
| 103 |
+
idx_b = nerf_indices[:, 1]
|
| 104 |
+
idx_c = nerf_indices[:, 0]
|
| 105 |
+
|
| 106 |
+
# update atom mask
|
| 107 |
+
# remove atoms for which one or several parameters are missing/incorrect
|
| 108 |
+
_atom_mask = atom_mask & atom_found & atom_found[idx_a] & atom_found[idx_b] & atom_found[idx_c]
|
| 109 |
+
if not torch.all(_atom_mask == atom_mask):
|
| 110 |
+
warnings.warn("Some atoms are missing for NERF reconstruction")
|
| 111 |
+
atom_mask = _atom_mask
|
| 112 |
+
|
| 113 |
+
coords_a = residue_coords[idx_a]
|
| 114 |
+
coords_b = residue_coords[idx_b]
|
| 115 |
+
coords_c = residue_coords[idx_c]
|
| 116 |
+
coords_d = residue_coords
|
| 117 |
+
|
| 118 |
+
length = torch.norm(coords_d - coords_c, dim=-1)
|
| 119 |
+
theta = get_angle(coords_b, coords_c, coords_d)
|
| 120 |
+
ddihedral = get_dihedral(coords_a, coords_b, coords_c, coords_d)
|
| 121 |
+
|
| 122 |
+
# subtract chi angles from dihedrals
|
| 123 |
+
ddihedral = ddihedral - chi[chi_indices]
|
| 124 |
+
|
| 125 |
+
# # (serial version)
|
| 126 |
+
# length = torch.zeros(14)
|
| 127 |
+
# theta = torch.zeros(14)
|
| 128 |
+
# ddihedral = torch.zeros(14)
|
| 129 |
+
# for i in range(5, 14):
|
| 130 |
+
# if not atom_mask[i]: # atom doesn't exist
|
| 131 |
+
# continue
|
| 132 |
+
|
| 133 |
+
# idx_a = nerf_indices[i, 2]
|
| 134 |
+
# idx_b = nerf_indices[i, 1]
|
| 135 |
+
# idx_c = nerf_indices[i, 0]
|
| 136 |
+
|
| 137 |
+
# coords_a = residue_coords[idx_a]
|
| 138 |
+
# coords_b = residue_coords[idx_b]
|
| 139 |
+
# coords_c = residue_coords[idx_c]
|
| 140 |
+
# coords_d = residue_coords[i]
|
| 141 |
+
|
| 142 |
+
# length[i] = torch.norm(coords_d - coords_c, dim=-1)
|
| 143 |
+
# theta[i] = get_angle(coords_b, coords_c, coords_d)
|
| 144 |
+
# ddihedral[i] = get_dihedral(coords_a, coords_b, coords_c, coords_d)
|
| 145 |
+
|
| 146 |
+
# # subtract chi angles from dihedrals
|
| 147 |
+
# ddihedral[i] = ddihedral[i] - chi[chi_indices[i]]
|
| 148 |
+
|
| 149 |
+
return {
|
| 150 |
+
'fixed_coord': fixed_coord,
|
| 151 |
+
'atom_mask': atom_mask,
|
| 152 |
+
'nerf_indices': nerf_indices,
|
| 153 |
+
'length': length,
|
| 154 |
+
'theta': theta,
|
| 155 |
+
'chi': chi,
|
| 156 |
+
'ddihedral': ddihedral,
|
| 157 |
+
'chi_indices': chi_indices,
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# https://github.com/EleutherAI/mp_nerf/blob/master/mp_nerf/massive_pnerf.py#L38C1-L65C67
|
| 162 |
+
def mp_nerf_torch(a, b, c, l, theta, chi):
|
| 163 |
+
""" Custom Natural extension of Reference Frame.
|
| 164 |
+
Inputs:
|
| 165 |
+
* a: (batch, 3) or (3,). point(s) of the plane, not connected to d
|
| 166 |
+
* b: (batch, 3) or (3,). point(s) of the plane, not connected to d
|
| 167 |
+
* c: (batch, 3) or (3,). point(s) of the plane, connected to d
|
| 168 |
+
* theta: (batch,) or (float). angle(s) between b-c-d
|
| 169 |
+
* chi: (batch,) or float. dihedral angle(s) between the a-b-c and b-c-d planes
|
| 170 |
+
Outputs: d (batch, 3) or (float). the next point in the sequence, linked to c
|
| 171 |
+
"""
|
| 172 |
+
# safety check
|
| 173 |
+
if not ( (-np.pi <= theta) * (theta <= np.pi) ).all().item():
|
| 174 |
+
raise ValueError(f"theta(s) must be in radians and in [-pi, pi]. theta(s) = {theta}")
|
| 175 |
+
# calc vecs
|
| 176 |
+
ba = b-a
|
| 177 |
+
cb = c-b
|
| 178 |
+
# calc rotation matrix. based on plane normals and normalized
|
| 179 |
+
n_plane = torch.cross(ba, cb, dim=-1)
|
| 180 |
+
n_plane_ = torch.cross(n_plane, cb, dim=-1)
|
| 181 |
+
rotate = torch.stack([cb, n_plane_, n_plane], dim=-1)
|
| 182 |
+
rotate /= torch.norm(rotate, dim=-2, keepdim=True)
|
| 183 |
+
# calc proto point, rotate. add (-1 for sidechainnet convention)
|
| 184 |
+
# https://github.com/jonathanking/sidechainnet/issues/14
|
| 185 |
+
d = torch.stack([-torch.cos(theta),
|
| 186 |
+
torch.sin(theta) * torch.cos(chi),
|
| 187 |
+
torch.sin(theta) * torch.sin(chi)], dim=-1).unsqueeze(-1)
|
| 188 |
+
# extend base point, set length
|
| 189 |
+
return c + l.unsqueeze(-1) * torch.matmul(rotate, d).squeeze()
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# inspired by: https://github.com/EleutherAI/mp_nerf/blob/master/mp_nerf/proteins.py#L323C5-L344C65
|
| 193 |
+
def ic_to_coords(fixed_coord, atom_mask, nerf_indices, length, theta, chi, ddihedral, chi_indices):
|
| 194 |
+
"""
|
| 195 |
+
Run NERF in parallel for all residues.
|
| 196 |
+
|
| 197 |
+
:param fixed_coord: (L, 5, 3) coordinates of (N, CA, C, O, CB) atoms, they don't depend on chi angles
|
| 198 |
+
:param atom_mask: (L, 14) indicates whether atom exists in this residue
|
| 199 |
+
:param nerf_indices: (L, 14, 3) indices of the three previous atoms ({c, b, a} for the NERF algorithm)
|
| 200 |
+
:param length: (L, 14) bond length between this and previous atom
|
| 201 |
+
:param theta: (L, 14) angle between this and previous two atoms
|
| 202 |
+
:param chi: (L, 6) values of the 5 rotatable bonds, plus zero in last column
|
| 203 |
+
:param ddihedral: (L, 14) angle offset to which chi is added
|
| 204 |
+
:param chi_indices: (L, 14) indexes into the chi array
|
| 205 |
+
:returns: (L, 14, 3) tensor with all coordinates, non-existing atoms are assigned CA coords
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
if not torch.all(chi[:, 5] == 0):
|
| 209 |
+
chi[:, 5] = 0.0
|
| 210 |
+
warnings.warn("Last column of 'chi' tensor should be zero. Overriding values.")
|
| 211 |
+
assert torch.all(chi[:, 5] == 0)
|
| 212 |
+
|
| 213 |
+
L, device = fixed_coord.size(0), fixed_coord.device
|
| 214 |
+
coords = torch.zeros((L, 14, 3), device=device)
|
| 215 |
+
coords[:, :5, :] = fixed_coord
|
| 216 |
+
|
| 217 |
+
for i in range(5, 14):
|
| 218 |
+
level_mask = atom_mask[:, i]
|
| 219 |
+
# level_mask = torch.ones(len(atom_mask), dtype=bool)
|
| 220 |
+
|
| 221 |
+
length_i = length[level_mask, i]
|
| 222 |
+
theta_i = theta[level_mask, i]
|
| 223 |
+
|
| 224 |
+
# dihedral_i = dihedral[level_mask, i]
|
| 225 |
+
dihedral_i = chi[level_mask, chi_indices[level_mask, i]] + ddihedral[level_mask, i]
|
| 226 |
+
|
| 227 |
+
idx_a = nerf_indices[level_mask, i, 2]
|
| 228 |
+
idx_b = nerf_indices[level_mask, i, 1]
|
| 229 |
+
idx_c = nerf_indices[level_mask, i, 0]
|
| 230 |
+
|
| 231 |
+
coords[level_mask, i] = mp_nerf_torch(coords[level_mask, idx_a],
|
| 232 |
+
coords[level_mask, idx_b],
|
| 233 |
+
coords[level_mask, idx_c],
|
| 234 |
+
length_i,
|
| 235 |
+
theta_i,
|
| 236 |
+
dihedral_i)
|
| 237 |
+
|
| 238 |
+
if coords.isnan().any():
|
| 239 |
+
warnings.warn("Side chain reconstruction error. Removing affected atoms...")
|
| 240 |
+
|
| 241 |
+
# mask out affected atoms
|
| 242 |
+
m, n, _ = torch.where(coords.isnan())
|
| 243 |
+
atom_mask[m, n] = False
|
| 244 |
+
coords[m, n, :] = 0.0
|
| 245 |
+
|
| 246 |
+
# replace non-existing atom coords with CA coords (TODO: don't hard-code CA index)
|
| 247 |
+
coords = atom_mask.unsqueeze(-1) * coords + \
|
| 248 |
+
(~atom_mask.unsqueeze(2)) * coords[:, 1, :].unsqueeze(1)
|
| 249 |
+
|
| 250 |
+
return coords
|
src/data/normal_modes.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
import numpy as np
|
| 3 |
+
import prody
|
| 4 |
+
prody.confProDy(verbosity='none')
|
| 5 |
+
from prody import parsePDB, ANM
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def pdb_to_normal_modes(pdb_file, num_modes=5, nmax=5000):
|
| 9 |
+
"""
|
| 10 |
+
Compute normal modes for a PDB file using an Anisotropic Network Model (ANM)
|
| 11 |
+
http://prody.csb.pitt.edu/tutorials/enm_analysis/anm.html (accessed 01/11/2023)
|
| 12 |
+
"""
|
| 13 |
+
protein = parsePDB(pdb_file, model=1).select('calpha')
|
| 14 |
+
|
| 15 |
+
if len(protein) > nmax:
|
| 16 |
+
warnings.warn("Protein is too big. Returning zeros...")
|
| 17 |
+
eig_vecs = np.zeros((len(protein), 3, num_modes))
|
| 18 |
+
|
| 19 |
+
else:
|
| 20 |
+
# build Hessian
|
| 21 |
+
anm = ANM('ANM analysis')
|
| 22 |
+
anm.buildHessian(protein, cutoff=15.0, gamma=1.0)
|
| 23 |
+
|
| 24 |
+
# calculate normal modes
|
| 25 |
+
anm.calcModes(num_modes, zeros=False)
|
| 26 |
+
|
| 27 |
+
# only use slowest modes
|
| 28 |
+
eig_vecs = anm.getEigvecs() # shape: (num_atoms * 3, num_modes)
|
| 29 |
+
eig_vecs = eig_vecs.reshape(len(protein), 3, num_modes)
|
| 30 |
+
# eig_vals = anm.getEigvals() # shape: (num_modes,)
|
| 31 |
+
|
| 32 |
+
nm_dict = {}
|
| 33 |
+
for atom, nm_vec in zip(protein, eig_vecs):
|
| 34 |
+
chain = atom.getChid()
|
| 35 |
+
resi = atom.getResnum()
|
| 36 |
+
name = atom.getName()
|
| 37 |
+
nm_dict[(chain, resi, name)] = nm_vec.T
|
| 38 |
+
|
| 39 |
+
return nm_dict
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
import argparse
|
| 44 |
+
from pathlib import Path
|
| 45 |
+
import torch
|
| 46 |
+
from tqdm import tqdm
|
| 47 |
+
|
| 48 |
+
parser = argparse.ArgumentParser()
|
| 49 |
+
parser.add_argument('basedir', type=Path)
|
| 50 |
+
parser.add_argument('--outfile', type=Path, default=None)
|
| 51 |
+
args = parser.parse_args()
|
| 52 |
+
|
| 53 |
+
# Read data split
|
| 54 |
+
split_path = Path(args.basedir, 'split_by_name.pt')
|
| 55 |
+
data_split = torch.load(split_path)
|
| 56 |
+
|
| 57 |
+
pockets = [x[0] for split in data_split.values() for x in split]
|
| 58 |
+
|
| 59 |
+
all_normal_modes = {}
|
| 60 |
+
for p in tqdm(pockets):
|
| 61 |
+
pdb_file = Path(args.basedir, 'crossdocked_pocket10', p)
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
nm_dict = pdb_to_normal_modes(str(pdb_file))
|
| 65 |
+
all_normal_modes[p] = nm_dict
|
| 66 |
+
except AttributeError as e:
|
| 67 |
+
warnings.warn(str(e))
|
| 68 |
+
|
| 69 |
+
np.save(args.outfile, all_normal_modes)
|
src/data/postprocessing.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
from rdkit import Chem
|
| 4 |
+
from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule, UFFHasAllMoleculeParams
|
| 5 |
+
|
| 6 |
+
from src.data import sanifix
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def uff_relax(mol, max_iter=200):
|
| 10 |
+
"""
|
| 11 |
+
Uses RDKit's universal force field (UFF) implementation to optimize a
|
| 12 |
+
molecule.
|
| 13 |
+
"""
|
| 14 |
+
if not UFFHasAllMoleculeParams(mol):
|
| 15 |
+
warnings.warn('UFF parameters not available for all atoms. '
|
| 16 |
+
'Returning None.')
|
| 17 |
+
return None
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
more_iterations_required = UFFOptimizeMolecule(mol, maxIters=max_iter)
|
| 21 |
+
if more_iterations_required:
|
| 22 |
+
warnings.warn(f'Maximum number of FF iterations reached. '
|
| 23 |
+
f'Returning molecule after {max_iter} relaxation steps.')
|
| 24 |
+
|
| 25 |
+
except RuntimeError:
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
return mol
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def add_hydrogens(rdmol):
|
| 32 |
+
return Chem.AddHs(rdmol, addCoords=(len(rdmol.GetConformers()) > 0))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_largest_fragment(rdmol):
|
| 36 |
+
mol_frags = Chem.GetMolFrags(rdmol, asMols=True, sanitizeFrags=False)
|
| 37 |
+
largest_frag = max(mol_frags, default=rdmol, key=lambda m: m.GetNumAtoms())
|
| 38 |
+
|
| 39 |
+
# try:
|
| 40 |
+
# Chem.SanitizeMol(largest_frag)
|
| 41 |
+
# except ValueError:
|
| 42 |
+
# return None
|
| 43 |
+
|
| 44 |
+
return largest_frag
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def process_all(rdmol, largest_frag=True, adjust_aromatic_Ns=True, relax_iter=0):
|
| 48 |
+
"""
|
| 49 |
+
Apply all filters and post-processing steps. Returns a new molecule.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
RDKit molecule or None if it does not pass the filters or processing
|
| 53 |
+
fails
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
# Only consider non-trivial molecules
|
| 57 |
+
if rdmol.GetNumAtoms() < 1:
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
# Create a copy
|
| 61 |
+
mol = Chem.Mol(rdmol)
|
| 62 |
+
|
| 63 |
+
# try:
|
| 64 |
+
# Chem.SanitizeMol(mol)
|
| 65 |
+
# except ValueError:
|
| 66 |
+
# warnings.warn('Sanitization failed. Returning None.')
|
| 67 |
+
# return None
|
| 68 |
+
|
| 69 |
+
if largest_frag:
|
| 70 |
+
mol = get_largest_fragment(mol)
|
| 71 |
+
# if mol is None:
|
| 72 |
+
# return None
|
| 73 |
+
|
| 74 |
+
if adjust_aromatic_Ns:
|
| 75 |
+
mol = sanifix.fix_mol(mol)
|
| 76 |
+
if mol is None:
|
| 77 |
+
return None
|
| 78 |
+
|
| 79 |
+
# if add_hydrogens:
|
| 80 |
+
# mol = add_hydrogens(mol)
|
| 81 |
+
|
| 82 |
+
if relax_iter > 0:
|
| 83 |
+
mol = uff_relax(mol, relax_iter)
|
| 84 |
+
if mol is None:
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
Chem.SanitizeMol(mol)
|
| 89 |
+
except ValueError:
|
| 90 |
+
warnings.warn('Sanitization failed. Returning None.')
|
| 91 |
+
return None
|
| 92 |
+
|
| 93 |
+
return mol
|
src/data/process_crossdocked.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from time import time
|
| 3 |
+
import argparse
|
| 4 |
+
import shutil
|
| 5 |
+
import random
|
| 6 |
+
import yaml
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import numpy as np
|
| 12 |
+
from Bio.PDB import PDBParser
|
| 13 |
+
from rdkit import Chem
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
basedir = Path(__file__).resolve().parent.parent.parent
|
| 17 |
+
sys.path.append(str(basedir))
|
| 18 |
+
|
| 19 |
+
from src.data.data_utils import process_raw_pair, get_n_nodes, get_type_histogram
|
| 20 |
+
from src.data.data_utils import rdmol_to_smiles
|
| 21 |
+
from src.constants import atom_encoder, bond_encoder
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if __name__ == '__main__':
|
| 25 |
+
parser = argparse.ArgumentParser()
|
| 26 |
+
parser.add_argument('basedir', type=Path)
|
| 27 |
+
parser.add_argument('--outdir', type=Path, default=None)
|
| 28 |
+
parser.add_argument('--split_path', type=Path, default=None)
|
| 29 |
+
parser.add_argument('--pocket', type=str, default='CA+',
|
| 30 |
+
choices=['side_chain_bead', 'CA+'])
|
| 31 |
+
parser.add_argument('--random_seed', type=int, default=42)
|
| 32 |
+
parser.add_argument('--val_size', type=int, default=100)
|
| 33 |
+
parser.add_argument('--normal_modes', action='store_true')
|
| 34 |
+
parser.add_argument('--flex', action='store_true')
|
| 35 |
+
parser.add_argument('--toy', action='store_true')
|
| 36 |
+
args = parser.parse_args()
|
| 37 |
+
|
| 38 |
+
random.seed(args.random_seed)
|
| 39 |
+
|
| 40 |
+
datadir = args.basedir / 'crossdocked_pocket10/'
|
| 41 |
+
|
| 42 |
+
# Make output directory
|
| 43 |
+
dirname = f"processed_crossdocked_{args.pocket}"
|
| 44 |
+
if args.flex:
|
| 45 |
+
dirname += '_flex'
|
| 46 |
+
if args.normal_modes:
|
| 47 |
+
dirname += '_nma'
|
| 48 |
+
if args.toy:
|
| 49 |
+
dirname += '_toy'
|
| 50 |
+
processed_dir = Path(args.basedir, dirname) if args.outdir is None else args.outdir
|
| 51 |
+
processed_dir.mkdir(parents=True)
|
| 52 |
+
|
| 53 |
+
# Read data split
|
| 54 |
+
split_path = Path(args.basedir, 'split_by_name.pt') if args.split_path is None else args.split_path
|
| 55 |
+
data_split = torch.load(split_path)
|
| 56 |
+
|
| 57 |
+
# If there is no validation set, copy training examples (the validation set
|
| 58 |
+
# is not very important in this application)
|
| 59 |
+
if 'val' not in data_split:
|
| 60 |
+
random.shuffle(data_split['train'])
|
| 61 |
+
data_split['val'] = data_split['train'][-args.val_size:]
|
| 62 |
+
data_split['train'] = data_split['train'][:-args.val_size]
|
| 63 |
+
|
| 64 |
+
if args.toy:
|
| 65 |
+
data_split['train'] = random.sample(data_split['train'], 100)
|
| 66 |
+
|
| 67 |
+
failed = {}
|
| 68 |
+
train_smiles = []
|
| 69 |
+
|
| 70 |
+
n_samples_after = {}
|
| 71 |
+
for split in data_split.keys():
|
| 72 |
+
|
| 73 |
+
print(f"Processing {split} dataset...")
|
| 74 |
+
|
| 75 |
+
ligands = defaultdict(list)
|
| 76 |
+
pockets = defaultdict(list)
|
| 77 |
+
|
| 78 |
+
tic = time()
|
| 79 |
+
pbar = tqdm(data_split[split])
|
| 80 |
+
for pocket_fn, ligand_fn in pbar:
|
| 81 |
+
|
| 82 |
+
pbar.set_description(f'#failed: {len(failed)}')
|
| 83 |
+
|
| 84 |
+
sdffile = datadir / f'{ligand_fn}'
|
| 85 |
+
pdbfile = datadir / f'{pocket_fn}'
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
pdb_model = PDBParser(QUIET=True).get_structure('', pdbfile)[0]
|
| 89 |
+
|
| 90 |
+
rdmol = Chem.SDMolSupplier(str(sdffile))[0]
|
| 91 |
+
|
| 92 |
+
ligand, pocket = process_raw_pair(
|
| 93 |
+
pdb_model, rdmol, pocket_representation=args.pocket,
|
| 94 |
+
compute_nerf_params=args.flex, compute_bb_frames=args.flex,
|
| 95 |
+
nma_input=pdbfile if args.normal_modes else None)
|
| 96 |
+
|
| 97 |
+
except (KeyError, AssertionError, FileNotFoundError, IndexError,
|
| 98 |
+
ValueError, AttributeError) as e:
|
| 99 |
+
failed[(split, sdffile, pdbfile)] = (type(e).__name__, str(e))
|
| 100 |
+
continue
|
| 101 |
+
|
| 102 |
+
nerf_keys = ['fixed_coord', 'atom_mask', 'nerf_indices', 'length', 'theta', 'chi', 'ddihedral', 'chi_indices']
|
| 103 |
+
for k in ['x', 'one_hot', 'bonds', 'bond_one_hot', 'v', 'nma_vec'] + nerf_keys + ['axis_angle']:
|
| 104 |
+
if k in ligand:
|
| 105 |
+
ligands[k].append(ligand[k])
|
| 106 |
+
if k in pocket:
|
| 107 |
+
pockets[k].append(pocket[k])
|
| 108 |
+
|
| 109 |
+
pocket_file = pdbfile.name.replace('_', '-')
|
| 110 |
+
ligand_file = Path(pocket_file).stem + '_' + Path(sdffile).name.replace('_', '-')
|
| 111 |
+
ligands['name'].append(ligand_file)
|
| 112 |
+
pockets['name'].append(pocket_file)
|
| 113 |
+
train_smiles.append(rdmol_to_smiles(rdmol))
|
| 114 |
+
|
| 115 |
+
if split in {'val', 'test'}:
|
| 116 |
+
pdb_sdf_dir = processed_dir / split
|
| 117 |
+
pdb_sdf_dir.mkdir(exist_ok=True)
|
| 118 |
+
|
| 119 |
+
# Copy PDB file
|
| 120 |
+
pdb_file_out = Path(pdb_sdf_dir, pocket_file)
|
| 121 |
+
shutil.copy(pdbfile, pdb_file_out)
|
| 122 |
+
|
| 123 |
+
# Copy SDF file
|
| 124 |
+
sdf_file_out = Path(pdb_sdf_dir, ligand_file)
|
| 125 |
+
shutil.copy(sdffile, sdf_file_out)
|
| 126 |
+
|
| 127 |
+
data = {'ligands': ligands, 'pockets': pockets}
|
| 128 |
+
torch.save(data, Path(processed_dir, f'{split}.pt'))
|
| 129 |
+
|
| 130 |
+
if split == 'train':
|
| 131 |
+
np.save(Path(processed_dir, 'train_smiles.npy'), train_smiles)
|
| 132 |
+
|
| 133 |
+
print(f"Processing {split} set took {(time() - tic) / 60.0:.2f} minutes")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# --------------------------------------------------------------------------
|
| 137 |
+
# Compute statistics & additional information
|
| 138 |
+
# --------------------------------------------------------------------------
|
| 139 |
+
train_data = torch.load(Path(processed_dir, f'train.pt'))
|
| 140 |
+
|
| 141 |
+
# Maximum molecule size
|
| 142 |
+
max_ligand_size = max([len(x) for x in train_data['ligands']['x']])
|
| 143 |
+
|
| 144 |
+
# Joint histogram of number of ligand and pocket nodes
|
| 145 |
+
pocket_coords = train_data['pockets']['x']
|
| 146 |
+
ligand_coords = train_data['ligands']['x']
|
| 147 |
+
n_nodes = get_n_nodes(ligand_coords, pocket_coords)
|
| 148 |
+
np.save(Path(processed_dir, 'size_distribution.npy'), n_nodes)
|
| 149 |
+
|
| 150 |
+
# Get histograms of ligand node types
|
| 151 |
+
lig_one_hot = [x.numpy() for x in train_data['ligands']['one_hot']]
|
| 152 |
+
ligand_hist = get_type_histogram(lig_one_hot, atom_encoder)
|
| 153 |
+
np.save(Path(processed_dir, 'ligand_type_histogram.npy'), ligand_hist)
|
| 154 |
+
|
| 155 |
+
# Get histograms of ligand edge types
|
| 156 |
+
lig_bond_one_hot = [x.numpy() for x in train_data['ligands']['bond_one_hot']]
|
| 157 |
+
ligand_bond_hist = get_type_histogram(lig_bond_one_hot, bond_encoder)
|
| 158 |
+
np.save(Path(processed_dir, 'ligand_bond_type_histogram.npy'), ligand_bond_hist)
|
| 159 |
+
|
| 160 |
+
# Write error report
|
| 161 |
+
error_str = ""
|
| 162 |
+
for k, v in failed.items():
|
| 163 |
+
error_str += f"{'Split':<15}: {k[0]}\n"
|
| 164 |
+
error_str += f"{'Ligand':<15}: {k[1]}\n"
|
| 165 |
+
error_str += f"{'Pocket':<15}: {k[2]}\n"
|
| 166 |
+
error_str += f"{'Error type':<15}: {v[0]}\n"
|
| 167 |
+
error_str += f"{'Error msg':<15}: {v[1]}\n\n"
|
| 168 |
+
|
| 169 |
+
with open(Path(processed_dir, 'errors.txt'), 'w') as f:
|
| 170 |
+
f.write(error_str)
|
| 171 |
+
|
| 172 |
+
metadata = {
|
| 173 |
+
'max_ligand_size': max_ligand_size
|
| 174 |
+
}
|
| 175 |
+
with open(Path(processed_dir, 'metadata.yml'), 'w') as f:
|
| 176 |
+
yaml.dump(metadata, f, default_flow_style=False)
|
src/data/process_dpo_dataset.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import numpy as np
|
| 4 |
+
import random
|
| 5 |
+
import shutil
|
| 6 |
+
from time import time
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from Bio.PDB import PDBParser
|
| 9 |
+
from rdkit import Chem
|
| 10 |
+
import torch
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import pandas as pd
|
| 13 |
+
from itertools import combinations
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
basedir = Path(__file__).resolve().parent.parent.parent
|
| 17 |
+
sys.path.append(str(basedir))
|
| 18 |
+
|
| 19 |
+
from src.sbdd_metrics.metrics import REOSEvaluator, MedChemEvaluator, PoseBustersEvaluator, GninaEvalulator
|
| 20 |
+
from src.data.data_utils import process_raw_pair, rdmol_to_smiles
|
| 21 |
+
|
| 22 |
+
def parse_args():
|
| 23 |
+
parser = argparse.ArgumentParser()
|
| 24 |
+
parser.add_argument('--smplsdir', type=Path, required=True)
|
| 25 |
+
parser.add_argument('--metrics-detailed', type=Path, required=False)
|
| 26 |
+
parser.add_argument('--ignore-missing-scores', action='store_true')
|
| 27 |
+
parser.add_argument('--datadir', type=Path, required=True)
|
| 28 |
+
parser.add_argument('--dpo-criterion', type=str, default='reos.all',
|
| 29 |
+
choices=['reos.all', 'medchem.sa', 'medchem.qed', 'gnina.vina_efficiency','combined'])
|
| 30 |
+
parser.add_argument('--basedir', type=Path, default=None)
|
| 31 |
+
parser.add_argument('--pocket', type=str, default='CA+',
|
| 32 |
+
choices=['side_chain_bead', 'CA+'])
|
| 33 |
+
parser.add_argument('--gnina', type=Path, default='gnina')
|
| 34 |
+
parser.add_argument('--random_seed', type=int, default=42)
|
| 35 |
+
parser.add_argument('--normal_modes', action='store_true')
|
| 36 |
+
parser.add_argument('--flex', action='store_true')
|
| 37 |
+
parser.add_argument('--toy', action='store_true')
|
| 38 |
+
parser.add_argument('--toy_size', type=int, default=100)
|
| 39 |
+
parser.add_argument('--n_pairs', type=int, default=5)
|
| 40 |
+
args = parser.parse_args()
|
| 41 |
+
return args
|
| 42 |
+
|
| 43 |
+
def scan_smpl_dir(samples_dir):
|
| 44 |
+
samples_dir = Path(samples_dir)
|
| 45 |
+
subdirs = []
|
| 46 |
+
for subdir in tqdm(samples_dir.iterdir(), desc='Scanning samples'):
|
| 47 |
+
if not subdir.is_dir():
|
| 48 |
+
continue
|
| 49 |
+
if not sample_dir_valid(subdir):
|
| 50 |
+
continue
|
| 51 |
+
subdirs.append(subdir)
|
| 52 |
+
return subdirs
|
| 53 |
+
|
| 54 |
+
def sample_dir_valid(samples_dir):
|
| 55 |
+
pocket = samples_dir / '0_pocket.pdb'
|
| 56 |
+
if not pocket.exists():
|
| 57 |
+
return False
|
| 58 |
+
ligands = list(samples_dir.glob('*_ligand.sdf'))
|
| 59 |
+
if len(ligands) < 2:
|
| 60 |
+
return False
|
| 61 |
+
for ligand in ligands:
|
| 62 |
+
if ligand.stat().st_size == 0:
|
| 63 |
+
return False
|
| 64 |
+
return True
|
| 65 |
+
|
| 66 |
+
def return_winning_losing_smpl(score_1, score_2, criterion):
|
| 67 |
+
if criterion == 'reos.all':
|
| 68 |
+
if score_1 == score_2:
|
| 69 |
+
return None
|
| 70 |
+
return score_1 > score_2
|
| 71 |
+
elif criterion == 'medchem.sa':
|
| 72 |
+
if np.abs(score_1 - score_2) < 0.5:
|
| 73 |
+
return None
|
| 74 |
+
return score_1 < score_2
|
| 75 |
+
elif criterion == 'medchem.qed':
|
| 76 |
+
if np.abs(score_1 - score_2) < 0.1:
|
| 77 |
+
return None
|
| 78 |
+
return score_1 > score_2
|
| 79 |
+
elif criterion == 'gnina.vina_efficiency':
|
| 80 |
+
if np.abs(score_1 - score_2) < 0.1:
|
| 81 |
+
return None
|
| 82 |
+
return score_1 < score_2
|
| 83 |
+
elif criterion == 'combined':
|
| 84 |
+
score_reos_1, score_reos_2 = score_1['reos.all'], score_2['reos.all']
|
| 85 |
+
score_sa_1, score_sa_2 = score_1['medchem.sa'], score_2['medchem.sa']
|
| 86 |
+
score_qed_1, score_qed_2 = score_1['medchem.qed'], score_2['medchem.qed']
|
| 87 |
+
score_vina_1, score_vina_2 = score_1['gnina.vina_efficiency'], score_2['gnina.vina_efficiency']
|
| 88 |
+
if score_reos_1 == score_reos_2: return None
|
| 89 |
+
# checking consistency
|
| 90 |
+
reos_sign = score_reos_1 > score_reos_2
|
| 91 |
+
sa_sign = score_sa_1 < score_sa_2
|
| 92 |
+
qed_sign = score_qed_1 > score_qed_2
|
| 93 |
+
vina_sign = score_vina_1 < score_vina_2
|
| 94 |
+
signs = [reos_sign, sa_sign, qed_sign, vina_sign]
|
| 95 |
+
if all(signs) or not any(signs): return signs[0]
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
def compute_scores(sample_dirs, evaluator, criterion, n_pairs=5, toy=False, toy_size=100,
|
| 99 |
+
precomp_scores=None, ignore_missing_scores=False):
|
| 100 |
+
samples = []
|
| 101 |
+
pose_evaluator = PoseBustersEvaluator()
|
| 102 |
+
pbar = tqdm(sample_dirs, desc='Computing scores for samples')
|
| 103 |
+
|
| 104 |
+
for dir in pbar:
|
| 105 |
+
pocket = dir / '0_pocket.pdb'
|
| 106 |
+
ligands = list(dir.glob('*_ligand.sdf'))
|
| 107 |
+
|
| 108 |
+
target_samples = []
|
| 109 |
+
for lig_path in ligands:
|
| 110 |
+
try:
|
| 111 |
+
mol = Chem.SDMolSupplier(str(lig_path))[0]
|
| 112 |
+
if mol is None:
|
| 113 |
+
continue
|
| 114 |
+
smiles = rdmol_to_smiles(mol)
|
| 115 |
+
except Exception as e:
|
| 116 |
+
print('Failed to read ligand:', lig_path)
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
if precomp_scores is not None and str(lig_path) in precomp_scores.index:
|
| 120 |
+
mol_props = precomp_scores.loc[str(lig_path)].to_dict()
|
| 121 |
+
if criterion == 'combined':
|
| 122 |
+
if not 'reos.all' in mol_props or not 'medchem.sa' in mol_props or not 'medchem.qed' in mol_props or not 'gnina.vina_efficiency' in mol_props:
|
| 123 |
+
print(f'Missing combined scores for ligand:', lig_path)
|
| 124 |
+
continue
|
| 125 |
+
mol_props['combined'] = {
|
| 126 |
+
'reos.all': mol_props['reos.all'],
|
| 127 |
+
'medchem.sa': mol_props['medchem.sa'],
|
| 128 |
+
'medchem.qed': mol_props['medchem.qed'],
|
| 129 |
+
'gnina.vina_efficiency': mol_props['gnina.vina_efficiency'],
|
| 130 |
+
'combined': mol_props['gnina.vina_efficiency']
|
| 131 |
+
}
|
| 132 |
+
else:
|
| 133 |
+
mol_props = {}
|
| 134 |
+
if criterion not in mol_props:
|
| 135 |
+
if ignore_missing_scores:
|
| 136 |
+
print(f'Missing {criterion} for ligand:', lig_path)
|
| 137 |
+
continue
|
| 138 |
+
print(f'Recomputing {criterion} for ligand:', lig_path)
|
| 139 |
+
try:
|
| 140 |
+
eval_res = evaluator.evaluate(mol)
|
| 141 |
+
criterion_cat = criterion.split('.')[0]
|
| 142 |
+
eval_res = {f'{criterion_cat}.{k}': v for k, v in eval_res.items()}
|
| 143 |
+
score = eval_res[criterion]
|
| 144 |
+
except:
|
| 145 |
+
continue
|
| 146 |
+
else:
|
| 147 |
+
score = mol_props[criterion]
|
| 148 |
+
|
| 149 |
+
if 'posebusters.all' not in mol_props:
|
| 150 |
+
if ignore_missing_scores:
|
| 151 |
+
print('Missing PoseBusters for ligand:', lig_path)
|
| 152 |
+
continue
|
| 153 |
+
print('Recomputing PoseBusters for ligand:', lig_path)
|
| 154 |
+
try:
|
| 155 |
+
pose_eval_res = pose_evaluator.evaluate(lig_path, pocket)
|
| 156 |
+
except:
|
| 157 |
+
continue
|
| 158 |
+
if 'all' not in pose_eval_res or not pose_eval_res['all']:
|
| 159 |
+
continue
|
| 160 |
+
else:
|
| 161 |
+
pose_eval_res = mol_props['posebusters.all']
|
| 162 |
+
if not pose_eval_res:
|
| 163 |
+
continue
|
| 164 |
+
|
| 165 |
+
target_samples.append({
|
| 166 |
+
'smiles': smiles,
|
| 167 |
+
'score': score,
|
| 168 |
+
'ligand_path': lig_path,
|
| 169 |
+
'pocket_path': pocket
|
| 170 |
+
})
|
| 171 |
+
|
| 172 |
+
# Deduplicate by SMILES
|
| 173 |
+
unique_samples = {}
|
| 174 |
+
for sample in target_samples:
|
| 175 |
+
if sample['smiles'] not in unique_samples:
|
| 176 |
+
unique_samples[sample['smiles']] = sample
|
| 177 |
+
unique_samples = list(unique_samples.values())
|
| 178 |
+
if len(unique_samples) < 2:
|
| 179 |
+
continue
|
| 180 |
+
|
| 181 |
+
# Generate all possible pairs
|
| 182 |
+
all_pairs = list(combinations(unique_samples, 2))
|
| 183 |
+
|
| 184 |
+
# Calculate score differences and filter valid pairs
|
| 185 |
+
valid_pairs = []
|
| 186 |
+
for s1, s2 in all_pairs:
|
| 187 |
+
sign = return_winning_losing_smpl(s1['score'], s2['score'], criterion)
|
| 188 |
+
if sign is None:
|
| 189 |
+
continue
|
| 190 |
+
score_diff = abs(s1['score'] - s2['score']) if not criterion == 'combined' else \
|
| 191 |
+
abs(s1['score']['combined'] - s2['score']['combined'])
|
| 192 |
+
if sign:
|
| 193 |
+
valid_pairs.append((s1, s2, score_diff))
|
| 194 |
+
elif sign is False:
|
| 195 |
+
valid_pairs.append((s2, s1, score_diff))
|
| 196 |
+
|
| 197 |
+
# Sort pairs by score difference (descending) and select top N pairs
|
| 198 |
+
valid_pairs.sort(key=lambda x: x[2], reverse=True)
|
| 199 |
+
used_ligand_paths = set()
|
| 200 |
+
selected_pairs = []
|
| 201 |
+
for winning, losing, score_diff in valid_pairs:
|
| 202 |
+
if winning['ligand_path'] in used_ligand_paths or losing['ligand_path'] in used_ligand_paths:
|
| 203 |
+
continue
|
| 204 |
+
|
| 205 |
+
selected_pairs.append((winning, losing, score_diff))
|
| 206 |
+
used_ligand_paths.add(winning['ligand_path'])
|
| 207 |
+
used_ligand_paths.add(losing['ligand_path'])
|
| 208 |
+
|
| 209 |
+
if len(selected_pairs) == n_pairs:
|
| 210 |
+
break
|
| 211 |
+
for winning, losing, _ in selected_pairs:
|
| 212 |
+
d = {
|
| 213 |
+
'score_w': winning['score'],
|
| 214 |
+
'score_l': losing['score'],
|
| 215 |
+
'pocket_p': winning['pocket_path'],
|
| 216 |
+
'ligand_p_w': winning['ligand_path'],
|
| 217 |
+
'ligand_p_l': losing['ligand_path']
|
| 218 |
+
}
|
| 219 |
+
if isinstance(winning['score'], dict):
|
| 220 |
+
for k, v in winning['score'].items():
|
| 221 |
+
d[f'{k}_w'] = v
|
| 222 |
+
d['score_w'] = winning['score']['combined']
|
| 223 |
+
if isinstance(losing['score'], dict):
|
| 224 |
+
for k, v in losing['score'].items():
|
| 225 |
+
d[f'{k}_l'] = v
|
| 226 |
+
d['score_l'] = losing['score']['combined']
|
| 227 |
+
samples.append(d)
|
| 228 |
+
|
| 229 |
+
pbar.set_postfix({'samples': len(samples)})
|
| 230 |
+
|
| 231 |
+
if toy and len(samples) >= toy_size:
|
| 232 |
+
break
|
| 233 |
+
|
| 234 |
+
return samples
|
| 235 |
+
|
| 236 |
+
def main():
|
| 237 |
+
args = parse_args()
|
| 238 |
+
|
| 239 |
+
if 'reos' in args.dpo_criterion:
|
| 240 |
+
evaluator = REOSEvaluator()
|
| 241 |
+
elif 'medchem' in args.dpo_criterion:
|
| 242 |
+
evaluator = MedChemEvaluator()
|
| 243 |
+
elif 'gnina' in args.dpo_criterion:
|
| 244 |
+
evaluator = GninaEvalulator(gnina=args.gnina)
|
| 245 |
+
elif 'combined' in args.dpo_criterion:
|
| 246 |
+
evaluator = None # for combined criterion, metrics have to be computed separately
|
| 247 |
+
if args.metrics_detailed is None:
|
| 248 |
+
raise ValueError('For combined criterion, detailed metrics file has to be provided')
|
| 249 |
+
if not args.ignore_missing_scores:
|
| 250 |
+
raise ValueError('For combined criterion, --ignore-missing-scores flag has to be set')
|
| 251 |
+
else:
|
| 252 |
+
raise ValueError(f"Unknown DPO criterion: {args.dpo_criterion}")
|
| 253 |
+
|
| 254 |
+
# Make output directory
|
| 255 |
+
dirname = f"dpo_{args.dpo_criterion.replace('.','_')}_{args.pocket}"
|
| 256 |
+
if args.flex:
|
| 257 |
+
dirname += '_flex'
|
| 258 |
+
if args.normal_modes:
|
| 259 |
+
dirname += '_nma'
|
| 260 |
+
if args.toy:
|
| 261 |
+
dirname += '_toy'
|
| 262 |
+
processed_dir = Path(args.basedir, dirname)
|
| 263 |
+
processed_dir.mkdir(parents=True, exist_ok=True)
|
| 264 |
+
|
| 265 |
+
if (processed_dir / f'samples_{args.dpo_criterion}.csv').exists():
|
| 266 |
+
print(f"Samples already computed for criterion {args.dpo_criterion}, loading from file")
|
| 267 |
+
samples = pd.read_csv(processed_dir / f'samples_{args.dpo_criterion}.csv')
|
| 268 |
+
samples = [dict(row) for _, row in samples.iterrows()]
|
| 269 |
+
print(f"Found {len(samples)} winning/losing samples")
|
| 270 |
+
else:
|
| 271 |
+
print('Scanning sample directory...')
|
| 272 |
+
samples_dir = Path(args.smplsdir)
|
| 273 |
+
# scan dir
|
| 274 |
+
sample_dirs = scan_smpl_dir(samples_dir)
|
| 275 |
+
if args.metrics_detailed:
|
| 276 |
+
print(f'Loading precomputed scores from {args.metrics_detailed}')
|
| 277 |
+
precomp_scores = pd.read_csv(args.metrics_detailed)
|
| 278 |
+
precomp_scores = precomp_scores.set_index('sdf_file')
|
| 279 |
+
else:
|
| 280 |
+
precomp_scores = None
|
| 281 |
+
print(f'Found {len(sample_dirs)} valid sample directories')
|
| 282 |
+
print('Computing scores...')
|
| 283 |
+
samples = compute_scores(sample_dirs, evaluator, args.dpo_criterion,
|
| 284 |
+
n_pairs=args.n_pairs, toy=args.toy, toy_size=args.toy_size,
|
| 285 |
+
precomp_scores=precomp_scores,
|
| 286 |
+
ignore_missing_scores=args.ignore_missing_scores)
|
| 287 |
+
print(f'Found {len(samples)} winning/losing samples, saving to file')
|
| 288 |
+
pd.DataFrame(samples).to_csv(Path(processed_dir, f'samples_{args.dpo_criterion}.csv'), index=False)
|
| 289 |
+
|
| 290 |
+
data_split = {}
|
| 291 |
+
data_split['train'] = samples
|
| 292 |
+
if args.toy:
|
| 293 |
+
data_split['train'] = random.sample(samples, min(args.toy_size, len(data_split['train'])))
|
| 294 |
+
|
| 295 |
+
failed = {}
|
| 296 |
+
train_smiles = []
|
| 297 |
+
|
| 298 |
+
for split in data_split.keys():
|
| 299 |
+
|
| 300 |
+
print(f"Processing {split} dataset...")
|
| 301 |
+
|
| 302 |
+
ligands_w = defaultdict(list)
|
| 303 |
+
ligands_l = defaultdict(list)
|
| 304 |
+
pockets = defaultdict(list)
|
| 305 |
+
|
| 306 |
+
tic = time()
|
| 307 |
+
pbar = tqdm(data_split[split])
|
| 308 |
+
for entry in pbar:
|
| 309 |
+
|
| 310 |
+
pbar.set_description(f'#failed: {len(failed)}')
|
| 311 |
+
|
| 312 |
+
pdbfile = Path(entry['pocket_p'])
|
| 313 |
+
entry['ligand_p_w'] = Path(entry['ligand_p_w'])
|
| 314 |
+
entry['ligand_p_l'] = Path(entry['ligand_p_l'])
|
| 315 |
+
entry['ligand_w'] = Chem.SDMolSupplier(str(entry['ligand_p_w']))[0]
|
| 316 |
+
entry['ligand_l'] = Chem.SDMolSupplier(str(entry['ligand_p_l']))[0]
|
| 317 |
+
|
| 318 |
+
try:
|
| 319 |
+
pdb_model = PDBParser(QUIET=True).get_structure('', pdbfile)[0]
|
| 320 |
+
|
| 321 |
+
ligand_w, pocket = process_raw_pair(
|
| 322 |
+
pdb_model, entry['ligand_w'], pocket_representation=args.pocket,
|
| 323 |
+
compute_nerf_params=args.flex, compute_bb_frames=args.flex,
|
| 324 |
+
nma_input=pdbfile if args.normal_modes else None)
|
| 325 |
+
ligand_l, _ = process_raw_pair(
|
| 326 |
+
pdb_model, entry['ligand_l'], pocket_representation=args.pocket,
|
| 327 |
+
compute_nerf_params=args.flex, compute_bb_frames=args.flex,
|
| 328 |
+
nma_input=pdbfile if args.normal_modes else None)
|
| 329 |
+
|
| 330 |
+
except (KeyError, AssertionError, FileNotFoundError, IndexError,
|
| 331 |
+
ValueError, AttributeError) as e:
|
| 332 |
+
failed[(split, entry['ligand_p_w'], entry['ligand_p_l'], pdbfile)] \
|
| 333 |
+
= (type(e).__name__, str(e))
|
| 334 |
+
continue
|
| 335 |
+
|
| 336 |
+
nerf_keys = ['fixed_coord', 'atom_mask', 'nerf_indices', 'length', 'theta', 'chi', 'ddihedral', 'chi_indices']
|
| 337 |
+
for k in ['x', 'one_hot', 'bonds', 'bond_one_hot', 'v', 'nma_vec'] + nerf_keys + ['axis_angle']:
|
| 338 |
+
if k in ligand_w:
|
| 339 |
+
ligands_w[k].append(ligand_w[k])
|
| 340 |
+
ligands_l[k].append(ligand_l[k])
|
| 341 |
+
if k in pocket:
|
| 342 |
+
pockets[k].append(pocket[k])
|
| 343 |
+
|
| 344 |
+
smpl_n = pdbfile.parent.name
|
| 345 |
+
pocket_file = f'{smpl_n}__{pdbfile.stem}.pdb'
|
| 346 |
+
ligand_file_w = f'{smpl_n}__{entry["ligand_p_w"].stem}.sdf'
|
| 347 |
+
ligand_file_l = f'{smpl_n}__{entry["ligand_p_l"].stem}.sdf'
|
| 348 |
+
ligands_w['name'].append(ligand_file_w)
|
| 349 |
+
ligands_l['name'].append(ligand_file_l)
|
| 350 |
+
pockets['name'].append(pocket_file)
|
| 351 |
+
train_smiles.append(rdmol_to_smiles(entry['ligand_w']))
|
| 352 |
+
train_smiles.append(rdmol_to_smiles(entry['ligand_l']))
|
| 353 |
+
|
| 354 |
+
data = {'ligands_w': ligands_w,
|
| 355 |
+
'ligands_l': ligands_l,
|
| 356 |
+
'pockets': pockets}
|
| 357 |
+
torch.save(data, Path(processed_dir, f'{split}.pt'))
|
| 358 |
+
|
| 359 |
+
if split == 'train':
|
| 360 |
+
np.save(Path(processed_dir, 'train_smiles.npy'), train_smiles)
|
| 361 |
+
|
| 362 |
+
print(f"Processing {split} set took {(time() - tic) / 60.0:.2f} minutes")
|
| 363 |
+
|
| 364 |
+
# cp stats from original dataset
|
| 365 |
+
size_distr_p = Path(args.datadir, 'size_distribution.npy')
|
| 366 |
+
type_histo_p = Path(args.datadir, 'ligand_type_histogram.npy')
|
| 367 |
+
bond_histo_p = Path(args.datadir, 'ligand_bond_type_histogram.npy')
|
| 368 |
+
metadata_p = Path(args.datadir, 'metadata.yml')
|
| 369 |
+
shutil.copy(size_distr_p, processed_dir)
|
| 370 |
+
shutil.copy(type_histo_p, processed_dir)
|
| 371 |
+
shutil.copy(bond_histo_p, processed_dir)
|
| 372 |
+
shutil.copy(metadata_p, processed_dir)
|
| 373 |
+
|
| 374 |
+
# cp val and test .pt and dirs
|
| 375 |
+
val_dir = Path(args.datadir, 'val')
|
| 376 |
+
test_dir = Path(args.datadir, 'test')
|
| 377 |
+
val_pt = Path(args.datadir, 'val.pt')
|
| 378 |
+
test_pt = Path(args.datadir, 'test.pt')
|
| 379 |
+
assert val_dir.exists() and test_dir.exists() and val_pt.exists() and test_pt.exists()
|
| 380 |
+
if (processed_dir / 'val').exists():
|
| 381 |
+
shutil.rmtree(processed_dir / 'val')
|
| 382 |
+
if (processed_dir / 'test').exists():
|
| 383 |
+
shutil.rmtree(processed_dir / 'test')
|
| 384 |
+
shutil.copytree(val_dir, processed_dir / 'val')
|
| 385 |
+
shutil.copytree(test_dir, processed_dir / 'test')
|
| 386 |
+
shutil.copy(val_pt, processed_dir)
|
| 387 |
+
shutil.copy(test_pt, processed_dir)
|
| 388 |
+
|
| 389 |
+
# Write error report
|
| 390 |
+
error_str = ""
|
| 391 |
+
for k, v in failed.items():
|
| 392 |
+
error_str += f"{'Split':<15}: {k[0]}\n"
|
| 393 |
+
error_str += f"{'Ligand W':<15}: {k[1]}\n"
|
| 394 |
+
error_str += f"{'Ligand L':<15}: {k[2]}\n"
|
| 395 |
+
error_str += f"{'Pocket':<15}: {k[3]}\n"
|
| 396 |
+
error_str += f"{'Error type':<15}: {v[0]}\n"
|
| 397 |
+
error_str += f"{'Error msg':<15}: {v[1]}\n\n"
|
| 398 |
+
|
| 399 |
+
with open(Path(processed_dir, 'errors.txt'), 'w') as f:
|
| 400 |
+
f.write(error_str)
|
| 401 |
+
|
| 402 |
+
with open(Path(processed_dir, 'dataset_config.txt'), 'w') as f:
|
| 403 |
+
f.write(str(args))
|
| 404 |
+
|
| 405 |
+
if __name__ == '__main__':
|
| 406 |
+
main()
|
src/data/sanifix.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" sanifix4.py
|
| 2 |
+
|
| 3 |
+
Contribution from James Davidson
|
| 4 |
+
adapted from: https://github.com/abradle/rdkitserver/blob/master/MYSITE/src/testproject/mol_parsing/sanifix.py
|
| 5 |
+
"""
|
| 6 |
+
from rdkit import Chem
|
| 7 |
+
from rdkit.Chem import AllChem
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
def _FragIndicesToMol(oMol,indices):
|
| 11 |
+
em = Chem.EditableMol(Chem.Mol())
|
| 12 |
+
|
| 13 |
+
newIndices={}
|
| 14 |
+
for i,idx in enumerate(indices):
|
| 15 |
+
em.AddAtom(oMol.GetAtomWithIdx(idx))
|
| 16 |
+
newIndices[idx]=i
|
| 17 |
+
|
| 18 |
+
for i,idx in enumerate(indices):
|
| 19 |
+
at = oMol.GetAtomWithIdx(idx)
|
| 20 |
+
for bond in at.GetBonds():
|
| 21 |
+
if bond.GetBeginAtomIdx()==idx:
|
| 22 |
+
oidx = bond.GetEndAtomIdx()
|
| 23 |
+
else:
|
| 24 |
+
oidx = bond.GetBeginAtomIdx()
|
| 25 |
+
# make sure every bond only gets added once:
|
| 26 |
+
if oidx<idx:
|
| 27 |
+
continue
|
| 28 |
+
em.AddBond(newIndices[idx],newIndices[oidx],bond.GetBondType())
|
| 29 |
+
res = em.GetMol()
|
| 30 |
+
res.ClearComputedProps()
|
| 31 |
+
Chem.GetSymmSSSR(res)
|
| 32 |
+
res.UpdatePropertyCache(False)
|
| 33 |
+
res._idxMap=newIndices
|
| 34 |
+
return res
|
| 35 |
+
|
| 36 |
+
def _recursivelyModifyNs(mol,matches,indices=None):
|
| 37 |
+
if indices is None:
|
| 38 |
+
indices=[]
|
| 39 |
+
res=None
|
| 40 |
+
while len(matches) and res is None:
|
| 41 |
+
tIndices=indices[:]
|
| 42 |
+
nextIdx = matches.pop(0)
|
| 43 |
+
tIndices.append(nextIdx)
|
| 44 |
+
nm = Chem.Mol(mol)
|
| 45 |
+
nm.GetAtomWithIdx(nextIdx).SetNoImplicit(True)
|
| 46 |
+
nm.GetAtomWithIdx(nextIdx).SetNumExplicitHs(1)
|
| 47 |
+
cp = Chem.Mol(nm)
|
| 48 |
+
try:
|
| 49 |
+
Chem.SanitizeMol(cp)
|
| 50 |
+
except ValueError:
|
| 51 |
+
res,indices = _recursivelyModifyNs(nm,matches,indices=tIndices)
|
| 52 |
+
else:
|
| 53 |
+
indices=tIndices
|
| 54 |
+
res=cp
|
| 55 |
+
return res,indices
|
| 56 |
+
|
| 57 |
+
def AdjustAromaticNs(m,nitrogenPattern='[n&D2&H0;r5,r6]'):
|
| 58 |
+
"""
|
| 59 |
+
default nitrogen pattern matches Ns in 5 rings and 6 rings in order to be able
|
| 60 |
+
to fix: O=c1ccncc1
|
| 61 |
+
"""
|
| 62 |
+
Chem.GetSymmSSSR(m)
|
| 63 |
+
m.UpdatePropertyCache(False)
|
| 64 |
+
|
| 65 |
+
# break non-ring bonds linking rings:
|
| 66 |
+
em = Chem.EditableMol(m)
|
| 67 |
+
linkers = m.GetSubstructMatches(Chem.MolFromSmarts('[r]!@[r]'))
|
| 68 |
+
plsFix=set()
|
| 69 |
+
for a,b in linkers:
|
| 70 |
+
em.RemoveBond(a,b)
|
| 71 |
+
plsFix.add(a)
|
| 72 |
+
plsFix.add(b)
|
| 73 |
+
nm = em.GetMol()
|
| 74 |
+
for at in plsFix:
|
| 75 |
+
at=nm.GetAtomWithIdx(at)
|
| 76 |
+
if at.GetIsAromatic() and at.GetAtomicNum()==7:
|
| 77 |
+
at.SetNumExplicitHs(1)
|
| 78 |
+
at.SetNoImplicit(True)
|
| 79 |
+
|
| 80 |
+
# build molecules from the fragments:
|
| 81 |
+
fragLists = Chem.GetMolFrags(nm)
|
| 82 |
+
frags = [_FragIndicesToMol(nm,x) for x in fragLists]
|
| 83 |
+
|
| 84 |
+
# loop through the fragments in turn and try to aromatize them:
|
| 85 |
+
ok=True
|
| 86 |
+
for i,frag in enumerate(frags):
|
| 87 |
+
cp = Chem.Mol(frag)
|
| 88 |
+
try:
|
| 89 |
+
Chem.SanitizeMol(cp)
|
| 90 |
+
except ValueError:
|
| 91 |
+
matches = [x[0] for x in frag.GetSubstructMatches(Chem.MolFromSmarts(nitrogenPattern))]
|
| 92 |
+
lres,indices=_recursivelyModifyNs(frag,matches)
|
| 93 |
+
if not lres:
|
| 94 |
+
#print 'frag %d failed (%s)'%(i,str(fragLists[i]))
|
| 95 |
+
ok=False
|
| 96 |
+
break
|
| 97 |
+
else:
|
| 98 |
+
revMap={}
|
| 99 |
+
for k,v in frag._idxMap.items():
|
| 100 |
+
revMap[v]=k
|
| 101 |
+
for idx in indices:
|
| 102 |
+
oatom = m.GetAtomWithIdx(revMap[idx])
|
| 103 |
+
oatom.SetNoImplicit(True)
|
| 104 |
+
oatom.SetNumExplicitHs(1)
|
| 105 |
+
if not ok:
|
| 106 |
+
return None
|
| 107 |
+
return m
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def fix_mol(m):
|
| 112 |
+
if m is None:
|
| 113 |
+
return None
|
| 114 |
+
try:
|
| 115 |
+
m.UpdatePropertyCache(False)
|
| 116 |
+
cp = Chem.Mol(m.ToBinary())
|
| 117 |
+
Chem.SanitizeMol(cp)
|
| 118 |
+
m = cp
|
| 119 |
+
# print('fine:',Chem.MolToSmiles(m))
|
| 120 |
+
warnings.warn(f'fine: {Chem.MolToSmiles(m)}')
|
| 121 |
+
return m
|
| 122 |
+
except ValueError:
|
| 123 |
+
# print('adjust')
|
| 124 |
+
warnings.warn('adjust')
|
| 125 |
+
nm=AdjustAromaticNs(m)
|
| 126 |
+
if nm is not None:
|
| 127 |
+
try:
|
| 128 |
+
Chem.SanitizeMol(nm)
|
| 129 |
+
# print('fixed:',Chem.MolToSmiles(nm))
|
| 130 |
+
warnings.warn(f'fixed: {Chem.MolToSmiles(nm)}')
|
| 131 |
+
except ValueError:
|
| 132 |
+
# print('still broken')
|
| 133 |
+
warnings.warn('still broken')
|
| 134 |
+
else:
|
| 135 |
+
# print('still broken')
|
| 136 |
+
warnings.warn('still broken')
|
| 137 |
+
return nm
|
| 138 |
+
|
| 139 |
+
if __name__=='__main__':
|
| 140 |
+
ms = [x for x in open("18.sdf").read().split("$$$$\n")]
|
| 141 |
+
for txt_m in ms:
|
| 142 |
+
if not txt_m:
|
| 143 |
+
continue
|
| 144 |
+
m = Chem.MolFromMolBlock(txt_m, False)
|
| 145 |
+
print('#---------------------')
|
| 146 |
+
try:
|
| 147 |
+
m.UpdatePropertyCache(False)
|
| 148 |
+
cp = Chem.Mol(m.ToBinary())
|
| 149 |
+
Chem.SanitizeMol(cp)
|
| 150 |
+
m = cp
|
| 151 |
+
print('fine:',Chem.MolToSmiles(m))
|
| 152 |
+
except ValueError:
|
| 153 |
+
print('adjust')
|
| 154 |
+
nm=AdjustAromaticNs(m)
|
| 155 |
+
if nm is not None:
|
| 156 |
+
Chem.SanitizeMol(nm)
|
| 157 |
+
print('fixed:',Chem.MolToSmiles(nm))
|
| 158 |
+
else:
|
| 159 |
+
print('still broken')
|
src/data/so3_utils.py
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def _batch_trace(m):
|
| 6 |
+
return torch.einsum('...ii', m)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def regularize(point, eps=1e-6):
|
| 10 |
+
"""
|
| 11 |
+
Norm of the rotation vector should be between 0 and pi.
|
| 12 |
+
Inverts the direction of the rotation axis if the value is between pi and 2 pi.
|
| 13 |
+
Args:
|
| 14 |
+
point, (n, 3)
|
| 15 |
+
Returns:
|
| 16 |
+
regularized point, (n, 3)
|
| 17 |
+
"""
|
| 18 |
+
theta = torch.linalg.norm(point, axis=-1)
|
| 19 |
+
|
| 20 |
+
# angle in [0, 2pi)
|
| 21 |
+
theta_wrapped = theta % (2 * math.pi)
|
| 22 |
+
inv_mask = theta_wrapped > math.pi
|
| 23 |
+
|
| 24 |
+
# angle in [0, pi) & invert
|
| 25 |
+
theta_wrapped[inv_mask] = -1 * (2 * math.pi - theta_wrapped[inv_mask])
|
| 26 |
+
|
| 27 |
+
# apply
|
| 28 |
+
theta = torch.clamp(theta, min=eps)
|
| 29 |
+
point = point * (theta_wrapped / theta).unsqueeze(-1)
|
| 30 |
+
assert not point.isnan().any()
|
| 31 |
+
return point
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def random_uniform(n_samples, device=None):
|
| 35 |
+
"""
|
| 36 |
+
Follow geomstats implementation:
|
| 37 |
+
https://geomstats.github.io/_modules/geomstats/geometry/special_orthogonal.html
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
n_samples: int
|
| 41 |
+
Returns:
|
| 42 |
+
rotation vectors, (n, 3)
|
| 43 |
+
"""
|
| 44 |
+
random_point = (torch.rand(n_samples, 3, device=device) * 2 - 1) * math.pi
|
| 45 |
+
random_point = regularize(random_point)
|
| 46 |
+
|
| 47 |
+
return random_point
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def hat(rot_vec):
|
| 51 |
+
"""
|
| 52 |
+
Maps R^3 vector to a skew-symmetric matrix r (i.e. r \in R^{3x3} and r^T = -r).
|
| 53 |
+
Since we have the identity rv = rot_vec x v for all v \in R^3, this is
|
| 54 |
+
identical to a cross-product-matrix representation of rot_vec.
|
| 55 |
+
rot_vec x v = hat(rot_vec)^T v
|
| 56 |
+
See also:
|
| 57 |
+
https://en.wikipedia.org/wiki/Cross_product#Conversion_to_matrix_multiplication
|
| 58 |
+
https://en.wikipedia.org/wiki/Hat_notation#Cross_product
|
| 59 |
+
Args:
|
| 60 |
+
rot_vec: (n, 3)
|
| 61 |
+
Returns:
|
| 62 |
+
skew-symmetric matrices (n, 3, 3)
|
| 63 |
+
"""
|
| 64 |
+
basis = torch.tensor([
|
| 65 |
+
[[0., 0., 0.], [0., 0., -1.], [0., 1., 0.]],
|
| 66 |
+
[[0., 0., 1.], [0., 0., 0.], [-1., 0., 0.]],
|
| 67 |
+
[[0., -1., 0.], [1., 0., 0.], [0., 0., 0.]]
|
| 68 |
+
], device=rot_vec.device)
|
| 69 |
+
# basis = torch.tensor([
|
| 70 |
+
# [[0., 0., 0.], [0., 0., 1.], [0., -1., 0.]],
|
| 71 |
+
# [[0., 0., -1.], [0., 0., 0.], [1., 0., 0.]],
|
| 72 |
+
# [[0., 1., 0.], [-1., 0., 0.], [0., 0., 0.]]
|
| 73 |
+
# ], device=rot_vec.device)
|
| 74 |
+
|
| 75 |
+
return torch.einsum('...i,ijk->...jk', rot_vec, basis)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def inv_hat(skew_mat):
|
| 79 |
+
"""
|
| 80 |
+
Inverse of hat operation
|
| 81 |
+
Args:
|
| 82 |
+
skew_mat: skew-symmetric matrices (n, 3, 3)
|
| 83 |
+
Returns:
|
| 84 |
+
rotation vectors, (n, 3)
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
assert torch.allclose(-skew_mat, skew_mat.transpose(-2, -1), atol=1e-4), \
|
| 88 |
+
f"Input not skew-symmetric (err={(-skew_mat - skew_mat.transpose(-2, -1)).abs().max():.4g})"
|
| 89 |
+
|
| 90 |
+
# vec = torch.stack([
|
| 91 |
+
# skew_mat[:, 1, 2],
|
| 92 |
+
# skew_mat[:, 2, 1],
|
| 93 |
+
# skew_mat[:, 0, 1]
|
| 94 |
+
# ], dim=1)
|
| 95 |
+
|
| 96 |
+
vec = torch.stack([
|
| 97 |
+
skew_mat[:, 2, 1],
|
| 98 |
+
skew_mat[:, 0, 2],
|
| 99 |
+
skew_mat[:, 1, 0]
|
| 100 |
+
], dim=1)
|
| 101 |
+
|
| 102 |
+
return vec
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def matrix_from_rotation_vector(axis_angle, eps=1e-6):
|
| 106 |
+
"""
|
| 107 |
+
Args:
|
| 108 |
+
axis_angle: (n, 3)
|
| 109 |
+
Returns:
|
| 110 |
+
rotation matrices, (n, 3, 3)
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
axis_angle = regularize(axis_angle)
|
| 114 |
+
angle = axis_angle.norm(dim=-1)
|
| 115 |
+
_norm = torch.clamp(angle, min=eps).unsqueeze(-1)
|
| 116 |
+
skew_mat = hat(axis_angle / _norm)
|
| 117 |
+
|
| 118 |
+
# https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula#Matrix_notation
|
| 119 |
+
_id = torch.eye(3, device=axis_angle.device).unsqueeze(0)
|
| 120 |
+
rot_mat = _id + \
|
| 121 |
+
torch.sin(angle)[:, None, None] * skew_mat + \
|
| 122 |
+
(1 - torch.cos(angle))[:, None, None] * torch.bmm(skew_mat, skew_mat)
|
| 123 |
+
|
| 124 |
+
return rot_mat
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class safe_acos(torch.autograd.Function):
|
| 128 |
+
"""
|
| 129 |
+
Implementation of arccos that avoids NaN in backward pass.
|
| 130 |
+
https://github.com/pytorch/pytorch/issues/8069#issuecomment-2041223872
|
| 131 |
+
"""
|
| 132 |
+
EPS = 1e-4
|
| 133 |
+
@classmethod
|
| 134 |
+
def d_acos_dx(cls, x):
|
| 135 |
+
x = torch.clamp(x, min=-1. + cls.EPS, max=1. - cls.EPS)
|
| 136 |
+
return -1.0 / (1 - x**2).sqrt()
|
| 137 |
+
|
| 138 |
+
@staticmethod
|
| 139 |
+
def forward(ctx, input):
|
| 140 |
+
ctx.save_for_backward(input)
|
| 141 |
+
return input.acos()
|
| 142 |
+
|
| 143 |
+
@staticmethod
|
| 144 |
+
def backward(ctx, grad_output):
|
| 145 |
+
input, = ctx.saved_tensors
|
| 146 |
+
return grad_output * safe_acos.d_acos_dx(input)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def rotation_vector_from_matrix(rot_mat, approx=1e-4):
|
| 150 |
+
"""
|
| 151 |
+
Args:
|
| 152 |
+
rot_mat: (n, 3, 3)
|
| 153 |
+
approx: float, minimum angle below which an approximation will be used
|
| 154 |
+
for numerical stability
|
| 155 |
+
Returns:
|
| 156 |
+
rotation vector, (n, 3)
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
# https://en.wikipedia.org/wiki/Rotation_matrix#Conversion_from_rotation_matrix_to_axis%E2%80%93angle
|
| 160 |
+
# https://en.wikipedia.org/wiki/Axis%E2%80%93angle_representation#Log_map_from_SO(3)_to_%F0%9D%94%B0%F0%9D%94%AC(3)
|
| 161 |
+
|
| 162 |
+
# determine axis
|
| 163 |
+
skew_mat = rot_mat - rot_mat.transpose(-2, -1)
|
| 164 |
+
|
| 165 |
+
# determine the angle
|
| 166 |
+
cos_angle = 0.5 * (_batch_trace(rot_mat) - 1)
|
| 167 |
+
# arccos is only defined between -1 and 1
|
| 168 |
+
assert torch.all(cos_angle.abs() <= 1 + 1e-6)
|
| 169 |
+
cos_angle = torch.clamp(cos_angle, min=-1., max=1.)
|
| 170 |
+
# abs_angle = torch.arccos(cos_angle)
|
| 171 |
+
abs_angle = safe_acos.apply(cos_angle)
|
| 172 |
+
|
| 173 |
+
# avoid numerical instability; use sin(x) \approx x for small x
|
| 174 |
+
close_to_0 = abs_angle < approx
|
| 175 |
+
_fac = torch.empty_like(abs_angle)
|
| 176 |
+
_fac[close_to_0] = 0.5
|
| 177 |
+
_fac[~close_to_0] = 0.5 * abs_angle[~close_to_0] / torch.sin(abs_angle[~close_to_0])
|
| 178 |
+
|
| 179 |
+
axis_angle = inv_hat(_fac[:, None, None] * skew_mat)
|
| 180 |
+
return regularize(axis_angle)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def get_jacobian(point, left=True, inverse=False, eps=1e-4):
|
| 184 |
+
|
| 185 |
+
# # From Geomstats: https://geomstats.github.io/_modules/geomstats/geometry/special_orthogonal.html
|
| 186 |
+
# jacobian = so3_vector.jacobian_translation(point, left)
|
| 187 |
+
#
|
| 188 |
+
# if inverse:
|
| 189 |
+
# jacobian = torch.linalg.inv(jacobian)
|
| 190 |
+
|
| 191 |
+
# Right Jacobian defined as J_r(theta) = \partial exp([theta]_x) / \partial theta
|
| 192 |
+
# https://math.stackexchange.com/questions/301533/jacobian-involving-so3-exponential-map-logr-expm
|
| 193 |
+
# Source:
|
| 194 |
+
# Chirikjian, Gregory S. Stochastic models, information theory, and Lie
|
| 195 |
+
# groups, volume 2: Analytic methods and modern applications. Vol. 2.
|
| 196 |
+
# Springer Science & Business Media, 2011. (page 40)
|
| 197 |
+
# NOTE: the definitions of 'inverse' and 'left' in the book are the opposite
|
| 198 |
+
# of their meanings in Geomstats, whose functionality we're mimicking here.
|
| 199 |
+
# This explains the differences in the equations.
|
| 200 |
+
angle_squared = point.square().sum(-1)
|
| 201 |
+
angle = angle_squared.sqrt()
|
| 202 |
+
skew_mat = hat(point)
|
| 203 |
+
|
| 204 |
+
assert torch.all(angle <= math.pi)
|
| 205 |
+
close_to_0 = angle < eps
|
| 206 |
+
close_to_pi = (math.pi - angle) < eps
|
| 207 |
+
|
| 208 |
+
angle = angle[:, None, None]
|
| 209 |
+
angle_squared = angle_squared[:, None, None]
|
| 210 |
+
|
| 211 |
+
if inverse:
|
| 212 |
+
# _jacobian = torch.eye(3, device=point.device).unsqueeze(0) + \
|
| 213 |
+
# (1 - torch.cos(angle)) / angle_squared * skew_mat + \
|
| 214 |
+
# (angle - torch.sin(angle)) / angle ** 3 * (skew_mat @ skew_mat)
|
| 215 |
+
|
| 216 |
+
_term1 = torch.empty_like(angle)
|
| 217 |
+
_term1[close_to_0] = 0.5 # approximate with value at zero
|
| 218 |
+
_term1[~close_to_0] = (1 - torch.cos(angle)) / angle_squared
|
| 219 |
+
|
| 220 |
+
_term2 = torch.empty_like(angle)
|
| 221 |
+
_term2[close_to_0] = 1 / 6 # approximate with value at zero
|
| 222 |
+
_term2[~close_to_0] = (angle - torch.sin(angle)) / angle ** 3
|
| 223 |
+
|
| 224 |
+
jacobian = torch.eye(3, device=point.device).unsqueeze(0) + \
|
| 225 |
+
_term1 * skew_mat + _term2 * (skew_mat @ skew_mat)
|
| 226 |
+
# assert torch.allclose(jacobian, _jacobian, atol=1e-4)
|
| 227 |
+
else:
|
| 228 |
+
# _jacobian = torch.eye(3, device=point.device).unsqueeze(0) - 0.5 * skew_mat + \
|
| 229 |
+
# (1 / angle_squared - (1 + torch.cos(angle)) / (2 * angle * torch.sin(angle))) * (skew_mat @ skew_mat)
|
| 230 |
+
|
| 231 |
+
_term1 = torch.empty_like(angle)
|
| 232 |
+
_term1[close_to_0] = 1 / 12 # approximate with value at zero
|
| 233 |
+
_term1[close_to_pi] = 1 / math.pi**2 # approximate with value at pi
|
| 234 |
+
default = ~close_to_0 & ~close_to_pi
|
| 235 |
+
_term1[default] = 1 / angle_squared[default] - \
|
| 236 |
+
(1 + torch.cos(angle[default])) / (2 * angle[default] * torch.sin(angle[default]))
|
| 237 |
+
|
| 238 |
+
jacobian = torch.eye(3, device=point.device).unsqueeze(0) - \
|
| 239 |
+
0.5 * skew_mat + _term1 * (skew_mat @ skew_mat)
|
| 240 |
+
# assert torch.allclose(jacobian, _jacobian, atol=1e-4)
|
| 241 |
+
|
| 242 |
+
if left:
|
| 243 |
+
jacobian = jacobian.transpose(-2, -1)
|
| 244 |
+
|
| 245 |
+
return jacobian
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def compose_rotations(rot_vec_1, rot_vec_2):
|
| 249 |
+
rot_mat_1 = matrix_from_rotation_vector(rot_vec_1)
|
| 250 |
+
rot_mat_2 = matrix_from_rotation_vector(rot_vec_2)
|
| 251 |
+
rot_mat_out = torch.bmm(rot_mat_1, rot_mat_2)
|
| 252 |
+
return rotation_vector_from_matrix(rot_mat_out)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def exp(tangent):
|
| 256 |
+
"""
|
| 257 |
+
Exponential map at identity.
|
| 258 |
+
Args:
|
| 259 |
+
tangent: vector on the tangent space, (n, 3)
|
| 260 |
+
Returns:
|
| 261 |
+
rotation vector on the manifold, (n, 3)
|
| 262 |
+
"""
|
| 263 |
+
# rotations are already represented by rotation vectors
|
| 264 |
+
exp_from_identity = regularize(tangent)
|
| 265 |
+
return exp_from_identity
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def exp_not_from_identity(tangent_vec, base_point):
|
| 269 |
+
"""
|
| 270 |
+
Exponential map at base point.
|
| 271 |
+
Args:
|
| 272 |
+
tangent_vec: vector on the tangent plane, (n, 3)
|
| 273 |
+
base_point: base point on the manifold, (n, 3)
|
| 274 |
+
Returns:
|
| 275 |
+
new point on the manifold, (n, 3)
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
tangent_vec = regularize(tangent_vec)
|
| 279 |
+
base_point = regularize(base_point)
|
| 280 |
+
|
| 281 |
+
# Lie algebra is the tangent space at the identity element of a Lie group
|
| 282 |
+
# -> to identity
|
| 283 |
+
jacobian = get_jacobian(base_point, left=True, inverse=True)
|
| 284 |
+
tangent_vec_at_id = torch.einsum("...ij,...j->...i", jacobian, tangent_vec)
|
| 285 |
+
|
| 286 |
+
# exponential map from identity
|
| 287 |
+
exp_from_identity = exp(tangent_vec_at_id)
|
| 288 |
+
|
| 289 |
+
# -> back to base point
|
| 290 |
+
return compose_rotations(base_point, exp_from_identity)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def log(rot_vec, as_skew=False):
|
| 294 |
+
"""
|
| 295 |
+
Logarithm map from tangent space at the identity.
|
| 296 |
+
Args:
|
| 297 |
+
rot_vec: point on the manifold, (n, 3)
|
| 298 |
+
Returns:
|
| 299 |
+
vector on the tangent space, (n, 3)
|
| 300 |
+
"""
|
| 301 |
+
# rotations are already represented by rotation vectors
|
| 302 |
+
# log_from_id = regularize(rot_vec)
|
| 303 |
+
log_from_id = rot_vec
|
| 304 |
+
if as_skew:
|
| 305 |
+
log_from_id = hat(log_from_id)
|
| 306 |
+
return log_from_id
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def log_not_from_identity(point, base_point):
|
| 310 |
+
"""
|
| 311 |
+
Logarithm map of point from base point.
|
| 312 |
+
Args:
|
| 313 |
+
point: point on the manifold, (n, 3)
|
| 314 |
+
base_point: base point on the manifold, (n, 3)
|
| 315 |
+
Returns:
|
| 316 |
+
vector on the tangent plane, (n, 3)
|
| 317 |
+
"""
|
| 318 |
+
point = regularize(point)
|
| 319 |
+
base_point = regularize(base_point)
|
| 320 |
+
|
| 321 |
+
inv_base_point = -1 * base_point
|
| 322 |
+
|
| 323 |
+
point_near_id = compose_rotations(inv_base_point, point)
|
| 324 |
+
|
| 325 |
+
# logarithm map from identity
|
| 326 |
+
log_from_id = log(point_near_id)
|
| 327 |
+
|
| 328 |
+
jacobian = get_jacobian(base_point, inverse=False)
|
| 329 |
+
tangent_vec_at_id = torch.einsum("...ij,...j->...i", jacobian, log_from_id)
|
| 330 |
+
|
| 331 |
+
return tangent_vec_at_id
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
if __name__ == "__main__":
|
| 335 |
+
|
| 336 |
+
import os
|
| 337 |
+
os.environ['GEOMSTATS_BACKEND'] = "pytorch"
|
| 338 |
+
import scipy.optimize # does not seem to be imported correctly when just loading geomstats
|
| 339 |
+
default_dtype = torch.get_default_dtype()
|
| 340 |
+
from geomstats.geometry.special_orthogonal import SpecialOrthogonal
|
| 341 |
+
torch.set_default_dtype(default_dtype) # Geomstats changes default type when imported
|
| 342 |
+
|
| 343 |
+
so3_vector = SpecialOrthogonal(n=3, point_type="vector")
|
| 344 |
+
|
| 345 |
+
# decorator
|
| 346 |
+
if torch.__version__ >= '2.0.0':
|
| 347 |
+
GEOMSTATS_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 348 |
+
|
| 349 |
+
def geomstats_tensor_type(func):
|
| 350 |
+
def inner(*args, **kwargs):
|
| 351 |
+
with torch.device(GEOMSTATS_DEVICE):
|
| 352 |
+
out = func(*args, **kwargs)
|
| 353 |
+
return out
|
| 354 |
+
|
| 355 |
+
return inner
|
| 356 |
+
else:
|
| 357 |
+
GEOMSTATS_TENSOR_TYPE = 'torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatTensor'
|
| 358 |
+
|
| 359 |
+
# GEOMSTATS_TENSOR_TYPE = 'torch.cuda.DoubleTensor' if torch.cuda.is_available() else 'torch.DoubleTensor'
|
| 360 |
+
def geomstats_tensor_type(func):
|
| 361 |
+
def inner(*args, **kwargs):
|
| 362 |
+
# tensor_type_before = TODO
|
| 363 |
+
torch.set_default_tensor_type(GEOMSTATS_TENSOR_TYPE)
|
| 364 |
+
out = func(*args, **kwargs)
|
| 365 |
+
# torch.set_default_tensor_type(tensor_type_before)
|
| 366 |
+
torch.set_default_tensor_type('torch.FloatTensor')
|
| 367 |
+
return out
|
| 368 |
+
|
| 369 |
+
return inner
|
| 370 |
+
|
| 371 |
+
@geomstats_tensor_type
|
| 372 |
+
def gs_matrix_from_rotation_vector(*args, **kwargs):
|
| 373 |
+
return so3_vector.matrix_from_rotation_vector(*args, **kwargs)
|
| 374 |
+
|
| 375 |
+
@geomstats_tensor_type
|
| 376 |
+
def gs_rotation_vector_from_matrix(*args, **kwargs):
|
| 377 |
+
return so3_vector.rotation_vector_from_matrix(*args, **kwargs)
|
| 378 |
+
|
| 379 |
+
@geomstats_tensor_type
|
| 380 |
+
def gs_exp_not_from_identity(*args, **kwargs):
|
| 381 |
+
return so3_vector.exp_not_from_identity(*args, **kwargs)
|
| 382 |
+
|
| 383 |
+
@geomstats_tensor_type
|
| 384 |
+
def gs_log_not_from_identity(*args, **kwargs):
|
| 385 |
+
# norm of the rotation vector will be between 0 and pi
|
| 386 |
+
return so3_vector.log_not_from_identity(*args, **kwargs)
|
| 387 |
+
|
| 388 |
+
@geomstats_tensor_type
|
| 389 |
+
def compose(*args, **kwargs):
|
| 390 |
+
return so3_vector.compose(*args, **kwargs)
|
| 391 |
+
|
| 392 |
+
@geomstats_tensor_type
|
| 393 |
+
def inverse(*args, **kwargs):
|
| 394 |
+
return so3_vector.inverse(*args, **kwargs)
|
| 395 |
+
|
| 396 |
+
@geomstats_tensor_type
|
| 397 |
+
def gs_random_uniform(*args, **kwargs):
|
| 398 |
+
return so3_vector.random_uniform(*args, **kwargs)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
#############
|
| 402 |
+
# RUN TESTS #
|
| 403 |
+
#############
|
| 404 |
+
|
| 405 |
+
n = 16
|
| 406 |
+
device = 'cuda' if torch.cuda.is_available() else None
|
| 407 |
+
|
| 408 |
+
### regularize ###
|
| 409 |
+
|
| 410 |
+
# vec = (torch.rand(n, 3) * 2 - 1) * math.pi
|
| 411 |
+
vec = (torch.rand(n, 3) * 4 - 2) * math.pi
|
| 412 |
+
axis_angle = regularize(vec)
|
| 413 |
+
assert torch.all(torch.cross(vec, axis_angle).norm(dim=-1) < 1e-5), "not all vectors collinear"
|
| 414 |
+
assert torch.all(axis_angle.norm(dim=-1) < math.pi) & torch.all(axis_angle.norm(dim=-1) >= 0), "norm not between 0 and pi"
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
### matrix_from_rotation_vector ###
|
| 418 |
+
|
| 419 |
+
rot_vec = random_uniform(16, device=device)
|
| 420 |
+
assert torch.allclose(matrix_from_rotation_vector(rot_vec),
|
| 421 |
+
gs_matrix_from_rotation_vector(rot_vec), atol=1e-06)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
### rotation_vector_from_matrix ###
|
| 425 |
+
|
| 426 |
+
rot_vec = random_uniform(16, device=device)
|
| 427 |
+
rot_mat = matrix_from_rotation_vector(rot_vec)
|
| 428 |
+
assert torch.allclose(rotation_vector_from_matrix(rot_mat),
|
| 429 |
+
gs_rotation_vector_from_matrix(rot_mat), atol=1e-05)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
### exp_not_from_identity ###
|
| 433 |
+
|
| 434 |
+
tangent_vec = random_uniform(16, device=device)
|
| 435 |
+
base_pt = random_uniform(16, device=device)
|
| 436 |
+
my_val = exp_not_from_identity(tangent_vec, base_pt)
|
| 437 |
+
gs_val = gs_exp_not_from_identity(tangent_vec, base_pt)
|
| 438 |
+
assert torch.allclose(my_val, gs_val, atol=1e-03), (my_val - gs_val).abs().max()
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
### log_not_from_identity ###
|
| 442 |
+
|
| 443 |
+
pt = random_uniform(16, device=device)
|
| 444 |
+
base_pt = random_uniform(16, device=device)
|
| 445 |
+
my_val = log_not_from_identity(pt, base_pt)
|
| 446 |
+
gs_val = gs_log_not_from_identity(pt, base_pt)
|
| 447 |
+
assert torch.allclose(my_val, gs_val, atol=1e-03), (my_val - gs_val).abs().max()
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
print("All tests successful!")
|
src/default/size_distribution.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d4e677a30c4b972051499bb5577a0de773e4f92ec54c282d432f94873406ec7e
|
| 3 |
+
size 158488
|
src/generate.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import warnings
|
| 5 |
+
import tempfile
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
from Bio.PDB import PDBParser
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from rdkit import Chem
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from functools import partial
|
| 13 |
+
|
| 14 |
+
basedir = Path(__file__).resolve().parent.parent
|
| 15 |
+
sys.path.append(str(basedir))
|
| 16 |
+
warnings.filterwarnings("ignore")
|
| 17 |
+
|
| 18 |
+
from src import utils
|
| 19 |
+
from src.data.dataset import ProcessedLigandPocketDataset
|
| 20 |
+
from src.data.data_utils import TensorDict, process_raw_pair
|
| 21 |
+
from src.model.lightning import DrugFlow
|
| 22 |
+
from src.sbdd_metrics.metrics import FullEvaluator
|
| 23 |
+
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
from pdb import set_trace
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def aggregate_metrics(table):
|
| 29 |
+
agg_col = 'posebusters'
|
| 30 |
+
total = 0
|
| 31 |
+
table[agg_col] = 0
|
| 32 |
+
for column in table.columns:
|
| 33 |
+
if column.startswith(agg_col) and column != agg_col:
|
| 34 |
+
table[agg_col] += table[column].fillna(0).astype(float)
|
| 35 |
+
total += 1
|
| 36 |
+
table[agg_col] = table[agg_col] / total
|
| 37 |
+
|
| 38 |
+
agg_col = 'reos'
|
| 39 |
+
total = 0
|
| 40 |
+
table[agg_col] = 0
|
| 41 |
+
for column in table.columns:
|
| 42 |
+
if column.startswith(agg_col) and column != agg_col:
|
| 43 |
+
table[agg_col] += table[column].fillna(0).astype(float)
|
| 44 |
+
total += 1
|
| 45 |
+
table[agg_col] = table[agg_col] / total
|
| 46 |
+
|
| 47 |
+
agg_col = 'chembl_ring_systems'
|
| 48 |
+
total = 0
|
| 49 |
+
table[agg_col] = 0
|
| 50 |
+
for column in table.columns:
|
| 51 |
+
if column.startswith(agg_col) and column != agg_col and not column.endswith('smi'):
|
| 52 |
+
table[agg_col] += table[column].fillna(0).astype(float)
|
| 53 |
+
total += 1
|
| 54 |
+
table[agg_col] = table[agg_col] / total
|
| 55 |
+
return table
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
p = argparse.ArgumentParser()
|
| 60 |
+
p.add_argument('--protein', type=str, required=True, help="Input PDB file.")
|
| 61 |
+
p.add_argument('--ref_ligand', type=str, required=True, help="SDF file with reference ligand used to define the pocket.")
|
| 62 |
+
p.add_argument('--checkpoint', type=str, required=True, help="Model checkpoint file.")
|
| 63 |
+
p.add_argument('--molecule_size', type=str, required=False, default=None, help="Maximum number of atoms in the sampled molecules. Can be a single number or a range, e.g. '15,20'. If None, size will be sampled.")
|
| 64 |
+
p.add_argument('--output', type=str, required=False, default='samples.sdf', help="Output file.")
|
| 65 |
+
p.add_argument('--n_samples', type=int, required=False, default=10, help="Number of sampled molecules.")
|
| 66 |
+
p.add_argument('--batch_size', type=int, required=False, default=32, help="Batch size.")
|
| 67 |
+
p.add_argument('--pocket_distance_cutoff', type=float, required=False, default=8.0, help="Distance cutoff to define the pocket around the reference ligand.")
|
| 68 |
+
p.add_argument('--n_steps', type=int, required=False, default=None, help="Number of denoising steps.")
|
| 69 |
+
p.add_argument('--device', type=str, required=False, default='cuda:0', help="Device to use.")
|
| 70 |
+
p.add_argument('--datadir', type=Path, required=False, default=Path(basedir, 'src', 'default'), help="Needs to be specified to sample molecule sizes.")
|
| 71 |
+
p.add_argument('--seed', type=int, required=False, default=42, help="Random seed.")
|
| 72 |
+
p.add_argument('--filter', action='store_true', required=False, default=False, help="Apply basic filters and keep sampling until `n_samples` molecules passing these filters are found.")
|
| 73 |
+
p.add_argument('--metrics_output', type=str, required=False, default=None, help="If provided, metrics will be computed and saved in csv format at this location.")
|
| 74 |
+
p.add_argument('--gnina', type=str, required=False, default=None, help="Path to a gnina executable. Required for computing docking scores.")
|
| 75 |
+
p.add_argument('--reduce', type=str, required=False, default=None, help="Path to a reduce executable. Required for computing interactions.")
|
| 76 |
+
args = p.parse_args()
|
| 77 |
+
|
| 78 |
+
utils.set_deterministic(seed=args.seed)
|
| 79 |
+
utils.disable_rdkit_logging()
|
| 80 |
+
|
| 81 |
+
if args.molecule_size is None and (args.datadir is None or not args.datadir.exists()):
|
| 82 |
+
raise NotImplementedError(
|
| 83 |
+
"Please provide a path to the processed dataset (using `--datadir`) "\
|
| 84 |
+
"to infer the number of nodes. It contains the size distribution histogram."
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
if not args.filter:
|
| 88 |
+
args.batch_size = min(args.batch_size, args.n_samples)
|
| 89 |
+
|
| 90 |
+
# Loading model
|
| 91 |
+
chkpt_path = Path(args.checkpoint)
|
| 92 |
+
chkpt_name = chkpt_path.parts[-1].split('.')[0]
|
| 93 |
+
model = DrugFlow.load_from_checkpoint(args.checkpoint, map_location=args.device, strict=False)
|
| 94 |
+
if args.datadir is not None:
|
| 95 |
+
model.datadir = args.datadir
|
| 96 |
+
|
| 97 |
+
model.setup(stage='generation')
|
| 98 |
+
model.batch_size = model.eval_batch_size = args.batch_size
|
| 99 |
+
model.eval().to(args.device)
|
| 100 |
+
if args.n_steps is not None:
|
| 101 |
+
model.T = args.n_steps
|
| 102 |
+
|
| 103 |
+
# Loading size model
|
| 104 |
+
size_model = None
|
| 105 |
+
molecule_size = None
|
| 106 |
+
molecule_size_boundaries = None
|
| 107 |
+
if args.molecule_size is not None:
|
| 108 |
+
if args.molecule_size.isdigit():
|
| 109 |
+
molecule_size = int(args.molecule_size)
|
| 110 |
+
print(f'Will generate molecules of size {molecule_size}')
|
| 111 |
+
else:
|
| 112 |
+
boundaries = [x.strip() for x in args.molecule_size.split(',')]
|
| 113 |
+
assert len(boundaries) == 2 and boundaries[0].isdigit() and boundaries[1].isdigit()
|
| 114 |
+
left = int(boundaries[0])
|
| 115 |
+
right = int(boundaries[1])
|
| 116 |
+
molecule_size = f"uniform_{left}_{right}"
|
| 117 |
+
print(f'Will generate molecules with numbers of atoms sampled from U({left}, {right})')
|
| 118 |
+
|
| 119 |
+
# Preparing input
|
| 120 |
+
pdb_model = PDBParser(QUIET=True).get_structure('', args.protein)[0]
|
| 121 |
+
rdmol = Chem.SDMolSupplier(str(args.ref_ligand))[0]
|
| 122 |
+
|
| 123 |
+
ligand, pocket = process_raw_pair(
|
| 124 |
+
pdb_model, rdmol,
|
| 125 |
+
dist_cutoff=args.pocket_distance_cutoff,
|
| 126 |
+
pocket_representation=model.pocket_representation,
|
| 127 |
+
compute_nerf_params=True,
|
| 128 |
+
nma_input=args.protein if model.dynamics.add_nma_feat else None
|
| 129 |
+
)
|
| 130 |
+
ligand['name'] = 'ligand'
|
| 131 |
+
dataset = [{'ligand': ligand, 'pocket': pocket} for _ in range(args.batch_size)]
|
| 132 |
+
dataloader = DataLoader(
|
| 133 |
+
dataset=dataset,
|
| 134 |
+
batch_size=args.batch_size,
|
| 135 |
+
collate_fn=partial(ProcessedLigandPocketDataset.collate_fn, ligand_transform=None),
|
| 136 |
+
pin_memory=True
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Start sampling
|
| 140 |
+
smiles = set()
|
| 141 |
+
sampled_molecules = []
|
| 142 |
+
metrics = []
|
| 143 |
+
Path(args.output).parent.absolute().mkdir(parents=True, exist_ok=True)
|
| 144 |
+
print(f'Will generate {args.n_samples} samples')
|
| 145 |
+
|
| 146 |
+
evaluator = FullEvaluator(gnina=args.gnina, reduce=args.reduce)
|
| 147 |
+
|
| 148 |
+
with tqdm(total=args.n_samples) as pbar:
|
| 149 |
+
while len(sampled_molecules) < args.n_samples:
|
| 150 |
+
for i, data in enumerate(dataloader):
|
| 151 |
+
new_data = {
|
| 152 |
+
'ligand': TensorDict(**data['ligand']).to(args.device),
|
| 153 |
+
'pocket': TensorDict(**data['pocket']).to(args.device),
|
| 154 |
+
}
|
| 155 |
+
rdmols, rdpockets, _ = model.sample(
|
| 156 |
+
new_data,
|
| 157 |
+
n_samples=1,
|
| 158 |
+
timesteps=args.n_steps,
|
| 159 |
+
num_nodes=molecule_size,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
if args.filter or (args.metrics_output is not None):
|
| 163 |
+
results = []
|
| 164 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 165 |
+
for mol, receptor in zip(rdmols, rdpockets):
|
| 166 |
+
receptor_path = Path(tmpdir, 'receptor.pdb')
|
| 167 |
+
Chem.MolToPDBFile(receptor, str(receptor_path))
|
| 168 |
+
results.append(evaluator(mol, receptor_path))
|
| 169 |
+
|
| 170 |
+
table = pd.DataFrame(results)
|
| 171 |
+
table['novel'] = ~table['representation.smiles'].isin(smiles)
|
| 172 |
+
table = aggregate_metrics(table)
|
| 173 |
+
|
| 174 |
+
added_molecules = 0
|
| 175 |
+
if args.filter:
|
| 176 |
+
table['passed_filters'] = (
|
| 177 |
+
(table['posebusters'] == 1) &
|
| 178 |
+
# (table['reos'] == 1) &
|
| 179 |
+
(table['chembl_ring_systems'] == 1) &
|
| 180 |
+
(table['novel'] == 1)
|
| 181 |
+
)
|
| 182 |
+
for i, (passed, smi) in enumerate(table[['passed_filters', 'representation.smiles']].values):
|
| 183 |
+
if passed:
|
| 184 |
+
sampled_molecules.append(rdmols[i])
|
| 185 |
+
smiles.add(smi)
|
| 186 |
+
added_molecules += 1
|
| 187 |
+
|
| 188 |
+
if args.metrics_output is not None:
|
| 189 |
+
metrics.append(table[table['passed_filters']])
|
| 190 |
+
|
| 191 |
+
else:
|
| 192 |
+
sampled_molecules.extend(rdmols)
|
| 193 |
+
added_molecules = len(rdmols)
|
| 194 |
+
if args.metrics_output is not None:
|
| 195 |
+
metrics.append(table)
|
| 196 |
+
|
| 197 |
+
pbar.update(added_molecules)
|
| 198 |
+
|
| 199 |
+
# Write results
|
| 200 |
+
utils.write_sdf_file(args.output, sampled_molecules)
|
| 201 |
+
|
| 202 |
+
if args.metrics_output is not None:
|
| 203 |
+
metrics = pd.concat(metrics)
|
| 204 |
+
metrics.to_csv(args.metrics_output, index=False)
|
src/model/diffusion_utils.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DistributionNodes:
|
| 8 |
+
def __init__(self, histogram):
|
| 9 |
+
|
| 10 |
+
histogram = torch.tensor(histogram).float()
|
| 11 |
+
histogram = histogram + 1e-3 # for numerical stability
|
| 12 |
+
|
| 13 |
+
prob = histogram / histogram.sum()
|
| 14 |
+
|
| 15 |
+
self.idx_to_n_nodes = torch.tensor(
|
| 16 |
+
[[(i, j) for j in range(prob.shape[1])] for i in range(prob.shape[0])]
|
| 17 |
+
).view(-1, 2)
|
| 18 |
+
|
| 19 |
+
self.n_nodes_to_idx = {tuple(x.tolist()): i
|
| 20 |
+
for i, x in enumerate(self.idx_to_n_nodes)}
|
| 21 |
+
|
| 22 |
+
self.prob = prob
|
| 23 |
+
self.m = torch.distributions.Categorical(self.prob.view(-1),
|
| 24 |
+
validate_args=True)
|
| 25 |
+
|
| 26 |
+
self.n1_given_n2 = \
|
| 27 |
+
[torch.distributions.Categorical(prob[:, j], validate_args=True)
|
| 28 |
+
for j in range(prob.shape[1])]
|
| 29 |
+
self.n2_given_n1 = \
|
| 30 |
+
[torch.distributions.Categorical(prob[i, :], validate_args=True)
|
| 31 |
+
for i in range(prob.shape[0])]
|
| 32 |
+
|
| 33 |
+
# entropy = -torch.sum(self.prob.view(-1) * torch.log(self.prob.view(-1) + 1e-30))
|
| 34 |
+
# entropy = self.m.entropy()
|
| 35 |
+
# print("Entropy of n_nodes: H[N]", entropy.item())
|
| 36 |
+
|
| 37 |
+
def sample(self, n_samples=1):
|
| 38 |
+
idx = self.m.sample((n_samples,))
|
| 39 |
+
num_nodes_lig, num_nodes_pocket = self.idx_to_n_nodes[idx].T
|
| 40 |
+
return num_nodes_lig, num_nodes_pocket
|
| 41 |
+
|
| 42 |
+
def sample_conditional(self, n1=None, n2=None):
|
| 43 |
+
assert (n1 is None) ^ (n2 is None), \
|
| 44 |
+
"Exactly one input argument must be None"
|
| 45 |
+
|
| 46 |
+
m = self.n1_given_n2 if n2 is not None else self.n2_given_n1
|
| 47 |
+
c = n2 if n2 is not None else n1
|
| 48 |
+
|
| 49 |
+
return torch.tensor([m[i].sample() for i in c], device=c.device)
|
| 50 |
+
|
| 51 |
+
def log_prob(self, batch_n_nodes_1, batch_n_nodes_2):
|
| 52 |
+
assert len(batch_n_nodes_1.size()) == 1
|
| 53 |
+
assert len(batch_n_nodes_2.size()) == 1
|
| 54 |
+
|
| 55 |
+
idx = torch.tensor(
|
| 56 |
+
[self.n_nodes_to_idx[(n1, n2)]
|
| 57 |
+
for n1, n2 in zip(batch_n_nodes_1.tolist(), batch_n_nodes_2.tolist())]
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# log_probs = torch.log(self.prob.view(-1)[idx] + 1e-30)
|
| 61 |
+
log_probs = self.m.log_prob(idx)
|
| 62 |
+
|
| 63 |
+
return log_probs.to(batch_n_nodes_1.device)
|
| 64 |
+
|
| 65 |
+
def log_prob_n1_given_n2(self, n1, n2):
|
| 66 |
+
assert len(n1.size()) == 1
|
| 67 |
+
assert len(n2.size()) == 1
|
| 68 |
+
log_probs = torch.stack([self.n1_given_n2[c].log_prob(i.cpu())
|
| 69 |
+
for i, c in zip(n1, n2)])
|
| 70 |
+
return log_probs.to(n1.device)
|
| 71 |
+
|
| 72 |
+
def log_prob_n2_given_n1(self, n2, n1):
|
| 73 |
+
assert len(n2.size()) == 1
|
| 74 |
+
assert len(n1.size()) == 1
|
| 75 |
+
log_probs = torch.stack([self.n2_given_n1[c].log_prob(i.cpu())
|
| 76 |
+
for i, c in zip(n2, n1)])
|
| 77 |
+
return log_probs.to(n2.device)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def cosine_beta_schedule_midi(timesteps, s=0.008, nu=1.0, clip=False):
|
| 81 |
+
"""
|
| 82 |
+
Modified cosine schedule as proposed in https://arxiv.org/abs/2302.09048.
|
| 83 |
+
Note: we use (t/T)^\nu not (t/T + s)^\nu as written in the MiDi paper
|
| 84 |
+
We also divide by alphas_cumprod[0] as the original cosine schedule from
|
| 85 |
+
https://arxiv.org/abs/2102.09672
|
| 86 |
+
"""
|
| 87 |
+
device = nu.device if torch.is_tensor(nu) else None
|
| 88 |
+
x = torch.linspace(0, timesteps, timesteps + 1, device=device)
|
| 89 |
+
alphas_cumprod = torch.cos(0.5 * np.pi * ((x / timesteps)**nu + s) / (1 + s)) ** 2
|
| 90 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
| 91 |
+
|
| 92 |
+
if clip:
|
| 93 |
+
alphas_cumprod = torch.cat([torch.tensor([1.0], device=alphas_cumprod.device), alphas_cumprod])
|
| 94 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
| 95 |
+
betas = torch.clip(betas, min=0, max=0.999)
|
| 96 |
+
alphas = 1. - betas
|
| 97 |
+
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
| 98 |
+
return alphas_cumprod
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class CosineSchedule(torch.nn.Module):
|
| 102 |
+
"""
|
| 103 |
+
nu=1.0 corresponds to the standard cosine schedule
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def __init__(self, timesteps, nu=1.0, trainable=False, clip_alpha2_step=0.001):
|
| 107 |
+
super(CosineSchedule, self).__init__()
|
| 108 |
+
self.timesteps = timesteps
|
| 109 |
+
self.trainable = trainable
|
| 110 |
+
self.nu = nu
|
| 111 |
+
assert 0.0 <= clip_alpha2_step < 1.0
|
| 112 |
+
self.clip = clip_alpha2_step
|
| 113 |
+
|
| 114 |
+
if self.trainable:
|
| 115 |
+
self.nu = torch.nn.Parameter(torch.Tensor([nu]), requires_grad=True)
|
| 116 |
+
else:
|
| 117 |
+
self._alpha2 = self.alphas2
|
| 118 |
+
self._gamma = torch.nn.Parameter(self.gammas, requires_grad=False)
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def alphas2(self):
|
| 122 |
+
"""
|
| 123 |
+
Cumulative alpha squared.
|
| 124 |
+
Called alpha_bar in: Nichol, Alexander Quinn, and Prafulla Dhariwal.
|
| 125 |
+
"Improved denoising diffusion probabilistic models." PMLR, 2021.
|
| 126 |
+
"""
|
| 127 |
+
if hasattr(self, '_alpha2'):
|
| 128 |
+
return self._alpha2
|
| 129 |
+
|
| 130 |
+
assert isinstance(self.nu, float) or ~self.nu.isnan()
|
| 131 |
+
|
| 132 |
+
# our alpha is eqivalent to sqrt(alpha) from https://arxiv.org/abs/2102.09672, where the cosine schedule was introduced
|
| 133 |
+
alphas2 = cosine_beta_schedule_midi(self.timesteps, nu=self.nu, clip=False)
|
| 134 |
+
|
| 135 |
+
# avoid singularities near t=T
|
| 136 |
+
alphas2 = torch.cat([torch.tensor([1.0], device=alphas2.device), alphas2])
|
| 137 |
+
alphas2_step = alphas2[1:] / alphas2[:-1]
|
| 138 |
+
alphas2_step = torch.clip(alphas2_step, min=self.clip, max=1.0)
|
| 139 |
+
alphas2 = torch.cumprod(alphas2_step, dim=0)
|
| 140 |
+
|
| 141 |
+
return alphas2
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def alphas2_t_given_tminus1(self):
|
| 145 |
+
"""
|
| 146 |
+
Alphas for a single transition
|
| 147 |
+
"""
|
| 148 |
+
alphas2 = torch.cat([torch.tensor([1.0]), self.alphas2])
|
| 149 |
+
return alphas2[1:] / alphas2[:-1]
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def gammas(self):
|
| 153 |
+
"""
|
| 154 |
+
Gammas as defined in appendix B of the EDM paper
|
| 155 |
+
gamma_t = -(log alpha_t^2 - log sigma_t^2)
|
| 156 |
+
"""
|
| 157 |
+
if hasattr(self, '_gamma'):
|
| 158 |
+
return self._gamma
|
| 159 |
+
|
| 160 |
+
alphas2 = self.alphas2
|
| 161 |
+
sigmas2 = 1 - alphas2
|
| 162 |
+
|
| 163 |
+
gammas = -(torch.log(alphas2) - torch.log(sigmas2))
|
| 164 |
+
|
| 165 |
+
return gammas.float()
|
| 166 |
+
|
| 167 |
+
def forward(self, t):
|
| 168 |
+
t_int = torch.round(t * self.timesteps).long()
|
| 169 |
+
return self.gammas[t_int]
|
| 170 |
+
|
| 171 |
+
@staticmethod
|
| 172 |
+
def alpha(gamma):
|
| 173 |
+
""" Computes alpha given gamma. """
|
| 174 |
+
return torch.sqrt(torch.sigmoid(-gamma))
|
| 175 |
+
|
| 176 |
+
@staticmethod
|
| 177 |
+
def sigma(gamma):
|
| 178 |
+
""" Computes sigma given gamma. """
|
| 179 |
+
return torch.sqrt(torch.sigmoid(gamma))
|
| 180 |
+
|
| 181 |
+
@staticmethod
|
| 182 |
+
def SNR(gamma):
|
| 183 |
+
""" Computes signal to noise ratio (alpha^2/sigma^2) given gamma. """
|
| 184 |
+
return torch.exp(-gamma)
|
| 185 |
+
|
| 186 |
+
def sigma_and_alpha_t_given_s(self, gamma_t: torch.Tensor, gamma_s: torch.Tensor):
|
| 187 |
+
"""
|
| 188 |
+
Computes sigma_t_given_s, using gamma_t and gamma_s. Used during sampling.
|
| 189 |
+
These are defined as:
|
| 190 |
+
alpha_t_given_s = alpha_t / alpha_s,
|
| 191 |
+
sigma_t_given_s = sqrt(1 - (alpha_t_given_s)^2 ).
|
| 192 |
+
"""
|
| 193 |
+
sigma2_t_given_s = -torch.expm1(
|
| 194 |
+
F.softplus(gamma_s) - F.softplus(gamma_t))
|
| 195 |
+
|
| 196 |
+
# alpha_t_given_s = alpha_t / alpha_s
|
| 197 |
+
log_alpha2_t = F.logsigmoid(-gamma_t)
|
| 198 |
+
log_alpha2_s = F.logsigmoid(-gamma_s)
|
| 199 |
+
log_alpha2_t_given_s = log_alpha2_t - log_alpha2_s
|
| 200 |
+
|
| 201 |
+
alpha_t_given_s = torch.exp(0.5 * log_alpha2_t_given_s)
|
| 202 |
+
alpha_t_given_s = torch.clip(alpha_t_given_s, min=self.clip ** 0.5, max=1.0)
|
| 203 |
+
|
| 204 |
+
sigma_t_given_s = torch.sqrt(sigma2_t_given_s)
|
| 205 |
+
|
| 206 |
+
return sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s
|
src/model/dpo.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from contextlib import nullcontext
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch_scatter import scatter_mean
|
| 8 |
+
|
| 9 |
+
from src.constants import atom_encoder, bond_encoder
|
| 10 |
+
from src.model.lightning import DrugFlow, set_default
|
| 11 |
+
from src.data.dataset import ProcessedLigandPocketDataset, DPODataset
|
| 12 |
+
from src.data.data_utils import AppendVirtualNodesInCoM, Residues, center_data
|
| 13 |
+
|
| 14 |
+
class DPO(DrugFlow):
|
| 15 |
+
def __init__(self, dpo_mode, ref_checkpoint_p, **kwargs):
|
| 16 |
+
super(DPO, self).__init__(**kwargs)
|
| 17 |
+
self.dpo_mode = dpo_mode
|
| 18 |
+
self.dpo_beta = kwargs['loss_params'].dpo_beta if 'dpo_beta' in kwargs['loss_params'] else 100.0
|
| 19 |
+
self.dpo_beta_schedule = kwargs['loss_params'].dpo_beta_schedule if 'dpo_beta_schedule' in kwargs['loss_params'] else 't'
|
| 20 |
+
self.clamp_dpo = kwargs['loss_params'].clamp_dpo if 'clamp_dpo' in kwargs['loss_params'] else True
|
| 21 |
+
self.dpo_lambda_dpo = kwargs['loss_params'].dpo_lambda_dpo if 'dpo_lambda_dpo' in kwargs['loss_params'] else 1
|
| 22 |
+
self.dpo_lambda_w = kwargs['loss_params'].dpo_lambda_w if 'dpo_lambda_w' in kwargs['loss_params'] else 1
|
| 23 |
+
self.dpo_lambda_l = kwargs['loss_params'].dpo_lambda_l if 'dpo_lambda_l' in kwargs['loss_params'] else 0.2
|
| 24 |
+
self.dpo_lambda_h = kwargs['loss_params'].dpo_lambda_h if 'dpo_lambda_h' in kwargs['loss_params'] else kwargs['loss_params'].lambda_h
|
| 25 |
+
self.dpo_lambda_e = kwargs['loss_params'].dpo_lambda_e if 'dpo_lambda_e' in kwargs['loss_params'] else kwargs['loss_params'].lambda_e
|
| 26 |
+
self.ref_dynamics = self.init_model(kwargs['predictor_params'])
|
| 27 |
+
state_dict = torch.load(ref_checkpoint_p)['state_dict']
|
| 28 |
+
self.ref_dynamics.load_state_dict({k.replace('dynamics.',''): v for k, v in state_dict.items() if k.startswith('dynamics.')})
|
| 29 |
+
print(f'Loaded reference model from {ref_checkpoint_p}')
|
| 30 |
+
# initializing model params with ref model params
|
| 31 |
+
self.dynamics.load_state_dict(self.ref_dynamics.state_dict())
|
| 32 |
+
|
| 33 |
+
def get_dataset(self, stage, pocket_transform=None):
|
| 34 |
+
|
| 35 |
+
# when sampling we don't append virtual nodes as we might need access to the ground truth size
|
| 36 |
+
if self.virtual_nodes and stage == 'train':
|
| 37 |
+
ligand_transform = AppendVirtualNodesInCoM(
|
| 38 |
+
atom_encoder, bond_encoder, add_min=self.add_virtual_min, add_max=self.add_virtual_max)
|
| 39 |
+
else:
|
| 40 |
+
ligand_transform = None
|
| 41 |
+
|
| 42 |
+
# we want to know if something goes wrong on the validation or test set
|
| 43 |
+
catch_errors = stage == 'train'
|
| 44 |
+
|
| 45 |
+
if self.sharded_dataset:
|
| 46 |
+
raise NotImplementedError('Sharded dataset not implemented for DPO')
|
| 47 |
+
|
| 48 |
+
if self.sample_from_clusters and stage == 'train': # val/test should be deterministic
|
| 49 |
+
raise NotImplementedError('Sampling from clusters not implemented for DPO')
|
| 50 |
+
|
| 51 |
+
if stage == 'train':
|
| 52 |
+
return DPODataset(
|
| 53 |
+
Path(self.datadir, 'train.pt'),
|
| 54 |
+
ligand_transform=None,
|
| 55 |
+
pocket_transform=pocket_transform,
|
| 56 |
+
catch_errors=True,
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
return ProcessedLigandPocketDataset(
|
| 60 |
+
pt_path=Path(self.datadir, 'val.pt' if self.debug else f'{stage}.pt'),
|
| 61 |
+
ligand_transform=ligand_transform,
|
| 62 |
+
pocket_transform=pocket_transform,
|
| 63 |
+
catch_errors=catch_errors,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def training_step(self, data, *args):
|
| 68 |
+
ligand_w, ligand_l, pocket = data['ligand'], data['ligand_l'], data['pocket']
|
| 69 |
+
loss, info = self.compute_dpo_loss(pocket, ligand_w=ligand_w, ligand_l=ligand_l, return_info=True)
|
| 70 |
+
|
| 71 |
+
if torch.isnan(loss):
|
| 72 |
+
print(f'For ligand pair , loss is NaN at epoch {self.current_epoch}. Info: {info}')
|
| 73 |
+
|
| 74 |
+
log_dict = {k: v for k, v in info.items() if isinstance(v, float) or torch.numel(v) <= 1}
|
| 75 |
+
self.log_metrics({'loss': loss, **log_dict}, 'train', batch_size=len(ligand_w['size']))
|
| 76 |
+
|
| 77 |
+
out = {'loss': loss, **info}
|
| 78 |
+
self.training_step_outputs.append(out)
|
| 79 |
+
return out
|
| 80 |
+
|
| 81 |
+
def validation_step(self, data, *args):
|
| 82 |
+
return super().validation_step(data, *args)
|
| 83 |
+
|
| 84 |
+
def compute_dpo_loss(self, pocket, ligand_w, ligand_l, return_info=False):
|
| 85 |
+
t = torch.rand(ligand_w['size'].size(0), device=ligand_w['x'].device).unsqueeze(-1)
|
| 86 |
+
|
| 87 |
+
if self.dpo_beta_schedule == 't':
|
| 88 |
+
# from https://arxiv.org/pdf/2407.13981
|
| 89 |
+
beta_t = (self.dpo_beta * t).squeeze()
|
| 90 |
+
elif self.dpo_beta_schedule == 'const':
|
| 91 |
+
beta_t = self.dpo_beta
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f'Unknown DPO beta schedule: {self.dpo_beta_schedule}')
|
| 94 |
+
|
| 95 |
+
loss_dict_w = self.compute_loss_single_pair(ligand_w, pocket, t)
|
| 96 |
+
loss_dict_l = self.compute_loss_single_pair(ligand_l, pocket, t)
|
| 97 |
+
info = {
|
| 98 |
+
'loss_x_w': loss_dict_w['theta']['x'].mean().item(),
|
| 99 |
+
'loss_h_w': loss_dict_w['theta']['h'].mean().item(),
|
| 100 |
+
'loss_e_w': loss_dict_w['theta']['e'].mean().item(),
|
| 101 |
+
'loss_x_l': loss_dict_l['theta']['x'].mean().item(),
|
| 102 |
+
'loss_h_l': loss_dict_l['theta']['h'].mean().item(),
|
| 103 |
+
'loss_e_l': loss_dict_l['theta']['e'].mean().item(),
|
| 104 |
+
}
|
| 105 |
+
if self.dpo_mode == 'single_dpo_comp':
|
| 106 |
+
loss_w_theta = (
|
| 107 |
+
loss_dict_w['theta']['x'] +
|
| 108 |
+
self.dpo_lambda_h * loss_dict_w['theta']['h'] +
|
| 109 |
+
self.dpo_lambda_e * loss_dict_w['theta']['e']
|
| 110 |
+
)
|
| 111 |
+
loss_w_ref = (
|
| 112 |
+
loss_dict_w['ref']['x'] +
|
| 113 |
+
self.dpo_lambda_h * loss_dict_w['ref']['h'] +
|
| 114 |
+
self.dpo_lambda_e * loss_dict_w['ref']['e']
|
| 115 |
+
)
|
| 116 |
+
loss_l_theta = (
|
| 117 |
+
loss_dict_l['theta']['x'] +
|
| 118 |
+
self.dpo_lambda_h * loss_dict_l['theta']['h'] +
|
| 119 |
+
self.dpo_lambda_e * loss_dict_l['theta']['e']
|
| 120 |
+
)
|
| 121 |
+
loss_l_ref = (
|
| 122 |
+
loss_dict_l['ref']['x'] +
|
| 123 |
+
self.dpo_lambda_h * loss_dict_l['ref']['h'] +
|
| 124 |
+
self.dpo_lambda_e * loss_dict_l['ref']['e']
|
| 125 |
+
)
|
| 126 |
+
diff_w = loss_w_theta - loss_w_ref
|
| 127 |
+
diff_l = loss_l_theta - loss_l_ref
|
| 128 |
+
info['diff_w'] = diff_w.mean().item()
|
| 129 |
+
info['diff_l'] = diff_l.mean().item()
|
| 130 |
+
# print(diff)
|
| 131 |
+
diff = -1 * beta_t * (diff_w - diff_l)
|
| 132 |
+
loss = -1 * F.logsigmoid(diff)
|
| 133 |
+
elif self.dpo_mode == 'single_dpo_comp_v3':
|
| 134 |
+
diff_w_x = loss_dict_w['theta']['x'] - loss_dict_w['ref']['x']
|
| 135 |
+
diff_w_h = loss_dict_w['theta']['h'] - loss_dict_w['ref']['h']
|
| 136 |
+
diff_w_e = loss_dict_w['theta']['e'] - loss_dict_w['ref']['e']
|
| 137 |
+
diff_l_x = loss_dict_l['theta']['x'] - loss_dict_l['ref']['x']
|
| 138 |
+
diff_l_h = loss_dict_l['theta']['h'] - loss_dict_l['ref']['h']
|
| 139 |
+
diff_l_e = loss_dict_l['theta']['e'] - loss_dict_l['ref']['e']
|
| 140 |
+
info['diff_w_x'] = diff_w_x.mean().item()
|
| 141 |
+
info['diff_w_h'] = diff_w_h.mean().item()
|
| 142 |
+
info['diff_w_e'] = diff_w_e.mean().item()
|
| 143 |
+
info['diff_l_x'] = diff_l_x.mean().item()
|
| 144 |
+
info['diff_l_h'] = diff_l_h.mean().item()
|
| 145 |
+
info['diff_l_e'] = diff_l_e.mean().item()
|
| 146 |
+
|
| 147 |
+
# not used, just for logging
|
| 148 |
+
_diff_w = diff_w_x + self.dpo_lambda_h * diff_w_h + self.dpo_lambda_e * diff_w_e
|
| 149 |
+
_diff_l = diff_l_x + self.dpo_lambda_h * diff_l_h + self.dpo_lambda_e * diff_l_e
|
| 150 |
+
info['diff_w'] = _diff_w.mean().item()
|
| 151 |
+
info['diff_l'] = _diff_l.mean().item()
|
| 152 |
+
|
| 153 |
+
diff_x = diff_w_x - diff_l_x
|
| 154 |
+
diff_h = diff_w_h - diff_l_h
|
| 155 |
+
diff_e = diff_w_e - diff_l_e
|
| 156 |
+
info['diff_x'] = diff_x.mean().item()
|
| 157 |
+
info['diff_h'] = diff_h.mean().item()
|
| 158 |
+
info['diff_e'] = diff_e.mean().item()
|
| 159 |
+
|
| 160 |
+
diff = -1 * beta_t * (diff_x + self.dpo_lambda_h * diff_h + self.dpo_lambda_e * diff_e)
|
| 161 |
+
if self.clamp_dpo:
|
| 162 |
+
diff = diff.clamp(-10, 10)
|
| 163 |
+
info['dpo_arg_min'] = diff.min().item()
|
| 164 |
+
info['dpo_arg_max'] = diff.max().item()
|
| 165 |
+
info['dpo_arg_mean'] = diff.mean().item()
|
| 166 |
+
dpo_loss = -1 * self.dpo_lambda_dpo * F.logsigmoid(diff)
|
| 167 |
+
info['dpo_loss'] = dpo_loss.mean().item()
|
| 168 |
+
|
| 169 |
+
loss_w_theta_reg = (
|
| 170 |
+
loss_dict_w['theta']['x'] +
|
| 171 |
+
self.lambda_h * loss_dict_w['theta']['h'] +
|
| 172 |
+
self.lambda_e * loss_dict_w['theta']['e']
|
| 173 |
+
)
|
| 174 |
+
info['loss_w_theta_reg'] = loss_w_theta_reg.mean().item()
|
| 175 |
+
loss_l_theta_reg = (
|
| 176 |
+
loss_dict_l['theta']['x'] +
|
| 177 |
+
self.lambda_h * loss_dict_l['theta']['h'] +
|
| 178 |
+
self.lambda_e * loss_dict_l['theta']['e']
|
| 179 |
+
)
|
| 180 |
+
info['loss_l_theta_reg'] = loss_l_theta_reg.mean().item()
|
| 181 |
+
dpo_reg = self.dpo_lambda_w * loss_w_theta_reg + \
|
| 182 |
+
self.dpo_lambda_l * loss_l_theta_reg
|
| 183 |
+
info['dpo_reg'] = dpo_reg.mean().item()
|
| 184 |
+
loss = dpo_loss + dpo_reg
|
| 185 |
+
else:
|
| 186 |
+
raise ValueError(f'Unknown DPO mode: {self.dpo_mode}')
|
| 187 |
+
|
| 188 |
+
if self.timestep_weights is not None:
|
| 189 |
+
w_t = self.timestep_weights(t).squeeze()
|
| 190 |
+
loss = w_t * loss
|
| 191 |
+
|
| 192 |
+
loss = loss.mean(0)
|
| 193 |
+
|
| 194 |
+
print(f'Loss is {loss}, info is {info}')
|
| 195 |
+
|
| 196 |
+
return (loss, info) if return_info else loss
|
| 197 |
+
|
| 198 |
+
def compute_loss_single_pair(self, ligand, pocket, t):
|
| 199 |
+
pocket = Residues(**pocket)
|
| 200 |
+
|
| 201 |
+
# Center sample
|
| 202 |
+
ligand, pocket = center_data(ligand, pocket)
|
| 203 |
+
pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0)
|
| 204 |
+
|
| 205 |
+
# Noise
|
| 206 |
+
z0_x = self.module_x.sample_z0(pocket_com, ligand['mask'])
|
| 207 |
+
z0_h = self.module_h.sample_z0(ligand['mask'])
|
| 208 |
+
z0_e = self.module_e.sample_z0(ligand['bond_mask'])
|
| 209 |
+
zt_x = self.module_x.sample_zt(z0_x, ligand['x'], t, ligand['mask'])
|
| 210 |
+
zt_h = self.module_h.sample_zt(z0_h, ligand['one_hot'], t, ligand['mask'])
|
| 211 |
+
zt_e = self.module_e.sample_zt(z0_e, ligand['bond_one_hot'], t, ligand['bond_mask'])
|
| 212 |
+
|
| 213 |
+
# Predict denoising
|
| 214 |
+
sc_transform = self.get_sc_transform_fn(None, zt_x, t, None, ligand['mask'], pocket)
|
| 215 |
+
|
| 216 |
+
pred_ligand, _ = self.dynamics(
|
| 217 |
+
zt_x, zt_h, ligand['mask'], pocket, t,
|
| 218 |
+
bonds_ligand=(ligand['bonds'], zt_e),
|
| 219 |
+
sc_transform=sc_transform
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Reference model
|
| 223 |
+
with torch.no_grad():
|
| 224 |
+
ref_pred_ligand, _ = self.ref_dynamics(
|
| 225 |
+
zt_x, zt_h, ligand['mask'], pocket, t,
|
| 226 |
+
bonds_ligand=(ligand['bonds'], zt_e),
|
| 227 |
+
sc_transform=sc_transform
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Compute L2 loss
|
| 231 |
+
loss_x = self.module_x.compute_loss(pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce=self.loss_reduce)
|
| 232 |
+
ref_loss_x = self.module_x.compute_loss(ref_pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce=self.loss_reduce)
|
| 233 |
+
|
| 234 |
+
t_next = torch.clamp(t + self.train_step_size, max=1.0)
|
| 235 |
+
|
| 236 |
+
loss_h = self.module_h.compute_loss(pred_ligand['logits_h'], zt_h, ligand['one_hot'], ligand['mask'], t, t_next, reduce=self.loss_reduce)
|
| 237 |
+
ref_loss_h = self.module_h.compute_loss(ref_pred_ligand['logits_h'], zt_h, ligand['one_hot'], ligand['mask'], t, t_next, reduce=self.loss_reduce)
|
| 238 |
+
loss_e = self.module_e.compute_loss(pred_ligand['logits_e'], zt_e, ligand['bond_one_hot'], ligand['bond_mask'], t, t_next, reduce=self.loss_reduce)
|
| 239 |
+
ref_loss_e = self.module_e.compute_loss(ref_pred_ligand['logits_e'], zt_e, ligand['bond_one_hot'], ligand['bond_mask'], t, t_next, reduce=self.loss_reduce)
|
| 240 |
+
|
| 241 |
+
return {
|
| 242 |
+
'theta': {
|
| 243 |
+
'x': loss_x,
|
| 244 |
+
'h': loss_h,
|
| 245 |
+
'e': loss_e,
|
| 246 |
+
},
|
| 247 |
+
'ref': {
|
| 248 |
+
'x': ref_loss_x,
|
| 249 |
+
'h': ref_loss_h,
|
| 250 |
+
'e': ref_loss_e,
|
| 251 |
+
}
|
| 252 |
+
}
|
src/model/dynamics.py
ADDED
|
@@ -0,0 +1,791 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Iterable
|
| 2 |
+
from abc import abstractmethod
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from src.constants import INT_TYPE
|
| 10 |
+
from src.model.gvp import GVPModel, GVP, LayerNorm
|
| 11 |
+
from src.model.gvp_transformer import GVPTransformerModel
|
| 12 |
+
from src.constants import FLOAT_TYPE
|
| 13 |
+
|
| 14 |
+
from pdb import set_trace
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def binomial_coefficient(n, k):
|
| 18 |
+
# source: https://discuss.pytorch.org/t/n-choose-k-function/121974
|
| 19 |
+
return ((n + 1).lgamma() - (k + 1).lgamma() - ((n - k) + 1).lgamma()).exp()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def cycle_counts(adj):
|
| 23 |
+
assert (adj.diag() == 0).all()
|
| 24 |
+
assert (adj == adj.T).all()
|
| 25 |
+
|
| 26 |
+
A = adj.float()
|
| 27 |
+
d = A.sum(dim=-1)
|
| 28 |
+
|
| 29 |
+
# Compute powers
|
| 30 |
+
A2 = A @ A
|
| 31 |
+
A3 = A2 @ A
|
| 32 |
+
A4 = A3 @ A
|
| 33 |
+
A5 = A4 @ A
|
| 34 |
+
|
| 35 |
+
x3 = A3.diag() / 2
|
| 36 |
+
x4 = (A4.diag() - d * (d - 1) - A @ d) / 2
|
| 37 |
+
|
| 38 |
+
""" New (different from DiGress)
|
| 39 |
+
case where correction is relevant:
|
| 40 |
+
2 o
|
| 41 |
+
|
|
| 42 |
+
1,3 o--o 4
|
| 43 |
+
| /
|
| 44 |
+
0,5 o
|
| 45 |
+
"""
|
| 46 |
+
# Triangle count matrix (indicates for each node i how many triangles it shares with node j)
|
| 47 |
+
T = adj * A2
|
| 48 |
+
x5 = (A5.diag() - 2 * T @ d - 4 * d * x3 - 2 * A @ x3 + 10 * x3) / 2
|
| 49 |
+
|
| 50 |
+
# # TODO
|
| 51 |
+
# A6 = A5 @ A
|
| 52 |
+
#
|
| 53 |
+
# # 4-cycle count matrix (indicates in how many shared 4-cycles i and j are 2 hops apart)
|
| 54 |
+
# Q2 = binomial_coefficient(n=A2 - d.diag(), k=torch.tensor(2))
|
| 55 |
+
#
|
| 56 |
+
# # 4-cycle count matrix (indicates in how many shared 4-cycles i and j are 1 (and 3) hop(s) apart)
|
| 57 |
+
# Q1 = A * (A3 - (d.view(-1, 1) + d.view(1, -1)) + 1) # "+1" because link between i and j is subtracted twice
|
| 58 |
+
#
|
| 59 |
+
# x6 = ...
|
| 60 |
+
# return torch.stack([x3, x4, x5, x6], dim=-1)
|
| 61 |
+
|
| 62 |
+
return torch.stack([x3, x4, x5], dim=-1)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# TODO: also consider directional aggregation as in:
|
| 66 |
+
# Beaini, Dominique, et al. "Directional graph networks."
|
| 67 |
+
# International Conference on Machine Learning. PMLR, 2021.
|
| 68 |
+
def eigenfeatures(A, batch_mask, k=5):
|
| 69 |
+
# TODO, see:
|
| 70 |
+
# - https://github.com/cvignac/DiGress/blob/main/src/diffusion/extra_features.py
|
| 71 |
+
# - https://arxiv.org/pdf/2209.14734.pdf (Appendix B.2)
|
| 72 |
+
|
| 73 |
+
# split adjacency matrix
|
| 74 |
+
batch = []
|
| 75 |
+
for i in torch.unique(batch_mask, sorted=True): # TODO: optimize (try to avoid loop)
|
| 76 |
+
batch_inds = torch.where(batch_mask == i)[0]
|
| 77 |
+
batch.append(A[torch.meshgrid(batch_inds, batch_inds, indexing='ij')])
|
| 78 |
+
|
| 79 |
+
eigenfeats = [get_nontrivial_eigenvectors(adj)[:, :k] for adj in batch]
|
| 80 |
+
# if there are less than k non-trivial eigenvectors
|
| 81 |
+
eigenfeats = [torch.cat([
|
| 82 |
+
x, torch.zeros(x.size(0), max(k - x.size(1), 0), device=x.device)], dim=-1)
|
| 83 |
+
for x in eigenfeats]
|
| 84 |
+
return torch.cat(eigenfeats, dim=0)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_nontrivial_eigenvectors(A, normalize_l=True, thresh=1e-5,
|
| 88 |
+
norm_eps=1e-12):
|
| 89 |
+
"""
|
| 90 |
+
Compute eigenvectors of the graph Laplacian corresponding to non-zero
|
| 91 |
+
eigenvalues.
|
| 92 |
+
"""
|
| 93 |
+
assert (A == A.T).all(), "undirected graph"
|
| 94 |
+
|
| 95 |
+
# Compute laplacian
|
| 96 |
+
d = A.sum(-1)
|
| 97 |
+
D = d.diag()
|
| 98 |
+
L = D - A
|
| 99 |
+
|
| 100 |
+
if normalize_l:
|
| 101 |
+
D_inv_sqrt = (1 / (d.sqrt() + norm_eps)).diag()
|
| 102 |
+
L = D_inv_sqrt @ L @ D_inv_sqrt
|
| 103 |
+
|
| 104 |
+
# Eigendecomposition
|
| 105 |
+
# eigenvalues are sorted in ascending order
|
| 106 |
+
# eigvecs matrix contains eigenvectors as its columns
|
| 107 |
+
eigvals, eigvecs = torch.linalg.eigh(L)
|
| 108 |
+
|
| 109 |
+
# index of first non-trivial eigenvector
|
| 110 |
+
try:
|
| 111 |
+
idx = torch.nonzero(eigvals > thresh)[0].item()
|
| 112 |
+
except IndexError:
|
| 113 |
+
# recover if no non-trivial eigenvectors are found
|
| 114 |
+
idx = eigvecs.size(1)
|
| 115 |
+
|
| 116 |
+
return eigvecs[:, idx:]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class DynamicsBase(nn.Module):
|
| 120 |
+
"""
|
| 121 |
+
Implements self-conditioning logic and basic functions
|
| 122 |
+
"""
|
| 123 |
+
def __init__(
|
| 124 |
+
self,
|
| 125 |
+
predict_angles=False,
|
| 126 |
+
predict_frames=False,
|
| 127 |
+
add_cycle_counts=False,
|
| 128 |
+
add_spectral_feat=False,
|
| 129 |
+
self_conditioning=False,
|
| 130 |
+
augment_residue_sc=False,
|
| 131 |
+
augment_ligand_sc=False
|
| 132 |
+
):
|
| 133 |
+
super().__init__()
|
| 134 |
+
|
| 135 |
+
if not hasattr(self, 'predict_angles'):
|
| 136 |
+
self.predict_angles = predict_angles
|
| 137 |
+
|
| 138 |
+
if not hasattr(self, 'predict_frames'):
|
| 139 |
+
self.predict_frames = predict_frames
|
| 140 |
+
|
| 141 |
+
if not hasattr(self, 'add_cycle_counts'):
|
| 142 |
+
self.add_cycle_counts = add_cycle_counts
|
| 143 |
+
|
| 144 |
+
if not hasattr(self, 'add_spectral_feat'):
|
| 145 |
+
self.add_spectral_feat = add_spectral_feat
|
| 146 |
+
|
| 147 |
+
if not hasattr(self, 'self_conditioning'):
|
| 148 |
+
self.self_conditioning = self_conditioning
|
| 149 |
+
|
| 150 |
+
if not hasattr(self, 'augment_residue_sc'):
|
| 151 |
+
self.augment_residue_sc = augment_residue_sc
|
| 152 |
+
|
| 153 |
+
if not hasattr(self, 'augment_ligand_sc'):
|
| 154 |
+
self.augment_ligand_sc = augment_ligand_sc
|
| 155 |
+
|
| 156 |
+
if self.self_conditioning:
|
| 157 |
+
self.prev_ligand = None
|
| 158 |
+
self.prev_residues = None
|
| 159 |
+
|
| 160 |
+
@abstractmethod
|
| 161 |
+
def _forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None,
|
| 162 |
+
h_atoms_sc=None, e_atoms_sc=None, h_residues_sc=None):
|
| 163 |
+
"""
|
| 164 |
+
Implement forward pass.
|
| 165 |
+
Returns:
|
| 166 |
+
- vel
|
| 167 |
+
- h_final_atoms
|
| 168 |
+
- edge_final_atoms
|
| 169 |
+
- residue_angles
|
| 170 |
+
- residue_trans
|
| 171 |
+
- residue_rot
|
| 172 |
+
"""
|
| 173 |
+
pass
|
| 174 |
+
|
| 175 |
+
def make_sc_input(self, pred_ligand, pred_residues, sc_transform):
|
| 176 |
+
|
| 177 |
+
if self.predict_confidence:
|
| 178 |
+
h_atoms_sc = (torch.cat([pred_ligand['logits_h'], pred_ligand['uncertainty_vel'].unsqueeze(1)], dim=-1),
|
| 179 |
+
pred_ligand['vel'].unsqueeze(1))
|
| 180 |
+
else:
|
| 181 |
+
h_atoms_sc = (pred_ligand['logits_h'], pred_ligand['vel'].unsqueeze(1))
|
| 182 |
+
e_atoms_sc = pred_ligand['logits_e']
|
| 183 |
+
|
| 184 |
+
if self.predict_frames:
|
| 185 |
+
h_residues_sc = (torch.cat([pred_residues['chi'], pred_residues['rot']], dim=-1),
|
| 186 |
+
pred_residues['trans'].unsqueeze(1))
|
| 187 |
+
elif self.predict_angles:
|
| 188 |
+
h_residues_sc = pred_residues['chi']
|
| 189 |
+
else:
|
| 190 |
+
h_residues_sc = None
|
| 191 |
+
|
| 192 |
+
if self.augment_residue_sc and h_residues_sc is not None:
|
| 193 |
+
if self.predict_frames:
|
| 194 |
+
h_residues_sc = (h_residues_sc[0], torch.cat(
|
| 195 |
+
[h_residues_sc[1], sc_transform['residues'](pred_residues['chi'], pred_residues['trans'].squeeze(1), pred_residues['rot'])], dim=1))
|
| 196 |
+
|
| 197 |
+
else:
|
| 198 |
+
h_residues_sc = (h_residues_sc, sc_transform['residues'](pred_residues['chi']))
|
| 199 |
+
|
| 200 |
+
if self.augment_ligand_sc:
|
| 201 |
+
h_atoms_sc = (h_atoms_sc[0], torch.cat(
|
| 202 |
+
[h_atoms_sc[1], sc_transform['atoms'](pred_ligand['vel'].unsqueeze(1))], dim=1))
|
| 203 |
+
|
| 204 |
+
return h_atoms_sc, e_atoms_sc, h_residues_sc
|
| 205 |
+
|
| 206 |
+
def forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None, sc_transform=None):
|
| 207 |
+
"""
|
| 208 |
+
Implements self-conditioning as in https://arxiv.org/abs/2208.04202
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
h_atoms_sc, e_atoms_sc = None, None
|
| 212 |
+
h_residues_sc = None
|
| 213 |
+
|
| 214 |
+
if self.self_conditioning:
|
| 215 |
+
|
| 216 |
+
# Sampling: use previous prediction in all but the first time step
|
| 217 |
+
if not self.training and t.min() > 0.0:
|
| 218 |
+
assert t.min() == t.max(), "currently only supports sampling at same time steps"
|
| 219 |
+
assert self.prev_ligand is not None
|
| 220 |
+
assert self.prev_residues is not None or not self.predict_frames
|
| 221 |
+
|
| 222 |
+
else:
|
| 223 |
+
# Create zero tensors
|
| 224 |
+
zeros_ligand = {'logits_h': torch.zeros_like(h_atoms),
|
| 225 |
+
'vel': torch.zeros_like(x_atoms),
|
| 226 |
+
'logits_e': torch.zeros_like(bonds_ligand[1])}
|
| 227 |
+
if self.predict_confidence:
|
| 228 |
+
zeros_ligand['uncertainty_vel'] = torch.zeros(
|
| 229 |
+
len(x_atoms), dtype=x_atoms.dtype, device=x_atoms.device)
|
| 230 |
+
|
| 231 |
+
zeros_residues = {}
|
| 232 |
+
if self.predict_angles:
|
| 233 |
+
zeros_residues['chi'] = torch.zeros((pocket['one_hot'].size(0), 5), device=pocket['one_hot'].device)
|
| 234 |
+
if self.predict_frames:
|
| 235 |
+
zeros_residues['trans'] = torch.zeros((pocket['one_hot'].size(0), 3), device=pocket['one_hot'].device)
|
| 236 |
+
zeros_residues['rot'] = torch.zeros((pocket['one_hot'].size(0), 3), device=pocket['one_hot'].device)
|
| 237 |
+
|
| 238 |
+
# Training: use 50% zeros and 50% predictions with detached gradients
|
| 239 |
+
if self.training and random.random() > 0.5:
|
| 240 |
+
with torch.no_grad():
|
| 241 |
+
h_atoms_sc, e_atoms_sc, h_residues_sc = self.make_sc_input(
|
| 242 |
+
zeros_ligand, zeros_residues, sc_transform)
|
| 243 |
+
|
| 244 |
+
self.prev_ligand, self.prev_residues = self._forward(
|
| 245 |
+
x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand,
|
| 246 |
+
h_atoms_sc, e_atoms_sc, h_residues_sc)
|
| 247 |
+
|
| 248 |
+
# use zeros for first sampling step and 50% of training
|
| 249 |
+
else:
|
| 250 |
+
self.prev_ligand = zeros_ligand
|
| 251 |
+
self.prev_residues = zeros_residues
|
| 252 |
+
|
| 253 |
+
h_atoms_sc, e_atoms_sc, h_residues_sc = self.make_sc_input(
|
| 254 |
+
self.prev_ligand, self.prev_residues, sc_transform)
|
| 255 |
+
|
| 256 |
+
pred_ligand, pred_residues = self._forward(
|
| 257 |
+
x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand,
|
| 258 |
+
h_atoms_sc, e_atoms_sc, h_residues_sc
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
if self.self_conditioning and not self.training:
|
| 262 |
+
self.prev_ligand = pred_ligand.copy()
|
| 263 |
+
self.prev_residues = pred_residues.copy()
|
| 264 |
+
|
| 265 |
+
return pred_ligand, pred_residues
|
| 266 |
+
|
| 267 |
+
def compute_extra_features(self, batch_mask, edge_indices, edge_types):
|
| 268 |
+
|
| 269 |
+
feat = torch.zeros(len(batch_mask), 0, device=batch_mask.device)
|
| 270 |
+
|
| 271 |
+
if not (self.add_cycle_counts or self.add_spectral_feat):
|
| 272 |
+
return feat
|
| 273 |
+
|
| 274 |
+
adj = batch_mask[:, None] == batch_mask[None, :]
|
| 275 |
+
|
| 276 |
+
E = torch.zeros_like(adj, dtype=INT_TYPE)
|
| 277 |
+
E[edge_indices[0], edge_indices[1]] = edge_types
|
| 278 |
+
|
| 279 |
+
A = (E > 0).float()
|
| 280 |
+
|
| 281 |
+
if self.add_cycle_counts:
|
| 282 |
+
cycle_features = cycle_counts(A)
|
| 283 |
+
cycle_features[cycle_features > 10] = 10 # avoid large values
|
| 284 |
+
|
| 285 |
+
feat = torch.cat([feat, cycle_features], dim=-1)
|
| 286 |
+
|
| 287 |
+
if self.add_spectral_feat:
|
| 288 |
+
feat = torch.cat([feat, eigenfeatures(A, batch_mask)], dim=-1)
|
| 289 |
+
|
| 290 |
+
return feat
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class Dynamics(DynamicsBase):
|
| 294 |
+
def __init__(self, atom_nf, residue_nf, joint_nf, bond_dict, pocket_bond_dict,
|
| 295 |
+
edge_nf, hidden_nf, act_fn=torch.nn.SiLU(), condition_time=True,
|
| 296 |
+
model='egnn', model_params=None,
|
| 297 |
+
edge_cutoff_ligand=None, edge_cutoff_pocket=None,
|
| 298 |
+
edge_cutoff_interaction=None,
|
| 299 |
+
predict_angles=False, predict_frames=False,
|
| 300 |
+
add_cycle_counts=False, add_spectral_feat=False,
|
| 301 |
+
add_nma_feat=False, self_conditioning=False,
|
| 302 |
+
augment_residue_sc=False, augment_ligand_sc=False,
|
| 303 |
+
add_chi_as_feature=False, angle_act_fn=False):
|
| 304 |
+
super().__init__()
|
| 305 |
+
self.model = model
|
| 306 |
+
self.edge_cutoff_l = edge_cutoff_ligand
|
| 307 |
+
self.edge_cutoff_p = edge_cutoff_pocket
|
| 308 |
+
self.edge_cutoff_i = edge_cutoff_interaction
|
| 309 |
+
self.hidden_nf = hidden_nf
|
| 310 |
+
self.predict_angles = predict_angles
|
| 311 |
+
self.predict_frames = predict_frames
|
| 312 |
+
self.bond_dict = bond_dict
|
| 313 |
+
self.pocket_bond_dict = pocket_bond_dict
|
| 314 |
+
self.bond_nf = len(bond_dict)
|
| 315 |
+
self.pocket_bond_nf = len(pocket_bond_dict)
|
| 316 |
+
self.edge_nf = edge_nf
|
| 317 |
+
self.add_cycle_counts = add_cycle_counts
|
| 318 |
+
self.add_spectral_feat = add_spectral_feat
|
| 319 |
+
self.add_nma_feat = add_nma_feat
|
| 320 |
+
self.self_conditioning = self_conditioning
|
| 321 |
+
self.augment_residue_sc = augment_residue_sc
|
| 322 |
+
self.augment_ligand_sc = augment_ligand_sc
|
| 323 |
+
self.add_chi_as_feature = add_chi_as_feature
|
| 324 |
+
self.predict_confidence = False
|
| 325 |
+
|
| 326 |
+
if self.self_conditioning:
|
| 327 |
+
self.prev_vel = None
|
| 328 |
+
self.prev_h = None
|
| 329 |
+
self.prev_e = None
|
| 330 |
+
self.prev_a = None
|
| 331 |
+
self.prev_ca = None
|
| 332 |
+
self.prev_rot = None
|
| 333 |
+
|
| 334 |
+
lig_nf = atom_nf
|
| 335 |
+
if self.add_cycle_counts:
|
| 336 |
+
lig_nf = lig_nf + 3
|
| 337 |
+
if self.add_spectral_feat:
|
| 338 |
+
lig_nf = lig_nf + 5
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
if not isinstance(joint_nf, Iterable):
|
| 342 |
+
# joint_nf contains only scalars
|
| 343 |
+
joint_nf = (joint_nf, 0)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
if isinstance(residue_nf, Iterable):
|
| 347 |
+
_atom_in_nf = (lig_nf, 0)
|
| 348 |
+
_residue_atom_dim = residue_nf[1]
|
| 349 |
+
|
| 350 |
+
if self.add_nma_feat:
|
| 351 |
+
residue_nf = (residue_nf[0], residue_nf[1] + 5)
|
| 352 |
+
|
| 353 |
+
if self.self_conditioning:
|
| 354 |
+
_atom_in_nf = (_atom_in_nf[0] + atom_nf, 1)
|
| 355 |
+
|
| 356 |
+
if self.augment_ligand_sc:
|
| 357 |
+
_atom_in_nf = (_atom_in_nf[0], _atom_in_nf[1] + 1)
|
| 358 |
+
|
| 359 |
+
if self.predict_angles:
|
| 360 |
+
residue_nf = (residue_nf[0] + 5, residue_nf[1])
|
| 361 |
+
|
| 362 |
+
if self.predict_frames:
|
| 363 |
+
residue_nf = (residue_nf[0], residue_nf[1] + 2)
|
| 364 |
+
|
| 365 |
+
if self.augment_residue_sc:
|
| 366 |
+
assert self.predict_angles
|
| 367 |
+
residue_nf = (residue_nf[0], residue_nf[1] + _residue_atom_dim)
|
| 368 |
+
|
| 369 |
+
if self.add_chi_as_feature:
|
| 370 |
+
residue_nf = (residue_nf[0] + 5, residue_nf[1])
|
| 371 |
+
|
| 372 |
+
self.atom_encoder = nn.Sequential(
|
| 373 |
+
GVP(_atom_in_nf, joint_nf, activations=(act_fn, torch.sigmoid)),
|
| 374 |
+
LayerNorm(joint_nf, learnable_vector_weight=True),
|
| 375 |
+
GVP(joint_nf, joint_nf, activations=(None, None)),
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
self.residue_encoder = nn.Sequential(
|
| 379 |
+
GVP(residue_nf, joint_nf, activations=(act_fn, torch.sigmoid)),
|
| 380 |
+
LayerNorm(joint_nf, learnable_vector_weight=True),
|
| 381 |
+
GVP(joint_nf, joint_nf, activations=(None, None)),
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
else:
|
| 385 |
+
# No vector-valued input features
|
| 386 |
+
assert joint_nf[1] == 0
|
| 387 |
+
|
| 388 |
+
# self-conditioning not yet supported
|
| 389 |
+
assert not self.self_conditioning
|
| 390 |
+
|
| 391 |
+
# Normal mode features are vectors
|
| 392 |
+
assert not self.add_nma_feat
|
| 393 |
+
|
| 394 |
+
if self.add_chi_as_feature:
|
| 395 |
+
residue_nf += 5
|
| 396 |
+
|
| 397 |
+
self.atom_encoder = nn.Sequential(
|
| 398 |
+
nn.Linear(lig_nf, 2 * atom_nf),
|
| 399 |
+
act_fn,
|
| 400 |
+
nn.Linear(2 * atom_nf, joint_nf[0])
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
self.residue_encoder = nn.Sequential(
|
| 404 |
+
nn.Linear(residue_nf, 2 * residue_nf),
|
| 405 |
+
act_fn,
|
| 406 |
+
nn.Linear(2 * residue_nf, joint_nf[0])
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
self.atom_decoder = nn.Sequential(
|
| 410 |
+
nn.Linear(joint_nf[0], 2 * atom_nf),
|
| 411 |
+
act_fn,
|
| 412 |
+
nn.Linear(2 * atom_nf, atom_nf)
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
self.edge_decoder = nn.Sequential(
|
| 416 |
+
nn.Linear(hidden_nf, hidden_nf),
|
| 417 |
+
act_fn,
|
| 418 |
+
nn.Linear(hidden_nf, self.bond_nf)
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
_atom_bond_nf = 2 * self.bond_nf if self.self_conditioning else self.bond_nf
|
| 422 |
+
self.ligand_bond_encoder = nn.Sequential(
|
| 423 |
+
nn.Linear(_atom_bond_nf, hidden_nf),
|
| 424 |
+
act_fn,
|
| 425 |
+
nn.Linear(hidden_nf, self.edge_nf)
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
self.pocket_bond_encoder = nn.Sequential(
|
| 429 |
+
nn.Linear(self.pocket_bond_nf, hidden_nf),
|
| 430 |
+
act_fn,
|
| 431 |
+
nn.Linear(hidden_nf, self.edge_nf)
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
out_nf = (joint_nf[0], 1)
|
| 435 |
+
res_out_nf = (0, 0)
|
| 436 |
+
if self.predict_angles:
|
| 437 |
+
res_out_nf = (res_out_nf[0] + 5, res_out_nf[1])
|
| 438 |
+
if self.predict_frames:
|
| 439 |
+
res_out_nf = (res_out_nf[0], res_out_nf[1] + 2)
|
| 440 |
+
self.residue_decoder = nn.Sequential(
|
| 441 |
+
GVP(out_nf, out_nf, activations=(act_fn, torch.sigmoid)),
|
| 442 |
+
LayerNorm(out_nf, learnable_vector_weight=True),
|
| 443 |
+
GVP(out_nf, res_out_nf, activations=(None, None)),
|
| 444 |
+
) if res_out_nf != (0, 0) else None
|
| 445 |
+
|
| 446 |
+
if angle_act_fn is None:
|
| 447 |
+
self.angle_act_fn = None
|
| 448 |
+
elif angle_act_fn == 'tanh':
|
| 449 |
+
self.angle_act_fn = lambda x: np.pi * F.tanh(x)
|
| 450 |
+
else:
|
| 451 |
+
raise NotImplementedError(f"Angle activation {angle_act_fn} not available")
|
| 452 |
+
|
| 453 |
+
# self.ligand_nobond_emb = nn.Parameter(torch.zeros(self.edge_nf))
|
| 454 |
+
# self.pocket_nobond_emb = nn.Parameter(torch.zeros(self.edge_nf))
|
| 455 |
+
self.cross_emb = nn.Parameter(torch.zeros(self.edge_nf),
|
| 456 |
+
requires_grad=True)
|
| 457 |
+
|
| 458 |
+
if condition_time:
|
| 459 |
+
dynamics_node_nf = (joint_nf[0] + 1, joint_nf[1])
|
| 460 |
+
else:
|
| 461 |
+
print('Warning: dynamics model is NOT conditioned on time.')
|
| 462 |
+
dynamics_node_nf = (joint_nf[0], joint_nf[1])
|
| 463 |
+
|
| 464 |
+
if model == 'egnn':
|
| 465 |
+
raise NotImplementedError
|
| 466 |
+
# self.net = EGNN(
|
| 467 |
+
# in_node_nf=dynamics_node_nf[0], in_edge_nf=self.edge_nf,
|
| 468 |
+
# hidden_nf=hidden_nf, out_node_nf=joint_nf[0],
|
| 469 |
+
# device=model_params.device, act_fn=act_fn,
|
| 470 |
+
# n_layers=model_params.n_layers,
|
| 471 |
+
# attention=model_params.attention,
|
| 472 |
+
# tanh=model_params.tanh,
|
| 473 |
+
# norm_constant=model_params.norm_constant,
|
| 474 |
+
# inv_sublayers=model_params.inv_sublayers,
|
| 475 |
+
# sin_embedding=model_params.sin_embedding,
|
| 476 |
+
# normalization_factor=model_params.normalization_factor,
|
| 477 |
+
# aggregation_method=model_params.aggregation_method,
|
| 478 |
+
# reflection_equiv=model_params.reflection_equivariant,
|
| 479 |
+
# update_edge_attr=True
|
| 480 |
+
# )
|
| 481 |
+
# self.node_nf = dynamics_node_nf[0]
|
| 482 |
+
|
| 483 |
+
elif model == 'gvp':
|
| 484 |
+
self.net = GVPModel(
|
| 485 |
+
node_in_dim=dynamics_node_nf, node_h_dim=model_params.node_h_dim,
|
| 486 |
+
node_out_nf=joint_nf[0], edge_in_nf=self.edge_nf,
|
| 487 |
+
edge_h_dim=model_params.edge_h_dim, edge_out_nf=hidden_nf,
|
| 488 |
+
num_layers=model_params.n_layers,
|
| 489 |
+
drop_rate=model_params.dropout,
|
| 490 |
+
vector_gate=model_params.vector_gate,
|
| 491 |
+
reflection_equiv=model_params.reflection_equivariant,
|
| 492 |
+
d_max=model_params.d_max,
|
| 493 |
+
num_rbf=model_params.num_rbf,
|
| 494 |
+
update_edge_attr=True
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
elif model == 'gvp_transformer':
|
| 498 |
+
self.net = GVPTransformerModel(
|
| 499 |
+
node_in_dim=dynamics_node_nf,
|
| 500 |
+
node_h_dim=model_params.node_h_dim,
|
| 501 |
+
node_out_nf=joint_nf[0],
|
| 502 |
+
edge_in_nf=self.edge_nf,
|
| 503 |
+
edge_h_dim=model_params.edge_h_dim,
|
| 504 |
+
edge_out_nf=hidden_nf,
|
| 505 |
+
num_layers=model_params.n_layers,
|
| 506 |
+
dk=model_params.dk,
|
| 507 |
+
dv=model_params.dv,
|
| 508 |
+
de=model_params.de,
|
| 509 |
+
db=model_params.db,
|
| 510 |
+
dy=model_params.dy,
|
| 511 |
+
attn_heads=model_params.attn_heads,
|
| 512 |
+
n_feedforward=model_params.n_feedforward,
|
| 513 |
+
drop_rate=model_params.dropout,
|
| 514 |
+
reflection_equiv=model_params.reflection_equivariant,
|
| 515 |
+
d_max=model_params.d_max,
|
| 516 |
+
num_rbf=model_params.num_rbf,
|
| 517 |
+
vector_gate=model_params.vector_gate,
|
| 518 |
+
attention=model_params.attention,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
elif model == 'gnn':
|
| 522 |
+
raise NotImplementedError
|
| 523 |
+
# n_dims = 3
|
| 524 |
+
# self.net = GNN(
|
| 525 |
+
# in_node_nf=dynamics_node_nf + n_dims, in_edge_nf=self.edge_emb_dim,
|
| 526 |
+
# hidden_nf=hidden_nf, out_node_nf=n_dims + dynamics_node_nf,
|
| 527 |
+
# device=model_params.device, act_fn=act_fn, n_layers=model_params.n_layers,
|
| 528 |
+
# attention=model_params.attention, normalization_factor=model_params.normalization_factor,
|
| 529 |
+
# aggregation_method=model_params.aggregation_method)
|
| 530 |
+
|
| 531 |
+
else:
|
| 532 |
+
raise NotImplementedError(f"{model} is not available")
|
| 533 |
+
|
| 534 |
+
# self.device = device
|
| 535 |
+
# self.n_dims = n_dims
|
| 536 |
+
self.condition_time = condition_time
|
| 537 |
+
|
| 538 |
+
def _forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None,
|
| 539 |
+
h_atoms_sc=None, e_atoms_sc=None, h_residues_sc=None):
|
| 540 |
+
"""
|
| 541 |
+
:param x_atoms:
|
| 542 |
+
:param h_atoms:
|
| 543 |
+
:param mask_atoms:
|
| 544 |
+
:param pocket: must contain keys: 'x', 'one_hot', 'mask', 'bonds' and 'bond_one_hot'
|
| 545 |
+
:param t:
|
| 546 |
+
:param bonds_ligand: tuple - bond indices (2, n_bonds) & bond types (n_bonds, bond_nf)
|
| 547 |
+
:param h_atoms_sc: additional node feature for self-conditioning, (s, V)
|
| 548 |
+
:param e_atoms_sc: additional edge feature for self-conditioning, only scalar
|
| 549 |
+
:param h_residues_sc: additional node feature for self-conditioning, tensor or tuple
|
| 550 |
+
:return:
|
| 551 |
+
"""
|
| 552 |
+
x_residues, h_residues, mask_residues = pocket['x'], pocket['one_hot'], pocket['mask']
|
| 553 |
+
if 'bonds' in pocket:
|
| 554 |
+
bonds_pocket = (pocket['bonds'], pocket['bond_one_hot'])
|
| 555 |
+
else:
|
| 556 |
+
bonds_pocket = None
|
| 557 |
+
|
| 558 |
+
if self.add_chi_as_feature:
|
| 559 |
+
h_residues = torch.cat([h_residues, pocket['chi'][:, :5]], dim=-1)
|
| 560 |
+
|
| 561 |
+
if 'v' in pocket:
|
| 562 |
+
v_residues = pocket['v']
|
| 563 |
+
if self.add_nma_feat:
|
| 564 |
+
v_residues = torch.cat([v_residues, pocket['nma_vec']], dim=1)
|
| 565 |
+
h_residues = (h_residues, v_residues)
|
| 566 |
+
|
| 567 |
+
if h_residues_sc is not None:
|
| 568 |
+
# if self.augment_residue_sc:
|
| 569 |
+
if isinstance(h_residues_sc, tuple):
|
| 570 |
+
h_residues = (torch.cat([h_residues[0], h_residues_sc[0]], dim=-1),
|
| 571 |
+
torch.cat([h_residues[1], h_residues_sc[1]], dim=1))
|
| 572 |
+
else:
|
| 573 |
+
h_residues = (torch.cat([h_residues[0], h_residues_sc], dim=-1),
|
| 574 |
+
h_residues[1])
|
| 575 |
+
|
| 576 |
+
# get graph edges and edge attributes
|
| 577 |
+
if bonds_ligand is not None:
|
| 578 |
+
# NOTE: 'bond' denotes one-directional edges and 'edge' means bi-directional
|
| 579 |
+
ligand_bond_indices = bonds_ligand[0]
|
| 580 |
+
|
| 581 |
+
# make sure messages are passed both ways
|
| 582 |
+
ligand_edge_indices = torch.cat(
|
| 583 |
+
[bonds_ligand[0], bonds_ligand[0].flip(dims=[0])], dim=1)
|
| 584 |
+
ligand_edge_types = torch.cat([bonds_ligand[1], bonds_ligand[1]], dim=0)
|
| 585 |
+
# edges_ligand = (ligand_edge_indices, ligand_edge_types)
|
| 586 |
+
|
| 587 |
+
# add auxiliary features to ligand nodes
|
| 588 |
+
extra_features = self.compute_extra_features(
|
| 589 |
+
mask_atoms, ligand_edge_indices, ligand_edge_types.argmax(-1))
|
| 590 |
+
h_atoms = torch.cat([h_atoms, extra_features], dim=-1)
|
| 591 |
+
|
| 592 |
+
if bonds_pocket is not None:
|
| 593 |
+
# make sure messages are passed both ways
|
| 594 |
+
pocket_edge_indices = torch.cat(
|
| 595 |
+
[bonds_pocket[0], bonds_pocket[0].flip(dims=[0])], dim=1)
|
| 596 |
+
pocket_edge_types = torch.cat([bonds_pocket[1], bonds_pocket[1]], dim=0)
|
| 597 |
+
# edges_pocket = (pocket_edge_indices, pocket_edge_types)
|
| 598 |
+
|
| 599 |
+
if h_atoms_sc is not None:
|
| 600 |
+
h_atoms = (torch.cat([h_atoms, h_atoms_sc[0]], dim=-1),
|
| 601 |
+
h_atoms_sc[1])
|
| 602 |
+
|
| 603 |
+
if e_atoms_sc is not None:
|
| 604 |
+
e_atoms_sc = torch.cat([e_atoms_sc, e_atoms_sc], dim=0)
|
| 605 |
+
ligand_edge_types = torch.cat([ligand_edge_types, e_atoms_sc], dim=-1)
|
| 606 |
+
|
| 607 |
+
# embed atom features and residue features in a shared space
|
| 608 |
+
h_atoms = self.atom_encoder(h_atoms)
|
| 609 |
+
e_ligand = self.ligand_bond_encoder(ligand_edge_types)
|
| 610 |
+
|
| 611 |
+
if len(h_residues) > 0:
|
| 612 |
+
h_residues = self.residue_encoder(h_residues)
|
| 613 |
+
e_pocket = self.pocket_bond_encoder(pocket_edge_types)
|
| 614 |
+
else:
|
| 615 |
+
e_pocket = pocket_edge_types
|
| 616 |
+
h_residues = (h_residues, h_residues)
|
| 617 |
+
pocket_edge_indices = torch.tensor([[], []], dtype=torch.long, device=h_residues[0].device)
|
| 618 |
+
pocket_edge_types = torch.tensor([[], []], dtype=torch.long, device=h_residues[0].device)
|
| 619 |
+
|
| 620 |
+
if isinstance(h_atoms, tuple):
|
| 621 |
+
h_atoms, v_atoms = h_atoms
|
| 622 |
+
h_residues, v_residues = h_residues
|
| 623 |
+
v = torch.cat((v_atoms, v_residues), dim=0)
|
| 624 |
+
else:
|
| 625 |
+
v = None
|
| 626 |
+
|
| 627 |
+
edges, edge_feat = self.get_edges(
|
| 628 |
+
mask_atoms, mask_residues, x_atoms, x_residues,
|
| 629 |
+
bond_inds_ligand=ligand_edge_indices, bond_inds_pocket=pocket_edge_indices,
|
| 630 |
+
bond_feat_ligand=e_ligand, bond_feat_pocket=e_pocket)
|
| 631 |
+
|
| 632 |
+
# combine the two node types
|
| 633 |
+
x = torch.cat((x_atoms, x_residues), dim=0)
|
| 634 |
+
h = torch.cat((h_atoms, h_residues), dim=0)
|
| 635 |
+
mask = torch.cat([mask_atoms, mask_residues])
|
| 636 |
+
|
| 637 |
+
if self.condition_time:
|
| 638 |
+
if np.prod(t.size()) == 1:
|
| 639 |
+
# t is the same for all elements in batch.
|
| 640 |
+
h_time = torch.empty_like(h[:, 0:1]).fill_(t.item())
|
| 641 |
+
else:
|
| 642 |
+
# t is different over the batch dimension.
|
| 643 |
+
h_time = t[mask]
|
| 644 |
+
h = torch.cat([h, h_time], dim=1)
|
| 645 |
+
|
| 646 |
+
assert torch.all(mask[edges[0]] == mask[edges[1]])
|
| 647 |
+
|
| 648 |
+
if self.model == 'egnn':
|
| 649 |
+
# Don't update pocket coordinates
|
| 650 |
+
update_coords_mask = torch.cat((torch.ones_like(mask_atoms),
|
| 651 |
+
torch.zeros_like(mask_residues))).unsqueeze(1)
|
| 652 |
+
h_final, vel, edge_final = self.net(
|
| 653 |
+
h, x, edges, batch_mask=mask, edge_attr=edge_feat,
|
| 654 |
+
update_coords_mask=update_coords_mask)
|
| 655 |
+
# vel = (x_final - x)
|
| 656 |
+
|
| 657 |
+
elif self.model == 'gvp' or self.model == 'gvp_transformer':
|
| 658 |
+
h_final, vel, edge_final = self.net(
|
| 659 |
+
h, x, edges, v=v, batch_mask=mask, edge_attr=edge_feat)
|
| 660 |
+
|
| 661 |
+
elif self.model == 'gnn':
|
| 662 |
+
xh = torch.cat([x, h], dim=1)
|
| 663 |
+
output = self.net(xh, edges, node_mask=None, edge_attr=edge_feat)
|
| 664 |
+
vel = output[:, :3]
|
| 665 |
+
h_final = output[:, 3:]
|
| 666 |
+
|
| 667 |
+
else:
|
| 668 |
+
raise NotImplementedError(f"Wrong model ({self.model})")
|
| 669 |
+
|
| 670 |
+
# if self.condition_time:
|
| 671 |
+
# # Slice off last dimension which represented time.
|
| 672 |
+
# h_final = h_final[:, :-1]
|
| 673 |
+
|
| 674 |
+
# decode atom and residue features
|
| 675 |
+
h_final_atoms = self.atom_decoder(h_final[:len(mask_atoms)])
|
| 676 |
+
|
| 677 |
+
if torch.any(torch.isnan(vel)) or torch.any(torch.isnan(h_final_atoms)):
|
| 678 |
+
if self.training:
|
| 679 |
+
vel[torch.isnan(vel)] = 0.0
|
| 680 |
+
h_final_atoms[torch.isnan(h_final_atoms)] = 0.0
|
| 681 |
+
else:
|
| 682 |
+
raise ValueError("NaN detected in network output")
|
| 683 |
+
|
| 684 |
+
# predict edge type
|
| 685 |
+
ligand_edge_mask = (edges[0] < len(mask_atoms)) & (edges[1] < len(mask_atoms))
|
| 686 |
+
edge_final = edge_final[ligand_edge_mask]
|
| 687 |
+
edges = edges[:, ligand_edge_mask]
|
| 688 |
+
|
| 689 |
+
# Symmetrize
|
| 690 |
+
edge_logits = torch.zeros(
|
| 691 |
+
(len(mask_atoms), len(mask_atoms), self.hidden_nf),
|
| 692 |
+
device=mask_atoms.device)
|
| 693 |
+
edge_logits[edges[0], edges[1]] = edge_final
|
| 694 |
+
edge_logits = (edge_logits + edge_logits.transpose(0, 1)) * 0.5
|
| 695 |
+
# edge_logits = edge_logits[lig_edge_indices[0], lig_edge_indices[1]]
|
| 696 |
+
|
| 697 |
+
# return upper triangular elements only (matching the input)
|
| 698 |
+
edge_logits = edge_logits[ligand_bond_indices[0], ligand_bond_indices[1]]
|
| 699 |
+
# assert (edge_logits == 0).sum() == 0
|
| 700 |
+
|
| 701 |
+
edge_final_atoms = self.edge_decoder(edge_logits)
|
| 702 |
+
|
| 703 |
+
# Predict torsion angles
|
| 704 |
+
residue_angles = None
|
| 705 |
+
residue_trans, residue_rot = None, None
|
| 706 |
+
if self.residue_decoder is not None:
|
| 707 |
+
h_residues = h_final[len(mask_atoms):]
|
| 708 |
+
vec_residues = vel[len(mask_atoms):].unsqueeze(1)
|
| 709 |
+
residue_angles = self.residue_decoder((h_residues, vec_residues))
|
| 710 |
+
if self.predict_frames:
|
| 711 |
+
residue_angles, residue_frames = residue_angles
|
| 712 |
+
residue_trans = residue_frames[:, 0, :].squeeze(1)
|
| 713 |
+
residue_rot = residue_frames[:, 1, :].squeeze(1)
|
| 714 |
+
if self.angle_act_fn is not None:
|
| 715 |
+
residue_angles = self.angle_act_fn(residue_angles)
|
| 716 |
+
|
| 717 |
+
# return vel[:len(mask_atoms)], h_final_atoms, edge_final_atoms, residue_angles, residue_trans, residue_rot
|
| 718 |
+
pred_ligand = {'vel': vel[:len(mask_atoms)], 'logits_h': h_final_atoms, 'logits_e': edge_final_atoms}
|
| 719 |
+
pred_residues = {'chi': residue_angles, 'trans': residue_trans, 'rot': residue_rot}
|
| 720 |
+
return pred_ligand, pred_residues
|
| 721 |
+
|
| 722 |
+
def get_edges(self, batch_mask_ligand, batch_mask_pocket, x_ligand,
|
| 723 |
+
x_pocket, bond_inds_ligand=None, bond_inds_pocket=None,
|
| 724 |
+
bond_feat_ligand=None, bond_feat_pocket=None, self_edges=False):
|
| 725 |
+
|
| 726 |
+
# Adjacency matrix
|
| 727 |
+
adj_ligand = batch_mask_ligand[:, None] == batch_mask_ligand[None, :]
|
| 728 |
+
adj_pocket = batch_mask_pocket[:, None] == batch_mask_pocket[None, :]
|
| 729 |
+
adj_cross = batch_mask_ligand[:, None] == batch_mask_pocket[None, :]
|
| 730 |
+
|
| 731 |
+
if self.edge_cutoff_l is not None:
|
| 732 |
+
adj_ligand = adj_ligand & (torch.cdist(x_ligand, x_ligand) <= self.edge_cutoff_l)
|
| 733 |
+
|
| 734 |
+
# Add missing bonds if they got removed
|
| 735 |
+
adj_ligand[bond_inds_ligand[0], bond_inds_ligand[1]] = True
|
| 736 |
+
|
| 737 |
+
if self.edge_cutoff_p is not None and len(x_pocket) > 0:
|
| 738 |
+
adj_pocket = adj_pocket & (torch.cdist(x_pocket, x_pocket) <= self.edge_cutoff_p)
|
| 739 |
+
|
| 740 |
+
# Add missing bonds if they got removed
|
| 741 |
+
adj_pocket[bond_inds_pocket[0], bond_inds_pocket[1]] = True
|
| 742 |
+
|
| 743 |
+
if self.edge_cutoff_i is not None and len(x_pocket) > 0:
|
| 744 |
+
adj_cross = adj_cross & (torch.cdist(x_ligand, x_pocket) <= self.edge_cutoff_i)
|
| 745 |
+
|
| 746 |
+
adj = torch.cat((torch.cat((adj_ligand, adj_cross), dim=1),
|
| 747 |
+
torch.cat((adj_cross.T, adj_pocket), dim=1)), dim=0)
|
| 748 |
+
|
| 749 |
+
if not self_edges:
|
| 750 |
+
adj = adj ^ torch.eye(*adj.size(), out=torch.empty_like(adj))
|
| 751 |
+
|
| 752 |
+
# # ensure that edge definition is consistent if bonds are provided (for loss computation)
|
| 753 |
+
# if bond_inds_ligand is not None:
|
| 754 |
+
# # remove ligand edges
|
| 755 |
+
# adj[:adj_ligand.size(0), :adj_ligand.size(1)] = False
|
| 756 |
+
# edges = torch.stack(torch.where(adj), dim=0)
|
| 757 |
+
# # add ligand edges back with original definition
|
| 758 |
+
# edges = torch.cat([bond_inds_ligand, edges], dim=-1)
|
| 759 |
+
# else:
|
| 760 |
+
# edges = torch.stack(torch.where(adj), dim=0)
|
| 761 |
+
|
| 762 |
+
# Feature matrix
|
| 763 |
+
ligand_nobond_onehot = F.one_hot(torch.tensor(
|
| 764 |
+
self.bond_dict['NOBOND'], device=bond_feat_ligand.device),
|
| 765 |
+
num_classes=self.ligand_bond_encoder[0].in_features)
|
| 766 |
+
ligand_nobond_emb = self.ligand_bond_encoder(
|
| 767 |
+
ligand_nobond_onehot.to(FLOAT_TYPE))
|
| 768 |
+
feat_ligand = ligand_nobond_emb.repeat(*adj_ligand.shape, 1)
|
| 769 |
+
feat_ligand[bond_inds_ligand[0], bond_inds_ligand[1]] = bond_feat_ligand
|
| 770 |
+
|
| 771 |
+
if len(adj_pocket) > 0:
|
| 772 |
+
pocket_nobond_onehot = F.one_hot(torch.tensor(
|
| 773 |
+
self.pocket_bond_dict['NOBOND'], device=bond_feat_pocket.device),
|
| 774 |
+
num_classes=self.pocket_bond_nf)
|
| 775 |
+
pocket_nobond_emb = self.pocket_bond_encoder(
|
| 776 |
+
pocket_nobond_onehot.to(FLOAT_TYPE))
|
| 777 |
+
feat_pocket = pocket_nobond_emb.repeat(*adj_pocket.shape, 1)
|
| 778 |
+
feat_pocket[bond_inds_pocket[0], bond_inds_pocket[1]] = bond_feat_pocket
|
| 779 |
+
|
| 780 |
+
feat_cross = self.cross_emb.repeat(*adj_cross.shape, 1)
|
| 781 |
+
|
| 782 |
+
feats = torch.cat((torch.cat((feat_ligand, feat_cross), dim=1),
|
| 783 |
+
torch.cat((feat_cross.transpose(0, 1), feat_pocket), dim=1)), dim=0)
|
| 784 |
+
else:
|
| 785 |
+
feats = feat_ligand
|
| 786 |
+
|
| 787 |
+
# Return results
|
| 788 |
+
edges = torch.stack(torch.where(adj), dim=0)
|
| 789 |
+
edge_feat = feats[edges[0], edges[1]]
|
| 790 |
+
|
| 791 |
+
return edges, edge_feat
|
src/model/dynamics_hetero.py
ADDED
|
@@ -0,0 +1,1008 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Iterable
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from functools import partial
|
| 4 |
+
import functools
|
| 5 |
+
import warnings
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import numpy as np
|
| 10 |
+
from torch_scatter import scatter_mean
|
| 11 |
+
from torch_geometric.nn import MessagePassing
|
| 12 |
+
from torch_geometric.nn.module_dict import ModuleDict
|
| 13 |
+
from torch_geometric.utils.hetero import check_add_self_loops
|
| 14 |
+
try:
|
| 15 |
+
from torch_geometric.nn.conv.hgt_conv import group
|
| 16 |
+
except ImportError as e:
|
| 17 |
+
from torch_geometric.nn.conv.hetero_conv import group
|
| 18 |
+
|
| 19 |
+
from src.model.dynamics import DynamicsBase
|
| 20 |
+
from src.model import gvp
|
| 21 |
+
from src.model.gvp import GVP, _rbf, _normalize, tuple_index, tuple_sum, _split, tuple_cat, _merge
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MyModuleDict(nn.ModuleDict):
|
| 25 |
+
def __init__(self, modules):
|
| 26 |
+
# a mapping (dictionary) of (string: module) or an iterable of key-value pairs of type (string, module)
|
| 27 |
+
if isinstance(modules, dict):
|
| 28 |
+
super().__init__({str(k): v for k, v in modules.items()})
|
| 29 |
+
else:
|
| 30 |
+
raise NotImplementedError
|
| 31 |
+
|
| 32 |
+
def __getitem__(self, key):
|
| 33 |
+
return super().__getitem__(str(key))
|
| 34 |
+
|
| 35 |
+
def __setitem__(self, key, value):
|
| 36 |
+
super().__setitem__(str(key), value)
|
| 37 |
+
|
| 38 |
+
def __delitem__(self, key):
|
| 39 |
+
super().__delitem__(str(key))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class MyHeteroConv(nn.Module):
|
| 43 |
+
"""
|
| 44 |
+
Implementation from PyG 2.2.0 with minor changes.
|
| 45 |
+
Override forward pass to control the final aggregation
|
| 46 |
+
Ref.: https://pytorch-geometric.readthedocs.io/en/2.2.0/_modules/torch_geometric/nn/conv/hetero_conv.html
|
| 47 |
+
"""
|
| 48 |
+
def __init__(self, convs, aggr="sum"):
|
| 49 |
+
self.vo = {}
|
| 50 |
+
for k, module in convs.items():
|
| 51 |
+
dst = k[-1]
|
| 52 |
+
if dst not in self.vo:
|
| 53 |
+
self.vo[dst] = module.vo
|
| 54 |
+
else:
|
| 55 |
+
assert self.vo[dst] == module.vo
|
| 56 |
+
|
| 57 |
+
# from the original implementation in PyTorch Geometric
|
| 58 |
+
super().__init__()
|
| 59 |
+
|
| 60 |
+
for edge_type, module in convs.items():
|
| 61 |
+
check_add_self_loops(module, [edge_type])
|
| 62 |
+
|
| 63 |
+
src_node_types = set([key[0] for key in convs.keys()])
|
| 64 |
+
dst_node_types = set([key[-1] for key in convs.keys()])
|
| 65 |
+
if len(src_node_types - dst_node_types) > 0:
|
| 66 |
+
warnings.warn(
|
| 67 |
+
f"There exist node types ({src_node_types - dst_node_types}) "
|
| 68 |
+
f"whose representations do not get updated during message "
|
| 69 |
+
f"passing as they do not occur as destination type in any "
|
| 70 |
+
f"edge type. This may lead to unexpected behaviour.")
|
| 71 |
+
|
| 72 |
+
self.convs = ModuleDict({'__'.join(k): v for k, v in convs.items()})
|
| 73 |
+
self.aggr = aggr
|
| 74 |
+
|
| 75 |
+
def reset_parameters(self):
|
| 76 |
+
for conv in self.convs.values():
|
| 77 |
+
conv.reset_parameters()
|
| 78 |
+
|
| 79 |
+
def __repr__(self) -> str:
|
| 80 |
+
return f'{self.__class__.__name__}(num_relations={len(self.convs)})'
|
| 81 |
+
|
| 82 |
+
def forward(
|
| 83 |
+
self,
|
| 84 |
+
x_dict,
|
| 85 |
+
edge_index_dict,
|
| 86 |
+
*args_dict,
|
| 87 |
+
**kwargs_dict,
|
| 88 |
+
):
|
| 89 |
+
r"""
|
| 90 |
+
Args:
|
| 91 |
+
x_dict (Dict[str, Tensor]): A dictionary holding node feature
|
| 92 |
+
information for each individual node type.
|
| 93 |
+
edge_index_dict (Dict[Tuple[str, str, str], Tensor]): A dictionary
|
| 94 |
+
holding graph connectivity information for each individual
|
| 95 |
+
edge type.
|
| 96 |
+
*args_dict (optional): Additional forward arguments of invididual
|
| 97 |
+
:class:`torch_geometric.nn.conv.MessagePassing` layers.
|
| 98 |
+
**kwargs_dict (optional): Additional forward arguments of
|
| 99 |
+
individual :class:`torch_geometric.nn.conv.MessagePassing`
|
| 100 |
+
layers.
|
| 101 |
+
For example, if a specific GNN layer at edge type
|
| 102 |
+
:obj:`edge_type` expects edge attributes :obj:`edge_attr` as a
|
| 103 |
+
forward argument, then you can pass them to
|
| 104 |
+
:meth:`~torch_geometric.nn.conv.HeteroConv.forward` via
|
| 105 |
+
:obj:`edge_attr_dict = { edge_type: edge_attr }`.
|
| 106 |
+
"""
|
| 107 |
+
out_dict = defaultdict(list)
|
| 108 |
+
out_dict_edge = {}
|
| 109 |
+
for edge_type, edge_index in edge_index_dict.items():
|
| 110 |
+
src, rel, dst = edge_type
|
| 111 |
+
|
| 112 |
+
str_edge_type = '__'.join(edge_type)
|
| 113 |
+
if str_edge_type not in self.convs:
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
args = []
|
| 117 |
+
for value_dict in args_dict:
|
| 118 |
+
if edge_type in value_dict:
|
| 119 |
+
args.append(value_dict[edge_type])
|
| 120 |
+
elif src == dst and src in value_dict:
|
| 121 |
+
args.append(value_dict[src])
|
| 122 |
+
elif src in value_dict or dst in value_dict:
|
| 123 |
+
args.append(
|
| 124 |
+
(value_dict.get(src, None), value_dict.get(dst, None)))
|
| 125 |
+
|
| 126 |
+
kwargs = {}
|
| 127 |
+
for arg, value_dict in kwargs_dict.items():
|
| 128 |
+
arg = arg[:-5] # `{*}_dict`
|
| 129 |
+
if edge_type in value_dict:
|
| 130 |
+
kwargs[arg] = value_dict[edge_type]
|
| 131 |
+
elif src == dst and src in value_dict:
|
| 132 |
+
kwargs[arg] = value_dict[src]
|
| 133 |
+
elif src in value_dict or dst in value_dict:
|
| 134 |
+
kwargs[arg] = (value_dict.get(src, None),
|
| 135 |
+
value_dict.get(dst, None))
|
| 136 |
+
|
| 137 |
+
conv = self.convs[str_edge_type]
|
| 138 |
+
|
| 139 |
+
if src == dst:
|
| 140 |
+
out = conv(x_dict[src], edge_index, *args, **kwargs)
|
| 141 |
+
else:
|
| 142 |
+
out = conv((x_dict[src], x_dict[dst]), edge_index, *args,
|
| 143 |
+
**kwargs)
|
| 144 |
+
|
| 145 |
+
if isinstance(out, (tuple, list)):
|
| 146 |
+
out, out_edge = out
|
| 147 |
+
out_dict_edge[edge_type] = out_edge
|
| 148 |
+
|
| 149 |
+
out_dict[dst].append(out)
|
| 150 |
+
|
| 151 |
+
for key, value in out_dict.items():
|
| 152 |
+
out_dict[key] = group(value, self.aggr)
|
| 153 |
+
out_dict[key] = _split(out_dict[key], self.vo[key])
|
| 154 |
+
|
| 155 |
+
return out_dict if len(out_dict_edge) <= 0 else out_dict, out_dict_edge
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class GVPHeteroConv(MessagePassing):
|
| 159 |
+
'''
|
| 160 |
+
Graph convolution / message passing with Geometric Vector Perceptrons.
|
| 161 |
+
Takes in a graph with node and edge embeddings,
|
| 162 |
+
and returns new node embeddings.
|
| 163 |
+
|
| 164 |
+
This does NOT do residual updates and pointwise feedforward layers
|
| 165 |
+
---see `GVPConvLayer`.
|
| 166 |
+
|
| 167 |
+
:param in_dims: input node embedding dimensions (n_scalar, n_vector)
|
| 168 |
+
:param out_dims: output node embedding dimensions (n_scalar, n_vector)
|
| 169 |
+
:param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
|
| 170 |
+
:param n_layers: number of GVPs in the message function
|
| 171 |
+
:param module_list: preconstructed message function, overrides n_layers
|
| 172 |
+
:param aggr: should be "add" if some incoming edges are masked, as in
|
| 173 |
+
a masked autoregressive decoder architecture, otherwise "mean"
|
| 174 |
+
:param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
|
| 175 |
+
:param vector_gate: whether to use vector gating.
|
| 176 |
+
(vector_act will be used as sigma^+ in vector gating if `True`)
|
| 177 |
+
:param update_edge_attr: whether to compute an updated edge representation
|
| 178 |
+
'''
|
| 179 |
+
|
| 180 |
+
def __init__(self, in_dims, out_dims, edge_dims, in_dims_other=None,
|
| 181 |
+
n_layers=3, module_list=None, aggr="mean",
|
| 182 |
+
activations=(F.relu, torch.sigmoid), vector_gate=False,
|
| 183 |
+
update_edge_attr=False):
|
| 184 |
+
super(GVPHeteroConv, self).__init__(aggr=aggr)
|
| 185 |
+
|
| 186 |
+
if in_dims_other is None:
|
| 187 |
+
in_dims_other = in_dims
|
| 188 |
+
|
| 189 |
+
self.si, self.vi = in_dims
|
| 190 |
+
self.si_other, self.vi_other = in_dims_other
|
| 191 |
+
self.so, self.vo = out_dims
|
| 192 |
+
self.se, self.ve = edge_dims
|
| 193 |
+
self.update_edge_attr = update_edge_attr
|
| 194 |
+
|
| 195 |
+
GVP_ = functools.partial(GVP,
|
| 196 |
+
activations=activations,
|
| 197 |
+
vector_gate=vector_gate)
|
| 198 |
+
|
| 199 |
+
def get_modules(module_list, out_dims):
|
| 200 |
+
module_list = module_list or []
|
| 201 |
+
if not module_list:
|
| 202 |
+
if n_layers == 1:
|
| 203 |
+
module_list.append(
|
| 204 |
+
GVP_((self.si + self.si_other + self.se, self.vi + self.vi_other + self.ve),
|
| 205 |
+
(self.so, self.vo), activations=(None, None)))
|
| 206 |
+
else:
|
| 207 |
+
module_list.append(
|
| 208 |
+
GVP_((self.si + self.si_other + self.se, self.vi + self.vi_other + self.ve),
|
| 209 |
+
out_dims)
|
| 210 |
+
)
|
| 211 |
+
for i in range(n_layers - 2):
|
| 212 |
+
module_list.append(GVP_(out_dims, out_dims))
|
| 213 |
+
module_list.append(GVP_(out_dims, out_dims,
|
| 214 |
+
activations=(None, None)))
|
| 215 |
+
return nn.Sequential(*module_list)
|
| 216 |
+
|
| 217 |
+
self.message_func = get_modules(module_list, out_dims)
|
| 218 |
+
self.edge_func = get_modules(module_list, edge_dims) if self.update_edge_attr else None
|
| 219 |
+
|
| 220 |
+
def forward(self, x, edge_index, edge_attr):
|
| 221 |
+
'''
|
| 222 |
+
:param x: tuple (s, V) of `torch.Tensor`
|
| 223 |
+
:param edge_index: array of shape [2, n_edges]
|
| 224 |
+
:param edge_attr: tuple (s, V) of `torch.Tensor`
|
| 225 |
+
'''
|
| 226 |
+
elem_0, elem_1 = x
|
| 227 |
+
if isinstance(elem_0, (tuple, list)):
|
| 228 |
+
assert isinstance(elem_1, (tuple, list))
|
| 229 |
+
x_s = (elem_0[0], elem_1[0])
|
| 230 |
+
x_v = (elem_0[1].reshape(elem_0[1].shape[0], 3 * elem_0[1].shape[1]),
|
| 231 |
+
elem_1[1].reshape(elem_1[1].shape[0], 3 * elem_1[1].shape[1]))
|
| 232 |
+
else:
|
| 233 |
+
x_s, x_v = elem_0, elem_1
|
| 234 |
+
x_v = x_v.reshape(x_v.shape[0], 3 * x_v.shape[1])
|
| 235 |
+
|
| 236 |
+
message = self.propagate(edge_index, s=x_s, v=x_v, edge_attr=edge_attr)
|
| 237 |
+
|
| 238 |
+
if self.update_edge_attr:
|
| 239 |
+
if isinstance(x_s, (tuple, list)):
|
| 240 |
+
s_i, s_j = x_s[1][edge_index[1]], x_s[0][edge_index[0]]
|
| 241 |
+
else:
|
| 242 |
+
s_i, s_j = x_s[edge_index[1]], x_s[edge_index[0]]
|
| 243 |
+
|
| 244 |
+
if isinstance(x_v, (tuple, list)):
|
| 245 |
+
v_i, v_j = x_v[1][edge_index[1]], x_v[0][edge_index[0]]
|
| 246 |
+
else:
|
| 247 |
+
v_i, v_j = x_v[edge_index[1]], x_v[edge_index[0]]
|
| 248 |
+
|
| 249 |
+
edge_out = self.edge_attr(s_i, v_i, s_j, v_j, edge_attr)
|
| 250 |
+
# return _split(message, self.vo), edge_out
|
| 251 |
+
return message, edge_out
|
| 252 |
+
else:
|
| 253 |
+
# return _split(message, self.vo)
|
| 254 |
+
return message
|
| 255 |
+
|
| 256 |
+
def message(self, s_i, v_i, s_j, v_j, edge_attr):
|
| 257 |
+
v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3)
|
| 258 |
+
v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3)
|
| 259 |
+
message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
|
| 260 |
+
message = self.message_func(message)
|
| 261 |
+
return _merge(*message)
|
| 262 |
+
|
| 263 |
+
def edge_attr(self, s_i, v_i, s_j, v_j, edge_attr):
|
| 264 |
+
v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3)
|
| 265 |
+
v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3)
|
| 266 |
+
message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
|
| 267 |
+
return self.edge_func(message)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class GVPHeteroConvLayer(nn.Module):
|
| 271 |
+
"""
|
| 272 |
+
Full graph convolution / message passing layer with
|
| 273 |
+
Geometric Vector Perceptrons. Residually updates node embeddings with
|
| 274 |
+
aggregated incoming messages, applies a pointwise feedforward
|
| 275 |
+
network to node embeddings, and returns updated node embeddings.
|
| 276 |
+
|
| 277 |
+
To only compute the aggregated messages, see `GVPConv`.
|
| 278 |
+
|
| 279 |
+
:param conv_dims: dictionary defining (src_dim, dst_dim, edge_dim) for each edge type
|
| 280 |
+
"""
|
| 281 |
+
def __init__(self, conv_dims,
|
| 282 |
+
n_message=3, n_feedforward=2, drop_rate=.1,
|
| 283 |
+
activations=(F.relu, torch.sigmoid), vector_gate=False,
|
| 284 |
+
update_edge_attr=False, ln_vector_weight=False):
|
| 285 |
+
|
| 286 |
+
super(GVPHeteroConvLayer, self).__init__()
|
| 287 |
+
self.update_edge_attr = update_edge_attr
|
| 288 |
+
|
| 289 |
+
gvp_conv = partial(GVPHeteroConv,
|
| 290 |
+
n_layers=n_message,
|
| 291 |
+
aggr="sum",
|
| 292 |
+
activations=activations,
|
| 293 |
+
vector_gate=vector_gate,
|
| 294 |
+
update_edge_attr=update_edge_attr)
|
| 295 |
+
|
| 296 |
+
def get_feedforward(n_dims):
|
| 297 |
+
GVP_ = partial(GVP, activations=activations, vector_gate=vector_gate)
|
| 298 |
+
|
| 299 |
+
ff_func = []
|
| 300 |
+
if n_feedforward == 1:
|
| 301 |
+
ff_func.append(GVP_(n_dims, n_dims, activations=(None, None)))
|
| 302 |
+
else:
|
| 303 |
+
hid_dims = 4 * n_dims[0], 2 * n_dims[1]
|
| 304 |
+
ff_func.append(GVP_(n_dims, hid_dims))
|
| 305 |
+
for i in range(n_feedforward - 2):
|
| 306 |
+
ff_func.append(GVP_(hid_dims, hid_dims))
|
| 307 |
+
ff_func.append(GVP_(hid_dims, n_dims, activations=(None, None)))
|
| 308 |
+
return nn.Sequential(*ff_func)
|
| 309 |
+
|
| 310 |
+
# self.conv = HeteroConv({k: gvp_conv(*dims) for k, dims in conv_dims.items()}, aggr='sum')
|
| 311 |
+
self.conv = MyHeteroConv({k: gvp_conv(*dims) for k, dims in conv_dims.items()}, aggr='sum')
|
| 312 |
+
|
| 313 |
+
node_dims = {k[-1]: dims[1] for k, dims in conv_dims.items()}
|
| 314 |
+
self.norm0 = MyModuleDict({k: gvp.LayerNorm(dims, ln_vector_weight) for k, dims in node_dims.items()})
|
| 315 |
+
self.dropout0 = MyModuleDict({k: gvp.Dropout(drop_rate) for k, dims in node_dims.items()})
|
| 316 |
+
self.ff_func = MyModuleDict({k: get_feedforward(dims) for k, dims in node_dims.items()})
|
| 317 |
+
self.norm1 = MyModuleDict({k: gvp.LayerNorm(dims, ln_vector_weight) for k, dims in node_dims.items()})
|
| 318 |
+
self.dropout1 = MyModuleDict({k: gvp.Dropout(drop_rate) for k, dims in node_dims.items()})
|
| 319 |
+
|
| 320 |
+
if self.update_edge_attr:
|
| 321 |
+
self.edge_norm0 = MyModuleDict({k: gvp.LayerNorm(dims[2], ln_vector_weight) for k, dims in conv_dims.items()})
|
| 322 |
+
self.edge_dropout0 = MyModuleDict({k: gvp.Dropout(drop_rate) for k, dims in conv_dims.items()})
|
| 323 |
+
self.edge_ff = MyModuleDict({k: get_feedforward(dims[2]) for k, dims in conv_dims.items()})
|
| 324 |
+
self.edge_norm1 = MyModuleDict({k: gvp.LayerNorm(dims[2], ln_vector_weight) for k, dims in conv_dims.items()})
|
| 325 |
+
self.edge_dropout1 = MyModuleDict({k: gvp.Dropout(drop_rate) for k, dims in conv_dims.items()})
|
| 326 |
+
|
| 327 |
+
def forward(self, x_dict, edge_index_dict, edge_attr_dict, node_mask_dict=None):
|
| 328 |
+
'''
|
| 329 |
+
:param x: tuple (s, V) of `torch.Tensor`
|
| 330 |
+
:param edge_index: array of shape [2, n_edges]
|
| 331 |
+
:param edge_attr: tuple (s, V) of `torch.Tensor`
|
| 332 |
+
:param node_mask: array of type `bool` to index into the first
|
| 333 |
+
dim of node embeddings (s, V). If not `None`, only
|
| 334 |
+
these nodes will be updated.
|
| 335 |
+
'''
|
| 336 |
+
|
| 337 |
+
dh_dict = self.conv(x_dict, edge_index_dict, edge_attr_dict)
|
| 338 |
+
|
| 339 |
+
if self.update_edge_attr:
|
| 340 |
+
dh_dict, de_dict = dh_dict
|
| 341 |
+
|
| 342 |
+
for k, edge_attr in edge_attr_dict.items():
|
| 343 |
+
de = de_dict[k]
|
| 344 |
+
|
| 345 |
+
edge_attr = self.edge_norm0[k](tuple_sum(edge_attr, self.edge_dropout0[k](de)))
|
| 346 |
+
de = self.edge_ff[k](edge_attr)
|
| 347 |
+
edge_attr = self.edge_norm1[k](tuple_sum(edge_attr, self.edge_dropout1[k](de)))
|
| 348 |
+
|
| 349 |
+
edge_attr_dict[k] = edge_attr
|
| 350 |
+
|
| 351 |
+
for k, x in x_dict.items():
|
| 352 |
+
dh = dh_dict[k]
|
| 353 |
+
node_mask = None if node_mask_dict is None else node_mask_dict[k]
|
| 354 |
+
|
| 355 |
+
if node_mask is not None:
|
| 356 |
+
x_ = x
|
| 357 |
+
x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask)
|
| 358 |
+
|
| 359 |
+
x = self.norm0[k](tuple_sum(x, self.dropout0[k](dh)))
|
| 360 |
+
|
| 361 |
+
dh = self.ff_func[k](x)
|
| 362 |
+
x = self.norm1[k](tuple_sum(x, self.dropout1[k](dh)))
|
| 363 |
+
|
| 364 |
+
if node_mask is not None:
|
| 365 |
+
x_[0][node_mask], x_[1][node_mask] = x[0], x[1]
|
| 366 |
+
x = x_
|
| 367 |
+
|
| 368 |
+
x_dict[k] = x
|
| 369 |
+
|
| 370 |
+
return (x_dict, edge_attr_dict) if self.update_edge_attr else x_dict
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
class GVPModel(torch.nn.Module):
|
| 374 |
+
"""
|
| 375 |
+
GVP-GNN model
|
| 376 |
+
inspired by: https://github.com/drorlab/gvp-pytorch/blob/main/gvp/models.py
|
| 377 |
+
and: https://github.com/drorlab/gvp-pytorch/blob/82af6b22eaf8311c15733117b0071408d24ed877/gvp/atom3d.py#L115
|
| 378 |
+
"""
|
| 379 |
+
def __init__(self,
|
| 380 |
+
node_in_dim_ligand, node_in_dim_pocket,
|
| 381 |
+
edge_in_dim_ligand, edge_in_dim_pocket, edge_in_dim_interaction,
|
| 382 |
+
node_h_dim_ligand, node_h_dim_pocket,
|
| 383 |
+
edge_h_dim_ligand, edge_h_dim_pocket, edge_h_dim_interaction,
|
| 384 |
+
node_out_dim_ligand=None, node_out_dim_pocket=None,
|
| 385 |
+
edge_out_dim_ligand=None, edge_out_dim_pocket=None, edge_out_dim_interaction=None,
|
| 386 |
+
num_layers=3, drop_rate=0.1, vector_gate=False, update_edge_attr=False):
|
| 387 |
+
|
| 388 |
+
super(GVPModel, self).__init__()
|
| 389 |
+
|
| 390 |
+
self.update_edge_attr = update_edge_attr
|
| 391 |
+
|
| 392 |
+
self.node_in = nn.ModuleDict({
|
| 393 |
+
'ligand': GVP(node_in_dim_ligand, node_h_dim_ligand, activations=(None, None), vector_gate=vector_gate),
|
| 394 |
+
'pocket': GVP(node_in_dim_pocket, node_h_dim_pocket, activations=(None, None), vector_gate=vector_gate),
|
| 395 |
+
})
|
| 396 |
+
# self.edge_in = MyModuleDict({
|
| 397 |
+
# ('ligand', 'ligand'): GVP(edge_in_dim_ligand, edge_h_dim_ligand, activations=(None, None), vector_gate=vector_gate),
|
| 398 |
+
# ('pocket', 'pocket'): GVP(edge_in_dim_pocket, edge_h_dim_pocket, activations=(None, None), vector_gate=vector_gate),
|
| 399 |
+
# ('ligand', 'pocket'): GVP(edge_in_dim_interaction, edge_h_dim_interaction, activations=(None, None), vector_gate=vector_gate),
|
| 400 |
+
# ('pocket', 'ligand'): GVP(edge_in_dim_interaction, edge_h_dim_interaction, activations=(None, None), vector_gate=vector_gate),
|
| 401 |
+
# })
|
| 402 |
+
self.edge_in = MyModuleDict({
|
| 403 |
+
('ligand', '', 'ligand'): GVP(edge_in_dim_ligand, edge_h_dim_ligand, activations=(None, None), vector_gate=vector_gate),
|
| 404 |
+
('pocket', '', 'pocket'): GVP(edge_in_dim_pocket, edge_h_dim_pocket, activations=(None, None), vector_gate=vector_gate),
|
| 405 |
+
('ligand', '', 'pocket'): GVP(edge_in_dim_interaction, edge_h_dim_interaction, activations=(None, None), vector_gate=vector_gate),
|
| 406 |
+
('pocket', '', 'ligand'): GVP(edge_in_dim_interaction, edge_h_dim_interaction, activations=(None, None), vector_gate=vector_gate),
|
| 407 |
+
})
|
| 408 |
+
|
| 409 |
+
# conv_dims = {
|
| 410 |
+
# ('ligand', 'ligand'): (node_h_dim_ligand, node_h_dim_ligand, edge_h_dim_ligand),
|
| 411 |
+
# ('pocket', 'pocket'): (node_h_dim_pocket, node_h_dim_pocket, edge_h_dim_pocket),
|
| 412 |
+
# ('ligand', 'pocket'): (node_h_dim_ligand, node_h_dim_pocket, edge_h_dim_interaction),
|
| 413 |
+
# ('pocket', 'ligand'): (node_h_dim_pocket, node_h_dim_ligand, edge_h_dim_interaction),
|
| 414 |
+
# }
|
| 415 |
+
conv_dims = {
|
| 416 |
+
('ligand', '', 'ligand'): (node_h_dim_ligand, node_h_dim_ligand, edge_h_dim_ligand),
|
| 417 |
+
('pocket', '', 'pocket'): (node_h_dim_pocket, node_h_dim_pocket, edge_h_dim_pocket),
|
| 418 |
+
('ligand', '', 'pocket'): (node_h_dim_ligand, node_h_dim_pocket, edge_h_dim_interaction, node_h_dim_pocket),
|
| 419 |
+
('pocket', '', 'ligand'): (node_h_dim_pocket, node_h_dim_ligand, edge_h_dim_interaction, node_h_dim_ligand),
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
self.layers = nn.ModuleList(
|
| 423 |
+
GVPHeteroConvLayer(conv_dims,
|
| 424 |
+
drop_rate=drop_rate,
|
| 425 |
+
update_edge_attr=self.update_edge_attr,
|
| 426 |
+
activations=(F.relu, None),
|
| 427 |
+
vector_gate=vector_gate,
|
| 428 |
+
ln_vector_weight=True)
|
| 429 |
+
for _ in range(num_layers))
|
| 430 |
+
|
| 431 |
+
self.node_out = nn.ModuleDict({
|
| 432 |
+
'ligand': GVP(node_h_dim_ligand, node_out_dim_ligand, activations=(None, None), vector_gate=vector_gate),
|
| 433 |
+
'pocket': GVP(node_h_dim_pocket, node_out_dim_pocket, activations=(None, None), vector_gate=vector_gate) if node_out_dim_pocket is not None else None,
|
| 434 |
+
})
|
| 435 |
+
# self.edge_out = MyModuleDict({
|
| 436 |
+
# ('ligand', 'ligand'): GVP(edge_h_dim_ligand, edge_out_dim_ligand, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_ligand is not None else None,
|
| 437 |
+
# ('pocket', 'pocket'): GVP(edge_h_dim_pocket, edge_out_dim_pocket, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_pocket is not None else None,
|
| 438 |
+
# ('ligand', 'pocket'): GVP(edge_h_dim_interaction, edge_out_dim_interaction, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_interaction is not None else None,
|
| 439 |
+
# ('pocket', 'ligand'): GVP(edge_h_dim_interaction, edge_out_dim_interaction, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_interaction is not None else None,
|
| 440 |
+
# })
|
| 441 |
+
self.edge_out = MyModuleDict({
|
| 442 |
+
('ligand', '', 'ligand'): GVP(edge_h_dim_ligand, edge_out_dim_ligand, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_ligand is not None else None,
|
| 443 |
+
('pocket', '', 'pocket'): GVP(edge_h_dim_pocket, edge_out_dim_pocket, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_pocket is not None else None,
|
| 444 |
+
('ligand', '', 'pocket'): GVP(edge_h_dim_interaction, edge_out_dim_interaction, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_interaction is not None else None,
|
| 445 |
+
('pocket', '', 'ligand'): GVP(edge_h_dim_interaction, edge_out_dim_interaction, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_interaction is not None else None,
|
| 446 |
+
})
|
| 447 |
+
|
| 448 |
+
def forward(self, node_attr, batch_mask, edge_index, edge_attr):
|
| 449 |
+
|
| 450 |
+
# to hidden dimension
|
| 451 |
+
for k in node_attr.keys():
|
| 452 |
+
node_attr[k] = self.node_in[k](node_attr[k])
|
| 453 |
+
|
| 454 |
+
for k in edge_attr.keys():
|
| 455 |
+
edge_attr[k] = self.edge_in[k](edge_attr[k])
|
| 456 |
+
|
| 457 |
+
# convolutions
|
| 458 |
+
for layer in self.layers:
|
| 459 |
+
out = layer(node_attr, edge_index, edge_attr)
|
| 460 |
+
if self.update_edge_attr:
|
| 461 |
+
node_attr, edge_attr = out
|
| 462 |
+
else:
|
| 463 |
+
node_attr = out
|
| 464 |
+
|
| 465 |
+
# to output dimension
|
| 466 |
+
for k in node_attr.keys():
|
| 467 |
+
node_attr[k] = self.node_out[k](node_attr[k]) \
|
| 468 |
+
if self.node_out[k] is not None else None
|
| 469 |
+
|
| 470 |
+
if self.update_edge_attr:
|
| 471 |
+
for k in edge_attr.keys():
|
| 472 |
+
if self.edge_out[k] is not None:
|
| 473 |
+
edge_attr[k] = self.edge_out[k](edge_attr[k])
|
| 474 |
+
|
| 475 |
+
return node_attr, edge_attr
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
class DynamicsHetero(DynamicsBase):
|
| 479 |
+
def __init__(self, atom_nf, residue_nf, bond_dict, pocket_bond_dict,
|
| 480 |
+
condition_time=True,
|
| 481 |
+
num_rbf_time=None,
|
| 482 |
+
model='gvp',
|
| 483 |
+
model_params=None,
|
| 484 |
+
edge_cutoff_ligand=None,
|
| 485 |
+
edge_cutoff_pocket=None,
|
| 486 |
+
edge_cutoff_interaction=None,
|
| 487 |
+
predict_angles=False,
|
| 488 |
+
predict_frames=False,
|
| 489 |
+
add_cycle_counts=False,
|
| 490 |
+
add_spectral_feat=False,
|
| 491 |
+
add_nma_feat=False,
|
| 492 |
+
reflection_equiv=False,
|
| 493 |
+
d_max=15.0,
|
| 494 |
+
num_rbf_dist=16,
|
| 495 |
+
self_conditioning=False,
|
| 496 |
+
augment_residue_sc=False,
|
| 497 |
+
augment_ligand_sc=False,
|
| 498 |
+
add_chi_as_feature=False,
|
| 499 |
+
angle_act_fn=False,
|
| 500 |
+
add_all_atom_diff=False,
|
| 501 |
+
predict_confidence=False):
|
| 502 |
+
|
| 503 |
+
super().__init__(
|
| 504 |
+
predict_angles=predict_angles,
|
| 505 |
+
predict_frames=predict_frames,
|
| 506 |
+
add_cycle_counts=add_cycle_counts,
|
| 507 |
+
add_spectral_feat=add_spectral_feat,
|
| 508 |
+
self_conditioning=self_conditioning,
|
| 509 |
+
augment_residue_sc=augment_residue_sc,
|
| 510 |
+
augment_ligand_sc=augment_ligand_sc
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
self.model = model
|
| 514 |
+
self.edge_cutoff_l = edge_cutoff_ligand
|
| 515 |
+
self.edge_cutoff_p = edge_cutoff_pocket
|
| 516 |
+
self.edge_cutoff_i = edge_cutoff_interaction
|
| 517 |
+
self.bond_dict = bond_dict
|
| 518 |
+
self.pocket_bond_dict = pocket_bond_dict
|
| 519 |
+
self.bond_nf = len(bond_dict)
|
| 520 |
+
self.pocket_bond_nf = len(pocket_bond_dict)
|
| 521 |
+
# self.edge_dim = edge_dim
|
| 522 |
+
self.add_nma_feat = add_nma_feat
|
| 523 |
+
self.add_chi_as_feature = add_chi_as_feature
|
| 524 |
+
self.add_all_atom_diff = add_all_atom_diff
|
| 525 |
+
self.condition_time = condition_time
|
| 526 |
+
self.predict_confidence = predict_confidence
|
| 527 |
+
|
| 528 |
+
# edge encoding params
|
| 529 |
+
self.reflection_equiv = reflection_equiv
|
| 530 |
+
self.d_max = d_max
|
| 531 |
+
self.num_rbf = num_rbf_dist
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
# Output dimensions dimensions, always tuple (scalar, vector)
|
| 535 |
+
_atom_out = (atom_nf[0], 1) if isinstance(atom_nf, Iterable) else (atom_nf, 1)
|
| 536 |
+
_residue_out = (0, 0)
|
| 537 |
+
|
| 538 |
+
if self.predict_confidence:
|
| 539 |
+
_atom_out = tuple_sum(_atom_out, (1, 0))
|
| 540 |
+
|
| 541 |
+
if self.predict_angles:
|
| 542 |
+
_residue_out = tuple_sum(_residue_out, (5, 0))
|
| 543 |
+
|
| 544 |
+
if self.predict_frames:
|
| 545 |
+
_residue_out = tuple_sum(_residue_out, (3, 1))
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
# Input dimensions dimensions, always tuple (scalar, vector)
|
| 549 |
+
assert isinstance(atom_nf, int), "expected: element onehot"
|
| 550 |
+
_atom_in = (atom_nf, 0)
|
| 551 |
+
assert isinstance(residue_nf, Iterable), "expected: (AA-onehot, vectors to atoms)"
|
| 552 |
+
_residue_in = tuple(residue_nf)
|
| 553 |
+
_residue_atom_dim = residue_nf[1]
|
| 554 |
+
|
| 555 |
+
if self.add_cycle_counts:
|
| 556 |
+
_atom_in = tuple_sum(_atom_in, (3, 0))
|
| 557 |
+
if self.add_spectral_feat:
|
| 558 |
+
_atom_in = tuple_sum(_atom_in, (5, 0))
|
| 559 |
+
|
| 560 |
+
if self.add_nma_feat:
|
| 561 |
+
_residue_in = tuple_sum(_residue_in, (0, 5))
|
| 562 |
+
|
| 563 |
+
if self.add_chi_as_feature:
|
| 564 |
+
_residue_in = tuple_sum(_residue_in, (5, 0))
|
| 565 |
+
|
| 566 |
+
if self.condition_time:
|
| 567 |
+
self.embed_time = num_rbf_time is not None
|
| 568 |
+
self.time_dim = num_rbf_time if self.embed_time else 1
|
| 569 |
+
|
| 570 |
+
_atom_in = tuple_sum(_atom_in, (self.time_dim, 0))
|
| 571 |
+
_residue_in = tuple_sum(_residue_in, (self.time_dim, 0))
|
| 572 |
+
else:
|
| 573 |
+
print('Warning: dynamics model is NOT conditioned on time.')
|
| 574 |
+
|
| 575 |
+
if self.self_conditioning:
|
| 576 |
+
_atom_in = tuple_sum(_atom_in, _atom_out)
|
| 577 |
+
_residue_in = tuple_sum(_residue_in, _residue_out)
|
| 578 |
+
|
| 579 |
+
if self.augment_ligand_sc:
|
| 580 |
+
_atom_in = tuple_sum(_atom_in, (0, 1))
|
| 581 |
+
|
| 582 |
+
if self.augment_residue_sc:
|
| 583 |
+
assert self.predict_angles
|
| 584 |
+
_residue_in = tuple_sum(_residue_in, (0, _residue_atom_dim))
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
# Edge output dimensions, always tuple (scalar, vector)
|
| 588 |
+
_edge_ligand_out = (self.bond_nf, 0)
|
| 589 |
+
_edge_ligand_before_symmetrization = (model_params.edge_h_dim[0], 0)
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
# Edge input dimensions dimensions, always tuple (scalar, vector)
|
| 593 |
+
_edge_ligand_in = (self.bond_nf + self.num_rbf, 1 if self.reflection_equiv else 2)
|
| 594 |
+
_edge_ligand_in = tuple_sum(_edge_ligand_in, _atom_in) # src node
|
| 595 |
+
_edge_ligand_in = tuple_sum(_edge_ligand_in, _atom_in) # dst node
|
| 596 |
+
|
| 597 |
+
if self_conditioning:
|
| 598 |
+
_edge_ligand_in = tuple_sum(_edge_ligand_in, _edge_ligand_out)
|
| 599 |
+
|
| 600 |
+
_n_dist_residue = _residue_atom_dim ** 2 if self.add_all_atom_diff else 1
|
| 601 |
+
_edge_pocket_in = (_n_dist_residue * self.num_rbf + self.pocket_bond_nf, _n_dist_residue)
|
| 602 |
+
_edge_pocket_in = tuple_sum(_edge_pocket_in, _residue_in) # src node
|
| 603 |
+
_edge_pocket_in = tuple_sum(_edge_pocket_in, _residue_in) # dst node
|
| 604 |
+
|
| 605 |
+
_n_dist_interaction = _residue_atom_dim if self.add_all_atom_diff else 1
|
| 606 |
+
_edge_interaction_in = (_n_dist_interaction * self.num_rbf, _n_dist_interaction)
|
| 607 |
+
_edge_interaction_in = tuple_sum(_edge_interaction_in, _atom_in) # atom node
|
| 608 |
+
_edge_interaction_in = tuple_sum(_edge_interaction_in, _residue_in) # residue node
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
# Embeddings for newly added edges
|
| 612 |
+
_ligand_nobond_nf = self.bond_nf + _edge_ligand_out[0] if self.self_conditioning else self.bond_nf
|
| 613 |
+
self.ligand_nobond_emb = nn.Parameter(torch.zeros(_ligand_nobond_nf), requires_grad=True)
|
| 614 |
+
self.pocket_nobond_emb = nn.Parameter(torch.zeros(self.pocket_bond_nf), requires_grad=True)
|
| 615 |
+
|
| 616 |
+
# for access in self-conditioning
|
| 617 |
+
self.atom_out_dim = _atom_out
|
| 618 |
+
self.residue_out_dim = _residue_out
|
| 619 |
+
self.edge_out_dim = _edge_ligand_out
|
| 620 |
+
|
| 621 |
+
if model == 'gvp':
|
| 622 |
+
|
| 623 |
+
self.net = GVPModel(
|
| 624 |
+
node_in_dim_ligand=_atom_in,
|
| 625 |
+
node_in_dim_pocket=_residue_in,
|
| 626 |
+
edge_in_dim_ligand=_edge_ligand_in,
|
| 627 |
+
edge_in_dim_pocket=_edge_pocket_in,
|
| 628 |
+
edge_in_dim_interaction=_edge_interaction_in,
|
| 629 |
+
node_h_dim_ligand=model_params.node_h_dim,
|
| 630 |
+
node_h_dim_pocket=model_params.node_h_dim,
|
| 631 |
+
edge_h_dim_ligand=model_params.edge_h_dim,
|
| 632 |
+
edge_h_dim_pocket=model_params.edge_h_dim,
|
| 633 |
+
edge_h_dim_interaction=model_params.edge_h_dim,
|
| 634 |
+
node_out_dim_ligand=_atom_out,
|
| 635 |
+
node_out_dim_pocket=_residue_out,
|
| 636 |
+
edge_out_dim_ligand=_edge_ligand_before_symmetrization,
|
| 637 |
+
edge_out_dim_pocket=None,
|
| 638 |
+
edge_out_dim_interaction=None,
|
| 639 |
+
num_layers=model_params.n_layers,
|
| 640 |
+
drop_rate=model_params.dropout,
|
| 641 |
+
vector_gate=model_params.vector_gate,
|
| 642 |
+
update_edge_attr=True
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
else:
|
| 646 |
+
raise NotImplementedError(f"{model} is not available")
|
| 647 |
+
|
| 648 |
+
assert _edge_ligand_out[1] == 0
|
| 649 |
+
assert _edge_ligand_before_symmetrization[1] == 0
|
| 650 |
+
self.edge_decoder = nn.Sequential(
|
| 651 |
+
nn.Linear(_edge_ligand_before_symmetrization[0], _edge_ligand_before_symmetrization[0]),
|
| 652 |
+
torch.nn.SiLU(),
|
| 653 |
+
nn.Linear(_edge_ligand_before_symmetrization[0], _edge_ligand_out[0])
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
if angle_act_fn is None:
|
| 657 |
+
self.angle_act_fn = None
|
| 658 |
+
elif angle_act_fn == 'tanh':
|
| 659 |
+
self.angle_act_fn = lambda x: np.pi * F.tanh(x)
|
| 660 |
+
else:
|
| 661 |
+
raise NotImplementedError(f"Angle activation {angle_act_fn} not available")
|
| 662 |
+
|
| 663 |
+
def _forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None,
|
| 664 |
+
h_atoms_sc=None, e_atoms_sc=None, h_residues_sc=None):
|
| 665 |
+
"""
|
| 666 |
+
:param x_atoms:
|
| 667 |
+
:param h_atoms:
|
| 668 |
+
:param mask_atoms:
|
| 669 |
+
:param pocket: must contain keys: 'x', 'one_hot', 'mask', 'bonds' and 'bond_one_hot'
|
| 670 |
+
:param t:
|
| 671 |
+
:param bonds_ligand: tuple - bond indices (2, n_bonds) & bond types (n_bonds, bond_nf)
|
| 672 |
+
:param h_atoms_sc: additional node feature for self-conditioning, (s, V)
|
| 673 |
+
:param e_atoms_sc: additional edge feature for self-conditioning, only scalar
|
| 674 |
+
:param h_residues_sc: additional node feature for self-conditioning, tensor or tuple
|
| 675 |
+
:return:
|
| 676 |
+
"""
|
| 677 |
+
x_residues, h_residues, mask_residues = pocket['x'], pocket['one_hot'], pocket['mask']
|
| 678 |
+
if 'bonds' in pocket:
|
| 679 |
+
bonds_pocket = (pocket['bonds'], pocket['bond_one_hot'])
|
| 680 |
+
else:
|
| 681 |
+
bonds_pocket = None
|
| 682 |
+
|
| 683 |
+
if self.add_chi_as_feature:
|
| 684 |
+
h_residues = torch.cat([h_residues, pocket['chi'][:, :5]], dim=-1)
|
| 685 |
+
|
| 686 |
+
if 'v' in pocket:
|
| 687 |
+
v_residues = pocket['v']
|
| 688 |
+
if self.add_nma_feat:
|
| 689 |
+
v_residues = torch.cat([v_residues, pocket['nma_vec']], dim=1)
|
| 690 |
+
h_residues = (h_residues, v_residues)
|
| 691 |
+
|
| 692 |
+
# NOTE: 'bond' denotes one-directional edges and 'edge' means bi-directional
|
| 693 |
+
# get graph edges and edge attributes
|
| 694 |
+
if bonds_ligand is not None:
|
| 695 |
+
|
| 696 |
+
ligand_bond_indices = bonds_ligand[0]
|
| 697 |
+
|
| 698 |
+
# make sure messages are passed both ways
|
| 699 |
+
ligand_edge_indices = torch.cat(
|
| 700 |
+
[bonds_ligand[0], bonds_ligand[0].flip(dims=[0])], dim=1)
|
| 701 |
+
ligand_edge_types = torch.cat([bonds_ligand[1], bonds_ligand[1]], dim=0)
|
| 702 |
+
if e_atoms_sc is not None:
|
| 703 |
+
e_atoms_sc = torch.cat([e_atoms_sc, e_atoms_sc], dim=0)
|
| 704 |
+
|
| 705 |
+
# add auxiliary features to ligand nodes
|
| 706 |
+
extra_features = self.compute_extra_features(
|
| 707 |
+
mask_atoms, ligand_edge_indices, ligand_edge_types.argmax(-1))
|
| 708 |
+
h_atoms = torch.cat([h_atoms, extra_features], dim=-1)
|
| 709 |
+
|
| 710 |
+
if bonds_pocket is not None:
|
| 711 |
+
# make sure messages are passed both ways
|
| 712 |
+
pocket_edge_indices = torch.cat(
|
| 713 |
+
[bonds_pocket[0], bonds_pocket[0].flip(dims=[0])], dim=1)
|
| 714 |
+
pocket_edge_types = torch.cat([bonds_pocket[1], bonds_pocket[1]], dim=0)
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
# Self-conditioning
|
| 718 |
+
if h_atoms_sc is not None:
|
| 719 |
+
h_atoms = (torch.cat([h_atoms, h_atoms_sc[0]], dim=-1), h_atoms_sc[1])
|
| 720 |
+
|
| 721 |
+
if e_atoms_sc is not None:
|
| 722 |
+
ligand_edge_types = torch.cat([ligand_edge_types, e_atoms_sc], dim=-1)
|
| 723 |
+
|
| 724 |
+
if h_residues_sc is not None:
|
| 725 |
+
# if self.augment_residue_sc:
|
| 726 |
+
if isinstance(h_residues_sc, tuple):
|
| 727 |
+
h_residues = (torch.cat([h_residues[0], h_residues_sc[0]], dim=-1),
|
| 728 |
+
torch.cat([h_residues[1], h_residues_sc[1]], dim=1))
|
| 729 |
+
else:
|
| 730 |
+
h_residues = (torch.cat([h_residues[0], h_residues_sc], dim=-1),
|
| 731 |
+
h_residues[1])
|
| 732 |
+
|
| 733 |
+
if self.condition_time:
|
| 734 |
+
if self.embed_time:
|
| 735 |
+
t = _rbf(t.squeeze(-1), D_min=0.0, D_max=1.0, D_count=self.time_dim, device=t.device)
|
| 736 |
+
if isinstance(h_atoms, tuple) :
|
| 737 |
+
h_atoms = (torch.cat([h_atoms[0], t[mask_atoms]], dim=1), h_atoms[1])
|
| 738 |
+
else:
|
| 739 |
+
h_atoms = torch.cat([h_atoms, t[mask_atoms]], dim=1)
|
| 740 |
+
h_residues = (torch.cat([h_residues[0], t[mask_residues]], dim=1), h_residues[1])
|
| 741 |
+
|
| 742 |
+
empty_pocket = (len(pocket['x']) == 0)
|
| 743 |
+
|
| 744 |
+
# Process edges and encode in shared feature space
|
| 745 |
+
edge_index_dict, edge_attr_dict = self.get_edges(
|
| 746 |
+
x_atoms, h_atoms, mask_atoms, ligand_edge_indices, ligand_edge_types,
|
| 747 |
+
x_residues, h_residues, mask_residues, pocket['v'], pocket_edge_indices, pocket_edge_types,
|
| 748 |
+
empty_pocket=empty_pocket
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
if not empty_pocket:
|
| 752 |
+
node_attr_dict = {
|
| 753 |
+
'ligand': h_atoms,
|
| 754 |
+
'pocket': h_residues,
|
| 755 |
+
}
|
| 756 |
+
batch_mask_dict = {
|
| 757 |
+
'ligand': mask_atoms,
|
| 758 |
+
'pocket': mask_residues,
|
| 759 |
+
}
|
| 760 |
+
else:
|
| 761 |
+
node_attr_dict = {'ligand': h_atoms}
|
| 762 |
+
batch_mask_dict = {'ligand': mask_atoms}
|
| 763 |
+
|
| 764 |
+
if self.model == 'gvp' or self.model == 'gvp_transformer':
|
| 765 |
+
out_node_attr, out_edge_attr = self.net(
|
| 766 |
+
node_attr_dict, batch_mask_dict, edge_index_dict, edge_attr_dict)
|
| 767 |
+
|
| 768 |
+
else:
|
| 769 |
+
raise NotImplementedError(f"Wrong model ({self.model})")
|
| 770 |
+
|
| 771 |
+
h_final_atoms = out_node_attr['ligand'][0]
|
| 772 |
+
vel = out_node_attr['ligand'][1].squeeze(-2)
|
| 773 |
+
|
| 774 |
+
if torch.any(torch.isnan(vel)) or torch.any(torch.isnan(h_final_atoms)):
|
| 775 |
+
if self.training:
|
| 776 |
+
vel[torch.isnan(vel)] = 0.0
|
| 777 |
+
h_final_atoms[torch.isnan(h_final_atoms)] = 0.0
|
| 778 |
+
else:
|
| 779 |
+
raise ValueError("NaN detected in network output")
|
| 780 |
+
|
| 781 |
+
# predict edge type
|
| 782 |
+
edge_final = out_edge_attr[('ligand', '', 'ligand')]
|
| 783 |
+
edges = edge_index_dict[('ligand', '', 'ligand')]
|
| 784 |
+
|
| 785 |
+
# Symmetrize
|
| 786 |
+
edge_logits = torch.zeros(
|
| 787 |
+
(len(mask_atoms), len(mask_atoms), edge_final.size(-1)),
|
| 788 |
+
device=mask_atoms.device)
|
| 789 |
+
edge_logits[edges[0], edges[1]] = edge_final
|
| 790 |
+
edge_logits = (edge_logits + edge_logits.transpose(0, 1)) * 0.5
|
| 791 |
+
|
| 792 |
+
# return upper triangular elements only (matching the input)
|
| 793 |
+
edge_logits = edge_logits[ligand_bond_indices[0], ligand_bond_indices[1]]
|
| 794 |
+
# assert (edge_logits == 0).sum() == 0
|
| 795 |
+
|
| 796 |
+
edge_final_atoms = self.edge_decoder(edge_logits)
|
| 797 |
+
|
| 798 |
+
pred_ligand = {'vel': vel, 'logits_e': edge_final_atoms}
|
| 799 |
+
|
| 800 |
+
if self.predict_confidence:
|
| 801 |
+
pred_ligand['logits_h'] = h_final_atoms[:, :-1]
|
| 802 |
+
pred_ligand['uncertainty_vel'] = F.softplus(h_final_atoms[:, -1])
|
| 803 |
+
else:
|
| 804 |
+
pred_ligand['logits_h'] = h_final_atoms
|
| 805 |
+
|
| 806 |
+
pred_residues = {}
|
| 807 |
+
|
| 808 |
+
# Predict torsion angles
|
| 809 |
+
if self.predict_angles and self.predict_frames:
|
| 810 |
+
residue_s, residue_v = out_node_attr['pocket']
|
| 811 |
+
pred_residues['chi'] = residue_s[:, :5]
|
| 812 |
+
pred_residues['rot'] = residue_s[:, 5:]
|
| 813 |
+
pred_residues['trans'] = residue_v.squeeze(1)
|
| 814 |
+
|
| 815 |
+
elif self.predict_frames:
|
| 816 |
+
pred_residues['rot'], pred_residues['trans'] = out_node_attr['pocket']
|
| 817 |
+
pred_residues['trans'] = pred_residues['trans'].squeeze(1)
|
| 818 |
+
|
| 819 |
+
elif self.predict_angles:
|
| 820 |
+
pred_residues['chi'] = out_node_attr['pocket']
|
| 821 |
+
|
| 822 |
+
if self.angle_act_fn is not None and 'chi' in pred_residues:
|
| 823 |
+
pred_residues['chi'] = self.angle_act_fn(pred_residues['chi'])
|
| 824 |
+
|
| 825 |
+
return pred_ligand, pred_residues
|
| 826 |
+
|
| 827 |
+
def get_edges(self, x_ligand, h_ligand, batch_mask_ligand, edges_ligand, edge_feat_ligand,
|
| 828 |
+
x_pocket, h_pocket, batch_mask_pocket, atom_vectors_pocket, edges_pocket, edge_feat_pocket,
|
| 829 |
+
self_edges=False, empty_pocket=False):
|
| 830 |
+
|
| 831 |
+
# Adjacency matrix
|
| 832 |
+
adj_ligand = batch_mask_ligand[:, None] == batch_mask_ligand[None, :]
|
| 833 |
+
adj_pocket = batch_mask_pocket[:, None] == batch_mask_pocket[None, :]
|
| 834 |
+
adj_cross = batch_mask_ligand[:, None] == batch_mask_pocket[None, :]
|
| 835 |
+
|
| 836 |
+
if self.edge_cutoff_l is not None:
|
| 837 |
+
adj_ligand = adj_ligand & (torch.cdist(x_ligand, x_ligand) <= self.edge_cutoff_l)
|
| 838 |
+
|
| 839 |
+
# Add missing bonds if they got removed
|
| 840 |
+
adj_ligand[edges_ligand[0], edges_ligand[1]] = True
|
| 841 |
+
|
| 842 |
+
if not self_edges:
|
| 843 |
+
adj_ligand = adj_ligand ^ torch.eye(*adj_ligand.size(), out=torch.empty_like(adj_ligand))
|
| 844 |
+
|
| 845 |
+
if self.edge_cutoff_p is not None and not empty_pocket:
|
| 846 |
+
adj_pocket = adj_pocket & (torch.cdist(x_pocket, x_pocket) <= self.edge_cutoff_p)
|
| 847 |
+
|
| 848 |
+
# Add missing bonds if they got removed
|
| 849 |
+
adj_pocket[edges_pocket[0], edges_pocket[1]] = True
|
| 850 |
+
|
| 851 |
+
if not self_edges:
|
| 852 |
+
adj_pocket = adj_pocket ^ torch.eye(*adj_pocket.size(), out=torch.empty_like(adj_pocket))
|
| 853 |
+
|
| 854 |
+
if self.edge_cutoff_i is not None and not empty_pocket:
|
| 855 |
+
adj_cross = adj_cross & (torch.cdist(x_ligand, x_pocket) <= self.edge_cutoff_i)
|
| 856 |
+
|
| 857 |
+
# ligand-ligand edge features
|
| 858 |
+
edges_ligand_updated = torch.stack(torch.where(adj_ligand), dim=0)
|
| 859 |
+
feat_ligand = self.ligand_nobond_emb.repeat(*adj_ligand.shape, 1)
|
| 860 |
+
feat_ligand[edges_ligand[0], edges_ligand[1]] = edge_feat_ligand
|
| 861 |
+
feat_ligand = feat_ligand[edges_ligand_updated[0], edges_ligand_updated[1]]
|
| 862 |
+
feat_ligand = self.ligand_edge_features(h_ligand, x_ligand, edges_ligand_updated, batch_mask_ligand, edge_attr=feat_ligand)
|
| 863 |
+
|
| 864 |
+
if not empty_pocket:
|
| 865 |
+
# residue-residue edge features
|
| 866 |
+
edges_pocket_updated = torch.stack(torch.where(adj_pocket), dim=0)
|
| 867 |
+
feat_pocket = self.pocket_nobond_emb.repeat(*adj_pocket.shape, 1)
|
| 868 |
+
feat_pocket[edges_pocket[0], edges_pocket[1]] = edge_feat_pocket
|
| 869 |
+
feat_pocket = feat_pocket[edges_pocket_updated[0], edges_pocket_updated[1]]
|
| 870 |
+
feat_pocket = self.pocket_edge_features(h_pocket, x_pocket, atom_vectors_pocket, edges_pocket_updated, edge_attr=feat_pocket)
|
| 871 |
+
|
| 872 |
+
# ligand-residue edge features
|
| 873 |
+
edges_cross = torch.stack(torch.where(adj_cross), dim=0)
|
| 874 |
+
feat_cross = self.cross_edge_features(h_ligand, x_ligand, h_pocket, x_pocket, atom_vectors_pocket, edges_cross)
|
| 875 |
+
|
| 876 |
+
edge_index = {
|
| 877 |
+
('ligand', '', 'ligand'): edges_ligand_updated,
|
| 878 |
+
('pocket', '', 'pocket'): edges_pocket_updated,
|
| 879 |
+
('ligand', '', 'pocket'): edges_cross,
|
| 880 |
+
('pocket', '', 'ligand'): edges_cross.flip(dims=[0]),
|
| 881 |
+
}
|
| 882 |
+
|
| 883 |
+
edge_attr = {
|
| 884 |
+
('ligand', '', 'ligand'): feat_ligand,
|
| 885 |
+
('pocket', '', 'pocket'): feat_pocket,
|
| 886 |
+
('ligand', '', 'pocket'): feat_cross,
|
| 887 |
+
('pocket', '', 'ligand'): feat_cross,
|
| 888 |
+
}
|
| 889 |
+
else:
|
| 890 |
+
edge_index = {('ligand', '', 'ligand'): edges_ligand_updated}
|
| 891 |
+
edge_attr = {('ligand', '', 'ligand'): feat_ligand}
|
| 892 |
+
|
| 893 |
+
return edge_index, edge_attr
|
| 894 |
+
|
| 895 |
+
def ligand_edge_features(self, h, x, edge_index, batch_mask=None, edge_attr=None):
|
| 896 |
+
"""
|
| 897 |
+
:param h: (s, V)
|
| 898 |
+
:param x:
|
| 899 |
+
:param edge_index:
|
| 900 |
+
:param batch_mask:
|
| 901 |
+
:param edge_attr:
|
| 902 |
+
:return: scalar and vector-valued edge features
|
| 903 |
+
"""
|
| 904 |
+
row, col = edge_index
|
| 905 |
+
coord_diff = x[row] - x[col]
|
| 906 |
+
dist = coord_diff.norm(dim=-1)
|
| 907 |
+
rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf,
|
| 908 |
+
device=x.device)
|
| 909 |
+
|
| 910 |
+
if isinstance(h, tuple):
|
| 911 |
+
edge_s = torch.cat([h[0][row], h[0][col], rbf], dim=1)
|
| 912 |
+
edge_v = torch.cat([h[1][row], h[1][col], _normalize(coord_diff).unsqueeze(-2)], dim=1)
|
| 913 |
+
else:
|
| 914 |
+
edge_s = torch.cat([h[row], h[col], rbf], dim=1)
|
| 915 |
+
edge_v = _normalize(coord_diff).unsqueeze(-2)
|
| 916 |
+
|
| 917 |
+
# edge_s = rbf
|
| 918 |
+
# edge_v = _normalize(coord_diff).unsqueeze(-2)
|
| 919 |
+
|
| 920 |
+
if edge_attr is not None:
|
| 921 |
+
edge_s = torch.cat([edge_s, edge_attr], dim=1)
|
| 922 |
+
|
| 923 |
+
# self.reflection_equiv: bool, use reflection-sensitive feature based on
|
| 924 |
+
# the cross product if False
|
| 925 |
+
if not self.reflection_equiv:
|
| 926 |
+
mean = scatter_mean(x, batch_mask, dim=0,
|
| 927 |
+
dim_size=batch_mask.max() + 1)
|
| 928 |
+
row, col = edge_index
|
| 929 |
+
cross = torch.cross(x[row] - mean[batch_mask[row]],
|
| 930 |
+
x[col] - mean[batch_mask[col]], dim=1)
|
| 931 |
+
cross = _normalize(cross).unsqueeze(-2)
|
| 932 |
+
|
| 933 |
+
edge_v = torch.cat([edge_v, cross], dim=-2)
|
| 934 |
+
|
| 935 |
+
return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v)
|
| 936 |
+
|
| 937 |
+
def pocket_edge_features(self, h, x, v, edge_index, edge_attr=None):
|
| 938 |
+
"""
|
| 939 |
+
:param h: (s, V)
|
| 940 |
+
:param x:
|
| 941 |
+
:param v:
|
| 942 |
+
:param edge_index:
|
| 943 |
+
:param edge_attr:
|
| 944 |
+
:return: scalar and vector-valued edge features
|
| 945 |
+
"""
|
| 946 |
+
row, col = edge_index
|
| 947 |
+
|
| 948 |
+
if self.add_all_atom_diff:
|
| 949 |
+
all_coord = v + x.unsqueeze(1) # (nR, nA, 3)
|
| 950 |
+
coord_diff = all_coord[row, :, None, :] - all_coord[col, None, :, :] # (nB, nA, nA, 3)
|
| 951 |
+
coord_diff = coord_diff.flatten(1, 2)
|
| 952 |
+
dist = coord_diff.norm(dim=-1) # (nB, nA^2)
|
| 953 |
+
rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf, device=x.device) # (nB, nA^2, rdb_dim)
|
| 954 |
+
rbf = rbf.flatten(1, 2)
|
| 955 |
+
coord_diff = _normalize(coord_diff)
|
| 956 |
+
else:
|
| 957 |
+
coord_diff = x[row] - x[col]
|
| 958 |
+
dist = coord_diff.norm(dim=-1)
|
| 959 |
+
rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf, device=x.device)
|
| 960 |
+
coord_diff = _normalize(coord_diff).unsqueeze(-2)
|
| 961 |
+
|
| 962 |
+
edge_s = torch.cat([h[0][row], h[0][col], rbf], dim=1)
|
| 963 |
+
edge_v = torch.cat([h[1][row], h[1][col], coord_diff], dim=1)
|
| 964 |
+
# edge_s = rbf
|
| 965 |
+
# edge_v = coord_diff
|
| 966 |
+
|
| 967 |
+
if edge_attr is not None:
|
| 968 |
+
edge_s = torch.cat([edge_s, edge_attr], dim=1)
|
| 969 |
+
|
| 970 |
+
return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v)
|
| 971 |
+
|
| 972 |
+
def cross_edge_features(self, h_ligand, x_ligand, h_pocket, x_pocket, v_pocket, edge_index):
|
| 973 |
+
"""
|
| 974 |
+
:param h_ligand: (s, V)
|
| 975 |
+
:param x_ligand:
|
| 976 |
+
:param h_pocket: (s, V)
|
| 977 |
+
:param x_pocket:
|
| 978 |
+
:param v_pocket:
|
| 979 |
+
:param edge_index: first row indexes into the ligand tensors, second row into the pocket tensors
|
| 980 |
+
|
| 981 |
+
:return: scalar and vector-valued edge features
|
| 982 |
+
"""
|
| 983 |
+
ligand_idx, pocket_idx = edge_index
|
| 984 |
+
|
| 985 |
+
if self.add_all_atom_diff:
|
| 986 |
+
all_coord_pocket = v_pocket + x_pocket.unsqueeze(1) # (nR, nA, 3)
|
| 987 |
+
coord_diff = x_ligand[ligand_idx, None, :] - all_coord_pocket[pocket_idx] # (nB, nA, 3)
|
| 988 |
+
dist = coord_diff.norm(dim=-1) # (nB, nA)
|
| 989 |
+
rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf, device=x_ligand.device) # (nB, nA, rdb_dim)
|
| 990 |
+
rbf = rbf.flatten(1, 2)
|
| 991 |
+
coord_diff = _normalize(coord_diff)
|
| 992 |
+
else:
|
| 993 |
+
coord_diff = x_ligand[ligand_idx] - x_pocket[pocket_idx]
|
| 994 |
+
dist = coord_diff.norm(dim=-1) # (nB, nA)
|
| 995 |
+
rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf, device=x_ligand.device)
|
| 996 |
+
coord_diff = _normalize(coord_diff).unsqueeze(-2)
|
| 997 |
+
|
| 998 |
+
if isinstance(h_ligand, tuple):
|
| 999 |
+
edge_s = torch.cat([h_ligand[0][ligand_idx], h_pocket[0][pocket_idx], rbf], dim=1)
|
| 1000 |
+
edge_v = torch.cat([h_ligand[1][ligand_idx], h_pocket[1][pocket_idx], coord_diff], dim=1)
|
| 1001 |
+
else:
|
| 1002 |
+
edge_s = torch.cat([h_ligand[ligand_idx], h_pocket[0][pocket_idx], rbf], dim=1)
|
| 1003 |
+
edge_v = torch.cat([h_pocket[1][pocket_idx], coord_diff], dim=1)
|
| 1004 |
+
|
| 1005 |
+
# edge_s = rbf
|
| 1006 |
+
# edge_v = coord_diff
|
| 1007 |
+
|
| 1008 |
+
return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v)
|
src/model/flows.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC
|
| 2 |
+
from abc import abstractmethod
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
from torch_scatter import scatter_mean, scatter_add
|
| 6 |
+
|
| 7 |
+
import src.data.so3_utils as so3
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ICFM(ABC):
|
| 11 |
+
"""
|
| 12 |
+
Abstract base class for all Independent-coupling CFM classes.
|
| 13 |
+
Defines a common interface.
|
| 14 |
+
Notation:
|
| 15 |
+
- zt is the intermediate representation at time step t \in [0, 1]
|
| 16 |
+
- zs is the noised representation at time step s < t
|
| 17 |
+
|
| 18 |
+
# TODO: add interpolation schedule (not necessrily linear)
|
| 19 |
+
"""
|
| 20 |
+
def __init__(self, sigma):
|
| 21 |
+
self.sigma = sigma
|
| 22 |
+
|
| 23 |
+
@abstractmethod
|
| 24 |
+
def sample_zt(self, z0, z1, t, *args, **kwargs):
|
| 25 |
+
""" TODO. """
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
@abstractmethod
|
| 29 |
+
def sample_zt_given_zs(self, *args, **kwargs):
|
| 30 |
+
""" Perform update, typically using an explicit Euler step. """
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def sample_z0(self, *args, **kwargs):
|
| 35 |
+
""" Prior. """
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
@abstractmethod
|
| 39 |
+
def compute_loss(self, pred, z0, z1, *args, **kwargs):
|
| 40 |
+
""" Compute loss per sample. """
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class CoordICFM(ICFM):
|
| 45 |
+
def __init__(self, sigma):
|
| 46 |
+
self.dim = 3
|
| 47 |
+
self.scale = 2.7
|
| 48 |
+
super().__init__(sigma)
|
| 49 |
+
|
| 50 |
+
def sample_zt(self, z0, z1, t, batch_mask):
|
| 51 |
+
zt = t[batch_mask] * z1 + (1 - t)[batch_mask] * z0
|
| 52 |
+
# zt = self.sigma * z0 + t[batch_mask] * z1 + (1 - t)[batch_mask] * z0 # TODO: do we have to compute Psi?
|
| 53 |
+
return zt
|
| 54 |
+
|
| 55 |
+
def sample_zt_given_zs(self, zs, pred, s, t, batch_mask):
|
| 56 |
+
""" Perform an explicit Euler step. """
|
| 57 |
+
step_size = t - s
|
| 58 |
+
zt = zs + step_size[batch_mask] * self.scale * pred
|
| 59 |
+
return zt
|
| 60 |
+
|
| 61 |
+
def sample_z0(self, com, batch_mask):
|
| 62 |
+
""" Prior. """
|
| 63 |
+
z0 = torch.randn((len(batch_mask), self.dim), device=batch_mask.device)
|
| 64 |
+
|
| 65 |
+
# Move center of mass
|
| 66 |
+
z0 = z0 + com[batch_mask]
|
| 67 |
+
|
| 68 |
+
return z0
|
| 69 |
+
|
| 70 |
+
def reduce_loss(self, loss, batch_mask, reduce):
|
| 71 |
+
assert reduce in {'mean', 'sum', 'none'}
|
| 72 |
+
|
| 73 |
+
if reduce == 'mean':
|
| 74 |
+
loss = scatter_mean(loss / self.dim, batch_mask, dim=0)
|
| 75 |
+
elif reduce == 'sum':
|
| 76 |
+
loss = scatter_add(loss, batch_mask, dim=0)
|
| 77 |
+
|
| 78 |
+
return loss
|
| 79 |
+
|
| 80 |
+
def compute_loss(self, pred, z0, z1, t, batch_mask, reduce='mean'):
|
| 81 |
+
""" Compute loss per sample. """
|
| 82 |
+
|
| 83 |
+
loss = torch.sum((pred - (z1 - z0) / self.scale) ** 2, dim=-1)
|
| 84 |
+
|
| 85 |
+
return self.reduce_loss(loss, batch_mask, reduce)
|
| 86 |
+
|
| 87 |
+
def get_z1_given_zt_and_pred(self, zt, pred, z0, t, batch_mask):
|
| 88 |
+
""" Make a best guess on the final state z1 given the current state and
|
| 89 |
+
the network prediction. """
|
| 90 |
+
# z1 = z0 + pred
|
| 91 |
+
z1 = zt + (1 - t)[batch_mask] * pred
|
| 92 |
+
return z1
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class TorusICFM(ICFM):
|
| 96 |
+
"""
|
| 97 |
+
Following:
|
| 98 |
+
Chen, Ricky TQ, and Yaron Lipman.
|
| 99 |
+
"Riemannian flow matching on general geometries."
|
| 100 |
+
arXiv preprint arXiv:2302.03660 (2023).
|
| 101 |
+
"""
|
| 102 |
+
def __init__(self, sigma, dim, scheduler_args=None):
|
| 103 |
+
super().__init__(sigma)
|
| 104 |
+
self.dim = dim
|
| 105 |
+
|
| 106 |
+
# Scheduler that determines the rate at which the geodesic distance decreases
|
| 107 |
+
scheduler_args = scheduler_args or {}
|
| 108 |
+
scheduler_args["type"] = scheduler_args.get("type", "linear") # default
|
| 109 |
+
scheduler_args["learn_scaled"] = scheduler_args.get("learn_scaled", False) # default
|
| 110 |
+
|
| 111 |
+
# linear scheduler: kappa(t) = 1-t (default)
|
| 112 |
+
if scheduler_args["type"] == "linear":
|
| 113 |
+
# equivalent to: 1 - kappa(t)
|
| 114 |
+
self.flow_scaling = lambda t: t
|
| 115 |
+
|
| 116 |
+
# equivalent to: -1 * d/dt kappa(t)
|
| 117 |
+
self.velocity_scaling = lambda t: torch.ones_like(t)
|
| 118 |
+
|
| 119 |
+
# exponential scheduler: kappa(t) = exp(-c*t)
|
| 120 |
+
elif scheduler_args["type"] == "exponential":
|
| 121 |
+
|
| 122 |
+
self.c = scheduler_args["c"]
|
| 123 |
+
assert self.c > 0
|
| 124 |
+
|
| 125 |
+
# equivalent to: 1 - kappa(t)
|
| 126 |
+
self.flow_scaling = lambda t: 1 - torch.exp(-self.c * t)
|
| 127 |
+
|
| 128 |
+
# equivalent to: -1 * d/dt kappa(t)
|
| 129 |
+
self.velocity_scaling = lambda t: self.c * torch.exp(-self.c * t)
|
| 130 |
+
|
| 131 |
+
# polynomial scheduler: kappa(t) = (1-t)^k
|
| 132 |
+
elif scheduler_args["type"] == "polynomial":
|
| 133 |
+
self.k = scheduler_args["k"]
|
| 134 |
+
assert self.k > 0
|
| 135 |
+
|
| 136 |
+
# equivalent to: 1 - kappa(t)
|
| 137 |
+
self.flow_scaling = lambda t: 1 - (1 - t)**self.k
|
| 138 |
+
|
| 139 |
+
# equivalent to: -1 * d/dt kappa(t)
|
| 140 |
+
self.velocity_scaling = lambda t: self.k * (1 - t)**(self.k - 1)
|
| 141 |
+
|
| 142 |
+
else:
|
| 143 |
+
raise NotImplementedError(f"Scheduler {scheduler_args['type']} not implemented.")
|
| 144 |
+
|
| 145 |
+
kappa_interval = self.flow_scaling(torch.tensor([0.0, 1.0]))
|
| 146 |
+
if kappa_interval[0] != 0.0 or kappa_interval[1] != 1.0:
|
| 147 |
+
print(f"Scheduler should satisfy kappa(0)=1 and kappa(1)=0. Found "
|
| 148 |
+
f"interval {kappa_interval.tolist()} instead.")
|
| 149 |
+
|
| 150 |
+
# determines whether the scaled vector field is learned or the scheduler
|
| 151 |
+
# is post-multiplied
|
| 152 |
+
self.learn_scaled = scheduler_args["learn_scaled"]
|
| 153 |
+
|
| 154 |
+
@staticmethod
|
| 155 |
+
def wrap(angle):
|
| 156 |
+
""" Maps angles to range [-\pi, \pi). """
|
| 157 |
+
return ((angle + math.pi) % (2 * math.pi)) - math.pi
|
| 158 |
+
|
| 159 |
+
def exponential_map(self, x, u):
|
| 160 |
+
"""
|
| 161 |
+
:param x: point on the manifold
|
| 162 |
+
:param u: point on the tangent space
|
| 163 |
+
"""
|
| 164 |
+
return self.wrap(x + u)
|
| 165 |
+
|
| 166 |
+
@staticmethod
|
| 167 |
+
def logarithm_map(x, y):
|
| 168 |
+
"""
|
| 169 |
+
:param x, y: points on the manifold
|
| 170 |
+
"""
|
| 171 |
+
return torch.atan2(torch.sin(y - x), torch.cos(y - x))
|
| 172 |
+
|
| 173 |
+
def sample_zt(self, z0, z1, t, batch_mask):
|
| 174 |
+
""" expressed in terms of exponential and logarithm maps """
|
| 175 |
+
|
| 176 |
+
# apply logarithm map
|
| 177 |
+
# zt_tangent = t[batch_mask] * self.logarithm_map(z0, z1)
|
| 178 |
+
zt_tangent = self.flow_scaling(t)[batch_mask] * self.logarithm_map(z0, z1)
|
| 179 |
+
|
| 180 |
+
# apply exponential map
|
| 181 |
+
return self.exponential_map(z0, zt_tangent)
|
| 182 |
+
|
| 183 |
+
def get_z1_given_zt_and_pred(self, zt, pred, z0, t, batch_mask):
|
| 184 |
+
""" Make a best guess on the final state z1 given the current state and
|
| 185 |
+
the network prediction. """
|
| 186 |
+
|
| 187 |
+
# estimate z1_tangent based on zt and pred only
|
| 188 |
+
if self.learn_scaled:
|
| 189 |
+
pred = pred / torch.clamp(self.velocity_scaling(t), min=1e-3)[batch_mask]
|
| 190 |
+
|
| 191 |
+
z1_tangent = (1 - t)[batch_mask] * pred
|
| 192 |
+
|
| 193 |
+
# exponential map
|
| 194 |
+
return self.exponential_map(zt, z1_tangent)
|
| 195 |
+
|
| 196 |
+
def sample_zt_given_zs(self, zs, pred, s, t, batch_mask):
|
| 197 |
+
""" Perform update, typically using an explicit Euler step. """
|
| 198 |
+
|
| 199 |
+
step_size = t - s
|
| 200 |
+
zt_tangent = step_size[batch_mask] * pred
|
| 201 |
+
|
| 202 |
+
if not self.learn_scaled:
|
| 203 |
+
zt_tangent = self.velocity_scaling(t)[batch_mask] * zt_tangent
|
| 204 |
+
|
| 205 |
+
# exponential map
|
| 206 |
+
return self.exponential_map(zs, zt_tangent)
|
| 207 |
+
|
| 208 |
+
def sample_z0(self, batch_mask):
|
| 209 |
+
""" Prior. """
|
| 210 |
+
|
| 211 |
+
# Uniform distribution
|
| 212 |
+
z0 = torch.rand((len(batch_mask), self.dim), device=batch_mask.device)
|
| 213 |
+
|
| 214 |
+
return 2 * math.pi * z0 - math.pi
|
| 215 |
+
|
| 216 |
+
def compute_loss(self, pred, z0, z1, zt, t, batch_mask, reduce='mean'):
|
| 217 |
+
""" Compute loss per sample. """
|
| 218 |
+
assert reduce in {'mean', 'sum', 'none'}
|
| 219 |
+
mask = ~torch.isnan(z1)
|
| 220 |
+
z1 = torch.nan_to_num(z1, nan=0.0)
|
| 221 |
+
|
| 222 |
+
zt_dot = self.logarithm_map(z0, z1)
|
| 223 |
+
if self.learn_scaled:
|
| 224 |
+
# NOTE: potentially requires output magnitude to vary substantially
|
| 225 |
+
zt_dot = self.velocity_scaling(t)[batch_mask] * zt_dot
|
| 226 |
+
loss = mask * (pred - zt_dot) ** 2
|
| 227 |
+
loss = torch.sum(loss, dim=-1)
|
| 228 |
+
|
| 229 |
+
if reduce == 'mean':
|
| 230 |
+
denom = mask.sum(dim=-1) + 1e-6
|
| 231 |
+
loss = scatter_mean(loss / denom, batch_mask, dim=0)
|
| 232 |
+
elif reduce == 'sum':
|
| 233 |
+
loss = scatter_add(loss, batch_mask, dim=0)
|
| 234 |
+
return loss
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class SO3ICFM(ICFM):
|
| 238 |
+
"""
|
| 239 |
+
All rotations are assumed to be in axis-angle format.
|
| 240 |
+
Mostly following descriptions from the FoldFlow paper:
|
| 241 |
+
https://openreview.net/forum?id=kJFIH23hXb
|
| 242 |
+
|
| 243 |
+
See also:
|
| 244 |
+
https://geomstats.github.io/_modules/geomstats/geometry/special_orthogonal.html#SpecialOrthogonal
|
| 245 |
+
https://geomstats.github.io/_modules/geomstats/geometry/lie_group.html#LieGroup
|
| 246 |
+
"""
|
| 247 |
+
def __init__(self, sigma):
|
| 248 |
+
super().__init__(sigma)
|
| 249 |
+
|
| 250 |
+
def exponential_map(self, base, tangent):
|
| 251 |
+
"""
|
| 252 |
+
Args:
|
| 253 |
+
base: base point (rotation vector) on the manifold
|
| 254 |
+
tangent: point in tangent space at identity
|
| 255 |
+
Returns:
|
| 256 |
+
rotation vector on the manifold
|
| 257 |
+
"""
|
| 258 |
+
# return so3.exp_not_from_identity(tangent, base_point=base)
|
| 259 |
+
return so3.compose_rotations(base, so3.exp(tangent))
|
| 260 |
+
|
| 261 |
+
def logarithm_map(self, base, r):
|
| 262 |
+
"""
|
| 263 |
+
Args:
|
| 264 |
+
base: base point (rotation vector) on the manifold
|
| 265 |
+
r: rotation vector on the manifold
|
| 266 |
+
Return:
|
| 267 |
+
point in tangent space at identity
|
| 268 |
+
"""
|
| 269 |
+
# return so3.log_not_from_identity(r, base_point=base)
|
| 270 |
+
return so3.log(so3.compose_rotations(-base, r))
|
| 271 |
+
|
| 272 |
+
def sample_zt(self, z0, z1, t, batch_mask):
|
| 273 |
+
"""
|
| 274 |
+
Expressed in terms of exponential and logarithm maps.
|
| 275 |
+
Corresponds to SLERP interpolation: R(t) = R1 exp( t * log(R1^T R2) )
|
| 276 |
+
(see https://lucaballan.altervista.org/pdfs/IK.pdf, slide 16)
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
# apply logarithm map
|
| 280 |
+
zt_tangent = t[batch_mask] * self.logarithm_map(z0, z1)
|
| 281 |
+
|
| 282 |
+
# apply exponential map
|
| 283 |
+
return self.exponential_map(z0, zt_tangent)
|
| 284 |
+
|
| 285 |
+
def get_z1_given_zt_and_pred(self, zt, pred, z0, t, batch_mask):
|
| 286 |
+
""" Make a best guess on the final state z1 given the current state and
|
| 287 |
+
the network prediction. """
|
| 288 |
+
|
| 289 |
+
# estimate z1_tangent based on zt and pred only
|
| 290 |
+
z1_tangent = (1 - t)[batch_mask] * pred
|
| 291 |
+
|
| 292 |
+
# exponential map
|
| 293 |
+
return self.exponential_map(zt, z1_tangent)
|
| 294 |
+
|
| 295 |
+
def sample_zt_given_zs(self, zs, pred, s, t, batch_mask):
|
| 296 |
+
""" Perform update, typically using an explicit Euler step. """
|
| 297 |
+
|
| 298 |
+
# # parallel transport vector field to lie algebra so3 (at identity)
|
| 299 |
+
# # (FoldFlow paper, Algorithm 3, line 8)
|
| 300 |
+
# # TODO: is this correct? is it necessary?
|
| 301 |
+
# pred = so3.compose(so3.inverse(zs), pred)
|
| 302 |
+
|
| 303 |
+
step_size = t - s
|
| 304 |
+
zt_tangent = step_size[batch_mask] * pred
|
| 305 |
+
|
| 306 |
+
# exponential map
|
| 307 |
+
return self.exponential_map(zs, zt_tangent)
|
| 308 |
+
|
| 309 |
+
def sample_z0(self, batch_mask):
|
| 310 |
+
""" Prior. """
|
| 311 |
+
return so3.random_uniform(n_samples=len(batch_mask), device=batch_mask.device)
|
| 312 |
+
|
| 313 |
+
@staticmethod
|
| 314 |
+
def d_R_squared_SO3(rot_vec_1, rot_vec_2):
|
| 315 |
+
"""
|
| 316 |
+
Squared Riemannian metric on SO(3).
|
| 317 |
+
Defined as d(R1, R2) = sqrt(0.5) ||log(R1^T R2)||_F
|
| 318 |
+
where R1, R2 are rotation matrices.
|
| 319 |
+
|
| 320 |
+
The following is equivalent if the difference between the rotations is
|
| 321 |
+
expressed as a rotation vector \omega_diff:
|
| 322 |
+
d(r1, r2) = ||\omega_diff||_2
|
| 323 |
+
-----
|
| 324 |
+
With the definition of the Frobenius matrix norm ||A||_F^2 = trace(A^H A):
|
| 325 |
+
d^2(R1, R2) = 1/2 ||log(R1^T R2)||_F^2
|
| 326 |
+
= 1/2 || hat(R_d) ||_F^2
|
| 327 |
+
= 1/2 tr( hat(R_d)^T hat(R_d) )
|
| 328 |
+
= 1/2 * 2 * ||\omega||_2^2
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
# rot_mat_1 = so3.matrix_from_rotation_vector(rot_vec_1)
|
| 332 |
+
# rot_mat_2 = so3.matrix_from_rotation_vector(rot_vec_2)
|
| 333 |
+
# rot_mat_diff = rot_mat_1.transpose(-2, -1) @ rot_mat_2
|
| 334 |
+
# return torch.norm(so3.log(rot_mat_diff, as_skew=True), p='fro', dim=(-2, -1))
|
| 335 |
+
|
| 336 |
+
diff_rot = so3.compose_rotations(-rot_vec_1, rot_vec_2)
|
| 337 |
+
return diff_rot.square().sum(dim=-1)
|
| 338 |
+
|
| 339 |
+
def compute_loss(self, pred, z0, z1, zt, t, batch_mask, reduce='mean', eps=5e-2):
|
| 340 |
+
""" Compute loss per sample. """
|
| 341 |
+
assert reduce in {'mean', 'sum', 'none'}
|
| 342 |
+
|
| 343 |
+
zt_dot = self.logarithm_map(zt, z1) / torch.clamp(1 - t, min=eps)[batch_mask]
|
| 344 |
+
|
| 345 |
+
# TODO: do I need this?
|
| 346 |
+
# pred_at_id = self.logarithm_map(zt, pred) / torch.clamp(1 - t, min=eps)[batch_mask]
|
| 347 |
+
|
| 348 |
+
loss = torch.sum((pred - zt_dot)**2, dim=-1) # TODO: is this the right loss in SO3?
|
| 349 |
+
# loss = self.d_R_squared_SO3(zt_dot, pred)
|
| 350 |
+
|
| 351 |
+
if reduce == 'mean':
|
| 352 |
+
loss = scatter_mean(loss, batch_mask, dim=0)
|
| 353 |
+
elif reduce == 'sum':
|
| 354 |
+
loss = scatter_add(loss, batch_mask, dim=0)
|
| 355 |
+
|
| 356 |
+
return loss
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
#################
|
| 360 |
+
# Predicting z1 #
|
| 361 |
+
#################
|
| 362 |
+
|
| 363 |
+
class CoordICFMPredictFinal(CoordICFM):
|
| 364 |
+
def __init__(self, sigma):
|
| 365 |
+
self.dim = 3
|
| 366 |
+
super().__init__(sigma)
|
| 367 |
+
|
| 368 |
+
def sample_zt_given_zs(self, zs, z1_minus_zs_pred, s, t, batch_mask):
|
| 369 |
+
""" Perform an explicit Euler step. """
|
| 370 |
+
|
| 371 |
+
# step_size = t - s
|
| 372 |
+
# zt = zs + step_size[batch_mask] * z1_minus_zs_pred / (1.0 - s)[batch_mask]
|
| 373 |
+
|
| 374 |
+
# for numerical stability
|
| 375 |
+
step_size = (t - s) / (1.0 - s)
|
| 376 |
+
assert torch.all(step_size <= 1.0)
|
| 377 |
+
# step_size = torch.clamp(step_size, max=1.0)
|
| 378 |
+
zt = zs + step_size[batch_mask] * z1_minus_zs_pred
|
| 379 |
+
return zt
|
| 380 |
+
|
| 381 |
+
def compute_loss(self, z1_minus_zt_pred, z0, z1, t, batch_mask, reduce='mean'):
|
| 382 |
+
""" Compute loss per sample. """
|
| 383 |
+
assert reduce in {'mean', 'sum', 'none'}
|
| 384 |
+
t = torch.clamp(t, max=0.9)
|
| 385 |
+
zt = self.sample_zt(z0, z1, t, batch_mask)
|
| 386 |
+
loss = torch.sum((z1_minus_zt_pred + zt - z1) ** 2, dim=-1) / torch.square(1 - t)[batch_mask].squeeze()
|
| 387 |
+
|
| 388 |
+
if reduce == 'mean':
|
| 389 |
+
loss = scatter_mean(loss / self.dim, batch_mask, dim=0)
|
| 390 |
+
elif reduce == 'sum':
|
| 391 |
+
loss = scatter_add(loss, batch_mask, dim=0)
|
| 392 |
+
|
| 393 |
+
return loss
|
| 394 |
+
|
| 395 |
+
def get_z1_given_zt_and_pred(self, zt, z1_minus_zt_pred, z0, t, batch_mask):
|
| 396 |
+
return z1_minus_zt_pred + zt
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class TorusICFMPredictFinal(TorusICFM):
|
| 400 |
+
"""
|
| 401 |
+
Following:
|
| 402 |
+
Chen, Ricky TQ, and Yaron Lipman.
|
| 403 |
+
"Riemannian flow matching on general geometries."
|
| 404 |
+
arXiv preprint arXiv:2302.03660 (2023).
|
| 405 |
+
"""
|
| 406 |
+
def __init__(self, sigma, dim):
|
| 407 |
+
super().__init__(sigma, dim)
|
| 408 |
+
|
| 409 |
+
def get_z1_given_zt_and_pred(self, zt, z1_tangent_pred, z0, t, batch_mask):
|
| 410 |
+
""" Make a best guess on the final state z1 given the current state and
|
| 411 |
+
the network prediction. """
|
| 412 |
+
|
| 413 |
+
# exponential map
|
| 414 |
+
return self.exponential_map(zt, z1_tangent_pred)
|
| 415 |
+
|
| 416 |
+
def sample_zt_given_zs(self, zs, z1_tangent_pred, s, t, batch_mask):
|
| 417 |
+
""" Perform update, typically using an explicit Euler step. """
|
| 418 |
+
|
| 419 |
+
# step_size = t - s
|
| 420 |
+
# zt_tangent = step_size[batch_mask] * z1_tangent_pred / (1.0 - s)[batch_mask]
|
| 421 |
+
|
| 422 |
+
# for numerical stability
|
| 423 |
+
step_size = (t - s) / (1.0 - s)
|
| 424 |
+
assert torch.all(step_size <= 1.0)
|
| 425 |
+
# step_size = torch.clamp(step_size, max=1.0)
|
| 426 |
+
zt_tangent = step_size[batch_mask] * z1_tangent_pred
|
| 427 |
+
|
| 428 |
+
# exponential map
|
| 429 |
+
return self.exponential_map(zs, zt_tangent)
|
| 430 |
+
|
| 431 |
+
def compute_loss(self, z1_tangent_pred, z0, z1, t, batch_mask, reduce='mean'):
|
| 432 |
+
""" Compute loss per sample. """
|
| 433 |
+
assert reduce in {'mean', 'sum', 'none'}
|
| 434 |
+
zt = self.sample_zt(z0, z1, t, batch_mask)
|
| 435 |
+
t = torch.clamp(t, max=0.9)
|
| 436 |
+
|
| 437 |
+
mask = ~torch.isnan(z1)
|
| 438 |
+
z1 = torch.nan_to_num(z1, nan=0.0)
|
| 439 |
+
loss = mask * (z1_tangent_pred - self.logarithm_map(zt, z1)) ** 2
|
| 440 |
+
loss = torch.sum(loss, dim=-1) / torch.square(1 - t)[batch_mask].squeeze()
|
| 441 |
+
|
| 442 |
+
if reduce == 'mean':
|
| 443 |
+
denom = mask.sum(dim=-1) + 1e-6
|
| 444 |
+
loss = scatter_mean(loss / denom, batch_mask, dim=0)
|
| 445 |
+
elif reduce == 'sum':
|
| 446 |
+
loss = scatter_add(loss, batch_mask, dim=0)
|
| 447 |
+
|
| 448 |
+
return loss
|
src/model/gvp.py
ADDED
|
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Geometric Vector Perceptron implementation taken from:
|
| 3 |
+
https://github.com/drorlab/gvp-pytorch/blob/main/gvp/__init__.py
|
| 4 |
+
"""
|
| 5 |
+
import copy
|
| 6 |
+
import warnings
|
| 7 |
+
|
| 8 |
+
import torch, functools
|
| 9 |
+
from torch import nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch_geometric.nn import MessagePassing
|
| 12 |
+
from torch_scatter import scatter_add, scatter_mean
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def tuple_sum(*args):
|
| 16 |
+
'''
|
| 17 |
+
Sums any number of tuples (s, V) elementwise.
|
| 18 |
+
'''
|
| 19 |
+
return tuple(map(sum, zip(*args)))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def tuple_cat(*args, dim=-1):
|
| 23 |
+
'''
|
| 24 |
+
Concatenates any number of tuples (s, V) elementwise.
|
| 25 |
+
|
| 26 |
+
:param dim: dimension along which to concatenate when viewed
|
| 27 |
+
as the `dim` index for the scalar-channel tensors.
|
| 28 |
+
This means that `dim=-1` will be applied as
|
| 29 |
+
`dim=-2` for the vector-channel tensors.
|
| 30 |
+
'''
|
| 31 |
+
dim %= len(args[0][0].shape)
|
| 32 |
+
s_args, v_args = list(zip(*args))
|
| 33 |
+
return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def tuple_index(x, idx):
|
| 37 |
+
'''
|
| 38 |
+
Indexes into a tuple (s, V) along the first dimension.
|
| 39 |
+
|
| 40 |
+
:param idx: any object which can be used to index into a `torch.Tensor`
|
| 41 |
+
'''
|
| 42 |
+
return x[0][idx], x[1][idx]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def randn(n, dims, device="cpu"):
|
| 46 |
+
'''
|
| 47 |
+
Returns random tuples (s, V) drawn elementwise from a normal distribution.
|
| 48 |
+
|
| 49 |
+
:param n: number of data points
|
| 50 |
+
:param dims: tuple of dimensions (n_scalar, n_vector)
|
| 51 |
+
|
| 52 |
+
:return: (s, V) with s.shape = (n, n_scalar) and
|
| 53 |
+
V.shape = (n, n_vector, 3)
|
| 54 |
+
'''
|
| 55 |
+
return torch.randn(n, dims[0], device=device), \
|
| 56 |
+
torch.randn(n, dims[1], 3, device=device)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True):
|
| 60 |
+
'''
|
| 61 |
+
L2 norm of tensor clamped above a minimum value `eps`.
|
| 62 |
+
|
| 63 |
+
:param sqrt: if `False`, returns the square of the L2 norm
|
| 64 |
+
'''
|
| 65 |
+
out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps)
|
| 66 |
+
return torch.sqrt(out) if sqrt else out
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _split(x, nv):
|
| 70 |
+
'''
|
| 71 |
+
Splits a merged representation of (s, V) back into a tuple.
|
| 72 |
+
Should be used only with `_merge(s, V)` and only if the tuple
|
| 73 |
+
representation cannot be used.
|
| 74 |
+
|
| 75 |
+
:param x: the `torch.Tensor` returned from `_merge`
|
| 76 |
+
:param nv: the number of vector channels in the input to `_merge`
|
| 77 |
+
'''
|
| 78 |
+
v = torch.reshape(x[..., -3 * nv:], x.shape[:-1] + (nv, 3))
|
| 79 |
+
s = x[..., :-3 * nv]
|
| 80 |
+
return s, v
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _merge(s, v):
|
| 84 |
+
'''
|
| 85 |
+
Merges a tuple (s, V) into a single `torch.Tensor`, where the
|
| 86 |
+
vector channels are flattened and appended to the scalar channels.
|
| 87 |
+
Should be used only if the tuple representation cannot be used.
|
| 88 |
+
Use `_split(x, nv)` to reverse.
|
| 89 |
+
'''
|
| 90 |
+
v = torch.reshape(v, v.shape[:-2] + (3 * v.shape[-2],))
|
| 91 |
+
return torch.cat([s, v], -1)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class GVP(nn.Module):
|
| 95 |
+
'''
|
| 96 |
+
Geometric Vector Perceptron. See manuscript and README.md
|
| 97 |
+
for more details.
|
| 98 |
+
|
| 99 |
+
:param in_dims: tuple (n_scalar, n_vector)
|
| 100 |
+
:param out_dims: tuple (n_scalar, n_vector)
|
| 101 |
+
:param h_dim: intermediate number of vector channels, optional
|
| 102 |
+
:param activations: tuple of functions (scalar_act, vector_act)
|
| 103 |
+
:param vector_gate: whether to use vector gating.
|
| 104 |
+
(vector_act will be used as sigma^+ in vector gating if `True`)
|
| 105 |
+
'''
|
| 106 |
+
|
| 107 |
+
def __init__(self, in_dims, out_dims, h_dim=None,
|
| 108 |
+
activations=(F.relu, torch.sigmoid), vector_gate=False):
|
| 109 |
+
super(GVP, self).__init__()
|
| 110 |
+
self.si, self.vi = in_dims
|
| 111 |
+
self.so, self.vo = out_dims
|
| 112 |
+
self.vector_gate = vector_gate
|
| 113 |
+
if self.vi:
|
| 114 |
+
self.h_dim = h_dim or max(self.vi, self.vo)
|
| 115 |
+
self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
|
| 116 |
+
self.ws = nn.Linear(self.h_dim + self.si, self.so)
|
| 117 |
+
if self.vo:
|
| 118 |
+
self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
|
| 119 |
+
if self.vector_gate: self.wsv = nn.Linear(self.so, self.vo)
|
| 120 |
+
else:
|
| 121 |
+
self.ws = nn.Linear(self.si, self.so)
|
| 122 |
+
|
| 123 |
+
self.scalar_act, self.vector_act = activations
|
| 124 |
+
self.dummy_param = nn.Parameter(torch.empty(0))
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
'''
|
| 128 |
+
:param x: tuple (s, V) of `torch.Tensor`,
|
| 129 |
+
or (if vectors_in is 0), a single `torch.Tensor`
|
| 130 |
+
:return: tuple (s, V) of `torch.Tensor`,
|
| 131 |
+
or (if vectors_out is 0), a single `torch.Tensor`
|
| 132 |
+
'''
|
| 133 |
+
if self.vi:
|
| 134 |
+
s, v = x
|
| 135 |
+
v = torch.transpose(v, -1, -2)
|
| 136 |
+
vh = self.wh(v)
|
| 137 |
+
vn = _norm_no_nan(vh, axis=-2)
|
| 138 |
+
s = self.ws(torch.cat([s, vn], -1))
|
| 139 |
+
if self.vo:
|
| 140 |
+
v = self.wv(vh)
|
| 141 |
+
v = torch.transpose(v, -1, -2)
|
| 142 |
+
if self.vector_gate:
|
| 143 |
+
if self.vector_act:
|
| 144 |
+
gate = self.wsv(self.vector_act(s))
|
| 145 |
+
else:
|
| 146 |
+
gate = self.wsv(s)
|
| 147 |
+
v = v * torch.sigmoid(gate).unsqueeze(-1)
|
| 148 |
+
elif self.vector_act:
|
| 149 |
+
v = v * self.vector_act(
|
| 150 |
+
_norm_no_nan(v, axis=-1, keepdims=True))
|
| 151 |
+
else:
|
| 152 |
+
s = self.ws(x)
|
| 153 |
+
if self.vo:
|
| 154 |
+
v = torch.zeros(s.shape[0], self.vo, 3,
|
| 155 |
+
device=self.dummy_param.device)
|
| 156 |
+
if self.scalar_act:
|
| 157 |
+
s = self.scalar_act(s)
|
| 158 |
+
|
| 159 |
+
return (s, v) if self.vo else s
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class _VDropout(nn.Module):
|
| 163 |
+
'''
|
| 164 |
+
Vector channel dropout where the elements of each
|
| 165 |
+
vector channel are dropped together.
|
| 166 |
+
'''
|
| 167 |
+
|
| 168 |
+
def __init__(self, drop_rate):
|
| 169 |
+
super(_VDropout, self).__init__()
|
| 170 |
+
self.drop_rate = drop_rate
|
| 171 |
+
self.dummy_param = nn.Parameter(torch.empty(0))
|
| 172 |
+
|
| 173 |
+
def forward(self, x):
|
| 174 |
+
'''
|
| 175 |
+
:param x: `torch.Tensor` corresponding to vector channels
|
| 176 |
+
'''
|
| 177 |
+
device = self.dummy_param.device
|
| 178 |
+
if not self.training:
|
| 179 |
+
return x
|
| 180 |
+
mask = torch.bernoulli(
|
| 181 |
+
(1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device)
|
| 182 |
+
).unsqueeze(-1)
|
| 183 |
+
x = mask * x / (1 - self.drop_rate)
|
| 184 |
+
return x
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class Dropout(nn.Module):
|
| 188 |
+
'''
|
| 189 |
+
Combined dropout for tuples (s, V).
|
| 190 |
+
Takes tuples (s, V) as input and as output.
|
| 191 |
+
'''
|
| 192 |
+
|
| 193 |
+
def __init__(self, drop_rate):
|
| 194 |
+
super(Dropout, self).__init__()
|
| 195 |
+
self.sdropout = nn.Dropout(drop_rate)
|
| 196 |
+
self.vdropout = _VDropout(drop_rate)
|
| 197 |
+
|
| 198 |
+
def forward(self, x):
|
| 199 |
+
'''
|
| 200 |
+
:param x: tuple (s, V) of `torch.Tensor`,
|
| 201 |
+
or single `torch.Tensor`
|
| 202 |
+
(will be assumed to be scalar channels)
|
| 203 |
+
'''
|
| 204 |
+
if type(x) is torch.Tensor:
|
| 205 |
+
return self.sdropout(x)
|
| 206 |
+
s, v = x
|
| 207 |
+
return self.sdropout(s), self.vdropout(v)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class LayerNorm(nn.Module):
|
| 211 |
+
'''
|
| 212 |
+
Combined LayerNorm for tuples (s, V).
|
| 213 |
+
Takes tuples (s, V) as input and as output.
|
| 214 |
+
'''
|
| 215 |
+
|
| 216 |
+
def __init__(self, dims, learnable_vector_weight=False):
|
| 217 |
+
super(LayerNorm, self).__init__()
|
| 218 |
+
self.s, self.v = dims
|
| 219 |
+
self.scalar_norm = nn.LayerNorm(self.s)
|
| 220 |
+
self.vector_norm = VectorLayerNorm(self.v, learnable_vector_weight) \
|
| 221 |
+
if self.v > 0 else None
|
| 222 |
+
|
| 223 |
+
def forward(self, x):
|
| 224 |
+
'''
|
| 225 |
+
:param x: tuple (s, V) of `torch.Tensor`,
|
| 226 |
+
or single `torch.Tensor`
|
| 227 |
+
(will be assumed to be scalar channels)
|
| 228 |
+
'''
|
| 229 |
+
if not self.v:
|
| 230 |
+
return self.scalar_norm(x)
|
| 231 |
+
s, v = x
|
| 232 |
+
# vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False)
|
| 233 |
+
# vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True))
|
| 234 |
+
# return self.scalar_norm(s), v / vn
|
| 235 |
+
return self.scalar_norm(s), self.vector_norm(v)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class VectorLayerNorm(nn.Module):
|
| 239 |
+
"""
|
| 240 |
+
Equivariant normalization of vector-valued features inspired by:
|
| 241 |
+
Liao, Yi-Lun, and Tess Smidt.
|
| 242 |
+
"Equiformer: Equivariant graph attention transformer for 3d atomistic graphs."
|
| 243 |
+
arXiv preprint arXiv:2206.11990 (2022).
|
| 244 |
+
Section 4.1, "Layer Normalization"
|
| 245 |
+
"""
|
| 246 |
+
def __init__(self, n_channels, learnable_weight=True):
|
| 247 |
+
super(VectorLayerNorm, self).__init__()
|
| 248 |
+
self.gamma = nn.Parameter(torch.ones(1, n_channels, 1)) \
|
| 249 |
+
if learnable_weight else None # (1, c, 1)
|
| 250 |
+
|
| 251 |
+
def forward(self, x):
|
| 252 |
+
"""
|
| 253 |
+
Computes LN(x) = ( x / RMS( L2-norm(x) ) ) * gamma
|
| 254 |
+
:param x: input tensor (n, c, 3)
|
| 255 |
+
:return: layer normalized vector feature
|
| 256 |
+
"""
|
| 257 |
+
norm2 = _norm_no_nan(x, axis=-1, keepdims=True, sqrt=False) # (n, c, 1)
|
| 258 |
+
rms = torch.sqrt(torch.mean(norm2, dim=-2, keepdim=True)) # (n, 1, 1)
|
| 259 |
+
x = x / rms # (n, c, 3)
|
| 260 |
+
if self.gamma is not None:
|
| 261 |
+
x = x * self.gamma
|
| 262 |
+
return x
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class GVPConv(MessagePassing):
|
| 266 |
+
'''
|
| 267 |
+
Graph convolution / message passing with Geometric Vector Perceptrons.
|
| 268 |
+
Takes in a graph with node and edge embeddings,
|
| 269 |
+
and returns new node embeddings.
|
| 270 |
+
|
| 271 |
+
This does NOT do residual updates and pointwise feedforward layers
|
| 272 |
+
---see `GVPConvLayer`.
|
| 273 |
+
|
| 274 |
+
:param in_dims: input node embedding dimensions (n_scalar, n_vector)
|
| 275 |
+
:param out_dims: output node embedding dimensions (n_scalar, n_vector)
|
| 276 |
+
:param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
|
| 277 |
+
:param n_layers: number of GVPs in the message function
|
| 278 |
+
:param module_list: preconstructed message function, overrides n_layers
|
| 279 |
+
:param aggr: should be "add" if some incoming edges are masked, as in
|
| 280 |
+
a masked autoregressive decoder architecture, otherwise "mean"
|
| 281 |
+
:param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
|
| 282 |
+
:param vector_gate: whether to use vector gating.
|
| 283 |
+
(vector_act will be used as sigma^+ in vector gating if `True`)
|
| 284 |
+
:param update_edge_attr: whether to compute an updated edge representation
|
| 285 |
+
'''
|
| 286 |
+
|
| 287 |
+
def __init__(self, in_dims, out_dims, edge_dims,
|
| 288 |
+
n_layers=3, module_list=None, aggr="mean",
|
| 289 |
+
activations=(F.relu, torch.sigmoid), vector_gate=False,
|
| 290 |
+
update_edge_attr=False):
|
| 291 |
+
super(GVPConv, self).__init__(aggr=aggr)
|
| 292 |
+
self.si, self.vi = in_dims
|
| 293 |
+
self.so, self.vo = out_dims
|
| 294 |
+
self.se, self.ve = edge_dims
|
| 295 |
+
self.update_edge_attr = update_edge_attr
|
| 296 |
+
|
| 297 |
+
GVP_ = functools.partial(GVP,
|
| 298 |
+
activations=activations,
|
| 299 |
+
vector_gate=vector_gate)
|
| 300 |
+
|
| 301 |
+
module_list = module_list or []
|
| 302 |
+
if not module_list:
|
| 303 |
+
if n_layers == 1:
|
| 304 |
+
module_list.append(
|
| 305 |
+
GVP_((2 * self.si + self.se, 2 * self.vi + self.ve),
|
| 306 |
+
(self.so, self.vo), activations=(None, None)))
|
| 307 |
+
else:
|
| 308 |
+
module_list.append(
|
| 309 |
+
GVP_((2 * self.si + self.se, 2 * self.vi + self.ve),
|
| 310 |
+
out_dims)
|
| 311 |
+
)
|
| 312 |
+
for i in range(n_layers - 2):
|
| 313 |
+
module_list.append(GVP_(out_dims, out_dims))
|
| 314 |
+
module_list.append(GVP_(out_dims, out_dims,
|
| 315 |
+
activations=(None, None)))
|
| 316 |
+
self.message_func = nn.Sequential(*module_list)
|
| 317 |
+
|
| 318 |
+
self.edge_func = copy.deepcopy(self.message_func) \
|
| 319 |
+
if self.update_edge_attr else None
|
| 320 |
+
|
| 321 |
+
def forward(self, x, edge_index, edge_attr):
|
| 322 |
+
'''
|
| 323 |
+
:param x: tuple (s, V) of `torch.Tensor`
|
| 324 |
+
:param edge_index: array of shape [2, n_edges]
|
| 325 |
+
:param edge_attr: tuple (s, V) of `torch.Tensor`
|
| 326 |
+
'''
|
| 327 |
+
x_s, x_v = x
|
| 328 |
+
message = self.propagate(edge_index,
|
| 329 |
+
s=x_s,
|
| 330 |
+
v=x_v.reshape(x_v.shape[0], 3 * x_v.shape[1]),
|
| 331 |
+
edge_attr=edge_attr)
|
| 332 |
+
|
| 333 |
+
if self.update_edge_attr:
|
| 334 |
+
s_i, s_j = x_s[edge_index[0]], x_s[edge_index[1]]
|
| 335 |
+
x_v = x_v.reshape(x_v.shape[0], 3 * x_v.shape[1])
|
| 336 |
+
v_i, v_j = x_v[edge_index[0]], x_v[edge_index[1]]
|
| 337 |
+
|
| 338 |
+
edge_out = self.edge_attr(s_i, v_i, s_j, v_j, edge_attr)
|
| 339 |
+
return _split(message, self.vo), edge_out
|
| 340 |
+
else:
|
| 341 |
+
return _split(message, self.vo)
|
| 342 |
+
|
| 343 |
+
def message(self, s_i, v_i, s_j, v_j, edge_attr):
|
| 344 |
+
v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3)
|
| 345 |
+
v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3)
|
| 346 |
+
message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
|
| 347 |
+
message = self.message_func(message)
|
| 348 |
+
return _merge(*message)
|
| 349 |
+
|
| 350 |
+
def edge_attr(self, s_i, v_i, s_j, v_j, edge_attr):
|
| 351 |
+
v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3)
|
| 352 |
+
v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3)
|
| 353 |
+
message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
|
| 354 |
+
return self.edge_func(message)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class GVPConvLayer(nn.Module):
|
| 358 |
+
'''
|
| 359 |
+
Full graph convolution / message passing layer with
|
| 360 |
+
Geometric Vector Perceptrons. Residually updates node embeddings with
|
| 361 |
+
aggregated incoming messages, applies a pointwise feedforward
|
| 362 |
+
network to node embeddings, and returns updated node embeddings.
|
| 363 |
+
|
| 364 |
+
To only compute the aggregated messages, see `GVPConv`.
|
| 365 |
+
|
| 366 |
+
:param node_dims: node embedding dimensions (n_scalar, n_vector)
|
| 367 |
+
:param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
|
| 368 |
+
:param n_message: number of GVPs to use in message function
|
| 369 |
+
:param n_feedforward: number of GVPs to use in feedforward function
|
| 370 |
+
:param drop_rate: drop probability in all dropout layers
|
| 371 |
+
:param autoregressive: if `True`, this `GVPConvLayer` will be used
|
| 372 |
+
with a different set of input node embeddings for messages
|
| 373 |
+
where src >= dst
|
| 374 |
+
:param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
|
| 375 |
+
:param vector_gate: whether to use vector gating.
|
| 376 |
+
(vector_act will be used as sigma^+ in vector gating if `True`)
|
| 377 |
+
:param update_edge_attr: whether to compute an updated edge representation
|
| 378 |
+
:param ln_vector_weight: whether to include a learnable weight in the vector
|
| 379 |
+
layer norm
|
| 380 |
+
'''
|
| 381 |
+
|
| 382 |
+
def __init__(self, node_dims, edge_dims,
|
| 383 |
+
n_message=3, n_feedforward=2, drop_rate=.1,
|
| 384 |
+
autoregressive=False,
|
| 385 |
+
activations=(F.relu, torch.sigmoid), vector_gate=False,
|
| 386 |
+
update_edge_attr=False, ln_vector_weight=False):
|
| 387 |
+
|
| 388 |
+
super(GVPConvLayer, self).__init__()
|
| 389 |
+
assert not (update_edge_attr and autoregressive), "Not implemented"
|
| 390 |
+
self.update_edge_attr = update_edge_attr
|
| 391 |
+
self.conv = GVPConv(node_dims, node_dims, edge_dims, n_message,
|
| 392 |
+
aggr="add" if autoregressive else "mean",
|
| 393 |
+
activations=activations, vector_gate=vector_gate,
|
| 394 |
+
update_edge_attr=update_edge_attr)
|
| 395 |
+
GVP_ = functools.partial(GVP,
|
| 396 |
+
activations=activations,
|
| 397 |
+
vector_gate=vector_gate)
|
| 398 |
+
self.norm = nn.ModuleList([LayerNorm(node_dims, ln_vector_weight)
|
| 399 |
+
for _ in range(2)])
|
| 400 |
+
self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])
|
| 401 |
+
|
| 402 |
+
def get_feedforward(n_dims):
|
| 403 |
+
ff_func = []
|
| 404 |
+
if n_feedforward == 1:
|
| 405 |
+
ff_func.append(GVP_(n_dims, n_dims, activations=(None, None)))
|
| 406 |
+
else:
|
| 407 |
+
hid_dims = 4 * n_dims[0], 2 * n_dims[1]
|
| 408 |
+
ff_func.append(GVP_(n_dims, hid_dims))
|
| 409 |
+
for i in range(n_feedforward - 2):
|
| 410 |
+
ff_func.append(GVP_(hid_dims, hid_dims))
|
| 411 |
+
ff_func.append(GVP_(hid_dims, n_dims, activations=(None, None)))
|
| 412 |
+
return nn.Sequential(*ff_func)
|
| 413 |
+
|
| 414 |
+
self.ff_func = get_feedforward(node_dims)
|
| 415 |
+
|
| 416 |
+
if self.update_edge_attr:
|
| 417 |
+
self.edge_norm = nn.ModuleList([LayerNorm(edge_dims, ln_vector_weight)
|
| 418 |
+
for _ in range(2)])
|
| 419 |
+
self.edge_dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])
|
| 420 |
+
self.edge_ff = get_feedforward(edge_dims)
|
| 421 |
+
|
| 422 |
+
def forward(self, x, edge_index, edge_attr,
|
| 423 |
+
autoregressive_x=None, node_mask=None):
|
| 424 |
+
'''
|
| 425 |
+
:param x: tuple (s, V) of `torch.Tensor`
|
| 426 |
+
:param edge_index: array of shape [2, n_edges]
|
| 427 |
+
:param edge_attr: tuple (s, V) of `torch.Tensor`
|
| 428 |
+
:param autoregressive_x: tuple (s, V) of `torch.Tensor`.
|
| 429 |
+
If not `None`, will be used as src node embeddings
|
| 430 |
+
for forming messages where src >= dst. The corrent node
|
| 431 |
+
embeddings `x` will still be the base of the update and the
|
| 432 |
+
pointwise feedforward.
|
| 433 |
+
:param node_mask: array of type `bool` to index into the first
|
| 434 |
+
dim of node embeddings (s, V). If not `None`, only
|
| 435 |
+
these nodes will be updated.
|
| 436 |
+
'''
|
| 437 |
+
|
| 438 |
+
if autoregressive_x is not None:
|
| 439 |
+
src, dst = edge_index
|
| 440 |
+
mask = src < dst
|
| 441 |
+
edge_index_forward = edge_index[:, mask]
|
| 442 |
+
edge_index_backward = edge_index[:, ~mask]
|
| 443 |
+
edge_attr_forward = tuple_index(edge_attr, mask)
|
| 444 |
+
edge_attr_backward = tuple_index(edge_attr, ~mask)
|
| 445 |
+
|
| 446 |
+
dh = tuple_sum(
|
| 447 |
+
self.conv(x, edge_index_forward, edge_attr_forward),
|
| 448 |
+
self.conv(autoregressive_x, edge_index_backward,
|
| 449 |
+
edge_attr_backward)
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
count = scatter_add(torch.ones_like(dst), dst,
|
| 453 |
+
dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(
|
| 454 |
+
-1)
|
| 455 |
+
|
| 456 |
+
dh = dh[0] / count, dh[1] / count.unsqueeze(-1)
|
| 457 |
+
|
| 458 |
+
else:
|
| 459 |
+
dh = self.conv(x, edge_index, edge_attr)
|
| 460 |
+
|
| 461 |
+
if self.update_edge_attr:
|
| 462 |
+
dh, de = dh
|
| 463 |
+
edge_attr = self.edge_norm[0](tuple_sum(edge_attr, self.dropout[0](de)))
|
| 464 |
+
de = self.edge_ff(edge_attr)
|
| 465 |
+
edge_attr = self.edge_norm[1](tuple_sum(edge_attr, self.dropout[1](de)))
|
| 466 |
+
|
| 467 |
+
if node_mask is not None:
|
| 468 |
+
x_ = x
|
| 469 |
+
x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask)
|
| 470 |
+
|
| 471 |
+
x = self.norm[0](tuple_sum(x, self.dropout[0](dh)))
|
| 472 |
+
|
| 473 |
+
dh = self.ff_func(x)
|
| 474 |
+
x = self.norm[1](tuple_sum(x, self.dropout[1](dh)))
|
| 475 |
+
|
| 476 |
+
if node_mask is not None:
|
| 477 |
+
x_[0][node_mask], x_[1][node_mask] = x[0], x[1]
|
| 478 |
+
x = x_
|
| 479 |
+
return (x, edge_attr) if self.update_edge_attr else x
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
################################################################################
|
| 483 |
+
def _normalize(tensor, dim=-1, eps=1e-8):
|
| 484 |
+
'''
|
| 485 |
+
Normalizes a `torch.Tensor` along dimension `dim` without `nan`s.
|
| 486 |
+
'''
|
| 487 |
+
return torch.nan_to_num(
|
| 488 |
+
torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True) + eps))
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def _rbf(D, D_min=0., D_max=20., D_count=16, device='cpu'):
|
| 492 |
+
'''
|
| 493 |
+
From https://github.com/jingraham/neurips19-graph-protein-design
|
| 494 |
+
|
| 495 |
+
Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1.
|
| 496 |
+
That is, if `D` has shape [...dims], then the returned tensor will have
|
| 497 |
+
shape [...dims, D_count].
|
| 498 |
+
'''
|
| 499 |
+
D_mu = torch.linspace(D_min, D_max, D_count, device=device)
|
| 500 |
+
D_mu = D_mu.view([1, -1])
|
| 501 |
+
D_sigma = (D_max - D_min) / D_count
|
| 502 |
+
D_expand = torch.unsqueeze(D, -1)
|
| 503 |
+
|
| 504 |
+
RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2)
|
| 505 |
+
return RBF
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
class GVPModel(torch.nn.Module):
|
| 509 |
+
"""
|
| 510 |
+
GVP-GNN model
|
| 511 |
+
inspired by: https://github.com/drorlab/gvp-pytorch/blob/main/gvp/models.py
|
| 512 |
+
and: https://github.com/drorlab/gvp-pytorch/blob/82af6b22eaf8311c15733117b0071408d24ed877/gvp/atom3d.py#L115
|
| 513 |
+
|
| 514 |
+
:param node_in_dim: node dimension in input graph, scalars or tuple (scalars, vectors)
|
| 515 |
+
:param node_h_dim: node dimensions to use in GVP-GNN layers, tuple (s, V)
|
| 516 |
+
:param node_out_nf: node dimensions in output graph, tuple (s, V)
|
| 517 |
+
:param edge_in_nf: edge dimension in input graph (scalars)
|
| 518 |
+
:param edge_h_dim: edge dimensions to embed to before use in GVP-GNN layers,
|
| 519 |
+
tuple (s, V)
|
| 520 |
+
:param edge_out_nf: edge dimensions in output graph, tuple (s, V)
|
| 521 |
+
:param num_layers: number of GVP-GNN layers
|
| 522 |
+
:param drop_rate: rate to use in all dropout layers
|
| 523 |
+
:param vector_gate: use vector gates in all GVPs
|
| 524 |
+
:param reflection_equiv: bool, use reflection-sensitive feature based on the
|
| 525 |
+
cross product if False
|
| 526 |
+
:param d_max:
|
| 527 |
+
:param num_rbf:
|
| 528 |
+
:param update_edge_attr: bool, update edge attributes at each layer in a
|
| 529 |
+
learnable way
|
| 530 |
+
"""
|
| 531 |
+
def __init__(self, node_in_dim, node_h_dim, node_out_nf,
|
| 532 |
+
edge_in_nf, edge_h_dim, edge_out_nf,
|
| 533 |
+
num_layers=3, drop_rate=0.1, vector_gate=False,
|
| 534 |
+
reflection_equiv=True, d_max=20.0, num_rbf=16,
|
| 535 |
+
update_edge_attr=False):
|
| 536 |
+
|
| 537 |
+
super(GVPModel, self).__init__()
|
| 538 |
+
|
| 539 |
+
self.reflection_equiv = reflection_equiv
|
| 540 |
+
self.update_edge_attr = update_edge_attr
|
| 541 |
+
self.d_max = d_max
|
| 542 |
+
self.num_rbf = num_rbf
|
| 543 |
+
|
| 544 |
+
# node_in_dim = (node_in_dim, 1)
|
| 545 |
+
if not isinstance(node_in_dim, tuple):
|
| 546 |
+
node_in_dim = (node_in_dim, 0)
|
| 547 |
+
|
| 548 |
+
edge_in_dim = (edge_in_nf + 2 * node_in_dim[0] + self.num_rbf, 1)
|
| 549 |
+
if not self.reflection_equiv:
|
| 550 |
+
edge_in_dim = (edge_in_dim[0], edge_in_dim[1] + 1)
|
| 551 |
+
|
| 552 |
+
# self.W_v = nn.Sequential(
|
| 553 |
+
# GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=True),
|
| 554 |
+
# LayerNorm(node_h_dim)
|
| 555 |
+
# )
|
| 556 |
+
self.W_v = nn.Sequential(
|
| 557 |
+
LayerNorm(node_in_dim, learnable_vector_weight=True),
|
| 558 |
+
GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=vector_gate),
|
| 559 |
+
)
|
| 560 |
+
# self.W_e = nn.Sequential(
|
| 561 |
+
# GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=True),
|
| 562 |
+
# LayerNorm(edge_h_dim)
|
| 563 |
+
# )
|
| 564 |
+
self.W_e = nn.Sequential(
|
| 565 |
+
LayerNorm(edge_in_dim, learnable_vector_weight=True),
|
| 566 |
+
GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=vector_gate),
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
self.layers = nn.ModuleList(
|
| 570 |
+
GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate,
|
| 571 |
+
update_edge_attr=self.update_edge_attr,
|
| 572 |
+
activations=(F.relu, None), vector_gate=vector_gate,
|
| 573 |
+
ln_vector_weight=True)
|
| 574 |
+
# activations=(F.relu, torch.sigmoid))
|
| 575 |
+
# GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate,
|
| 576 |
+
# update_edge_attr=self.update_edge_attr,
|
| 577 |
+
# activations=(nn.SiLU(), nn.SiLU()))
|
| 578 |
+
for _ in range(num_layers))
|
| 579 |
+
|
| 580 |
+
# self.W_v_out = GVP(node_h_dim, (node_out_nf, 1),
|
| 581 |
+
# activations=(None, None), vector_gate=True)
|
| 582 |
+
self.W_v_out = nn.Sequential(
|
| 583 |
+
LayerNorm(node_h_dim, learnable_vector_weight=True),
|
| 584 |
+
GVP(node_h_dim, (node_out_nf, 1), activations=(None, None), vector_gate=vector_gate),
|
| 585 |
+
)
|
| 586 |
+
# self.W_e_out = GVP(edge_h_dim, (edge_out_nf, 0),
|
| 587 |
+
# activations=(None, None), vector_gate=True) \
|
| 588 |
+
# if self.update_edge_attr else None
|
| 589 |
+
self.W_e_out = nn.Sequential(
|
| 590 |
+
LayerNorm(edge_h_dim, learnable_vector_weight=True),
|
| 591 |
+
GVP(edge_h_dim, (edge_out_nf, 0), activations=(None, None), vector_gate=vector_gate)
|
| 592 |
+
) if self.update_edge_attr else None
|
| 593 |
+
|
| 594 |
+
def edge_features(self, h, x, edge_index, batch_mask=None, edge_attr=None):
|
| 595 |
+
"""
|
| 596 |
+
:param h:
|
| 597 |
+
:param x:
|
| 598 |
+
:param edge_index:
|
| 599 |
+
:param batch_mask:
|
| 600 |
+
:param edge_attr:
|
| 601 |
+
:return: scalar and vector-valued edge features
|
| 602 |
+
"""
|
| 603 |
+
row, col = edge_index
|
| 604 |
+
coord_diff = x[row] - x[col]
|
| 605 |
+
dist = coord_diff.norm(dim=-1)
|
| 606 |
+
rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf,
|
| 607 |
+
device=x.device)
|
| 608 |
+
|
| 609 |
+
edge_s = torch.cat([h[row], h[col], rbf], dim=1)
|
| 610 |
+
edge_v = _normalize(coord_diff).unsqueeze(-2)
|
| 611 |
+
|
| 612 |
+
if edge_attr is not None:
|
| 613 |
+
edge_s = torch.cat([edge_s, edge_attr], dim=1)
|
| 614 |
+
|
| 615 |
+
if not self.reflection_equiv:
|
| 616 |
+
mean = scatter_mean(x, batch_mask, dim=0,
|
| 617 |
+
dim_size=batch_mask.max() + 1)
|
| 618 |
+
row, col = edge_index
|
| 619 |
+
cross = torch.cross(x[row] - mean[batch_mask[row]],
|
| 620 |
+
x[col] - mean[batch_mask[col]], dim=1)
|
| 621 |
+
cross = _normalize(cross).unsqueeze(-2)
|
| 622 |
+
|
| 623 |
+
edge_v = torch.cat([edge_v, cross], dim=-2)
|
| 624 |
+
|
| 625 |
+
return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v)
|
| 626 |
+
|
| 627 |
+
def forward(self, h, x, edge_index, v=None, batch_mask=None, edge_attr=None):
|
| 628 |
+
|
| 629 |
+
# h_v = (h, x.unsqueeze(-2))
|
| 630 |
+
h_v = h if v is None else (h, v)
|
| 631 |
+
h_e = self.edge_features(h, x, edge_index, batch_mask, edge_attr)
|
| 632 |
+
|
| 633 |
+
h_v = self.W_v(h_v)
|
| 634 |
+
h_e = self.W_e(h_e)
|
| 635 |
+
|
| 636 |
+
for layer in self.layers:
|
| 637 |
+
h_v = layer(h_v, edge_index, edge_attr=h_e)
|
| 638 |
+
if self.update_edge_attr:
|
| 639 |
+
h_v, h_e = h_v
|
| 640 |
+
|
| 641 |
+
# h, x = self.W_v_out(h_v)
|
| 642 |
+
# x = x.squeeze(-2)
|
| 643 |
+
h, vel = self.W_v_out(h_v)
|
| 644 |
+
# x = x + vel.squeeze(-2)
|
| 645 |
+
|
| 646 |
+
if self.update_edge_attr:
|
| 647 |
+
edge_attr = self.W_e_out(h_e)
|
| 648 |
+
|
| 649 |
+
# return h, x, edge_attr
|
| 650 |
+
return h, vel.squeeze(-2), edge_attr
|
src/model/gvp_transformer.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import functools
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch_scatter import scatter_mean, scatter_std, scatter_min, scatter_max, scatter_softmax
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# ## debug
|
| 10 |
+
# import sys
|
| 11 |
+
# from pathlib import Path
|
| 12 |
+
#
|
| 13 |
+
# basedir = Path(__file__).resolve().parent.parent.parent
|
| 14 |
+
# sys.path.append(str(basedir))
|
| 15 |
+
# ###
|
| 16 |
+
|
| 17 |
+
from src.model.gvp import GVP, _norm_no_nan, tuple_sum, Dropout, LayerNorm, \
|
| 18 |
+
tuple_cat, tuple_index, _rbf, _normalize
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def tuple_mul(tup, val):
|
| 22 |
+
if isinstance(val, torch.Tensor):
|
| 23 |
+
return (tup[0] * val, tup[1] * val.unsqueeze(-1))
|
| 24 |
+
return (tup[0] * val, tup[1] * val)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class GVPBlock(nn.Module):
|
| 28 |
+
def __init__(self, in_dims, out_dims, n_layers=1,
|
| 29 |
+
activations=(F.relu, torch.sigmoid), vector_gate=False,
|
| 30 |
+
dropout=0.0, skip=False, layernorm=False):
|
| 31 |
+
super(GVPBlock, self).__init__()
|
| 32 |
+
self.si, self.vi = in_dims
|
| 33 |
+
self.so, self.vo = out_dims
|
| 34 |
+
assert not skip or (self.si == self.so and self.vi == self.vo)
|
| 35 |
+
self.skip = skip
|
| 36 |
+
|
| 37 |
+
GVP_ = functools.partial(GVP, activations=activations, vector_gate=vector_gate)
|
| 38 |
+
|
| 39 |
+
module_list = []
|
| 40 |
+
if n_layers == 1:
|
| 41 |
+
module_list.append(GVP_(in_dims, out_dims, activations=(None, None)))
|
| 42 |
+
else:
|
| 43 |
+
module_list.append(GVP_(in_dims, out_dims))
|
| 44 |
+
for i in range(n_layers - 2):
|
| 45 |
+
module_list.append(GVP_(out_dims, out_dims))
|
| 46 |
+
module_list.append(GVP_(out_dims, out_dims, activations=(None, None)))
|
| 47 |
+
|
| 48 |
+
self.layers = nn.Sequential(*module_list)
|
| 49 |
+
|
| 50 |
+
self.norm = LayerNorm(out_dims, learnable_vector_weight=True) if layernorm else None
|
| 51 |
+
self.dropout = Dropout(dropout) if dropout > 0 else None
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
"""
|
| 55 |
+
:param x: tuple (s, V) of `torch.Tensor`
|
| 56 |
+
:return: tuple (s, V) of `torch.Tensor`
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
dx = self.layers(x)
|
| 60 |
+
|
| 61 |
+
if self.dropout is not None:
|
| 62 |
+
dx = self.dropout(dx)
|
| 63 |
+
|
| 64 |
+
if self.skip:
|
| 65 |
+
x = tuple_sum(x, dx)
|
| 66 |
+
else:
|
| 67 |
+
x = dx
|
| 68 |
+
|
| 69 |
+
if self.norm is not None:
|
| 70 |
+
x = self.norm(x)
|
| 71 |
+
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class GeometricPNA(nn.Module):
|
| 76 |
+
def __init__(self, d_in, d_out):
|
| 77 |
+
""" Map features to global features """
|
| 78 |
+
super().__init__()
|
| 79 |
+
si, vi = d_in
|
| 80 |
+
so, vo = d_out
|
| 81 |
+
self.gvp = GVPBlock((4 * si + 3 * vi, vi), d_out)
|
| 82 |
+
|
| 83 |
+
def forward(self, x, batch_mask, batch_size=None):
|
| 84 |
+
""" x: tuple (s, V) """
|
| 85 |
+
s, v = x
|
| 86 |
+
|
| 87 |
+
sm = scatter_mean(s, batch_mask, dim=0, dim_size=batch_size)
|
| 88 |
+
smi = scatter_min(s, batch_mask, dim=0, dim_size=batch_size)[0]
|
| 89 |
+
sma = scatter_max(s, batch_mask, dim=0, dim_size=batch_size)[0]
|
| 90 |
+
sstd = scatter_std(s, batch_mask, dim=0, dim_size=batch_size)
|
| 91 |
+
|
| 92 |
+
vnorm = _norm_no_nan(v)
|
| 93 |
+
vm = scatter_mean(v, batch_mask, dim=0, dim_size=batch_size)
|
| 94 |
+
vmi = scatter_min(vnorm, batch_mask, dim=0, dim_size=batch_size)[0]
|
| 95 |
+
vma = scatter_max(vnorm, batch_mask, dim=0, dim_size=batch_size)[0]
|
| 96 |
+
vstd = scatter_std(vnorm, batch_mask, dim=0, dim_size=batch_size)
|
| 97 |
+
|
| 98 |
+
z = torch.hstack((sm, smi, sma, sstd, vmi, vma, vstd))
|
| 99 |
+
out = self.gvp((z, vm))
|
| 100 |
+
return out
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class TupleLinear(nn.Module):
|
| 104 |
+
def __init__(self, in_dims, out_dims, bias=True):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.si, self.vi = in_dims
|
| 107 |
+
self.so, self.vo = out_dims
|
| 108 |
+
assert self.si and self.so
|
| 109 |
+
self.ws = nn.Linear(self.si, self.so, bias=bias)
|
| 110 |
+
self.wv = nn.Linear(self.vi, self.vo, bias=bias) if self.vi and self.vo else None
|
| 111 |
+
|
| 112 |
+
def forward(self, x):
|
| 113 |
+
if self.vi:
|
| 114 |
+
s, v = x
|
| 115 |
+
|
| 116 |
+
s = self.ws(s)
|
| 117 |
+
|
| 118 |
+
if self.vo:
|
| 119 |
+
v = v.transpose(-1, -2)
|
| 120 |
+
v = self.wv(v)
|
| 121 |
+
v = v.transpose(-1, -2)
|
| 122 |
+
|
| 123 |
+
else:
|
| 124 |
+
s = self.ws(x)
|
| 125 |
+
|
| 126 |
+
if self.vo:
|
| 127 |
+
v = torch.zeros(s.size(0), self.vo, 3, device=s.device)
|
| 128 |
+
|
| 129 |
+
return (s, v) if self.vo else s
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class GVPTransformerLayer(nn.Module):
|
| 133 |
+
"""
|
| 134 |
+
Full graph transformer layer with Geometric Vector Perceptrons.
|
| 135 |
+
Inspired by
|
| 136 |
+
- GVP: Jing, Bowen, et al. "Learning from protein structure with geometric vector perceptrons." arXiv preprint arXiv:2009.01411 (2020).
|
| 137 |
+
- Transformer architecture: Vignac, Clement, et al. "Digress: Discrete denoising diffusion for graph generation." arXiv preprint arXiv:2209.14734 (2022).
|
| 138 |
+
- Invariant point attention: Jumper, John, et al. "Highly accurate protein structure prediction with AlphaFold." Nature 596.7873 (2021): 583-589.
|
| 139 |
+
|
| 140 |
+
:param node_dims: node embedding dimensions (n_scalar, n_vector)
|
| 141 |
+
:param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
|
| 142 |
+
:param global_dims: global feature dimension (n_scalar, n_vector)
|
| 143 |
+
:param dk: key dimension, (n_scalar, n_vector)
|
| 144 |
+
:param dv: node value dimension, (n_scalar, n_vector)
|
| 145 |
+
:param de: edge value dimension, (n_scalar, n_vector)
|
| 146 |
+
:param db: dimension of edge contribution to attention, int
|
| 147 |
+
:param attn_heads: number of attention heads, int
|
| 148 |
+
:param n_feedforward: number of GVPs to use in feedforward function
|
| 149 |
+
:param drop_rate: drop probability in all dropout layers
|
| 150 |
+
:param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
|
| 151 |
+
:param vector_gate: whether to use vector gating.
|
| 152 |
+
(vector_act will be used as sigma^+ in vector gating if `True`)
|
| 153 |
+
:param attention: can be used to turn off the attention mechanism
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
def __init__(self, node_dims, edge_dims, global_dims, dk, dv, de, db,
|
| 157 |
+
attn_heads, n_feedforward=1, drop_rate=0.0,
|
| 158 |
+
activations=(F.relu, torch.sigmoid), vector_gate=False,
|
| 159 |
+
attention=True):
|
| 160 |
+
|
| 161 |
+
super(GVPTransformerLayer, self).__init__()
|
| 162 |
+
|
| 163 |
+
self.attention = attention
|
| 164 |
+
|
| 165 |
+
dq = dk
|
| 166 |
+
self.dq = dq
|
| 167 |
+
self.dk = dk
|
| 168 |
+
self.dv = dv
|
| 169 |
+
self.de = de
|
| 170 |
+
self.db = db
|
| 171 |
+
|
| 172 |
+
self.h = attn_heads
|
| 173 |
+
|
| 174 |
+
self.q = TupleLinear(node_dims, tuple_mul(dq, self.h), bias=False) if self.attention else None
|
| 175 |
+
self.k = TupleLinear(node_dims, tuple_mul(dk, self.h), bias=False) if self.attention else None
|
| 176 |
+
self.vx = TupleLinear(node_dims, tuple_mul(dv, self.h), bias=False)
|
| 177 |
+
|
| 178 |
+
self.ve = TupleLinear(edge_dims, tuple_mul(de, self.h), bias=False)
|
| 179 |
+
self.b = TupleLinear(edge_dims, (db * self.h, 0), bias=False) if self.attention else None
|
| 180 |
+
|
| 181 |
+
m_dim = tuple_sum(tuple_mul(dv, self.h), tuple_mul(de, self.h))
|
| 182 |
+
self.msg = GVPBlock(m_dim, m_dim, n_feedforward,
|
| 183 |
+
activations=activations, vector_gate=vector_gate)
|
| 184 |
+
|
| 185 |
+
m_dim = tuple_sum(m_dim, global_dims)
|
| 186 |
+
self.x_out = GVPBlock(m_dim, node_dims, n_feedforward,
|
| 187 |
+
activations=activations, vector_gate=vector_gate)
|
| 188 |
+
self.x_norm = LayerNorm(node_dims, learnable_vector_weight=True)
|
| 189 |
+
self.x_dropout = Dropout(drop_rate)
|
| 190 |
+
|
| 191 |
+
e_dim = tuple_sum(tuple_mul(node_dims, 2), edge_dims, global_dims)
|
| 192 |
+
if self.attention:
|
| 193 |
+
e_dim = (e_dim[0] + 3 * attn_heads, e_dim[1])
|
| 194 |
+
self.e_out = GVPBlock(e_dim, edge_dims, n_feedforward,
|
| 195 |
+
activations=activations, vector_gate=vector_gate)
|
| 196 |
+
self.e_norm = LayerNorm(edge_dims, learnable_vector_weight=True)
|
| 197 |
+
self.e_dropout = Dropout(drop_rate)
|
| 198 |
+
|
| 199 |
+
self.pna_x = GeometricPNA(node_dims, node_dims)
|
| 200 |
+
self.pna_e = GeometricPNA(edge_dims, edge_dims)
|
| 201 |
+
self.y = GVP(global_dims, global_dims, activations=(None, None), vector_gate=vector_gate)
|
| 202 |
+
_dim = tuple_sum(node_dims, edge_dims, global_dims)
|
| 203 |
+
self.y_out = GVPBlock(_dim, global_dims, n_feedforward,
|
| 204 |
+
activations=activations, vector_gate=vector_gate)
|
| 205 |
+
self.y_norm = LayerNorm(global_dims, learnable_vector_weight=True)
|
| 206 |
+
self.y_dropout = Dropout(drop_rate)
|
| 207 |
+
|
| 208 |
+
def forward(self, x, edge_index, batch_mask, edge_attr, global_attr=None,
|
| 209 |
+
node_mask=None):
|
| 210 |
+
"""
|
| 211 |
+
:param x: tuple (s, V) of `torch.Tensor`
|
| 212 |
+
:param edge_index: array of shape [2, n_edges]
|
| 213 |
+
:param batch_mask: array indicating different graphs
|
| 214 |
+
:param edge_attr: tuple (s, V) of `torch.Tensor`
|
| 215 |
+
:param global_attr: tuple (s, V) of `torch.Tensor`
|
| 216 |
+
:param node_mask: array of type `bool` to index into the first
|
| 217 |
+
dim of node embeddings (s, V). If not `None`, only
|
| 218 |
+
these nodes will be updated.
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
row, col = edge_index
|
| 222 |
+
n = len(x[0])
|
| 223 |
+
batch_size = len(torch.unique(batch_mask))
|
| 224 |
+
|
| 225 |
+
# Compute attention
|
| 226 |
+
if self.attention:
|
| 227 |
+
Q = self.q(x)
|
| 228 |
+
K = self.k(x)
|
| 229 |
+
b = self.b(edge_attr)
|
| 230 |
+
|
| 231 |
+
qs, qv = Q # (n, dq * h), (n, dq * h, 3)
|
| 232 |
+
ks, kv = K # (n, dq * h), (n, dq * h, 3)
|
| 233 |
+
attn_s = (qs[row] * ks[col]).reshape(len(row), self.h, self.dq[0]).sum(dim=-1) # (m, h)
|
| 234 |
+
# NOTE: attn_v is the Frobenius inner product between vector-valued queries and keys of size [dq, 3]
|
| 235 |
+
# (generalizes the dot-product between queries and keys similar to Pocket2Mol)
|
| 236 |
+
# TODO: double-check if this is correctly implemented!
|
| 237 |
+
attn_v = (qv[row] * kv[col]).reshape(len(row), self.h, self.dq[1], 3).sum(dim=(-2, -1)) # (m, h)
|
| 238 |
+
attn_e = b.reshape(b.size(0), self.h, self.db).sum(dim=-1) # (m, h)
|
| 239 |
+
|
| 240 |
+
attn = attn_s / math.sqrt(3 * self.dk[0]) + \
|
| 241 |
+
attn_v / math.sqrt(9 * self.dk[1]) + \
|
| 242 |
+
attn_e / math.sqrt(3 * self.db)
|
| 243 |
+
attn = scatter_softmax(attn, row, dim=0) # (m, h)
|
| 244 |
+
attn = attn.unsqueeze(-1) # (m, h, 1)
|
| 245 |
+
|
| 246 |
+
# Compute new features
|
| 247 |
+
Vx = self.vx(x)
|
| 248 |
+
Ve = self.ve(edge_attr)
|
| 249 |
+
|
| 250 |
+
mx = (Vx[0].reshape(Vx[0].size(0), self.h, self.dv[0]), # (n, h, dv)
|
| 251 |
+
Vx[1].reshape(Vx[1].size(0), self.h, self.dv[1], 3)) # (n, h, dv, 3)
|
| 252 |
+
me = (Ve[0].reshape(Ve[0].size(0), self.h, self.de[0]),
|
| 253 |
+
Ve[1].reshape(Ve[1].size(0), self.h, self.de[1], 3))
|
| 254 |
+
|
| 255 |
+
mx = tuple_index(mx, col)
|
| 256 |
+
if self.attention:
|
| 257 |
+
mx = tuple_mul(mx, attn)
|
| 258 |
+
me = tuple_mul(me, attn)
|
| 259 |
+
|
| 260 |
+
_m = tuple_cat(mx, me)
|
| 261 |
+
_m = (_m[0].flatten(1), _m[1].flatten(1, 2))
|
| 262 |
+
m = self.msg(_m) # (m, h * dv), (m, h * dv, 3)
|
| 263 |
+
m = (scatter_mean(m[0], row, dim=0, dim_size=n), # (n, h * dv)
|
| 264 |
+
scatter_mean(m[1], row, dim=0, dim_size=n)) # (n, h * dv, 3)
|
| 265 |
+
if global_attr is not None:
|
| 266 |
+
m = tuple_cat(m, tuple_index(global_attr, batch_mask))
|
| 267 |
+
X_out = self.x_norm(tuple_sum(x, self.x_dropout(self.x_out(m))))
|
| 268 |
+
|
| 269 |
+
_e = tuple_cat(tuple_index(x, row), tuple_index(x, col), edge_attr)
|
| 270 |
+
if self.attention:
|
| 271 |
+
_e = (torch.cat([_e[0], attn_s, attn_v, attn_e], dim=-1), _e[1])
|
| 272 |
+
if global_attr is not None:
|
| 273 |
+
_e = tuple_cat(_e, tuple_index(global_attr, batch_mask[row]))
|
| 274 |
+
E_out = self.e_norm(tuple_sum(edge_attr, self.e_dropout(self.e_out(_e))))
|
| 275 |
+
|
| 276 |
+
_y = tuple_cat(self.pna_x(x, batch_mask, batch_size),
|
| 277 |
+
self.pna_e(edge_attr, batch_mask[row], batch_size))
|
| 278 |
+
if global_attr is not None:
|
| 279 |
+
_y = tuple_cat(_y, self.y(global_attr))
|
| 280 |
+
y_out = self.y_norm(tuple_sum(global_attr, self.y_dropout(self.y_out(_y))))
|
| 281 |
+
else:
|
| 282 |
+
y_out = self.y_norm(self.y_dropout(self.y_out(_y)))
|
| 283 |
+
|
| 284 |
+
if node_mask is not None:
|
| 285 |
+
X_out[0][~node_mask], X_out[1][~node_mask] = tuple_index(x, ~node_mask)
|
| 286 |
+
|
| 287 |
+
return X_out, E_out, y_out
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class GVPTransformerModel(torch.nn.Module):
|
| 291 |
+
"""
|
| 292 |
+
GVP-Transformer model
|
| 293 |
+
|
| 294 |
+
:param node_in_dim: node dimension in input graph, scalars or tuple (scalars, vectors)
|
| 295 |
+
:param node_h_dim: node dimensions to use in GVP-GNN layers, tuple (s, V)
|
| 296 |
+
:param node_out_nf: node dimensions in output graph, tuple (s, V)
|
| 297 |
+
:param edge_in_nf: edge dimension in input graph (scalars)
|
| 298 |
+
:param edge_h_dim: edge dimensions to embed to before use in GVP-GNN layers,
|
| 299 |
+
tuple (s, V)
|
| 300 |
+
:param edge_out_nf: edge dimensions in output graph, tuple (s, V)
|
| 301 |
+
:param num_layers: number of GVP-GNN layers
|
| 302 |
+
:param drop_rate: rate to use in all dropout layers
|
| 303 |
+
:param reflection_equiv: bool, use reflection-sensitive feature based on the
|
| 304 |
+
cross product if False
|
| 305 |
+
:param d_max:
|
| 306 |
+
:param num_rbf:
|
| 307 |
+
:param vector_gate: use vector gates in all GVPs
|
| 308 |
+
:param attention: can be used to turn off the attention mechanism
|
| 309 |
+
"""
|
| 310 |
+
def __init__(self, node_in_dim, node_h_dim, node_out_nf, edge_in_nf,
|
| 311 |
+
edge_h_dim, edge_out_nf, num_layers, dk, dv, de, db, dy,
|
| 312 |
+
attn_heads, n_feedforward, drop_rate, reflection_equiv=True,
|
| 313 |
+
d_max=20.0, num_rbf=16, vector_gate=False, attention=True):
|
| 314 |
+
|
| 315 |
+
super(GVPTransformerModel, self).__init__()
|
| 316 |
+
|
| 317 |
+
self.reflection_equiv = reflection_equiv
|
| 318 |
+
self.d_max = d_max
|
| 319 |
+
self.num_rbf = num_rbf
|
| 320 |
+
|
| 321 |
+
# node_in_dim = (node_in_dim, 1)
|
| 322 |
+
if not isinstance(node_in_dim, tuple):
|
| 323 |
+
node_in_dim = (node_in_dim, 0)
|
| 324 |
+
|
| 325 |
+
edge_in_dim = (edge_in_nf + 2 * node_in_dim[0] + self.num_rbf, 1)
|
| 326 |
+
if not self.reflection_equiv:
|
| 327 |
+
edge_in_dim = (edge_in_dim[0], edge_in_dim[1] + 1)
|
| 328 |
+
|
| 329 |
+
self.W_v = GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=vector_gate)
|
| 330 |
+
self.W_e = GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=vector_gate)
|
| 331 |
+
# self.W_v = nn.Sequential(
|
| 332 |
+
# LayerNorm(node_in_dim, learnable_vector_weight=True),
|
| 333 |
+
# GVP(node_in_dim, node_h_dim, activations=(None, None)),
|
| 334 |
+
# )
|
| 335 |
+
# self.W_e = nn.Sequential(
|
| 336 |
+
# LayerNorm(edge_in_dim, learnable_vector_weight=True),
|
| 337 |
+
# GVP(edge_in_dim, edge_h_dim, activations=(None, None)),
|
| 338 |
+
# )
|
| 339 |
+
|
| 340 |
+
self.dy = dy
|
| 341 |
+
self.layers = nn.ModuleList(
|
| 342 |
+
GVPTransformerLayer(node_h_dim, edge_h_dim, dy, dk, dv, de, db,
|
| 343 |
+
attn_heads, n_feedforward=n_feedforward,
|
| 344 |
+
drop_rate=drop_rate, vector_gate=vector_gate,
|
| 345 |
+
activations=(F.relu, None), attention=attention)
|
| 346 |
+
for _ in range(num_layers))
|
| 347 |
+
|
| 348 |
+
self.W_v_out = GVP(node_h_dim, (node_out_nf, 1), activations=(None, None), vector_gate=vector_gate)
|
| 349 |
+
self.W_e_out = GVP(edge_h_dim, (edge_out_nf, 0), activations=(None, None), vector_gate=vector_gate)
|
| 350 |
+
# self.W_v_out = nn.Sequential(
|
| 351 |
+
# LayerNorm(node_h_dim, learnable_vector_weight=True),
|
| 352 |
+
# GVP(node_h_dim, (node_out_nf, 1), activations=(None, None)),
|
| 353 |
+
# )
|
| 354 |
+
# self.W_e_out = nn.Sequential(
|
| 355 |
+
# LayerNorm(edge_h_dim, learnable_vector_weight=True),
|
| 356 |
+
# GVP(edge_h_dim, (edge_out_nf, 0), activations=(None, None))
|
| 357 |
+
# )
|
| 358 |
+
|
| 359 |
+
def edge_features(self, h, x, edge_index, batch_mask=None, edge_attr=None):
|
| 360 |
+
"""
|
| 361 |
+
:param h:
|
| 362 |
+
:param x:
|
| 363 |
+
:param edge_index:
|
| 364 |
+
:param batch_mask:
|
| 365 |
+
:param edge_attr:
|
| 366 |
+
:return: scalar and vector-valued edge features
|
| 367 |
+
"""
|
| 368 |
+
row, col = edge_index
|
| 369 |
+
coord_diff = x[row] - x[col]
|
| 370 |
+
dist = coord_diff.norm(dim=-1)
|
| 371 |
+
rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf,
|
| 372 |
+
device=x.device)
|
| 373 |
+
|
| 374 |
+
edge_s = torch.cat([h[row], h[col], rbf], dim=1)
|
| 375 |
+
edge_v = _normalize(coord_diff).unsqueeze(-2)
|
| 376 |
+
|
| 377 |
+
if edge_attr is not None:
|
| 378 |
+
edge_s = torch.cat([edge_s, edge_attr], dim=1)
|
| 379 |
+
|
| 380 |
+
if not self.reflection_equiv:
|
| 381 |
+
mean = scatter_mean(x, batch_mask, dim=0,
|
| 382 |
+
dim_size=batch_mask.max() + 1)
|
| 383 |
+
row, col = edge_index
|
| 384 |
+
cross = torch.cross(x[row] - mean[batch_mask[row]],
|
| 385 |
+
x[col] - mean[batch_mask[col]], dim=1)
|
| 386 |
+
cross = _normalize(cross).unsqueeze(-2)
|
| 387 |
+
|
| 388 |
+
edge_v = torch.cat([edge_v, cross], dim=-2)
|
| 389 |
+
|
| 390 |
+
return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v)
|
| 391 |
+
|
| 392 |
+
def forward(self, h, x, edge_index, v=None, batch_mask=None, edge_attr=None):
|
| 393 |
+
|
| 394 |
+
bs = len(batch_mask.unique())
|
| 395 |
+
|
| 396 |
+
# h_v = (h, x.unsqueeze(-2))
|
| 397 |
+
h_v = h if v is None else (h, v)
|
| 398 |
+
h_e = self.edge_features(h, x, edge_index, batch_mask, edge_attr)
|
| 399 |
+
|
| 400 |
+
h_v = self.W_v(h_v)
|
| 401 |
+
h_e = self.W_e(h_e)
|
| 402 |
+
h_y = (torch.zeros(bs, self.dy[0], device=h.device),
|
| 403 |
+
torch.zeros(bs, self.dy[1], 3, device=h.device))
|
| 404 |
+
|
| 405 |
+
for layer in self.layers:
|
| 406 |
+
h_v, h_e, h_y = layer(h_v, edge_index, batch_mask, h_e, h_y)
|
| 407 |
+
|
| 408 |
+
# h, x = self.W_v_out(h_v)
|
| 409 |
+
# x = x.squeeze(-2)
|
| 410 |
+
h, vel = self.W_v_out(h_v)
|
| 411 |
+
# x = x + vel.squeeze(-2)
|
| 412 |
+
|
| 413 |
+
edge_attr = self.W_e_out(h_e)
|
| 414 |
+
|
| 415 |
+
# return h, x, edge_attr
|
| 416 |
+
return h, vel.squeeze(-2), edge_attr
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
if __name__ == "__main__":
|
| 420 |
+
from src.model.gvp import randn
|
| 421 |
+
from scipy.spatial.transform import Rotation
|
| 422 |
+
|
| 423 |
+
def test_equivariance(model, nodes, edges, glob_feat):
|
| 424 |
+
random = torch.as_tensor(Rotation.random().as_matrix(),
|
| 425 |
+
dtype=torch.float32, device=device)
|
| 426 |
+
|
| 427 |
+
with torch.no_grad():
|
| 428 |
+
X_out, E_out, y_out = model(nodes, edges, glob_feat)
|
| 429 |
+
n_v_rot, e_v_rot, y_v_rot = nodes[1] @ random, edges[1] @ random, glob_feat[1] @ random
|
| 430 |
+
X_out_v_rot = X_out[1] @ random
|
| 431 |
+
E_out_v_rot = E_out[1] @ random
|
| 432 |
+
y_out_v_rot = y_out[1] @ random
|
| 433 |
+
X_out_prime, E_out_prime, y_out_prime = model((nodes[0], n_v_rot), (edges[0], e_v_rot), (glob_feat[0], y_v_rot))
|
| 434 |
+
|
| 435 |
+
assert torch.allclose(X_out[0], X_out_prime[0], atol=1e-5, rtol=1e-4)
|
| 436 |
+
assert torch.allclose(X_out_v_rot, X_out_prime[1], atol=1e-5, rtol=1e-4)
|
| 437 |
+
assert torch.allclose(E_out[0], E_out_prime[0], atol=1e-5, rtol=1e-4)
|
| 438 |
+
assert torch.allclose(E_out_v_rot, E_out_prime[1], atol=1e-5, rtol=1e-4)
|
| 439 |
+
assert torch.allclose(y_out[0], y_out_prime[0], atol=1e-5, rtol=1e-4)
|
| 440 |
+
assert torch.allclose(y_out_v_rot, y_out_prime[1], atol=1e-5, rtol=1e-4)
|
| 441 |
+
print("SUCCESS")
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
n_nodes = 300
|
| 445 |
+
n_edges = 10000
|
| 446 |
+
batch_size = 6
|
| 447 |
+
|
| 448 |
+
node_dim = (16, 8)
|
| 449 |
+
edge_dim = (8, 4)
|
| 450 |
+
global_dim = (4, 2)
|
| 451 |
+
dk = (6, 3)
|
| 452 |
+
dv = (7, 4)
|
| 453 |
+
de = (5, 2)
|
| 454 |
+
db = 10
|
| 455 |
+
attn_heads = 9
|
| 456 |
+
|
| 457 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
nodes = randn(n_nodes, node_dim, device=device)
|
| 461 |
+
edges = randn(n_edges, edge_dim, device=device)
|
| 462 |
+
glob_feat = randn(batch_size, global_dim, device=device)
|
| 463 |
+
edge_index = torch.randint(0, n_nodes, (2, n_edges), device=device)
|
| 464 |
+
batch_idx = torch.randint(0, batch_size, (n_nodes,), device=device)
|
| 465 |
+
|
| 466 |
+
model = GVPTransformerLayer(node_dim, edge_dim, global_dim, dk, dv, de, db,
|
| 467 |
+
attn_heads, n_feedforward = 2,
|
| 468 |
+
drop_rate = 0.1).to(device).eval()
|
| 469 |
+
|
| 470 |
+
model_fn = lambda h_V, h_E, h_y: model(h_V, edge_index, batch_idx, h_E, h_y)
|
| 471 |
+
test_equivariance(model_fn, nodes, edges, glob_feat)
|
src/model/lightning.py
ADDED
|
@@ -0,0 +1,1426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
import tempfile
|
| 3 |
+
from typing import Optional, Union
|
| 4 |
+
from time import time
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from functools import partial
|
| 7 |
+
from itertools import accumulate
|
| 8 |
+
from argparse import Namespace
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from rdkit import Chem
|
| 13 |
+
import torch
|
| 14 |
+
from torch.utils.data import DataLoader, SubsetRandomSampler
|
| 15 |
+
from torch.distributions.categorical import Categorical
|
| 16 |
+
import pytorch_lightning as pl
|
| 17 |
+
from torch_scatter import scatter_mean
|
| 18 |
+
|
| 19 |
+
import src.utils as utils
|
| 20 |
+
from src.constants import atom_encoder, atom_decoder, aa_encoder, aa_decoder, \
|
| 21 |
+
bond_encoder, bond_decoder, residue_encoder, residue_bond_encoder, \
|
| 22 |
+
residue_decoder, residue_bond_decoder, aa_atom_index, aa_atom_mask
|
| 23 |
+
from src.data.dataset import ProcessedLigandPocketDataset, ClusteredDataset, get_wds
|
| 24 |
+
from src.data import data_utils
|
| 25 |
+
from src.data.data_utils import AppendVirtualNodesInCoM, center_data, Residues, TensorDict, randomize_tensors
|
| 26 |
+
from src.model.flows import CoordICFM, TorusICFM, CoordICFMPredictFinal, TorusICFMPredictFinal, SO3ICFM
|
| 27 |
+
from src.model.markov_bridge import UniformPriorMarkovBridge, MarginalPriorMarkovBridge
|
| 28 |
+
from src.model.dynamics import Dynamics
|
| 29 |
+
from src.model.dynamics_hetero import DynamicsHetero
|
| 30 |
+
from src.model.diffusion_utils import DistributionNodes
|
| 31 |
+
from src.model.loss_utils import TimestepWeights, clash_loss
|
| 32 |
+
from src.analysis.visualization_utils import pocket_to_rdkit, mols_to_pdbfile
|
| 33 |
+
from src.analysis.metrics import MoleculeValidity, CategoricalDistribution, MolecularProperties
|
| 34 |
+
from src.data.molecule_builder import build_molecule
|
| 35 |
+
from src.data.postprocessing import process_all
|
| 36 |
+
from src.sbdd_metrics.metrics import FullEvaluator
|
| 37 |
+
from src.sbdd_metrics.evaluation import VALIDITY_METRIC_NAME, aggregated_metrics, collection_metrics
|
| 38 |
+
from tqdm import tqdm
|
| 39 |
+
|
| 40 |
+
# derive additional constants
|
| 41 |
+
aa_atom_mask_tensor = torch.tensor([aa_atom_mask[aa] for aa in aa_decoder])
|
| 42 |
+
aa_atom_decoder = {aa: {v: k for k, v in aa_atom_index[aa].items()} for aa in aa_decoder}
|
| 43 |
+
aa_atom_type_tensor = torch.tensor([[atom_encoder.get(aa_atom_decoder[aa].get(i, '-')[0], -42)
|
| 44 |
+
for i in range(14)] for aa in aa_decoder])
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def set_default(namespace, key, default_val):
|
| 48 |
+
val = vars(namespace).get(key, default_val)
|
| 49 |
+
setattr(namespace, key, val)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class DrugFlow(pl.LightningModule):
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
pocket_representation: str,
|
| 56 |
+
train_params: Namespace,
|
| 57 |
+
loss_params: Namespace,
|
| 58 |
+
eval_params: Namespace,
|
| 59 |
+
predictor_params: Namespace,
|
| 60 |
+
simulation_params: Namespace,
|
| 61 |
+
virtual_nodes: Union[list, None],
|
| 62 |
+
flexible: bool,
|
| 63 |
+
flexible_bb: bool = False,
|
| 64 |
+
debug: bool = False,
|
| 65 |
+
overfit: bool = False,
|
| 66 |
+
):
|
| 67 |
+
super(DrugFlow, self).__init__()
|
| 68 |
+
self.save_hyperparameters()
|
| 69 |
+
|
| 70 |
+
# Set default parameters
|
| 71 |
+
set_default(train_params, "sharded_dataset", False)
|
| 72 |
+
set_default(train_params, "sample_from_clusters", False)
|
| 73 |
+
set_default(train_params, "lr_step_size", None)
|
| 74 |
+
set_default(train_params, "lr_gamma", None)
|
| 75 |
+
set_default(train_params, "gnina", None)
|
| 76 |
+
set_default(loss_params, "lambda_x", 1.0)
|
| 77 |
+
set_default(loss_params, "lambda_clash", None)
|
| 78 |
+
set_default(loss_params, "reduce", "mean")
|
| 79 |
+
set_default(loss_params, "regularize_uncertainty", None)
|
| 80 |
+
set_default(eval_params, "n_loss_per_sample", 1)
|
| 81 |
+
set_default(eval_params, "n_sampling_steps", simulation_params.n_steps)
|
| 82 |
+
set_default(predictor_params, "transform_sc_pred", False)
|
| 83 |
+
set_default(predictor_params, "add_chi_as_feature", False)
|
| 84 |
+
set_default(predictor_params, "augment_residue_sc", False)
|
| 85 |
+
set_default(predictor_params, "augment_ligand_sc", False)
|
| 86 |
+
set_default(predictor_params, "add_all_atom_diff", False)
|
| 87 |
+
set_default(predictor_params, "angle_act_fn", None)
|
| 88 |
+
set_default(simulation_params, "predict_confidence", False)
|
| 89 |
+
set_default(simulation_params, "predict_final", False)
|
| 90 |
+
set_default(simulation_params, "scheduler_chi", None)
|
| 91 |
+
|
| 92 |
+
# Check for invalid configurations
|
| 93 |
+
assert pocket_representation in {'side_chain_bead', 'CA+'}
|
| 94 |
+
self.pocket_representation = pocket_representation
|
| 95 |
+
|
| 96 |
+
assert flexible or not predictor_params.augment_residue_sc
|
| 97 |
+
self.augment_residue_sc = predictor_params.augment_residue_sc \
|
| 98 |
+
if 'augment_residue_sc' in predictor_params else False
|
| 99 |
+
self.augment_ligand_sc = predictor_params.augment_ligand_sc \
|
| 100 |
+
if 'augment_ligand_sc' in predictor_params else False
|
| 101 |
+
|
| 102 |
+
assert not (flexible_bb and predictor_params.normal_modes), \
|
| 103 |
+
"Normal mode eigenvectors are only meaningful for fixed backbones"
|
| 104 |
+
assert (not flexible_bb) or flexible, \
|
| 105 |
+
"Currently atom vectors aren't updated if flexible=False"
|
| 106 |
+
|
| 107 |
+
assert not (simulation_params.predict_confidence and
|
| 108 |
+
(not predictor_params.heterogeneous_graph or simulation_params.predict_final))
|
| 109 |
+
|
| 110 |
+
# Set parameters
|
| 111 |
+
self.train_dataset = None
|
| 112 |
+
self.val_dataset = None
|
| 113 |
+
self.test_dataset = None
|
| 114 |
+
self.virtual_nodes = virtual_nodes
|
| 115 |
+
self.flexible = flexible
|
| 116 |
+
self.flexible_bb = flexible_bb
|
| 117 |
+
self.debug = debug
|
| 118 |
+
self.overfit = overfit
|
| 119 |
+
self.predict_confidence = simulation_params.predict_confidence
|
| 120 |
+
|
| 121 |
+
if self.virtual_nodes:
|
| 122 |
+
self.add_virtual_min = virtual_nodes[0]
|
| 123 |
+
self.add_virtual_max = virtual_nodes[1]
|
| 124 |
+
|
| 125 |
+
# Training parameters
|
| 126 |
+
self.datadir = train_params.datadir
|
| 127 |
+
self.receptor_dir = train_params.datadir
|
| 128 |
+
self.batch_size = train_params.batch_size
|
| 129 |
+
self.lr = train_params.lr
|
| 130 |
+
self.lr_step_size = train_params.lr_step_size
|
| 131 |
+
self.lr_gamma = train_params.lr_gamma
|
| 132 |
+
self.num_workers = train_params.num_workers
|
| 133 |
+
self.sample_from_clusters = train_params.sample_from_clusters
|
| 134 |
+
self.sharded_dataset = train_params.sharded_dataset
|
| 135 |
+
self.clip_grad = train_params.clip_grad
|
| 136 |
+
if self.clip_grad:
|
| 137 |
+
self.gradnorm_queue = utils.Queue()
|
| 138 |
+
# Add large value that will be flushed.
|
| 139 |
+
self.gradnorm_queue.add(3000)
|
| 140 |
+
|
| 141 |
+
# Evaluation parameters
|
| 142 |
+
self.outdir = eval_params.outdir
|
| 143 |
+
self.eval_batch_size = eval_params.eval_batch_size
|
| 144 |
+
self.eval_epochs = eval_params.eval_epochs
|
| 145 |
+
# assert eval_params.visualize_sample_epoch % self.eval_epochs == 0
|
| 146 |
+
self.visualize_sample_epoch = eval_params.visualize_sample_epoch
|
| 147 |
+
self.visualize_chain_epoch = eval_params.visualize_chain_epoch
|
| 148 |
+
self.sample_with_ground_truth_size = eval_params.sample_with_ground_truth_size
|
| 149 |
+
self.n_loss_per_sample = eval_params.n_loss_per_sample
|
| 150 |
+
self.n_eval_samples = eval_params.n_eval_samples
|
| 151 |
+
self.n_visualize_samples = eval_params.n_visualize_samples
|
| 152 |
+
self.keep_frames = eval_params.keep_frames
|
| 153 |
+
self.gnina = train_params.gnina
|
| 154 |
+
|
| 155 |
+
# Feature encoders/decoders
|
| 156 |
+
self.atom_encoder = atom_encoder
|
| 157 |
+
self.atom_decoder = atom_decoder
|
| 158 |
+
self.bond_encoder = bond_encoder
|
| 159 |
+
self.bond_decoder = bond_decoder
|
| 160 |
+
self.aa_encoder = aa_encoder
|
| 161 |
+
self.aa_decoder = aa_decoder
|
| 162 |
+
self.residue_encoder = residue_encoder
|
| 163 |
+
self.residue_decoder = residue_decoder
|
| 164 |
+
self.residue_bond_encoder = residue_bond_encoder
|
| 165 |
+
self.residue_bond_decoder = residue_bond_decoder
|
| 166 |
+
|
| 167 |
+
self.atom_nf = len(self.atom_decoder)
|
| 168 |
+
self.residue_nf = len(self.aa_decoder)
|
| 169 |
+
if self.pocket_representation == 'side_chain_bead':
|
| 170 |
+
self.residue_nf += len(self.residue_encoder)
|
| 171 |
+
if self.pocket_representation == 'CA+':
|
| 172 |
+
self.aa_atom_index = aa_atom_index
|
| 173 |
+
self.n_atom_aa = max([x for aa in aa_atom_index.values() for x in aa.values()]) + 1
|
| 174 |
+
self.residue_nf = (self.residue_nf, self.n_atom_aa) # (s, V)
|
| 175 |
+
self.bond_nf = len(self.bond_decoder)
|
| 176 |
+
self.pocket_bond_nf = len(self.residue_bond_decoder)
|
| 177 |
+
self.x_dim = 3
|
| 178 |
+
|
| 179 |
+
# Set up the neural network
|
| 180 |
+
self.dynamics = self.init_model(predictor_params)
|
| 181 |
+
|
| 182 |
+
# Initialize objects for each variable type
|
| 183 |
+
if simulation_params.predict_final:
|
| 184 |
+
self.module_x = CoordICFMPredictFinal(None)
|
| 185 |
+
self.module_chi = TorusICFMPredictFinal(None, 5) if self.flexible else None
|
| 186 |
+
if self.flexible_bb:
|
| 187 |
+
raise NotImplementedError()
|
| 188 |
+
else:
|
| 189 |
+
self.module_x = CoordICFM(None)
|
| 190 |
+
# self.module_chi = AngleICFM(None, 5) if self.flexible else None
|
| 191 |
+
scheduler_args = None if simulation_params.scheduler_chi is None else vars(simulation_params.scheduler_chi)
|
| 192 |
+
self.module_chi = TorusICFM(None, 5, scheduler_args) if self.flexible else None
|
| 193 |
+
self.module_trans = CoordICFM(None) if self.flexible_bb else None
|
| 194 |
+
self.module_rot = SO3ICFM(None) if self.flexible_bb else None
|
| 195 |
+
|
| 196 |
+
if simulation_params.prior_h == 'uniform':
|
| 197 |
+
self.module_h = UniformPriorMarkovBridge(self.atom_nf, loss_type=loss_params.discrete_loss)
|
| 198 |
+
elif simulation_params.prior_h == 'marginal':
|
| 199 |
+
self.register_buffer('prior_h', self.get_categorical_prop('atom')) # add to module
|
| 200 |
+
self.module_h = MarginalPriorMarkovBridge(self.atom_nf, self.prior_h, loss_type=loss_params.discrete_loss)
|
| 201 |
+
|
| 202 |
+
if simulation_params.prior_e == 'uniform':
|
| 203 |
+
self.module_e = UniformPriorMarkovBridge(self.bond_nf, loss_type=loss_params.discrete_loss)
|
| 204 |
+
elif simulation_params.prior_e == 'marginal':
|
| 205 |
+
self.register_buffer('prior_e', self.get_categorical_prop('bond')) # add to module
|
| 206 |
+
self.module_e = MarginalPriorMarkovBridge(self.bond_nf, self.prior_e, loss_type=loss_params.discrete_loss)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# Loss parameters
|
| 210 |
+
self.loss_reduce = loss_params.reduce
|
| 211 |
+
self.lambda_x = loss_params.lambda_x
|
| 212 |
+
self.lambda_h = loss_params.lambda_h
|
| 213 |
+
self.lambda_e = loss_params.lambda_e
|
| 214 |
+
self.lambda_chi = loss_params.lambda_chi if self.flexible else None
|
| 215 |
+
self.lambda_trans = loss_params.lambda_trans if self.flexible_bb else None
|
| 216 |
+
self.lambda_rot = loss_params.lambda_rot if self.flexible_bb else None
|
| 217 |
+
self.lambda_clash = loss_params.lambda_clash
|
| 218 |
+
self.regularize_uncertainty = loss_params.regularize_uncertainty
|
| 219 |
+
|
| 220 |
+
if loss_params.timestep_weights is not None:
|
| 221 |
+
weight_type = loss_params.timestep_weights.split('_')[0]
|
| 222 |
+
kwargs = loss_params.timestep_weights.split('_')[1:]
|
| 223 |
+
kwargs = {x.split('=')[0]: float(x.split('=')[1]) for x in kwargs}
|
| 224 |
+
self.timestep_weights = TimestepWeights(weight_type, **kwargs)
|
| 225 |
+
else:
|
| 226 |
+
self.timestep_weights = None
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
# Sampling
|
| 230 |
+
self.T_sampling = eval_params.n_sampling_steps
|
| 231 |
+
self.train_step_size = 1 / simulation_params.n_steps
|
| 232 |
+
self.size_distribution = None # initialized only if needed
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# Metrics, initialized only if needed
|
| 236 |
+
self.train_smiles = None
|
| 237 |
+
self.ligand_metrics = None
|
| 238 |
+
self.molecule_properties = None
|
| 239 |
+
self.evaluator = None
|
| 240 |
+
self.ligand_atom_type_distribution = None
|
| 241 |
+
self.ligand_bond_type_distribution = None
|
| 242 |
+
|
| 243 |
+
# containers for metric aggregation
|
| 244 |
+
self.training_step_outputs = []
|
| 245 |
+
self.validation_step_outputs = []
|
| 246 |
+
|
| 247 |
+
def on_load_checkpoint(self, checkpoint):
|
| 248 |
+
"""
|
| 249 |
+
This hook is only used for backward compatibility with checkpoints that
|
| 250 |
+
did not save prior_h and prior_e in state_dict in the past
|
| 251 |
+
"""
|
| 252 |
+
if hasattr(self, "prior_h") and "prior_h" not in checkpoint["state_dict"]:
|
| 253 |
+
checkpoint["state_dict"]["prior_h"] = self.get_categorical_prop('atom')
|
| 254 |
+
if hasattr(self, "prior_e") and "prior_e" not in checkpoint["state_dict"]:
|
| 255 |
+
checkpoint["state_dict"]["prior_e"] = self.get_categorical_prop('bond')
|
| 256 |
+
if "prior_e" in checkpoint["state_dict"] and not hasattr(self, "prior_e"):
|
| 257 |
+
# NOTE: a very exotic case that happened to one model. Potentially can be removed in the future
|
| 258 |
+
self.register_buffer("prior_e", self.get_categorical_prop('bond'))
|
| 259 |
+
|
| 260 |
+
def init_model(self, predictor_params):
|
| 261 |
+
|
| 262 |
+
model_type = predictor_params.backbone
|
| 263 |
+
|
| 264 |
+
if 'heterogeneous_graph' in predictor_params and predictor_params.heterogeneous_graph:
|
| 265 |
+
return DynamicsHetero(
|
| 266 |
+
atom_nf=self.atom_nf,
|
| 267 |
+
residue_nf=self.residue_nf,
|
| 268 |
+
bond_dict=self.bond_encoder,
|
| 269 |
+
pocket_bond_dict=self.residue_bond_encoder,
|
| 270 |
+
model=model_type,
|
| 271 |
+
num_rbf_time=predictor_params.__dict__.get('num_rbf_time'),
|
| 272 |
+
model_params=getattr(predictor_params, model_type + '_params'),
|
| 273 |
+
edge_cutoff_ligand=predictor_params.edge_cutoff_ligand,
|
| 274 |
+
edge_cutoff_pocket=predictor_params.edge_cutoff_pocket,
|
| 275 |
+
edge_cutoff_interaction=predictor_params.edge_cutoff_interaction,
|
| 276 |
+
predict_angles=self.flexible,
|
| 277 |
+
predict_frames=self.flexible_bb,
|
| 278 |
+
add_cycle_counts=predictor_params.cycle_counts,
|
| 279 |
+
add_spectral_feat=predictor_params.spectral_feat,
|
| 280 |
+
add_nma_feat=predictor_params.normal_modes,
|
| 281 |
+
reflection_equiv=predictor_params.reflection_equivariant,
|
| 282 |
+
d_max=predictor_params.d_max,
|
| 283 |
+
num_rbf_dist=predictor_params.num_rbf,
|
| 284 |
+
self_conditioning=predictor_params.self_conditioning,
|
| 285 |
+
augment_residue_sc=self.augment_residue_sc,
|
| 286 |
+
augment_ligand_sc=self.augment_ligand_sc,
|
| 287 |
+
add_chi_as_feature=predictor_params.add_chi_as_feature,
|
| 288 |
+
angle_act_fn=predictor_params.angle_act_fn,
|
| 289 |
+
add_all_atom_diff=predictor_params.add_all_atom_diff,
|
| 290 |
+
predict_confidence=self.predict_confidence,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
else:
|
| 294 |
+
if predictor_params.__dict__.get('num_rbf_time') is not None:
|
| 295 |
+
raise NotImplementedError("RBF time embedding not yet implemented")
|
| 296 |
+
|
| 297 |
+
return Dynamics(
|
| 298 |
+
atom_nf=self.atom_nf,
|
| 299 |
+
residue_nf=self.residue_nf,
|
| 300 |
+
joint_nf=predictor_params.joint_nf,
|
| 301 |
+
bond_dict=self.bond_encoder,
|
| 302 |
+
pocket_bond_dict=self.residue_bond_encoder,
|
| 303 |
+
edge_nf=predictor_params.edge_nf,
|
| 304 |
+
hidden_nf=predictor_params.hidden_nf,
|
| 305 |
+
model=model_type,
|
| 306 |
+
model_params=getattr(predictor_params, model_type + '_params'),
|
| 307 |
+
edge_cutoff_ligand=predictor_params.edge_cutoff_ligand,
|
| 308 |
+
edge_cutoff_pocket=predictor_params.edge_cutoff_pocket,
|
| 309 |
+
edge_cutoff_interaction=predictor_params.edge_cutoff_interaction,
|
| 310 |
+
predict_angles=self.flexible,
|
| 311 |
+
predict_frames=self.flexible_bb,
|
| 312 |
+
add_cycle_counts=predictor_params.cycle_counts,
|
| 313 |
+
add_spectral_feat=predictor_params.spectral_feat,
|
| 314 |
+
add_nma_feat=predictor_params.normal_modes,
|
| 315 |
+
self_conditioning=predictor_params.self_conditioning,
|
| 316 |
+
augment_residue_sc=self.augment_residue_sc,
|
| 317 |
+
augment_ligand_sc=self.augment_ligand_sc,
|
| 318 |
+
add_chi_as_feature=predictor_params.add_chi_as_feature,
|
| 319 |
+
angle_act_fn=predictor_params.angle_act_fn,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
def _load_histogram(self, type):
|
| 323 |
+
"""
|
| 324 |
+
Load empirical categorical distributions of atom or bond types from disk.
|
| 325 |
+
Returns None if the required file is not found.
|
| 326 |
+
"""
|
| 327 |
+
assert type in {"atom", "bond"}
|
| 328 |
+
filename = 'ligand_type_histogram.npy' if type == 'atom' else 'ligand_bond_type_histogram.npy'
|
| 329 |
+
encoder = self.atom_encoder if type == 'atom' else self.bond_encoder
|
| 330 |
+
hist_file = Path(self.datadir, filename)
|
| 331 |
+
if not hist_file.exists():
|
| 332 |
+
return None
|
| 333 |
+
hist = np.load(hist_file, allow_pickle=True).item()
|
| 334 |
+
return CategoricalDistribution(hist, encoder)
|
| 335 |
+
|
| 336 |
+
def get_categorical_prop(self, type):
|
| 337 |
+
hist = self._load_histogram(type)
|
| 338 |
+
encoder = self.atom_encoder if type == 'atom' else self.bond_encoder
|
| 339 |
+
# Note: default value ensures that code will crash if prior is not
|
| 340 |
+
# read from disk or loaded from checkpoint later on
|
| 341 |
+
return torch.zeros(len(encoder)) * float("nan") if hist is None else torch.tensor(hist.p)
|
| 342 |
+
|
| 343 |
+
def configure_optimizers(self):
|
| 344 |
+
optimizers = [
|
| 345 |
+
torch.optim.AdamW(self.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12),
|
| 346 |
+
]
|
| 347 |
+
|
| 348 |
+
if self.lr_step_size is None or self.lr_gamma is None:
|
| 349 |
+
lr_schedulers = []
|
| 350 |
+
else:
|
| 351 |
+
lr_schedulers = [
|
| 352 |
+
torch.optim.lr_scheduler.StepLR(optimizers[0], step_size=self.lr_step_size, gamma=self.lr_gamma),
|
| 353 |
+
]
|
| 354 |
+
return optimizers, lr_schedulers
|
| 355 |
+
|
| 356 |
+
def setup(self, stage: Optional[str] = None):
|
| 357 |
+
|
| 358 |
+
self.setup_sampling()
|
| 359 |
+
|
| 360 |
+
if stage == 'fit':
|
| 361 |
+
self.train_dataset = self.get_dataset(stage='train')
|
| 362 |
+
self.val_dataset = self.get_dataset(stage='val')
|
| 363 |
+
self.setup_metrics()
|
| 364 |
+
elif stage == 'val':
|
| 365 |
+
self.val_dataset = self.get_dataset(stage='val')
|
| 366 |
+
self.setup_metrics()
|
| 367 |
+
elif stage == 'test':
|
| 368 |
+
self.test_dataset = self.get_dataset(stage='test')
|
| 369 |
+
self.setup_metrics()
|
| 370 |
+
elif stage == 'generation':
|
| 371 |
+
pass
|
| 372 |
+
else:
|
| 373 |
+
raise NotImplementedError
|
| 374 |
+
|
| 375 |
+
def get_dataset(self, stage, pocket_transform=None):
|
| 376 |
+
|
| 377 |
+
# when sampling we don't append virtual nodes as we might need access to the ground truth size
|
| 378 |
+
if self.virtual_nodes and stage == "train":
|
| 379 |
+
ligand_transform = AppendVirtualNodesInCoM(
|
| 380 |
+
atom_encoder, bond_encoder, add_min=self.add_virtual_min, add_max=self.add_virtual_max)
|
| 381 |
+
else:
|
| 382 |
+
ligand_transform = None
|
| 383 |
+
|
| 384 |
+
# we want to know if something goes wrong on the validation or test set
|
| 385 |
+
catch_errors = stage == "train"
|
| 386 |
+
|
| 387 |
+
if self.sharded_dataset:
|
| 388 |
+
return get_wds(
|
| 389 |
+
data_path=self.datadir,
|
| 390 |
+
stage='val' if self.debug else stage,
|
| 391 |
+
ligand_transform=ligand_transform,
|
| 392 |
+
pocket_transform=pocket_transform,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
if self.sample_from_clusters and stage == "train": # val/test should be deterministic
|
| 396 |
+
return ClusteredDataset(
|
| 397 |
+
pt_path=Path(self.datadir, 'val.pt' if self.debug else f'{stage}.pt'),
|
| 398 |
+
ligand_transform=ligand_transform,
|
| 399 |
+
pocket_transform=pocket_transform,
|
| 400 |
+
catch_errors=catch_errors
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
return ProcessedLigandPocketDataset(
|
| 404 |
+
pt_path=Path(self.datadir, 'val.pt' if self.debug else f'{stage}.pt'),
|
| 405 |
+
ligand_transform=ligand_transform,
|
| 406 |
+
pocket_transform=pocket_transform,
|
| 407 |
+
catch_errors=catch_errors
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
def setup_sampling(self):
|
| 411 |
+
# distribution of nodes
|
| 412 |
+
histogram_file = Path(self.datadir, 'size_distribution.npy') # TODO: store this in model checkpoint so that we can sample without this file
|
| 413 |
+
size_histogram = np.load(histogram_file).tolist()
|
| 414 |
+
self.size_distribution = DistributionNodes(size_histogram)
|
| 415 |
+
|
| 416 |
+
def setup_metrics(self):
|
| 417 |
+
# For metrics
|
| 418 |
+
smiles_file = Path(self.datadir, 'train_smiles.npy')
|
| 419 |
+
self.train_smiles = None if not smiles_file.exists() else np.load(smiles_file)
|
| 420 |
+
|
| 421 |
+
self.ligand_metrics = MoleculeValidity()
|
| 422 |
+
self.molecule_properties = MolecularProperties()
|
| 423 |
+
self.evaluator = FullEvaluator(gnina=self.gnina, exclude_evaluators=['geometry', 'ring_count'])
|
| 424 |
+
self.ligand_atom_type_distribution = self._load_histogram('atom')
|
| 425 |
+
self.ligand_bond_type_distribution = self._load_histogram('bond')
|
| 426 |
+
|
| 427 |
+
def train_dataloader(self):
|
| 428 |
+
shuffle = None if self.overfit else False if self.sharded_dataset else True
|
| 429 |
+
return DataLoader(self.train_dataset, self.batch_size, shuffle=shuffle,
|
| 430 |
+
sampler=SubsetRandomSampler([0]) if self.overfit else None,
|
| 431 |
+
num_workers=self.num_workers,
|
| 432 |
+
collate_fn=self.train_dataset.collate_fn,
|
| 433 |
+
# collate_fn=partial(self.train_dataset.collate_fn, ligand_transform=batch_transform),
|
| 434 |
+
pin_memory=True)
|
| 435 |
+
|
| 436 |
+
def val_dataloader(self):
|
| 437 |
+
if self.overfit:
|
| 438 |
+
return self.train_dataloader()
|
| 439 |
+
|
| 440 |
+
return DataLoader(self.val_dataset, self.eval_batch_size,
|
| 441 |
+
shuffle=False, num_workers=self.num_workers,
|
| 442 |
+
collate_fn=self.val_dataset.collate_fn,
|
| 443 |
+
pin_memory=True)
|
| 444 |
+
|
| 445 |
+
def test_dataloader(self):
|
| 446 |
+
return DataLoader(self.test_dataset, self.eval_batch_size, shuffle=False,
|
| 447 |
+
num_workers=self.num_workers,
|
| 448 |
+
collate_fn=self.test_dataset.collate_fn,
|
| 449 |
+
pin_memory=True)
|
| 450 |
+
|
| 451 |
+
def log_metrics(self, metrics_dict, split, batch_size=None, **kwargs):
|
| 452 |
+
for m, value in metrics_dict.items():
|
| 453 |
+
self.log(f'{m}/{split}', value, batch_size=batch_size, **kwargs)
|
| 454 |
+
|
| 455 |
+
def aggregate_metrics(self, step_outputs, prefix):
|
| 456 |
+
if 'timestep' in step_outputs[0]:
|
| 457 |
+
timesteps = torch.cat([x['timestep'] for x in step_outputs]).squeeze()
|
| 458 |
+
|
| 459 |
+
if 'loss_per_sample' in step_outputs[0]:
|
| 460 |
+
losses = torch.cat([x['loss_per_sample'] for x in step_outputs])
|
| 461 |
+
pearson_corr = torch.corrcoef(torch.stack([timesteps, losses], dim=0))[0, 1]
|
| 462 |
+
self.log(f'corr_loss_timestep/{prefix}', pearson_corr, prog_bar=False)
|
| 463 |
+
|
| 464 |
+
if 'eps_hat_norm' in step_outputs[0]:
|
| 465 |
+
eps_norm = torch.cat([x['eps_hat_norm'] for x in step_outputs])
|
| 466 |
+
pearson_corr = torch.corrcoef(torch.stack([timesteps, eps_norm], dim=0))[0, 1]
|
| 467 |
+
self.log(f'corr_eps_timestep/{prefix}', pearson_corr, prog_bar=False)
|
| 468 |
+
|
| 469 |
+
def on_train_epoch_end(self):
|
| 470 |
+
self.aggregate_metrics(self.training_step_outputs, 'train')
|
| 471 |
+
self.training_step_outputs.clear()
|
| 472 |
+
|
| 473 |
+
# TODO: doesn't work in multi-GPU mode
|
| 474 |
+
# def on_before_batch_transfer(self, batch, dataloader_idx):
|
| 475 |
+
# """
|
| 476 |
+
# Performs operations on data before it is transferred to the GPU.
|
| 477 |
+
# Hence, supports multiple dataloaders for speedup.
|
| 478 |
+
# """
|
| 479 |
+
# batch['pocket'] = Residues(**batch['pocket'])
|
| 480 |
+
# return batch
|
| 481 |
+
|
| 482 |
+
# # TODO: try if this is compatible with DDP
|
| 483 |
+
# def on_after_batch_transfer(self, batch, dataloader_idx):
|
| 484 |
+
# """
|
| 485 |
+
# Performs operations on data after it is transferred to the GPU.
|
| 486 |
+
# """
|
| 487 |
+
# batch['pocket'] = Residues(**batch['pocket'])
|
| 488 |
+
# batch['ligand'] = TensorDict(**batch['ligand'])
|
| 489 |
+
# return batch
|
| 490 |
+
|
| 491 |
+
def get_sc_transform_fn(self, zt_chi, zt_x, t, z0_chi, ligand_mask, pocket):
|
| 492 |
+
sc_transform = {}
|
| 493 |
+
|
| 494 |
+
if self.augment_residue_sc:
|
| 495 |
+
def pred_all_atom(pred_chi, pred_trans=None, pred_rot=None):
|
| 496 |
+
temp_pocket = pocket.deepcopy()
|
| 497 |
+
|
| 498 |
+
if pred_trans is not None and pred_rot is not None:
|
| 499 |
+
zt_trans = pocket['x']
|
| 500 |
+
zt_rot = pocket['axis_angle']
|
| 501 |
+
z1_trans_pred = self.module_trans.get_z1_given_zt_and_pred(
|
| 502 |
+
zt_trans, pred_trans, None, t, pocket['mask'])
|
| 503 |
+
z1_rot_pred = self.module_rot.get_z1_given_zt_and_pred(
|
| 504 |
+
zt_rot, pred_rot, None, t, pocket['mask'])
|
| 505 |
+
temp_pocket.set_frame(z1_trans_pred, z1_rot_pred)
|
| 506 |
+
|
| 507 |
+
z1_chi_pred = self.module_chi.get_z1_given_zt_and_pred(
|
| 508 |
+
zt_chi[..., :5], pred_chi, z0_chi, t, pocket['mask'])
|
| 509 |
+
temp_pocket.set_chi(z1_chi_pred)
|
| 510 |
+
|
| 511 |
+
all_coord = temp_pocket['v'] + temp_pocket['x'].unsqueeze(1)
|
| 512 |
+
return all_coord - pocket['x'].unsqueeze(1)
|
| 513 |
+
|
| 514 |
+
sc_transform['residues'] = pred_all_atom
|
| 515 |
+
|
| 516 |
+
if self.augment_ligand_sc:
|
| 517 |
+
# sc_transform['atoms'] = partial(self.module_x.get_z1_given_zt_and_pred, zt=zs_x, z0=None, t=t, batch_mask=lig_mask)
|
| 518 |
+
sc_transform['atoms'] = lambda pred: (self.module_x.get_z1_given_zt_and_pred(
|
| 519 |
+
zt_x, pred.squeeze(1), None, t, ligand_mask) - zt_x).unsqueeze(1)
|
| 520 |
+
|
| 521 |
+
return sc_transform
|
| 522 |
+
|
| 523 |
+
def compute_loss(self, ligand, pocket, return_info=False):
|
| 524 |
+
"""
|
| 525 |
+
Samples time steps and computes network predictions
|
| 526 |
+
"""
|
| 527 |
+
# TODO: move somewhere else (like collate_fn)
|
| 528 |
+
pocket = Residues(**pocket)
|
| 529 |
+
|
| 530 |
+
# Center sample
|
| 531 |
+
ligand, pocket = center_data(ligand, pocket)
|
| 532 |
+
if pocket['x'].numel() > 0:
|
| 533 |
+
pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0)
|
| 534 |
+
else:
|
| 535 |
+
pocket_com = scatter_mean(ligand['x'], ligand['mask'], dim=0)
|
| 536 |
+
|
| 537 |
+
# # Normalize pocket coordinates
|
| 538 |
+
# pocket['x'] = self.module_x.normalize(pocket['x'])
|
| 539 |
+
|
| 540 |
+
# Sample a timestep t for each example in batch
|
| 541 |
+
t = torch.rand(ligand['size'].size(0), device=ligand['x'].device).unsqueeze(-1)
|
| 542 |
+
|
| 543 |
+
# Noise
|
| 544 |
+
z0_x = self.module_x.sample_z0(pocket_com, ligand['mask'])
|
| 545 |
+
z0_h = self.module_h.sample_z0(ligand['mask'])
|
| 546 |
+
z0_e = self.module_e.sample_z0(ligand['bond_mask'])
|
| 547 |
+
zt_x = self.module_x.sample_zt(z0_x, ligand['x'], t, ligand['mask'])
|
| 548 |
+
zt_h = self.module_h.sample_zt(z0_h, ligand['one_hot'], t, ligand['mask'])
|
| 549 |
+
zt_e = self.module_e.sample_zt(z0_e, ligand['bond_one_hot'], t, ligand['bond_mask'])
|
| 550 |
+
|
| 551 |
+
if self.flexible_bb:
|
| 552 |
+
z0_trans = self.module_trans.sample_z0(pocket_com, pocket['mask'])
|
| 553 |
+
z1_trans = pocket['x'].detach().clone()
|
| 554 |
+
zt_trans = self.module_trans.sample_zt(z0_trans, z1_trans, t, pocket['mask'])
|
| 555 |
+
|
| 556 |
+
z0_rot = self.module_rot.sample_z0(pocket['mask'])
|
| 557 |
+
z1_rot = pocket['axis_angle'].detach().clone()
|
| 558 |
+
zt_rot = self.module_rot.sample_zt(z0_rot, z1_rot, t, pocket['mask'])
|
| 559 |
+
|
| 560 |
+
# update pocket
|
| 561 |
+
pocket.set_frame(zt_trans, zt_rot)
|
| 562 |
+
|
| 563 |
+
z0_chi, zt_chi = None, None
|
| 564 |
+
if self.flexible:
|
| 565 |
+
# residues = [data_utils.residue_from_internal_coord(ic) for ic in pocket['residues']]
|
| 566 |
+
# residues = pocket['residues']
|
| 567 |
+
# z1_chi = torch.stack([data_utils.get_torsion_angles(r, device=self.device) for r in residues], dim=0)
|
| 568 |
+
z1_chi = pocket['chi'][:, :5].detach().clone()
|
| 569 |
+
|
| 570 |
+
z0_chi = self.module_chi.sample_z0(pocket['mask'])
|
| 571 |
+
zt_chi = self.module_chi.sample_zt(z0_chi, z1_chi, t, pocket['mask'])
|
| 572 |
+
|
| 573 |
+
# internal to external coordinates
|
| 574 |
+
pocket.set_chi(zt_chi)
|
| 575 |
+
if pocket['x'].numel() == 0:
|
| 576 |
+
pocket.set_empty_v()
|
| 577 |
+
|
| 578 |
+
# Predict denoising
|
| 579 |
+
sc_transform = self.get_sc_transform_fn(zt_chi, zt_x, t, z0_chi, ligand['mask'], pocket)
|
| 580 |
+
# sc_transform = None
|
| 581 |
+
pred_ligand, pred_residues = self.dynamics(
|
| 582 |
+
zt_x, zt_h, ligand['mask'], pocket, t,
|
| 583 |
+
bonds_ligand=(ligand['bonds'], zt_e), sc_transform=sc_transform
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
# Compute L2 loss
|
| 587 |
+
if self.predict_confidence:
|
| 588 |
+
loss_x = self.module_x.compute_loss(pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce='none')
|
| 589 |
+
|
| 590 |
+
# compute confidence regularization
|
| 591 |
+
k = self.module_x.dim # pred.size(-1)
|
| 592 |
+
sigma = pred_ligand['uncertainty_vel']
|
| 593 |
+
loss_x = loss_x / (2 * sigma ** 2) + k * torch.log(sigma)
|
| 594 |
+
|
| 595 |
+
if self.regularize_uncertainty is not None:
|
| 596 |
+
loss_x = loss_x + self.regularize_uncertainty * (pred_ligand['uncertainty_vel'] - 1) ** 2
|
| 597 |
+
|
| 598 |
+
loss_x = self.module_x.reduce_loss(loss_x, ligand['mask'], reduce=self.loss_reduce)
|
| 599 |
+
else:
|
| 600 |
+
loss_x = self.module_x.compute_loss(pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce=self.loss_reduce)
|
| 601 |
+
|
| 602 |
+
# Loss for categorical variables
|
| 603 |
+
t_next = torch.clamp(t + self.train_step_size, max=1.0)
|
| 604 |
+
loss_h = self.module_h.compute_loss(pred_ligand['logits_h'], zt_h, ligand['one_hot'], ligand['mask'], t, t_next, reduce=self.loss_reduce)
|
| 605 |
+
loss_e = self.module_e.compute_loss(pred_ligand['logits_e'], zt_e, ligand['bond_one_hot'], ligand['bond_mask'], t, t_next, reduce=self.loss_reduce)
|
| 606 |
+
|
| 607 |
+
loss = self.lambda_x * loss_x + self.lambda_h * loss_h + self.lambda_e * loss_e
|
| 608 |
+
if self.flexible:
|
| 609 |
+
loss_chi = self.module_chi.compute_loss(pred_residues['chi'], z0_chi, z1_chi, zt_chi, t, pocket['mask'], reduce=self.loss_reduce)
|
| 610 |
+
loss = loss + self.lambda_chi * loss_chi
|
| 611 |
+
|
| 612 |
+
if self.flexible_bb:
|
| 613 |
+
loss_trans = self.module_trans.compute_loss(pred_residues['trans'], z0_trans, z1_trans, t, pocket['mask'], reduce=self.loss_reduce)
|
| 614 |
+
loss_rot = self.module_rot.compute_loss(pred_residues['rot'], z0_rot, z1_rot, zt_rot, t, pocket['mask'], reduce=self.loss_reduce)
|
| 615 |
+
loss = loss + self.lambda_trans * loss_trans + self.lambda_rot * loss_rot
|
| 616 |
+
|
| 617 |
+
if self.lambda_clash is not None and self.lambda_clash > 0:
|
| 618 |
+
|
| 619 |
+
if self.flexible_bb:
|
| 620 |
+
pred_z1_trans = self.module_trans.get_z1_given_zt_and_pred(zt_trans, pred_residues['trans'], z0_trans, t, pocket['mask'])
|
| 621 |
+
pred_z1_rot = self.module_rot.get_z1_given_zt_and_pred(zt_rot, pred_residues['rot'], z0_rot, t, pocket['mask'])
|
| 622 |
+
pocket.set_frame(pred_z1_trans, pred_z1_rot)
|
| 623 |
+
|
| 624 |
+
if self.flexible:
|
| 625 |
+
# internal to external coordinates
|
| 626 |
+
pred_z1_chi = self.module_chi.get_z1_given_zt_and_pred(zt_chi, pred_residues['chi'], z0_chi, t, pocket['mask'])
|
| 627 |
+
pocket.set_chi(pred_z1_chi)
|
| 628 |
+
|
| 629 |
+
pocket_coord = pocket['x'].unsqueeze(1) + pocket['v']
|
| 630 |
+
pocket_types = aa_atom_type_tensor[pocket['one_hot'].argmax(dim=-1)]
|
| 631 |
+
pocket_mask = pocket['mask'].unsqueeze(-1).repeat((1, pocket['v'].size(1)))
|
| 632 |
+
|
| 633 |
+
# Extract only existing atoms
|
| 634 |
+
atom_mask = aa_atom_mask_tensor[pocket['one_hot'].argmax(dim=-1)]
|
| 635 |
+
pocket_coord = pocket_coord[atom_mask]
|
| 636 |
+
pocket_types = pocket_types[atom_mask]
|
| 637 |
+
pocket_mask = pocket_mask[atom_mask]
|
| 638 |
+
|
| 639 |
+
# pred_z1_x = pred_x + z0_x
|
| 640 |
+
pred_z1_x = self.module_x.get_z1_given_zt_and_pred(zt_x, pred_ligand['vel'], z0_x, t, ligand['mask'])
|
| 641 |
+
pred_z1_h = pred_ligand['logits_h'].argmax(dim=-1)
|
| 642 |
+
loss_clash = clash_loss(pred_z1_x, pred_z1_h, ligand['mask'],
|
| 643 |
+
pocket_coord, pocket_types, pocket_mask)
|
| 644 |
+
loss = loss + self.lambda_clash * loss_clash
|
| 645 |
+
|
| 646 |
+
if self.timestep_weights is not None:
|
| 647 |
+
w_t = self.timestep_weights(t).squeeze()
|
| 648 |
+
loss = w_t * loss
|
| 649 |
+
|
| 650 |
+
loss = loss.mean(0)
|
| 651 |
+
|
| 652 |
+
info = {
|
| 653 |
+
'loss_x': loss_x.mean().item(),
|
| 654 |
+
'loss_h': loss_h.mean().item(),
|
| 655 |
+
'loss_e': loss_e.mean().item(),
|
| 656 |
+
}
|
| 657 |
+
if self.flexible:
|
| 658 |
+
info['loss_chi'] = loss_chi.mean().item()
|
| 659 |
+
if self.flexible_bb:
|
| 660 |
+
info['loss_trans'] = loss_trans.mean().item()
|
| 661 |
+
info['loss_rot'] = loss_rot.mean().item()
|
| 662 |
+
if self.lambda_clash is not None:
|
| 663 |
+
info['loss_clash'] = loss_clash.mean().item()
|
| 664 |
+
if self.predict_confidence:
|
| 665 |
+
sigma_x_mol = scatter_mean(pred_ligand['uncertainty_vel'], ligand['mask'], dim=0)
|
| 666 |
+
info['pearson_sigma_x'] = torch.corrcoef(torch.stack([sigma_x_mol.detach(), t.squeeze()]))[0, 1].item()
|
| 667 |
+
info['mean_sigma_x'] = sigma_x_mol.mean().item()
|
| 668 |
+
entropy_h = Categorical(logits=pred_ligand['logits_h']).entropy()
|
| 669 |
+
entropy_h_mol = scatter_mean(entropy_h, ligand['mask'], dim=0)
|
| 670 |
+
info['pearson_entropy_h'] = torch.corrcoef(torch.stack([entropy_h_mol.detach(), t.squeeze()]))[0, 1].item()
|
| 671 |
+
info['mean_entropy_h'] = entropy_h_mol.mean().item()
|
| 672 |
+
entropy_e = Categorical(logits=pred_ligand['logits_e']).entropy()
|
| 673 |
+
entropy_e_mol = scatter_mean(entropy_e, ligand['bond_mask'], dim=0)
|
| 674 |
+
info['pearson_entropy_e'] = torch.corrcoef(torch.stack([entropy_e_mol.detach(), t.squeeze()]))[0, 1].item()
|
| 675 |
+
info['mean_entropy_e'] = entropy_e_mol.mean().item()
|
| 676 |
+
|
| 677 |
+
return (loss, info) if return_info else loss
|
| 678 |
+
|
| 679 |
+
def training_step(self, data, *args):
|
| 680 |
+
ligand, pocket = data['ligand'], data['pocket']
|
| 681 |
+
try:
|
| 682 |
+
loss, info = self.compute_loss(ligand, pocket, return_info=True)
|
| 683 |
+
except RuntimeError as e:
|
| 684 |
+
# this is not supported for multi-GPU
|
| 685 |
+
if self.trainer.num_devices < 2 and 'out of memory' in str(e):
|
| 686 |
+
print('WARNING: ran out of memory, skipping to the next batch')
|
| 687 |
+
return None
|
| 688 |
+
else:
|
| 689 |
+
raise e
|
| 690 |
+
|
| 691 |
+
log_dict = {k: v for k, v in info.items() if isinstance(v, float)
|
| 692 |
+
or torch.numel(v) <= 1}
|
| 693 |
+
# if self.learn_nu:
|
| 694 |
+
# log_dict['nu_x'] = self.noise_schedules['x'].nu.item()
|
| 695 |
+
# log_dict['nu_h'] = self.noise_schedules['h'].nu.item()
|
| 696 |
+
# log_dict['nu_e'] = self.noise_schedules['e'].nu.item()
|
| 697 |
+
|
| 698 |
+
self.log_metrics({'loss': loss, **log_dict}, 'train',
|
| 699 |
+
batch_size=len(ligand['size']))
|
| 700 |
+
|
| 701 |
+
out = {'loss': loss, **info}
|
| 702 |
+
self.training_step_outputs.append(out)
|
| 703 |
+
return out
|
| 704 |
+
|
| 705 |
+
def validation_step(self, data, *args):
|
| 706 |
+
|
| 707 |
+
# Compute the loss N times and average to get a better estimate
|
| 708 |
+
loss_list, info_list = [], []
|
| 709 |
+
self.dynamics.train() # TODO: this is currently necessary to make self-conditioning work
|
| 710 |
+
for _ in range(self.n_loss_per_sample):
|
| 711 |
+
loss, info = self.compute_loss(data['ligand'].copy(),
|
| 712 |
+
data['pocket'].copy(),
|
| 713 |
+
return_info=True)
|
| 714 |
+
loss_list.append(loss.item())
|
| 715 |
+
info_list.append(info)
|
| 716 |
+
self.dynamics.eval()
|
| 717 |
+
if len(loss_list) >= 1:
|
| 718 |
+
loss = np.mean(loss_list)
|
| 719 |
+
info = {k: np.mean([x[k] for x in info_list]) for k in info_list[0]}
|
| 720 |
+
self.log_metrics({'loss': loss, **info}, 'val', batch_size=len(data['ligand']['size']))
|
| 721 |
+
|
| 722 |
+
# Sample
|
| 723 |
+
rdmols, rdpockets, _ = self.sample(
|
| 724 |
+
data=data,
|
| 725 |
+
n_samples=self.n_eval_samples,
|
| 726 |
+
num_nodes="ground_truth" if self.sample_with_ground_truth_size else None,
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
out = {
|
| 730 |
+
'ligands': rdmols,
|
| 731 |
+
'pockets': rdpockets,
|
| 732 |
+
'receptor_files': [Path(self.receptor_dir, 'val', x) for x in data['pocket']['name']]
|
| 733 |
+
}
|
| 734 |
+
self.validation_step_outputs.append(out)
|
| 735 |
+
return out
|
| 736 |
+
|
| 737 |
+
# def test_step(self, data, *args):
|
| 738 |
+
# self._shared_eval(data, 'test', *args)
|
| 739 |
+
|
| 740 |
+
def on_validation_epoch_end(self):
|
| 741 |
+
|
| 742 |
+
outdir = Path(self.outdir, f'epoch_{self.current_epoch}')
|
| 743 |
+
|
| 744 |
+
rdmols = [m for x in self.validation_step_outputs for m in x['ligands']]
|
| 745 |
+
rdpockets = [p for x in self.validation_step_outputs for p in x['pockets']]
|
| 746 |
+
receptors = [r for x in self.validation_step_outputs for r in x['receptor_files']]
|
| 747 |
+
self.validation_step_outputs.clear()
|
| 748 |
+
|
| 749 |
+
ligand_atom_types = [atom_encoder[a.GetSymbol()] for m in rdmols for a in m.GetAtoms()]
|
| 750 |
+
ligand_bond_types = []
|
| 751 |
+
for m in rdmols:
|
| 752 |
+
bonds = m.GetBonds()
|
| 753 |
+
no_bonds = m.GetNumAtoms() * (m.GetNumAtoms() - 1) // 2 - m.GetNumBonds()
|
| 754 |
+
ligand_bond_types += [bond_encoder['NOBOND']] * no_bonds
|
| 755 |
+
for b in bonds:
|
| 756 |
+
ligand_bond_types.append(bond_encoder[b.GetBondType().name])
|
| 757 |
+
|
| 758 |
+
tic = time()
|
| 759 |
+
results = self.analyze_sample(
|
| 760 |
+
rdmols, ligand_atom_types, ligand_bond_types, receptors=(rdpockets if len(rdpockets) != 0 else None)
|
| 761 |
+
)
|
| 762 |
+
self.log_metrics(results, 'val')
|
| 763 |
+
print(f'Evaluation took {time() - tic:.2f} seconds')
|
| 764 |
+
|
| 765 |
+
if (self.current_epoch + 1) % self.visualize_sample_epoch == 0:
|
| 766 |
+
tic = time()
|
| 767 |
+
|
| 768 |
+
outdir.mkdir(exist_ok=True, parents=True)
|
| 769 |
+
|
| 770 |
+
# center for better visualization
|
| 771 |
+
rdmols = rdmols[:self.n_visualize_samples]
|
| 772 |
+
rdpockets = rdpockets[:self.n_visualize_samples]
|
| 773 |
+
for m, p in zip(rdmols, rdpockets):
|
| 774 |
+
center = m.GetConformer().GetPositions().mean(axis=0)
|
| 775 |
+
for i in range(m.GetNumAtoms()):
|
| 776 |
+
x, y, z = m.GetConformer().GetPositions()[i] - center
|
| 777 |
+
m.GetConformer().SetAtomPosition(i, (x, y, z))
|
| 778 |
+
for i in range(p.GetNumAtoms()):
|
| 779 |
+
x, y, z = p.GetConformer().GetPositions()[i] - center
|
| 780 |
+
p.GetConformer().SetAtomPosition(i, (x, y, z))
|
| 781 |
+
|
| 782 |
+
# save molecule
|
| 783 |
+
utils.write_sdf_file(Path(outdir, 'molecules.sdf'), rdmols)
|
| 784 |
+
|
| 785 |
+
# save pocket
|
| 786 |
+
utils.write_sdf_file(Path(outdir, 'pockets.sdf'), rdpockets)
|
| 787 |
+
|
| 788 |
+
print(f'Sample visualization took {time() - tic:.2f} seconds')
|
| 789 |
+
|
| 790 |
+
if (self.current_epoch + 1) % self.visualize_chain_epoch == 0:
|
| 791 |
+
tic = time()
|
| 792 |
+
outdir.mkdir(exist_ok=True, parents=True)
|
| 793 |
+
|
| 794 |
+
if self.sharded_dataset:
|
| 795 |
+
index = torch.randint(len(self.val_dataset), size=(1,)).item()
|
| 796 |
+
for i, x in enumerate(self.val_dataset):
|
| 797 |
+
if i == index:
|
| 798 |
+
break
|
| 799 |
+
batch = self.val_dataset.collate_fn([x])
|
| 800 |
+
else:
|
| 801 |
+
batch = self.val_dataset.collate_fn([self.val_dataset[torch.randint(len(self.val_dataset), size=(1,))]])
|
| 802 |
+
batch['pocket'] = Residues(**batch['pocket']).to(self.device)
|
| 803 |
+
pocket_copy = batch['pocket'].copy()
|
| 804 |
+
|
| 805 |
+
if len(batch['pocket']['x']) > 0:
|
| 806 |
+
ligand_chain, pocket_chain, info = self.sample_chain(batch['pocket'], self.keep_frames)
|
| 807 |
+
else:
|
| 808 |
+
num_nodes, _ = self.size_distribution.sample()
|
| 809 |
+
ligand_chain, pocket_chain, info = self.sample_chain(batch['pocket'], self.keep_frames, num_nodes=num_nodes)
|
| 810 |
+
|
| 811 |
+
# utils.write_sdf_file(Path(outdir, 'chain_pocket.sdf'), pocket_chain)
|
| 812 |
+
# utils.write_chain(Path(outdir, 'chain_pocket.xyz'), pocket_chain)
|
| 813 |
+
if self.flexible or self.flexible_bb:
|
| 814 |
+
# insert ground truth at the beginning so that it's used by PyMOL to determine the connectivity
|
| 815 |
+
ground_truth_pocket = pocket_to_rdkit(
|
| 816 |
+
pocket_copy, self.pocket_representation,
|
| 817 |
+
self.atom_encoder, self.atom_decoder,
|
| 818 |
+
self.aa_decoder, self.residue_decoder,
|
| 819 |
+
self.aa_atom_index
|
| 820 |
+
)[0]
|
| 821 |
+
ground_truth_ligand = build_molecule(
|
| 822 |
+
batch['ligand']['x'], batch['ligand']['one_hot'].argmax(1),
|
| 823 |
+
bonds=batch['ligand']['bonds'],
|
| 824 |
+
bond_types=batch['ligand']['bond_one_hot'].argmax(1),
|
| 825 |
+
atom_decoder=self.atom_decoder,
|
| 826 |
+
bond_decoder=self.bond_decoder
|
| 827 |
+
)
|
| 828 |
+
pocket_chain.insert(0, ground_truth_pocket)
|
| 829 |
+
ligand_chain.insert(0, ground_truth_ligand)
|
| 830 |
+
# pocket_chain.insert(0, pocket_chain[-1])
|
| 831 |
+
# ligand_chain.insert(0, ligand_chain[-1])
|
| 832 |
+
|
| 833 |
+
# save molecules
|
| 834 |
+
utils.write_sdf_file(Path(outdir, 'chain_ligand.sdf'), ligand_chain)
|
| 835 |
+
|
| 836 |
+
# save pocket
|
| 837 |
+
mols_to_pdbfile(pocket_chain, Path(outdir, 'chain_pocket.pdb'))
|
| 838 |
+
|
| 839 |
+
self.log_metrics(info, 'val')
|
| 840 |
+
print(f'Chain visualization took {time() - tic:.2f} seconds')
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
# NOTE: temporary fix of this Lightning bug:
|
| 844 |
+
# https://github.com/Lightning-AI/pytorch-lightning/discussions/18110
|
| 845 |
+
# Without it resume training has a strange behavior and fails
|
| 846 |
+
@property
|
| 847 |
+
def total_batch_idx(self) -> int:
|
| 848 |
+
"""Returns the current batch index (across epochs)"""
|
| 849 |
+
# use `ready` instead of `completed` in case this is accessed after `completed` has been increased
|
| 850 |
+
# but before the next `ready` increase
|
| 851 |
+
return max(0, self.batch_progress.total.ready - 1)
|
| 852 |
+
|
| 853 |
+
@property
|
| 854 |
+
def batch_idx(self) -> int:
|
| 855 |
+
"""Returns the current batch index (within this epoch)"""
|
| 856 |
+
# use `ready` instead of `completed` in case this is accessed after `completed` has been increased
|
| 857 |
+
# but before the next `ready` increase
|
| 858 |
+
return max(0, self.batch_progress.current.ready - 1)
|
| 859 |
+
|
| 860 |
+
# def analyze_sample(self, rdmols, atom_types, bond_types, aa_types=None, receptors=None):
|
| 861 |
+
# out = {}
|
| 862 |
+
|
| 863 |
+
# # Distribution of node types
|
| 864 |
+
# kl_div_atom = self.ligand_atom_type_distribution.kl_divergence(atom_types) \
|
| 865 |
+
# if self.ligand_atom_type_distribution is not None else -1
|
| 866 |
+
# out['kl_div_atom_types'] = kl_div_atom
|
| 867 |
+
|
| 868 |
+
# # Distribution of edge types
|
| 869 |
+
# kl_div_bond = self.ligand_bond_type_distribution.kl_divergence(bond_types) \
|
| 870 |
+
# if self.ligand_bond_type_distribution is not None else -1
|
| 871 |
+
# out['kl_div_bond_types'] = kl_div_bond
|
| 872 |
+
|
| 873 |
+
# if aa_types is not None:
|
| 874 |
+
# kl_div_aa = self.pocket_type_distribution.kl_divergence(aa_types) \
|
| 875 |
+
# if self.pocket_type_distribution is not None else -1
|
| 876 |
+
# out['kl_div_residue_types'] = kl_div_aa
|
| 877 |
+
|
| 878 |
+
# # Post-process sample
|
| 879 |
+
# processed_mols = [process_all(m) for m in rdmols]
|
| 880 |
+
|
| 881 |
+
# # Other basic metrics
|
| 882 |
+
# results = self.ligand_metrics(rdmols)
|
| 883 |
+
# out['n_samples'] = results['n_total']
|
| 884 |
+
# out['Validity'] = results['validity']
|
| 885 |
+
# out['Connectivity'] = results['connectivity']
|
| 886 |
+
# out['valid_and_connected'] = results['valid_and_connected']
|
| 887 |
+
|
| 888 |
+
# # connected_mols = [get_largest_fragment(m) for m in rdmols]
|
| 889 |
+
# connected_mols = [process_all(m, largest_frag=True, adjust_aromatic_Ns=False, relax_iter=0) for m in rdmols]
|
| 890 |
+
# connected_mols = [m for m in connected_mols if m is not None]
|
| 891 |
+
# out.update(self.molecule_properties(connected_mols))
|
| 892 |
+
|
| 893 |
+
# # Repeat after post-processing
|
| 894 |
+
# results = self.ligand_metrics(processed_mols)
|
| 895 |
+
# out['validity_processed'] = results['validity']
|
| 896 |
+
# out['connectivity_processed'] = results['connectivity']
|
| 897 |
+
# out['valid_and_connected_processed'] = results['valid_and_connected']
|
| 898 |
+
|
| 899 |
+
# processed_mols = [m for m in processed_mols if m is not None]
|
| 900 |
+
# for k, v in self.molecule_properties(processed_mols).items():
|
| 901 |
+
# out[f"{k}_processed"] = v
|
| 902 |
+
|
| 903 |
+
# # Simple docking score
|
| 904 |
+
# if receptors is not None and self.gnina is not None:
|
| 905 |
+
# assert len(receptors) == len(rdmols)
|
| 906 |
+
# docking_results = compute_gnina_scores(rdmols, receptors, gnina=self.gnina)
|
| 907 |
+
# out.update(docking_results)
|
| 908 |
+
|
| 909 |
+
# # Clash score
|
| 910 |
+
# if receptors is not None:
|
| 911 |
+
# assert len(receptors) == len(rdmols)
|
| 912 |
+
# clashes = {
|
| 913 |
+
# 'ligands': [legacy_clash_score(m) for m in rdmols],
|
| 914 |
+
# 'pockets': [legacy_clash_score(p) for p in receptors],
|
| 915 |
+
# 'between': [legacy_clash_score(m, p) for m, p in zip(rdmols, receptors)],
|
| 916 |
+
# 'v2_ligands': [clash_score(m) for m in rdmols],
|
| 917 |
+
# 'v2_pockets': [clash_score(p) for p in receptors],
|
| 918 |
+
# 'v2_between': [clash_score(m, p) for m, p in zip(rdmols, receptors)]
|
| 919 |
+
# }
|
| 920 |
+
# for k, v in clashes.items():
|
| 921 |
+
# out[f'mean_clash_score_{k}'] = np.mean(v)
|
| 922 |
+
# out[f'frac_no_clashes_{k}'] = np.mean(np.array(v) <= 0.0)
|
| 923 |
+
|
| 924 |
+
# return out
|
| 925 |
+
|
| 926 |
+
def analyze_sample(self, rdmols, atom_types, bond_types, aa_types=None, receptors=None):
|
| 927 |
+
out = {}
|
| 928 |
+
|
| 929 |
+
# Distribution of node types
|
| 930 |
+
kl_div_atom = self.ligand_atom_type_distribution.kl_divergence(atom_types) \
|
| 931 |
+
if self.ligand_atom_type_distribution is not None else -1
|
| 932 |
+
out['kl_div_atom_types'] = kl_div_atom
|
| 933 |
+
|
| 934 |
+
# Distribution of edge types
|
| 935 |
+
kl_div_bond = self.ligand_bond_type_distribution.kl_divergence(bond_types) \
|
| 936 |
+
if self.ligand_bond_type_distribution is not None else -1
|
| 937 |
+
out['kl_div_bond_types'] = kl_div_bond
|
| 938 |
+
|
| 939 |
+
if aa_types is not None:
|
| 940 |
+
kl_div_aa = self.pocket_type_distribution.kl_divergence(aa_types) \
|
| 941 |
+
if self.pocket_type_distribution is not None else -1
|
| 942 |
+
out['kl_div_residue_types'] = kl_div_aa
|
| 943 |
+
|
| 944 |
+
# Evaluation
|
| 945 |
+
results = []
|
| 946 |
+
if receptors is not None:
|
| 947 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 948 |
+
for mol, receptor in zip(tqdm(rdmols, desc='FullEvaluator'), receptors):
|
| 949 |
+
receptor_path = Path(tmpdir, 'receptor.pdb')
|
| 950 |
+
Chem.MolToPDBFile(receptor, str(receptor_path))
|
| 951 |
+
results.append(self.evaluator(mol, receptor_path))
|
| 952 |
+
else:
|
| 953 |
+
for mol in tqdm(rdmols, desc='FullEvaluator'):
|
| 954 |
+
self.evaluator = FullEvaluator(pb_conf='mol')
|
| 955 |
+
results.append(self.evaluator(mol))
|
| 956 |
+
|
| 957 |
+
results = pd.DataFrame(results)
|
| 958 |
+
agg_results = aggregated_metrics(results, self.evaluator.dtypes, VALIDITY_METRIC_NAME).fillna(0)
|
| 959 |
+
agg_results['metric'] = agg_results['metric'].str.replace('.', '/')
|
| 960 |
+
|
| 961 |
+
col_results = collection_metrics(results, self.train_smiles, VALIDITY_METRIC_NAME, exclude_evaluators='fcd')
|
| 962 |
+
col_results['metric'] = 'collection/' + col_results['metric']
|
| 963 |
+
|
| 964 |
+
all_results = pd.concat([agg_results, col_results])
|
| 965 |
+
out.update(**dict(all_results[['metric', 'value']].values))
|
| 966 |
+
|
| 967 |
+
return out
|
| 968 |
+
|
| 969 |
+
def sample_zt_given_zs(self, zs_ligand, zs_pocket, s, t, delta_eps_x=None, uncertainty=None):
|
| 970 |
+
|
| 971 |
+
sc_transform = self.get_sc_transform_fn(zs_pocket.get('chi'), zs_ligand['x'], s, None, zs_ligand['mask'], zs_pocket)
|
| 972 |
+
pred_ligand, pred_residues = self.dynamics(
|
| 973 |
+
zs_ligand['x'], zs_ligand['h'], zs_ligand['mask'], zs_pocket, s, bonds_ligand=(zs_ligand['bonds'], zs_ligand['e']),
|
| 974 |
+
sc_transform=sc_transform
|
| 975 |
+
)
|
| 976 |
+
|
| 977 |
+
if delta_eps_x is not None:
|
| 978 |
+
pred_ligand['vel'] = pred_ligand['vel'] + delta_eps_x
|
| 979 |
+
|
| 980 |
+
zt_ligand = zs_ligand.copy()
|
| 981 |
+
zt_ligand['x'] = self.module_x.sample_zt_given_zs(zs_ligand['x'], pred_ligand['vel'], s, t, zs_ligand['mask'])
|
| 982 |
+
|
| 983 |
+
zt_ligand['h'] = self.module_h.sample_zt_given_zs(zs_ligand['h'], pred_ligand['logits_h'], s, t, zs_ligand['mask'])
|
| 984 |
+
zt_ligand['e'] = self.module_e.sample_zt_given_zs(zs_ligand['e'], pred_ligand['logits_e'], s, t, zs_ligand['edge_mask'])
|
| 985 |
+
|
| 986 |
+
zt_pocket = zs_pocket.copy()
|
| 987 |
+
if self.flexible_bb:
|
| 988 |
+
zt_trans_pocket = self.module_trans.sample_zt_given_zs(zs_pocket['x'], pred_residues['trans'], s, t, zs_pocket['mask'])
|
| 989 |
+
zt_rot_pocket = self.module_rot.sample_zt_given_zs(zs_pocket['axis_angle'], pred_residues['rot'], s, t, zs_pocket['mask'])
|
| 990 |
+
|
| 991 |
+
# update pocket in-place
|
| 992 |
+
zt_pocket.set_frame(zt_trans_pocket, zt_rot_pocket)
|
| 993 |
+
|
| 994 |
+
if self.flexible:
|
| 995 |
+
zt_chi_pocket = self.module_chi.sample_zt_given_zs(zs_pocket['chi'][..., :5], pred_residues['chi'], s, t, zs_pocket['mask'])
|
| 996 |
+
|
| 997 |
+
# update pocket in-place
|
| 998 |
+
zt_pocket.set_chi(zt_chi_pocket)
|
| 999 |
+
|
| 1000 |
+
if self.predict_confidence:
|
| 1001 |
+
assert uncertainty is not None
|
| 1002 |
+
dt = (t - s).view(-1)[zt_ligand['mask']]
|
| 1003 |
+
uncertainty['sigma_x_squared'] += (dt * pred_ligand['uncertainty_vel']**2)
|
| 1004 |
+
uncertainty['entropy_h'] += (dt * Categorical(logits=pred_ligand['logits_h']).entropy())
|
| 1005 |
+
|
| 1006 |
+
return zt_ligand, zt_pocket
|
| 1007 |
+
|
| 1008 |
+
def simulate(self, ligand, pocket, timesteps, t_start, t_end=1.0,
|
| 1009 |
+
return_frames=1, guide_log_prob=None):
|
| 1010 |
+
"""
|
| 1011 |
+
Take a version of the ligand and pocket (at any time step t_start) and
|
| 1012 |
+
simulate the generative process from t_start to t_end.
|
| 1013 |
+
"""
|
| 1014 |
+
|
| 1015 |
+
assert 0 < return_frames <= timesteps
|
| 1016 |
+
assert timesteps % return_frames == 0
|
| 1017 |
+
assert 0.0 <= t_start < 1.0
|
| 1018 |
+
assert 0 < t_end <= 1.0
|
| 1019 |
+
assert t_start < t_end
|
| 1020 |
+
|
| 1021 |
+
device = ligand['x'].device
|
| 1022 |
+
n_samples = len(pocket['size'])
|
| 1023 |
+
delta_t = (t_end - t_start) / timesteps
|
| 1024 |
+
|
| 1025 |
+
# Initialize output tensors
|
| 1026 |
+
out_ligand = {
|
| 1027 |
+
'x': torch.zeros((return_frames, len(ligand['mask']), self.x_dim), device=device),
|
| 1028 |
+
'h': torch.zeros((return_frames, len(ligand['mask']), self.atom_nf), device=device),
|
| 1029 |
+
'e': torch.zeros((return_frames, len(ligand['edge_mask']), self.bond_nf), device=device)
|
| 1030 |
+
}
|
| 1031 |
+
if self.predict_confidence:
|
| 1032 |
+
out_ligand['sigma_x'] = torch.zeros((return_frames, len(ligand['mask'])), device=device)
|
| 1033 |
+
out_ligand['entropy_h'] = torch.zeros((return_frames, len(ligand['mask'])), device=device)
|
| 1034 |
+
out_pocket = {
|
| 1035 |
+
'x': torch.zeros((return_frames, len(pocket['mask']), 3), device=device), # CA-coord
|
| 1036 |
+
'v': torch.zeros((return_frames, len(pocket['mask']), self.n_atom_aa, 3), device=device) # difference vectors to all other atoms
|
| 1037 |
+
}
|
| 1038 |
+
|
| 1039 |
+
cumulative_uncertainty = {
|
| 1040 |
+
'sigma_x_squared': torch.zeros(len(ligand['mask']), device=device),
|
| 1041 |
+
'entropy_h': torch.zeros(len(ligand['mask']), device=device)
|
| 1042 |
+
} if self.predict_confidence else None
|
| 1043 |
+
|
| 1044 |
+
for i, t in enumerate(torch.linspace(t_start, t_end - delta_t, timesteps)):
|
| 1045 |
+
t_array = torch.full((n_samples, 1), fill_value=t, device=device)
|
| 1046 |
+
|
| 1047 |
+
if guide_log_prob is not None:
|
| 1048 |
+
raise NotImplementedError('Not yet implemented for flow matching model')
|
| 1049 |
+
alpha_t = self.diffusion_x.schedule.alpha(self.gamma_x(t_array))
|
| 1050 |
+
|
| 1051 |
+
with torch.enable_grad():
|
| 1052 |
+
zt_x_ligand.requires_grad = True
|
| 1053 |
+
g = guide_log_prob(t_array, x=ligand['x'], h=ligand['h'], batch_mask=ligand['mask'],
|
| 1054 |
+
bonds=ligand['bonds'], bond_types=ligand['e'])
|
| 1055 |
+
|
| 1056 |
+
# Compute gradient w.r.t. coordinates
|
| 1057 |
+
grad_x_lig = torch.autograd.grad(g.sum(), inputs=ligand['x'])[0]
|
| 1058 |
+
|
| 1059 |
+
# clip gradients
|
| 1060 |
+
g_max = 1.0
|
| 1061 |
+
clip_mask = (grad_x_lig.norm(dim=-1) > g_max)
|
| 1062 |
+
grad_x_lig[clip_mask] = \
|
| 1063 |
+
grad_x_lig[clip_mask] / grad_x_lig[clip_mask].norm(
|
| 1064 |
+
dim=-1, keepdim=True) * g_max
|
| 1065 |
+
|
| 1066 |
+
delta_eps_lig = -1 * (1 - alpha_t[lig_mask]).sqrt() * grad_x_lig
|
| 1067 |
+
else:
|
| 1068 |
+
delta_eps_lig = None
|
| 1069 |
+
|
| 1070 |
+
ligand, pocket = self.sample_zt_given_zs(
|
| 1071 |
+
ligand, pocket, t_array, t_array + delta_t, delta_eps_lig, cumulative_uncertainty)
|
| 1072 |
+
|
| 1073 |
+
# save frame
|
| 1074 |
+
if (i + 1) % (timesteps // return_frames) == 0:
|
| 1075 |
+
idx = (i + 1) // (timesteps // return_frames)
|
| 1076 |
+
idx = idx - 1
|
| 1077 |
+
|
| 1078 |
+
out_ligand['x'][idx] = ligand['x'].detach()
|
| 1079 |
+
out_ligand['h'][idx] = ligand['h'].detach()
|
| 1080 |
+
out_ligand['e'][idx] = ligand['e'].detach()
|
| 1081 |
+
if pocket['x'].numel() > 0:
|
| 1082 |
+
out_pocket['x'][idx] = pocket['x'].detach()
|
| 1083 |
+
out_pocket['v'][idx] = pocket['v'][:, :self.n_atom_aa, :].detach()
|
| 1084 |
+
if self.predict_confidence:
|
| 1085 |
+
out_ligand['sigma_x'][idx] = cumulative_uncertainty['sigma_x_squared'].sqrt().detach()
|
| 1086 |
+
out_ligand['entropy_h'][idx] = cumulative_uncertainty['entropy_h'].detach()
|
| 1087 |
+
|
| 1088 |
+
# remove frame dimension if only the final molecule is returned
|
| 1089 |
+
out_ligand = {k: v.squeeze(0) for k, v in out_ligand.items()}
|
| 1090 |
+
out_pocket = {k: v.squeeze(0) for k, v in out_pocket.items()}
|
| 1091 |
+
|
| 1092 |
+
return out_ligand, out_pocket
|
| 1093 |
+
|
| 1094 |
+
def init_ligand(self, num_nodes_lig, pocket):
|
| 1095 |
+
device = pocket['x'].device
|
| 1096 |
+
|
| 1097 |
+
n_samples = len(pocket['size'])
|
| 1098 |
+
lig_mask = utils.num_nodes_to_batch_mask(n_samples, num_nodes_lig, device)
|
| 1099 |
+
|
| 1100 |
+
# only consider upper triangular matrix for symmetry
|
| 1101 |
+
lig_bonds = torch.stack(torch.where(torch.triu(
|
| 1102 |
+
lig_mask[:, None] == lig_mask[None, :], diagonal=1)), dim=0)
|
| 1103 |
+
lig_edge_mask = lig_mask[lig_bonds[0]]
|
| 1104 |
+
|
| 1105 |
+
# Sample from Normal distribution in the pocket center
|
| 1106 |
+
pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0)
|
| 1107 |
+
z0_x = self.module_x.sample_z0(pocket_com, lig_mask)
|
| 1108 |
+
z0_h = self.module_h.sample_z0(lig_mask)
|
| 1109 |
+
z0_e = self.module_e.sample_z0(lig_edge_mask)
|
| 1110 |
+
|
| 1111 |
+
return TensorDict(**{
|
| 1112 |
+
'x': z0_x, 'h': z0_h, 'e': z0_e, 'mask': lig_mask,
|
| 1113 |
+
'bonds': lig_bonds, 'edge_mask': lig_edge_mask
|
| 1114 |
+
})
|
| 1115 |
+
|
| 1116 |
+
def init_pocket(self, pocket):
|
| 1117 |
+
|
| 1118 |
+
if self.flexible_bb:
|
| 1119 |
+
pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0)
|
| 1120 |
+
z0_trans = self.module_trans.sample_z0(pocket_com, pocket['mask'])
|
| 1121 |
+
z0_rot = self.module_rot.sample_z0(pocket['mask'])
|
| 1122 |
+
|
| 1123 |
+
# update pocket in-place
|
| 1124 |
+
pocket.set_frame(z0_trans, z0_rot)
|
| 1125 |
+
|
| 1126 |
+
if self.flexible:
|
| 1127 |
+
z0_chi = self.module_chi.sample_z0(pocket['mask'])
|
| 1128 |
+
|
| 1129 |
+
# # DEBUG ##
|
| 1130 |
+
# z0_chi = torch.stack([data_utils.get_torsion_angles(r, device=self.device) for r in pocket['residues']], dim=0)
|
| 1131 |
+
# ####
|
| 1132 |
+
|
| 1133 |
+
# internal to external coordinates
|
| 1134 |
+
pocket.set_chi(z0_chi)
|
| 1135 |
+
|
| 1136 |
+
if pocket['x'].numel() == 0:
|
| 1137 |
+
pocket.set_empty_v()
|
| 1138 |
+
|
| 1139 |
+
return pocket
|
| 1140 |
+
|
| 1141 |
+
def parse_num_nodes_spec(self, batch, spec=None, size_model=None):
|
| 1142 |
+
|
| 1143 |
+
if spec == "2d_histogram" or spec is None: # default option
|
| 1144 |
+
assert "pocket" in batch
|
| 1145 |
+
num_nodes = self.size_distribution.sample_conditional(
|
| 1146 |
+
n1=None, n2=batch['pocket']['size'])
|
| 1147 |
+
|
| 1148 |
+
# make sure there is at least one potential bond
|
| 1149 |
+
num_nodes[num_nodes < 2] = 2
|
| 1150 |
+
|
| 1151 |
+
elif isinstance(spec, (int, torch.Tensor)):
|
| 1152 |
+
num_nodes = spec
|
| 1153 |
+
|
| 1154 |
+
elif spec == "ground_truth":
|
| 1155 |
+
assert "ligand" in batch
|
| 1156 |
+
num_nodes = batch['ligand']['size']
|
| 1157 |
+
|
| 1158 |
+
elif spec == "nn_prediction":
|
| 1159 |
+
assert size_model is not None
|
| 1160 |
+
assert "pocket" in batch
|
| 1161 |
+
predictions = size_model.forward(batch['pocket'])
|
| 1162 |
+
predictions = torch.softmax(predictions, dim=-1)
|
| 1163 |
+
predictions[:, :5] = 0.0
|
| 1164 |
+
probabilities = predictions / predictions.sum(dim=1, keepdims=True)
|
| 1165 |
+
num_nodes = torch.distributions.Categorical(probabilities).sample()
|
| 1166 |
+
|
| 1167 |
+
elif isinstance(spec, str) and spec.startswith("uniform"):
|
| 1168 |
+
# expected format: uniform_low_high
|
| 1169 |
+
assert "pocket" in batch
|
| 1170 |
+
left, right = map(int, spec.split("_")[1:])
|
| 1171 |
+
shape = batch['pocket']['size'].shape
|
| 1172 |
+
num_nodes = torch.randint(left, right + 1, shape, dtype=torch.long)
|
| 1173 |
+
|
| 1174 |
+
else:
|
| 1175 |
+
raise NotImplementedError(f"Invalid size specification {spec}")
|
| 1176 |
+
|
| 1177 |
+
if self.virtual_nodes:
|
| 1178 |
+
num_nodes += self.add_virtual_max
|
| 1179 |
+
|
| 1180 |
+
return num_nodes
|
| 1181 |
+
|
| 1182 |
+
@torch.no_grad()
|
| 1183 |
+
def sample(self, data, n_samples, num_nodes=None, timesteps=None,
|
| 1184 |
+
guide_log_prob=None, size_model=None, **kwargs):
|
| 1185 |
+
|
| 1186 |
+
# TODO: move somewhere else (like collate_fn)
|
| 1187 |
+
data['pocket'] = Residues(**data['pocket'])
|
| 1188 |
+
|
| 1189 |
+
timesteps = self.T_sampling if timesteps is None else timesteps
|
| 1190 |
+
|
| 1191 |
+
if len(data['pocket']['x']) > 0:
|
| 1192 |
+
pocket = data_utils.repeat_items(data['pocket'], n_samples)
|
| 1193 |
+
else:
|
| 1194 |
+
pocket = Residues(**{key: value for key, value in data['pocket'].items()})
|
| 1195 |
+
pocket['name'] = pocket['name'] * n_samples
|
| 1196 |
+
pocket['size'] = pocket['size'].repeat(n_samples)
|
| 1197 |
+
pocket['n_bonds'] = pocket['n_bonds'].repeat(n_samples)
|
| 1198 |
+
|
| 1199 |
+
_ligand = data_utils.repeat_items(data['ligand'], n_samples)
|
| 1200 |
+
# _ligand = randomize_tensors(_ligand, exclude_keys=['size', 'name']) # avoid data leakage
|
| 1201 |
+
|
| 1202 |
+
batch = {"ligand": _ligand, "pocket": pocket}
|
| 1203 |
+
num_nodes = self.parse_num_nodes_spec(batch, spec=num_nodes, size_model=size_model)
|
| 1204 |
+
|
| 1205 |
+
# Sample from prior
|
| 1206 |
+
if pocket['x'].numel() > 0:
|
| 1207 |
+
ligand = self.init_ligand(num_nodes, pocket)
|
| 1208 |
+
else:
|
| 1209 |
+
ligand = self.init_ligand(num_nodes, _ligand)
|
| 1210 |
+
pocket = self.init_pocket(pocket)
|
| 1211 |
+
|
| 1212 |
+
# return prior samples
|
| 1213 |
+
if timesteps == 0:
|
| 1214 |
+
# Convert into rdmols
|
| 1215 |
+
rdmols = [build_molecule(coords=m['x'],
|
| 1216 |
+
atom_types=m['h'].argmax(1),
|
| 1217 |
+
bonds=m['bonds'],
|
| 1218 |
+
bond_types=m['e'].argmax(1),
|
| 1219 |
+
atom_decoder=self.atom_decoder, bond_decoder=self.bond_decoder)
|
| 1220 |
+
for m in data_utils.split_entity(ligand.detach().cpu(), edge_types={"e", "edge_mask"}, edge_mask=ligand["edge_mask"])]
|
| 1221 |
+
|
| 1222 |
+
rdpockets = pocket_to_rdkit(pocket, self.pocket_representation,
|
| 1223 |
+
self.atom_encoder, self.atom_decoder,
|
| 1224 |
+
self.aa_decoder, self.residue_decoder,
|
| 1225 |
+
self.aa_atom_index)
|
| 1226 |
+
|
| 1227 |
+
return rdmols, rdpockets, _ligand['name']
|
| 1228 |
+
|
| 1229 |
+
out_tensors_ligand, out_tensors_pocket = self.simulate(
|
| 1230 |
+
ligand, pocket, timesteps, 0.0, 1.0,
|
| 1231 |
+
guide_log_prob=guide_log_prob
|
| 1232 |
+
)
|
| 1233 |
+
|
| 1234 |
+
# Build mol objects
|
| 1235 |
+
x = out_tensors_ligand['x'].detach().cpu()
|
| 1236 |
+
ligand_type = out_tensors_ligand['h'].argmax(1).detach().cpu()
|
| 1237 |
+
edge_type = out_tensors_ligand['e'].argmax(1).detach().cpu()
|
| 1238 |
+
lig_mask = ligand['mask'].detach().cpu()
|
| 1239 |
+
lig_bonds = ligand['bonds'].detach().cpu()
|
| 1240 |
+
lig_edge_mask = ligand['edge_mask'].detach().cpu()
|
| 1241 |
+
sizes = torch.unique(ligand['mask'], return_counts=True)[1].tolist()
|
| 1242 |
+
offsets = list(accumulate(sizes[:-1], initial=0))
|
| 1243 |
+
mol_kwargs = {
|
| 1244 |
+
'coords': utils.batch_to_list(x, lig_mask),
|
| 1245 |
+
'atom_types': utils.batch_to_list(ligand_type, lig_mask),
|
| 1246 |
+
'bonds': utils.batch_to_list_for_indices(lig_bonds, lig_edge_mask, offsets),
|
| 1247 |
+
'bond_types': utils.batch_to_list(edge_type, lig_edge_mask)
|
| 1248 |
+
}
|
| 1249 |
+
if self.predict_confidence:
|
| 1250 |
+
sigma_x = out_tensors_ligand['sigma_x'].detach().cpu()
|
| 1251 |
+
entropy_h = out_tensors_ligand['entropy_h'].detach().cpu()
|
| 1252 |
+
mol_kwargs['atom_props'] = [
|
| 1253 |
+
{'sigma_x': x[0], 'entropy_h': x[1]}
|
| 1254 |
+
for x in zip(utils.batch_to_list(sigma_x, lig_mask),
|
| 1255 |
+
utils.batch_to_list(entropy_h, lig_mask))
|
| 1256 |
+
]
|
| 1257 |
+
mol_kwargs = [{k: v[i] for k, v in mol_kwargs.items()}
|
| 1258 |
+
for i in range(len(mol_kwargs['coords']))]
|
| 1259 |
+
|
| 1260 |
+
# Convert into rdmols
|
| 1261 |
+
rdmols = [build_molecule(
|
| 1262 |
+
**m, atom_decoder=self.atom_decoder, bond_decoder=self.bond_decoder)
|
| 1263 |
+
for m in mol_kwargs
|
| 1264 |
+
]
|
| 1265 |
+
|
| 1266 |
+
out_pocket = pocket.copy()
|
| 1267 |
+
out_pocket['x'] = out_tensors_pocket['x']
|
| 1268 |
+
out_pocket['v'] = out_tensors_pocket['v']
|
| 1269 |
+
rdpockets = pocket_to_rdkit(out_pocket, self.pocket_representation,
|
| 1270 |
+
self.atom_encoder, self.atom_decoder,
|
| 1271 |
+
self.aa_decoder, self.residue_decoder,
|
| 1272 |
+
self.aa_atom_index)
|
| 1273 |
+
|
| 1274 |
+
return rdmols, rdpockets, _ligand['name']
|
| 1275 |
+
|
| 1276 |
+
@torch.no_grad()
|
| 1277 |
+
def sample_chain(self, pocket, keep_frames, num_nodes=None, timesteps=None,
|
| 1278 |
+
guide_log_prob=None, **kwargs):
|
| 1279 |
+
|
| 1280 |
+
# TODO: move somewhere else (like collate_fn)
|
| 1281 |
+
pocket = Residues(**pocket)
|
| 1282 |
+
|
| 1283 |
+
info = {}
|
| 1284 |
+
|
| 1285 |
+
timesteps = self.T_sampling if timesteps is None else timesteps
|
| 1286 |
+
|
| 1287 |
+
# n_samples = 1
|
| 1288 |
+
# TODO: get batch_size differently
|
| 1289 |
+
assert len(pocket['mask'].unique()) <= 1, "sample_chain only supports a single sample"
|
| 1290 |
+
|
| 1291 |
+
# # Pocket's initial center of mass
|
| 1292 |
+
# pocket_com_before = scatter_mean(pocket['x'], pocket['mask'], dim=0)
|
| 1293 |
+
|
| 1294 |
+
num_nodes = self.parse_num_nodes_spec(batch={"pocket": pocket}, spec=num_nodes)
|
| 1295 |
+
|
| 1296 |
+
# Sample from prior
|
| 1297 |
+
if pocket['x'].numel() > 0:
|
| 1298 |
+
ligand = self.init_ligand(num_nodes, pocket)
|
| 1299 |
+
else:
|
| 1300 |
+
dummy_pocket = Residues.empty(pocket['x'].device)
|
| 1301 |
+
ligand = self.init_ligand(num_nodes, dummy_pocket)
|
| 1302 |
+
|
| 1303 |
+
pocket = self.init_pocket(pocket)
|
| 1304 |
+
|
| 1305 |
+
out_tensors_ligand, out_tensors_pocket = self.simulate(
|
| 1306 |
+
ligand, pocket, timesteps, 0.0, 1.0, guide_log_prob=guide_log_prob, return_frames=keep_frames)
|
| 1307 |
+
|
| 1308 |
+
# chain_lig = utils.reverse_tensor(chain_lig)
|
| 1309 |
+
# chain_pocket = utils.reverse_tensor(chain_pocket)
|
| 1310 |
+
# chain_bond = utils.reverse_tensor(chain_bond)
|
| 1311 |
+
|
| 1312 |
+
info['traj_displacement_lig'] = torch.norm(out_tensors_ligand['x'][-1] - out_tensors_ligand['x'][0], dim=-1).mean()
|
| 1313 |
+
info['traj_rms_lig'] = out_tensors_ligand['x'].std(dim=0).mean()
|
| 1314 |
+
|
| 1315 |
+
# # Repeat last frame to see final sample better.
|
| 1316 |
+
# chain_lig = torch.cat([chain_lig, chain_lig[-1:].repeat(10, 1, 1)], dim=0)
|
| 1317 |
+
# chain_pocket = torch.cat([chain_pocket, chain_pocket[-1:].repeat(10, 1, 1)], dim=0)
|
| 1318 |
+
# chain_bond = torch.cat([chain_bond, chain_bond[-1:].repeat(10, 1, 1)], dim=0)
|
| 1319 |
+
|
| 1320 |
+
# Flatten
|
| 1321 |
+
assert keep_frames == out_tensors_ligand['x'].size(0) == out_tensors_pocket['x'].size(0)
|
| 1322 |
+
n_atoms = out_tensors_ligand['x'].size(1)
|
| 1323 |
+
n_bonds = out_tensors_ligand['e'].size(1)
|
| 1324 |
+
n_residues = out_tensors_pocket['x'].size(1)
|
| 1325 |
+
device = out_tensors_ligand['x'].device
|
| 1326 |
+
|
| 1327 |
+
def flatten_tensor(chain):
|
| 1328 |
+
if len(chain.size()) == 3: # l=0 values
|
| 1329 |
+
return chain.view(-1, chain.size(-1))
|
| 1330 |
+
elif len(chain.size()) == 4: # vectors
|
| 1331 |
+
return chain.view(-1, chain.size(-2), chain.size(-1))
|
| 1332 |
+
else:
|
| 1333 |
+
warnings.warn(f"Could not flatten frame dimension of tensor with shape {list(chain.size())}")
|
| 1334 |
+
return chain
|
| 1335 |
+
|
| 1336 |
+
out_tensors_ligand_flat = {k: flatten_tensor(chain) for k, chain in out_tensors_ligand.items()}
|
| 1337 |
+
out_tensors_pocket_flat = {k: flatten_tensor(chain) for k, chain in out_tensors_pocket.items()}
|
| 1338 |
+
# ligand_flat = chain_lig.view(-1, chain_lig.size(-1))
|
| 1339 |
+
# ligand_mask_flat = torch.arange(chain_lig.size(0)).repeat_interleave(chain_lig.size(1)).to(chain_lig.device)
|
| 1340 |
+
ligand_mask_flat = torch.arange(keep_frames).repeat_interleave(n_atoms).to(device)
|
| 1341 |
+
|
| 1342 |
+
# # pocket_flat = chain_pocket.view(-1, chain_pocket.size(-1))
|
| 1343 |
+
# # pocket_v_flat = pocket['v'].repeat(100, 1, 1)
|
| 1344 |
+
# pocket_flat = chain_pocket.view(-1, chain_pocket.size(-2), chain_pocket.size(-1))
|
| 1345 |
+
# pocket_mask_flat = torch.arange(chain_pocket.size(0)).repeat_interleave(chain_pocket.size(1)).to(chain_pocket.device)
|
| 1346 |
+
pocket_mask_flat = torch.arange(keep_frames).repeat_interleave(n_residues).to(device)
|
| 1347 |
+
|
| 1348 |
+
# bond_flat = chain_bond.view(-1, chain_bond.size(-1))
|
| 1349 |
+
# bond_mask_flat = torch.arange(chain_bond.size(0)).repeat_interleave(chain_bond.size(1)).to(chain_bond.device)
|
| 1350 |
+
bond_mask_flat = torch.arange(keep_frames).repeat_interleave(n_bonds).to(device)
|
| 1351 |
+
edges_flat = ligand['bonds'].repeat(1, keep_frames)
|
| 1352 |
+
|
| 1353 |
+
# # Move generated molecule back to the original pocket position
|
| 1354 |
+
# pocket_com_after = scatter_mean(pocket_flat[:, 0, :], pocket_mask_flat, dim=0)
|
| 1355 |
+
# ligand_flat[:, :self.x_dim] += (pocket_com_before - pocket_com_after)[ligand_mask_flat]
|
| 1356 |
+
#
|
| 1357 |
+
# # Move pocket back as well (for visualization purposes)
|
| 1358 |
+
# pocket_flat[:, 0, :] += (pocket_com_before - pocket_com_after)[pocket_mask_flat]
|
| 1359 |
+
|
| 1360 |
+
# Build ligands
|
| 1361 |
+
x = out_tensors_ligand_flat['x'].detach().cpu()
|
| 1362 |
+
ligand_type = out_tensors_ligand_flat['h'].argmax(1).detach().cpu()
|
| 1363 |
+
ligand_mask_flat = ligand_mask_flat.detach().cpu()
|
| 1364 |
+
bond_mask_flat = bond_mask_flat.detach().cpu()
|
| 1365 |
+
edges_flat = edges_flat.detach().cpu()
|
| 1366 |
+
edge_type = out_tensors_ligand_flat['e'].argmax(1).detach().cpu()
|
| 1367 |
+
offsets = torch.zeros(keep_frames, dtype=int) # edges_flat is already zero-based
|
| 1368 |
+
molecules = list(
|
| 1369 |
+
zip(utils.batch_to_list(x, ligand_mask_flat),
|
| 1370 |
+
utils.batch_to_list(ligand_type, ligand_mask_flat),
|
| 1371 |
+
utils.batch_to_list_for_indices(edges_flat, bond_mask_flat, offsets),
|
| 1372 |
+
utils.batch_to_list(edge_type, bond_mask_flat)
|
| 1373 |
+
)
|
| 1374 |
+
)
|
| 1375 |
+
|
| 1376 |
+
# Convert into rdmols
|
| 1377 |
+
ligand_chain = [build_molecule(
|
| 1378 |
+
*graph, atom_decoder=self.atom_decoder,
|
| 1379 |
+
bond_decoder=self.bond_decoder) for graph in molecules
|
| 1380 |
+
]
|
| 1381 |
+
|
| 1382 |
+
# Build pockets
|
| 1383 |
+
# as long as the pocket does not change during sampling, we can ust
|
| 1384 |
+
# write it once
|
| 1385 |
+
out_pocket = {
|
| 1386 |
+
'x': out_tensors_pocket_flat['x'],
|
| 1387 |
+
'one_hot': pocket['one_hot'].repeat(keep_frames, 1),
|
| 1388 |
+
'mask': pocket_mask_flat,
|
| 1389 |
+
'v': out_tensors_pocket_flat['v'],
|
| 1390 |
+
'atom_mask': pocket['atom_mask'].repeat(keep_frames, 1),
|
| 1391 |
+
} if self.flexible else pocket
|
| 1392 |
+
pocket_chain = pocket_to_rdkit(out_pocket, self.pocket_representation,
|
| 1393 |
+
self.atom_encoder, self.atom_decoder,
|
| 1394 |
+
self.aa_decoder, self.residue_decoder,
|
| 1395 |
+
self.aa_atom_index)
|
| 1396 |
+
|
| 1397 |
+
return ligand_chain, pocket_chain, info
|
| 1398 |
+
|
| 1399 |
+
# def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
|
| 1400 |
+
# def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm):
|
| 1401 |
+
def configure_gradient_clipping(self, optimizer, *args, **kwargs):
|
| 1402 |
+
|
| 1403 |
+
if not self.clip_grad:
|
| 1404 |
+
return
|
| 1405 |
+
|
| 1406 |
+
# Allow gradient norm to be 150% + 2 * stdev of the recent history.
|
| 1407 |
+
max_grad_norm = 1.5 * self.gradnorm_queue.mean() + \
|
| 1408 |
+
2 * self.gradnorm_queue.std()
|
| 1409 |
+
|
| 1410 |
+
# hard upper limit
|
| 1411 |
+
max_grad_norm = min(max_grad_norm, 10.0)
|
| 1412 |
+
|
| 1413 |
+
# Get current grad_norm
|
| 1414 |
+
params = [p for g in optimizer.param_groups for p in g['params']]
|
| 1415 |
+
grad_norm = utils.get_grad_norm(params)
|
| 1416 |
+
|
| 1417 |
+
# Lightning will handle the gradient clipping
|
| 1418 |
+
self.clip_gradients(optimizer, gradient_clip_val=max_grad_norm,
|
| 1419 |
+
gradient_clip_algorithm='norm')
|
| 1420 |
+
|
| 1421 |
+
if float(grad_norm) > max_grad_norm:
|
| 1422 |
+
print(f'Clipped gradient with value {grad_norm:.1f} '
|
| 1423 |
+
f'while allowed {max_grad_norm:.1f}')
|
| 1424 |
+
grad_norm = max_grad_norm
|
| 1425 |
+
|
| 1426 |
+
self.gradnorm_queue.add(float(grad_norm))
|
src/model/loss_utils.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch_scatter import scatter_add, scatter_mean
|
| 3 |
+
|
| 4 |
+
from src.constants import atom_decoder, vdw_radii
|
| 5 |
+
_vdw_radii = {**vdw_radii}
|
| 6 |
+
_vdw_radii['NH'] = vdw_radii['N']
|
| 7 |
+
_vdw_radii['N+'] = vdw_radii['N']
|
| 8 |
+
_vdw_radii['O-'] = vdw_radii['O']
|
| 9 |
+
_vdw_radii['NOATOM'] = 0
|
| 10 |
+
vdw_radii_array = torch.tensor([_vdw_radii[a] for a in atom_decoder])
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def clash_loss(ligand_coord, ligand_types, ligand_mask, pocket_coord,
|
| 14 |
+
pocket_types, pocket_mask):
|
| 15 |
+
"""
|
| 16 |
+
Computes a clash loss that penalizes interatomic distances smaller than the
|
| 17 |
+
sum of van der Waals radii between atoms.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
ligand_radii = vdw_radii_array[ligand_types].to(ligand_coord.device)
|
| 21 |
+
pocket_radii = vdw_radii_array[pocket_types].to(pocket_coord.device)
|
| 22 |
+
|
| 23 |
+
dist = torch.sqrt(torch.sum((ligand_coord[:, None, :] - pocket_coord[None, :, :]) ** 2, dim=-1))
|
| 24 |
+
# dist[ligand_mask[:, None] != pocket_mask[None, :]] = float('inf')
|
| 25 |
+
|
| 26 |
+
# compute linearly decreasing penalty
|
| 27 |
+
# penalty = max(1 - 1/sum_vdw * d, 0)
|
| 28 |
+
sum_vdw = ligand_radii[:, None] + pocket_radii[None, :]
|
| 29 |
+
loss = torch.clamp(1 - dist / sum_vdw, min=0.0) # (n_ligand, n_pocket)
|
| 30 |
+
|
| 31 |
+
loss = scatter_add(loss, pocket_mask, dim=1)
|
| 32 |
+
loss = scatter_mean(loss, ligand_mask, dim=0)
|
| 33 |
+
loss = loss.diag()
|
| 34 |
+
|
| 35 |
+
# # DEBUG (non-differentiable version)
|
| 36 |
+
# dist = torch.sqrt(torch.sum((ligand_coord[:, None, :] - pocket_coord[None, :, :]) ** 2, dim=-1))
|
| 37 |
+
# dist[ligand_mask[:, None] != pocket_mask[None, :]] = float('inf')
|
| 38 |
+
# _loss = torch.clamp(1 - dist / sum_vdw, min=0.0) # (n_ligand, n_pocket)
|
| 39 |
+
# _loss = _loss.sum(dim=-1)
|
| 40 |
+
# _loss = scatter_mean(_loss, ligand_mask, dim=0)
|
| 41 |
+
# assert torch.allclose(loss, _loss)
|
| 42 |
+
|
| 43 |
+
return loss
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TimestepSampler:
|
| 47 |
+
def __init__(self, type='uniform', lowest_t=1, highest_t=500):
|
| 48 |
+
assert type in {'uniform', 'sigmoid'}
|
| 49 |
+
self.type = type
|
| 50 |
+
self.lowest_t = lowest_t
|
| 51 |
+
self.highest_t = highest_t
|
| 52 |
+
|
| 53 |
+
def __call__(self, n, device=None):
|
| 54 |
+
if self.type == 'uniform':
|
| 55 |
+
t_int = torch.randint(self.lowest_t, self.highest_t + 1,
|
| 56 |
+
size=(n, 1), device=device)
|
| 57 |
+
|
| 58 |
+
elif self.type == 'sigmoid':
|
| 59 |
+
weight_fun = lambda t: 1.45 * torch.sigmoid(-t * 10 / self.highest_t + 5) + 0.05
|
| 60 |
+
|
| 61 |
+
possible_ts = torch.arange(self.lowest_t, self.highest_t + 1, device=device)
|
| 62 |
+
weights = weight_fun(possible_ts)
|
| 63 |
+
weights = weights / weights.sum()
|
| 64 |
+
t_int = possible_ts[torch.multinomial(weights, n, replacement=True)].unsqueeze(-1)
|
| 65 |
+
|
| 66 |
+
return t_int.float()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class TimestepWeights:
|
| 70 |
+
def __init__(self, weight_type, a, b):
|
| 71 |
+
if weight_type != 'sigmoid':
|
| 72 |
+
raise NotImplementedError("Only sigmoidal loss weighting is available.")
|
| 73 |
+
# self.weight_fn = lambda t: a * torch.sigmoid((-t + 0.5) * b) + (1 - a / 2)
|
| 74 |
+
self.weight_fn = lambda t: a * torch.sigmoid((t - 0.5) * b) + (1 - a / 2)
|
| 75 |
+
|
| 76 |
+
def __call__(self, t_array):
|
| 77 |
+
# normalized t \in [0, 1]
|
| 78 |
+
# return self.weight_fn(1 - t_array)
|
| 79 |
+
return self.weight_fn(t_array)
|
src/model/markov_bridge.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import reduce
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch_scatter import scatter_mean, scatter_add
|
| 5 |
+
|
| 6 |
+
from src.utils import bvm
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LinearSchedule:
|
| 10 |
+
"""
|
| 11 |
+
We use the scheduling parameter \beta to linearly remove noise, i.e.
|
| 12 |
+
\bar{\beta}_t = 1 - h (h: step size) with
|
| 13 |
+
\bar{Q}_t = \bar{\beta}_t I + (1 - \bar{\beta}_t) 1_vec z1^T
|
| 14 |
+
|
| 15 |
+
From this, it follows that for each step transition matrix, we have
|
| 16 |
+
\beta_t = \bar{\beta}_t / \bar{\beta}_{t-h} = \frac{1-t}{1-t+h}
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self):
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
def beta_bar(self, t):
|
| 22 |
+
return 1 - t
|
| 23 |
+
|
| 24 |
+
def beta(self, t, step_size):
|
| 25 |
+
return (1 - t) / (1 - t + step_size)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class UniformPriorMarkovBridge:
|
| 29 |
+
"""
|
| 30 |
+
Markov bridge model in which z0 is drawn from a uniform prior.
|
| 31 |
+
Transitions are defined as:
|
| 32 |
+
Q_t = \beta_t I + (1 - \beta_t) 1_vec z1^T
|
| 33 |
+
where z1 is a one-hot representation of the final state.
|
| 34 |
+
We follow the notation from [1] and multiply transition matrices from the
|
| 35 |
+
right to one-hot state vectors.
|
| 36 |
+
|
| 37 |
+
We use the scheduling parameter \beta to linearly remove noise, i.e.
|
| 38 |
+
\bar{\beta}_t = 1 - h (h: step size) with
|
| 39 |
+
\bar{Q}_t = \bar{\beta}_t I + (1 - \bar{\beta}_t) 1_vec z1^T
|
| 40 |
+
|
| 41 |
+
From this, it follows that for each step transition matrix, we have
|
| 42 |
+
\beta_t = \bar{\beta}_t / \bar{\beta}_{t-h} = \frac{1-t}{1-t+h}
|
| 43 |
+
|
| 44 |
+
[1] Austin, Jacob, et al.
|
| 45 |
+
"Structured denoising diffusion models in discrete state-spaces."
|
| 46 |
+
Advances in Neural Information Processing Systems 34 (2021): 17981-17993.
|
| 47 |
+
"""
|
| 48 |
+
def __init__(self, dim, loss_type='CE', step_size=None):
|
| 49 |
+
assert loss_type in ['VLB', 'CE']
|
| 50 |
+
self.dim = dim
|
| 51 |
+
self.step_size = step_size # required for VLB
|
| 52 |
+
self.schedule = LinearSchedule()
|
| 53 |
+
self.loss_type = loss_type
|
| 54 |
+
super(UniformPriorMarkovBridge, self).__init__()
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def sample_categorical(p):
|
| 58 |
+
"""
|
| 59 |
+
Sample from categorical distribution defined by probabilities 'p'
|
| 60 |
+
:param p: (n, dim)
|
| 61 |
+
:return: one-hot encoded samples (n, dim)
|
| 62 |
+
"""
|
| 63 |
+
sampled = torch.multinomial(p, 1).squeeze(-1)
|
| 64 |
+
return F.one_hot(sampled, num_classes=p.size(1)).float()
|
| 65 |
+
|
| 66 |
+
def p_z0(self, batch_mask):
|
| 67 |
+
return torch.ones((len(batch_mask), self.dim), device=batch_mask.device) / self.dim
|
| 68 |
+
|
| 69 |
+
def sample_z0(self, batch_mask):
|
| 70 |
+
""" Prior. """
|
| 71 |
+
z0 = self.sample_categorical(self.p_z0(batch_mask))
|
| 72 |
+
return z0
|
| 73 |
+
|
| 74 |
+
def p_zt(self, z0, z1, t, batch_mask):
|
| 75 |
+
Qt_bar = self.get_Qt_bar(t, z1, batch_mask)
|
| 76 |
+
return bvm(z0, Qt_bar)
|
| 77 |
+
|
| 78 |
+
def sample_zt(self, z0, z1, t, batch_mask):
|
| 79 |
+
zt = self.sample_categorical(self.p_zt(z0, z1, t, batch_mask))
|
| 80 |
+
return zt
|
| 81 |
+
|
| 82 |
+
def p_zt_given_zs_and_z1(self, zs, z1, s, t, batch_mask):
|
| 83 |
+
# 'z1' are one-hot "probabilities" for each class
|
| 84 |
+
Qt = self.get_Qt(t, s, z1, batch_mask)
|
| 85 |
+
# from pdb import set_trace; set_trace()
|
| 86 |
+
q_zs_given_zt = bvm(zs, Qt)
|
| 87 |
+
return q_zs_given_zt
|
| 88 |
+
|
| 89 |
+
def p_zt_given_zs(self, zs, p_z1_hat, s, t, batch_mask):
|
| 90 |
+
"""
|
| 91 |
+
Note that x can also represent a categorical distribution to compute
|
| 92 |
+
transitions more efficiently at sampling time:
|
| 93 |
+
p(z_t|z_s) = \sum_{\hat{z}_1} p(z_t | z_s, \hat{z}_1) * p(\hat{z}_1 | z_s)
|
| 94 |
+
= \sum_i z_s (\beta_t I + (1 - \beta_t) 1_vec z1_i^T) * \hat{p}_i
|
| 95 |
+
= \beta_t z_s I + (1 - \beta_t) z_s 1_vec \hat{p}^t
|
| 96 |
+
"""
|
| 97 |
+
return self.p_zt_given_zs_and_z1(zs, p_z1_hat, s, t, batch_mask)
|
| 98 |
+
|
| 99 |
+
def sample_zt_given_zs(self, zs, z1_logits, s, t, batch_mask):
|
| 100 |
+
p_z1 = z1_logits.softmax(dim=-1)
|
| 101 |
+
zt = self.sample_categorical(self.p_zt_given_zs(zs, p_z1, s, t, batch_mask))
|
| 102 |
+
return zt
|
| 103 |
+
|
| 104 |
+
def compute_loss(self, pred_logits, zs, z1, batch_mask, s, t, reduce='mean'):
|
| 105 |
+
""" Compute loss per sample. """
|
| 106 |
+
assert reduce in {'mean', 'sum', 'none'}
|
| 107 |
+
|
| 108 |
+
if self.loss_type == 'CE':
|
| 109 |
+
loss = F.cross_entropy(pred_logits, z1, reduction='none')
|
| 110 |
+
|
| 111 |
+
else: # VLB
|
| 112 |
+
true_p_zs = self.p_zt_given_zs_and_z1(zs, z1, s, t, batch_mask)
|
| 113 |
+
pred_p_zs = self.p_zt_given_zs(zs, pred_logits.softmax(dim=-1), s, t, batch_mask)
|
| 114 |
+
loss = F.kl_div(pred_p_zs.log(), true_p_zs, reduction='none').sum(dim=-1)
|
| 115 |
+
|
| 116 |
+
if reduce == 'mean':
|
| 117 |
+
loss = scatter_mean(loss, batch_mask, dim=0)
|
| 118 |
+
elif reduce == 'sum':
|
| 119 |
+
loss = scatter_add(loss, batch_mask, dim=0)
|
| 120 |
+
|
| 121 |
+
return loss
|
| 122 |
+
|
| 123 |
+
def get_Qt(self, t, s, z1, batch_mask):
|
| 124 |
+
""" Returns one-step transition matrix from step s to step t. """
|
| 125 |
+
|
| 126 |
+
beta_t_given_s = self.schedule.beta(t, t - s)
|
| 127 |
+
beta_t_given_s = beta_t_given_s.unsqueeze(-1)[batch_mask]
|
| 128 |
+
|
| 129 |
+
# Q_t = beta_t * I + (1 - beta_t) * ones (dot) z1^T
|
| 130 |
+
Qt = beta_t_given_s * torch.eye(self.dim, device=t.device).unsqueeze(0) + \
|
| 131 |
+
(1 - beta_t_given_s) * z1.unsqueeze(1)
|
| 132 |
+
# (1 - beta_t_given_s) * (torch.ones(self.dim, 1, device=t.device) @ z1)
|
| 133 |
+
|
| 134 |
+
# assert (Qt.sum(-1) == 1).all()
|
| 135 |
+
|
| 136 |
+
return Qt
|
| 137 |
+
|
| 138 |
+
def get_Qt_bar(self, t, z1, batch_mask):
|
| 139 |
+
""" Returns transition matrix from step 0 to step t. """
|
| 140 |
+
|
| 141 |
+
beta_bar_t = self.schedule.beta_bar(t)
|
| 142 |
+
beta_bar_t = beta_bar_t.unsqueeze(-1)[batch_mask]
|
| 143 |
+
|
| 144 |
+
# Q_t_bar = beta_bar * I + (1 - beta_bar) * ones (dot) z1^T
|
| 145 |
+
Qt_bar = beta_bar_t * torch.eye(self.dim, device=t.device).unsqueeze(0) + \
|
| 146 |
+
(1 - beta_bar_t) * z1.unsqueeze(1)
|
| 147 |
+
# (1 - beta_bar_t) * (torch.ones(self.dim, 1, device=t.device) @ z1)
|
| 148 |
+
|
| 149 |
+
# assert (Qt_bar.sum(-1) == 1).all()
|
| 150 |
+
|
| 151 |
+
return Qt_bar
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class MarginalPriorMarkovBridge(UniformPriorMarkovBridge):
|
| 155 |
+
def __init__(self, dim, prior_p, loss_type='CE', step_size=None):
|
| 156 |
+
self.prior_p = prior_p
|
| 157 |
+
print('Marginal Prior MB')
|
| 158 |
+
super(MarginalPriorMarkovBridge, self).__init__(dim, loss_type, step_size)
|
| 159 |
+
|
| 160 |
+
def p_z0(self, batch_mask):
|
| 161 |
+
device = batch_mask.device
|
| 162 |
+
p = torch.ones((len(batch_mask), self.dim), device=device) * self.prior_p.view(1, -1).to(device)
|
| 163 |
+
return p
|
src/sample_and_evaluate.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import sys
|
| 3 |
+
import yaml
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pickle
|
| 7 |
+
from argparse import Namespace
|
| 8 |
+
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
basedir = Path(__file__).resolve().parent.parent
|
| 12 |
+
sys.path.append(str(basedir))
|
| 13 |
+
|
| 14 |
+
from src import utils
|
| 15 |
+
from src.utils import dict_to_namespace, namespace_to_dict
|
| 16 |
+
from src.analysis.visualization_utils import mols_to_pdbfile, mol_as_pdb
|
| 17 |
+
from src.data.data_utils import TensorDict, Residues
|
| 18 |
+
from src.data.postprocessing import process_all
|
| 19 |
+
from src.model.lightning import DrugFlow
|
| 20 |
+
from src.sbdd_metrics.evaluation import compute_all_metrics_drugflow
|
| 21 |
+
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
from pdb import set_trace
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def combine(base_args, override_args):
|
| 27 |
+
assert not isinstance(base_args, dict)
|
| 28 |
+
assert not isinstance(override_args, dict)
|
| 29 |
+
|
| 30 |
+
arg_dict = base_args.__dict__
|
| 31 |
+
for key, value in override_args.__dict__.items():
|
| 32 |
+
if key not in arg_dict or arg_dict[key] is None: # parameter not provided previously
|
| 33 |
+
print(f"Add parameter {key}: {value}")
|
| 34 |
+
arg_dict[key] = value
|
| 35 |
+
elif isinstance(value, Namespace):
|
| 36 |
+
arg_dict[key] = combine(arg_dict[key], value)
|
| 37 |
+
else:
|
| 38 |
+
print(f"Replace parameter {key}: {arg_dict[key]} -> {value}")
|
| 39 |
+
arg_dict[key] = value
|
| 40 |
+
return base_args
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def path_to_str(input_dict):
|
| 44 |
+
for key, value in input_dict.items():
|
| 45 |
+
if isinstance(value, dict):
|
| 46 |
+
input_dict[key] = path_to_str(value)
|
| 47 |
+
else:
|
| 48 |
+
input_dict[key] = str(value) if isinstance(value, Path) else value
|
| 49 |
+
return input_dict
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def sample(cfg, model_params, samples_dir, job_id=0, n_jobs=1):
|
| 53 |
+
print('Sampling...')
|
| 54 |
+
model = DrugFlow.load_from_checkpoint(cfg.checkpoint, map_location=cfg.device, strict=False,
|
| 55 |
+
**model_params)
|
| 56 |
+
model.setup(stage='fit' if cfg.set == 'train' else cfg.set)
|
| 57 |
+
model.eval().to(cfg.device)
|
| 58 |
+
|
| 59 |
+
dataloader = getattr(model, f'{cfg.set}_dataloader')()
|
| 60 |
+
print(f'Real batch size is {dataloader.batch_size * cfg.n_samples}')
|
| 61 |
+
|
| 62 |
+
name2count = {}
|
| 63 |
+
for i, data in enumerate(tqdm(dataloader)):
|
| 64 |
+
if i % n_jobs != job_id:
|
| 65 |
+
print(f'Skipping batch {i}')
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
new_data = {
|
| 69 |
+
'ligand': TensorDict(**data['ligand']).to(cfg.device),
|
| 70 |
+
'pocket': Residues(**data['pocket']).to(cfg.device),
|
| 71 |
+
}
|
| 72 |
+
try:
|
| 73 |
+
rdmols, rdpockets, names = model.sample(
|
| 74 |
+
data=new_data,
|
| 75 |
+
n_samples=cfg.n_samples,
|
| 76 |
+
num_nodes=("ground_truth" if cfg.sample_with_ground_truth_size else None)
|
| 77 |
+
)
|
| 78 |
+
except Exception as e:
|
| 79 |
+
if cfg.set == 'train':
|
| 80 |
+
names = data['ligand']['name']
|
| 81 |
+
print(f'Failed to sample for {names}: {e}')
|
| 82 |
+
continue
|
| 83 |
+
else:
|
| 84 |
+
raise e
|
| 85 |
+
|
| 86 |
+
for mol, pocket, name in zip(rdmols, rdpockets, names):
|
| 87 |
+
name = name.replace('.sdf', '')
|
| 88 |
+
idx = name2count.setdefault(name, 0)
|
| 89 |
+
output_dir = Path(samples_dir, name)
|
| 90 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 91 |
+
if cfg.postprocess:
|
| 92 |
+
mol = process_all(mol, largest_frag=True, adjust_aromatic_Ns=True, relax_iter=0)
|
| 93 |
+
|
| 94 |
+
for prop in mol.GetAtoms()[0].GetPropsAsDict().keys():
|
| 95 |
+
# compute avg uncertainty
|
| 96 |
+
mol.SetDoubleProp(prop, np.mean([a.GetDoubleProp(prop) for a in mol.GetAtoms()]))
|
| 97 |
+
|
| 98 |
+
# visualise local differences
|
| 99 |
+
out_pdb_path = Path(output_dir, f'{idx}_ligand_{prop}.pdb')
|
| 100 |
+
mol_as_pdb(mol, out_pdb_path, bfactor=prop)
|
| 101 |
+
|
| 102 |
+
out_sdf_path = Path(output_dir, f'{idx}_ligand.sdf')
|
| 103 |
+
out_pdb_path = Path(output_dir, f'{idx}_pocket.pdb')
|
| 104 |
+
utils.write_sdf_file(out_sdf_path, [mol])
|
| 105 |
+
mols_to_pdbfile([pocket], out_pdb_path)
|
| 106 |
+
|
| 107 |
+
name2count[name] += 1
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def evaluate(cfg, model_params, samples_dir):
|
| 111 |
+
print('Evaluation...')
|
| 112 |
+
data, table_detailed, table_aggregated = compute_all_metrics_drugflow(
|
| 113 |
+
in_dir=samples_dir,
|
| 114 |
+
gnina_path=model_params['train_params'].gnina,
|
| 115 |
+
reduce_path=cfg.reduce,
|
| 116 |
+
reference_smiles_path=Path(model_params['train_params'].datadir, 'train_smiles.npy'),
|
| 117 |
+
n_samples=cfg.n_samples,
|
| 118 |
+
exclude_evaluators=[] if cfg.exclude_evaluators is None else cfg.exclude_evaluators,
|
| 119 |
+
)
|
| 120 |
+
with open(Path(samples_dir, 'metrics_data.pkl'), 'wb') as f:
|
| 121 |
+
pickle.dump(data, f)
|
| 122 |
+
table_detailed.to_csv(Path(samples_dir, 'metrics_detailed.csv'), index=False)
|
| 123 |
+
table_aggregated.to_csv(Path(samples_dir, 'metrics_aggregated.csv'), index=False)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
p = argparse.ArgumentParser()
|
| 128 |
+
p.add_argument('--config', type=str)
|
| 129 |
+
p.add_argument('--job_id', type=int, default=0, help='Job ID')
|
| 130 |
+
p.add_argument('--n_jobs', type=int, default=1, help='Number of jobs')
|
| 131 |
+
args = p.parse_args()
|
| 132 |
+
|
| 133 |
+
with open(args.config, 'r') as f:
|
| 134 |
+
cfg = yaml.safe_load(f)
|
| 135 |
+
cfg = dict_to_namespace(cfg)
|
| 136 |
+
|
| 137 |
+
utils.set_deterministic(seed=cfg.seed)
|
| 138 |
+
utils.disable_rdkit_logging()
|
| 139 |
+
|
| 140 |
+
model_params = torch.load(cfg.checkpoint, map_location=cfg.device)['hyper_parameters']
|
| 141 |
+
if 'model_args' in cfg:
|
| 142 |
+
ckpt_args = dict_to_namespace(model_params)
|
| 143 |
+
model_params = combine(ckpt_args, cfg.model_args).__dict__
|
| 144 |
+
|
| 145 |
+
ckpt_path = Path(cfg.checkpoint)
|
| 146 |
+
ckpt_name = ckpt_path.parts[-1].split('.')[0]
|
| 147 |
+
n_steps = model_params['simulation_params'].n_steps
|
| 148 |
+
samples_dir = Path(cfg.sample_outdir, cfg.set, f'{ckpt_name}_T={n_steps}') or \
|
| 149 |
+
Path(ckpt_path.parent.parent, 'samples', cfg.set, f'{ckpt_name}_T={n_steps}')
|
| 150 |
+
assert cfg.set in {'val', 'test', 'train'}
|
| 151 |
+
samples_dir.mkdir(parents=True, exist_ok=True)
|
| 152 |
+
|
| 153 |
+
# save configs
|
| 154 |
+
with open(Path(samples_dir, 'model_params.yaml'), 'w') as f:
|
| 155 |
+
yaml.dump(path_to_str(namespace_to_dict(model_params)), f)
|
| 156 |
+
with open(Path(samples_dir, 'sampling_params.yaml'), 'w') as f:
|
| 157 |
+
yaml.dump(path_to_str(namespace_to_dict(cfg)), f)
|
| 158 |
+
|
| 159 |
+
if cfg.sample:
|
| 160 |
+
sample(cfg, model_params, samples_dir, job_id=args.job_id, n_jobs=args.n_jobs)
|
| 161 |
+
|
| 162 |
+
if cfg.evaluate:
|
| 163 |
+
assert args.job_id == 0 and args.n_jobs == 1, 'Evaluation is not parallelised on GPU machines'
|
| 164 |
+
evaluate(cfg, model_params, samples_dir)
|
src/sbdd_metrics/evaluation.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Collection, List, Dict, Type
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from .metrics import FullEvaluator, FullCollectionEvaluator
|
| 13 |
+
|
| 14 |
+
AUXILIARY_COLUMNS = ['sample', 'sdf_file', 'pdb_file', 'subdir']
|
| 15 |
+
VALIDITY_METRIC_NAME = 'medchem.valid'
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_data_type(key: str, data_types: Dict[str, Type], default=float) -> Type:
|
| 19 |
+
found_data_type_key = None
|
| 20 |
+
found_data_type_value = None
|
| 21 |
+
for data_type_key, data_type_value in data_types.items():
|
| 22 |
+
if re.match(data_type_key, key) is not None:
|
| 23 |
+
if found_data_type_key is not None:
|
| 24 |
+
raise ValueError(f'Multiple data type keys match [{key}]: {found_data_type_key}, {data_type_key}')
|
| 25 |
+
|
| 26 |
+
found_data_type_value = data_type_value
|
| 27 |
+
found_data_type_key = data_type_key
|
| 28 |
+
|
| 29 |
+
if found_data_type_key is None:
|
| 30 |
+
if default is None:
|
| 31 |
+
raise KeyError(key)
|
| 32 |
+
else:
|
| 33 |
+
found_data_type_value = default
|
| 34 |
+
|
| 35 |
+
return found_data_type_value
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def convert_data_to_table(data: List[Dict], data_types: Dict[str, Type]) -> pd.DataFrame:
|
| 39 |
+
"""
|
| 40 |
+
Converts data from `evaluate_drugflow` to a detailed table
|
| 41 |
+
"""
|
| 42 |
+
table = []
|
| 43 |
+
for entry in data:
|
| 44 |
+
table_entry = {}
|
| 45 |
+
for key, value in entry.items():
|
| 46 |
+
if key in AUXILIARY_COLUMNS:
|
| 47 |
+
table_entry[key] = value
|
| 48 |
+
continue
|
| 49 |
+
if get_data_type(key, data_types) != list:
|
| 50 |
+
table_entry[key] = value
|
| 51 |
+
table.append(table_entry)
|
| 52 |
+
|
| 53 |
+
return pd.DataFrame(table)
|
| 54 |
+
|
| 55 |
+
def aggregated_metrics(table: pd.DataFrame, data_types: Dict[str, Type], validity_metric_name: str = None):
|
| 56 |
+
"""
|
| 57 |
+
Args:
|
| 58 |
+
table (pd.DataFrame): table with metrics computed for each sample
|
| 59 |
+
data_types (Dict[str, Type]): dictionary with data types for each column
|
| 60 |
+
validity_metric_name (str): name of the column that has validity metric
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
agg_table (pd.DataFrame): table with columns ['metric', 'value', 'std']
|
| 64 |
+
"""
|
| 65 |
+
aggregated_results = []
|
| 66 |
+
|
| 67 |
+
# If validity column name is provided:
|
| 68 |
+
# 1. compute validity on the entire data
|
| 69 |
+
# 2. drop all invalid molecules to compute the rest
|
| 70 |
+
if validity_metric_name is not None:
|
| 71 |
+
aggregated_results.append({
|
| 72 |
+
'metric': validity_metric_name,
|
| 73 |
+
'value': table[validity_metric_name].fillna(False).astype(float).mean(),
|
| 74 |
+
'std': None,
|
| 75 |
+
})
|
| 76 |
+
table = table[table[validity_metric_name]]
|
| 77 |
+
|
| 78 |
+
# Compute aggregated metrics + standard deviations where applicable
|
| 79 |
+
for column in table.columns:
|
| 80 |
+
if column in AUXILIARY_COLUMNS + [validity_metric_name] or get_data_type(column, data_types) == str:
|
| 81 |
+
continue
|
| 82 |
+
with pd.option_context("future.no_silent_downcasting", True):
|
| 83 |
+
if get_data_type(column, data_types) == bool:
|
| 84 |
+
values = table[column].fillna(0).values.astype(float).mean()
|
| 85 |
+
std = None
|
| 86 |
+
else:
|
| 87 |
+
values = table[column].dropna().values.astype(float).mean()
|
| 88 |
+
std = table[column].dropna().values.astype(float).std()
|
| 89 |
+
|
| 90 |
+
aggregated_results.append({
|
| 91 |
+
'metric': column,
|
| 92 |
+
'value': values,
|
| 93 |
+
'std': std,
|
| 94 |
+
})
|
| 95 |
+
|
| 96 |
+
agg_table = pd.DataFrame(aggregated_results)
|
| 97 |
+
return agg_table
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def collection_metrics(
|
| 101 |
+
table: pd.DataFrame,
|
| 102 |
+
reference_smiles: Collection[str],
|
| 103 |
+
validity_metric_name: str = None,
|
| 104 |
+
exclude_evaluators: Collection[str] = [],
|
| 105 |
+
):
|
| 106 |
+
"""
|
| 107 |
+
Args:
|
| 108 |
+
table (pd.DataFrame): table with metrics computed for each sample
|
| 109 |
+
reference_smiles (Collection[str]): list of reference SMILES (e.g. training set)
|
| 110 |
+
validity_metric_name (str): name of the column that has validity metric
|
| 111 |
+
exclude_evaluators (Collection[str]): Evaluator IDs to exclude
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
col_table (pd.DataFrame): table with columns ['metric', 'value']
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
# If validity column name is provided drop all invalid molecules
|
| 118 |
+
if validity_metric_name is not None:
|
| 119 |
+
table = table[table[validity_metric_name]]
|
| 120 |
+
|
| 121 |
+
evaluator = FullCollectionEvaluator(reference_smiles, exclude_evaluators=exclude_evaluators)
|
| 122 |
+
smiles = table['representation.smiles'].values
|
| 123 |
+
if len(smiles) == 0:
|
| 124 |
+
print('No valid input molecules')
|
| 125 |
+
return pd.DataFrame(columns=['metric', 'value'])
|
| 126 |
+
|
| 127 |
+
collection_metrics = evaluator(smiles)
|
| 128 |
+
results = [
|
| 129 |
+
{'metric': key, 'value': value}
|
| 130 |
+
for key, value in collection_metrics.items()
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
col_table = pd.DataFrame(results)
|
| 134 |
+
return col_table
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def evaluate_drugflow_subdir(
|
| 138 |
+
in_dir: Path,
|
| 139 |
+
evaluator: FullEvaluator,
|
| 140 |
+
desc: str = None,
|
| 141 |
+
n_samples: int = None,
|
| 142 |
+
) -> List[Dict]:
|
| 143 |
+
"""
|
| 144 |
+
Computes per-molecule metrics for a single directory of samples for one target
|
| 145 |
+
"""
|
| 146 |
+
results = []
|
| 147 |
+
valid_files = [
|
| 148 |
+
int(fname.split('_')[0])
|
| 149 |
+
for fname in os.listdir(in_dir)
|
| 150 |
+
if fname.endswith('_ligand.sdf') and not fname.startswith('.')
|
| 151 |
+
]
|
| 152 |
+
if len(valid_files) == 0:
|
| 153 |
+
return pd.DataFrame()
|
| 154 |
+
|
| 155 |
+
upper_bound = max(valid_files) + 1
|
| 156 |
+
if n_samples is not None:
|
| 157 |
+
upper_bound = min(upper_bound, n_samples)
|
| 158 |
+
|
| 159 |
+
for i in tqdm(range(upper_bound), desc=desc, file=sys.stdout):
|
| 160 |
+
in_mol = Path(in_dir, f'{i}_ligand.sdf')
|
| 161 |
+
in_prot = Path(in_dir, f'{i}_pocket.pdb')
|
| 162 |
+
res = evaluator(in_mol, in_prot)
|
| 163 |
+
|
| 164 |
+
res['sample'] = i
|
| 165 |
+
res['sdf_file'] = str(in_mol)
|
| 166 |
+
res['pdb_file'] = str(in_prot)
|
| 167 |
+
results.append(res)
|
| 168 |
+
|
| 169 |
+
return results
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def evaluate_drugflow(
|
| 173 |
+
in_dir: Path,
|
| 174 |
+
evaluator: FullEvaluator,
|
| 175 |
+
n_samples: int = None,
|
| 176 |
+
job_id: int = 0,
|
| 177 |
+
n_jobs: int = 1,
|
| 178 |
+
) -> List[Dict]:
|
| 179 |
+
"""
|
| 180 |
+
1. Computes per-molecule metrics for all single directories of samples
|
| 181 |
+
2. Aggregates these metrics
|
| 182 |
+
3. Computes additional collection metrics (if `reference_smiles_path` is provided)
|
| 183 |
+
"""
|
| 184 |
+
data = []
|
| 185 |
+
total_number_of_subdirs = len([path for path in in_dir.glob("[!.]*") if os.path.isdir(path)])
|
| 186 |
+
i = 0
|
| 187 |
+
for subdir in in_dir.glob("[!.]*"):
|
| 188 |
+
if not os.path.isdir(subdir):
|
| 189 |
+
continue
|
| 190 |
+
|
| 191 |
+
i += 1
|
| 192 |
+
if (i - 1) % n_jobs != job_id:
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
curr_data = evaluate_drugflow_subdir(
|
| 196 |
+
in_dir=subdir,
|
| 197 |
+
evaluator=evaluator,
|
| 198 |
+
desc=f'[{i}/{total_number_of_subdirs}] {str(subdir.name)}',
|
| 199 |
+
n_samples=n_samples,
|
| 200 |
+
)
|
| 201 |
+
for entry in curr_data:
|
| 202 |
+
entry['subdir'] = str(subdir)
|
| 203 |
+
data.append(entry)
|
| 204 |
+
|
| 205 |
+
return data
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def compute_all_metrics_drugflow(
|
| 209 |
+
in_dir: Path,
|
| 210 |
+
gnina_path: Path,
|
| 211 |
+
reduce_path: Path = None,
|
| 212 |
+
reference_smiles_path: Path = None,
|
| 213 |
+
n_samples: int = None,
|
| 214 |
+
validity_metric_name: str = VALIDITY_METRIC_NAME,
|
| 215 |
+
exclude_evaluators: Collection[str] = [],
|
| 216 |
+
job_id: int = 0,
|
| 217 |
+
n_jobs: int = 1,
|
| 218 |
+
):
|
| 219 |
+
evaluator = FullEvaluator(gnina=gnina_path, reduce=reduce_path, exclude_evaluators=exclude_evaluators)
|
| 220 |
+
data = evaluate_drugflow(in_dir=in_dir, evaluator=evaluator, n_samples=n_samples, job_id=job_id, n_jobs=n_jobs)
|
| 221 |
+
table_detailed = convert_data_to_table(data, evaluator.dtypes)
|
| 222 |
+
table_aggregated = aggregated_metrics(
|
| 223 |
+
table_detailed,
|
| 224 |
+
data_types=evaluator.dtypes,
|
| 225 |
+
validity_metric_name=validity_metric_name
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Add collection metrics (uniqueness, novelty, FCD, etc.) if reference smiles are provided
|
| 229 |
+
if reference_smiles_path is not None:
|
| 230 |
+
reference_smiles = np.load(reference_smiles_path)
|
| 231 |
+
col_metrics = collection_metrics(
|
| 232 |
+
table=table_detailed,
|
| 233 |
+
reference_smiles=reference_smiles,
|
| 234 |
+
validity_metric_name=validity_metric_name,
|
| 235 |
+
exclude_evaluators=exclude_evaluators
|
| 236 |
+
)
|
| 237 |
+
table_aggregated = pd.concat([table_aggregated, col_metrics])
|
| 238 |
+
|
| 239 |
+
return data, table_detailed, table_aggregated
|
src/sbdd_metrics/fpscores.pkl.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:10dcef9340c873e7b987924461b0af5365eb8dd96be607203debe8ddf80c1e73
|
| 3 |
+
size 3848394
|
src/sbdd_metrics/interactions.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import prody
|
| 2 |
+
import prolif as plf
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import subprocess
|
| 5 |
+
|
| 6 |
+
from io import StringIO
|
| 7 |
+
from prolif.fingerprint import Fingerprint
|
| 8 |
+
from prolif.plotting.complex3d import Complex3D
|
| 9 |
+
from prolif.residue import ResidueId
|
| 10 |
+
from prolif.ifp import IFP
|
| 11 |
+
from rdkit import Chem
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
prody.confProDy(verbosity='none')
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
INTERACTION_LIST = [
|
| 19 |
+
'Anionic', 'Cationic', # Salt Bridges ~400 kJ/mol
|
| 20 |
+
'HBAcceptor', 'HBDonor', # Hydrogen bonds ~10 kJ/mol
|
| 21 |
+
'XBAcceptor', 'XBDonor', # Halogen bonds ~5-30 kJ/mol
|
| 22 |
+
'CationPi', 'PiCation', # 5-10 kJ/mol
|
| 23 |
+
'PiStacking', # ~2-10 kJ/mol
|
| 24 |
+
'Hydrophobic', # 1-10 kJ/mol
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
INTERACTION_ALIASES = {
|
| 28 |
+
'Anionic': 'SaltBridge',
|
| 29 |
+
'Cationic': 'SaltBridge',
|
| 30 |
+
'HBAcceptor': 'HBAcceptor',
|
| 31 |
+
'HBDonor': 'HBDonor',
|
| 32 |
+
'XBAcceptor': 'HalogenBond',
|
| 33 |
+
'XBDonor': 'HalogenBond',
|
| 34 |
+
'CationPi': 'CationPi',
|
| 35 |
+
'PiCation': 'PiCation',
|
| 36 |
+
'PiStacking': 'PiStacking',
|
| 37 |
+
'Hydrophobic': 'Hydrophobic',
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
INTERACTION_COLORS = {
|
| 41 |
+
'SaltBridge': '#eba823',
|
| 42 |
+
'HBDonor': '#3d5dfc',
|
| 43 |
+
'HBAcceptor': '#3d5dfc',
|
| 44 |
+
'HalogenBond': '#53f514',
|
| 45 |
+
'CationPi': '#ff0000',
|
| 46 |
+
'PiCation': '#ff0000',
|
| 47 |
+
'PiStacking': '#e359d8',
|
| 48 |
+
'Hydrophobic': '#c9c5c5',
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
INTERACTION_IMPORTANCE = ['SaltBridge', 'HydrogenBond', 'HBAcceptor', 'HBDonor', 'CationPi', 'PiCation', 'PiStacking', 'Hydrophobic']
|
| 52 |
+
|
| 53 |
+
REDUCE_EXEC = './reduce'
|
| 54 |
+
|
| 55 |
+
def remove_residue_by_atomic_number(structure, resnum, chain_id, icode):
|
| 56 |
+
exclude_selection = f'not (chain {chain_id} and resnum {resnum} and icode {icode})'
|
| 57 |
+
structure = structure.select(exclude_selection)
|
| 58 |
+
return structure
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def read_protein(protein_path, verbose=False, reduce_exec=REDUCE_EXEC):
|
| 62 |
+
structure = prody.parsePDB(protein_path).select('protein')
|
| 63 |
+
hydrogens = structure.select('hydrogen')
|
| 64 |
+
if hydrogens is None or len(hydrogens) < len(set(structure.getResnums())):
|
| 65 |
+
if verbose:
|
| 66 |
+
print('Target structure is not protonated. Adding hydrogens...')
|
| 67 |
+
|
| 68 |
+
reduce_cmd = f'{str(reduce_exec)} {protein_path}'
|
| 69 |
+
reduce_result = subprocess.run(reduce_cmd, shell=True, capture_output=True, text=True)
|
| 70 |
+
if reduce_result.returncode != 0:
|
| 71 |
+
raise RuntimeError('Error during reduce execution:', reduce_result.stderr)
|
| 72 |
+
|
| 73 |
+
pdb_content = reduce_result.stdout
|
| 74 |
+
stream = StringIO()
|
| 75 |
+
stream.write(pdb_content)
|
| 76 |
+
stream.seek(0)
|
| 77 |
+
structure = prody.parsePDBStream(stream).select('protein')
|
| 78 |
+
|
| 79 |
+
# Select only one (largest) altloc
|
| 80 |
+
altlocs = set(structure.getAltlocs())
|
| 81 |
+
try:
|
| 82 |
+
best_altloc = max(altlocs, key=lambda a: structure.select(f'altloc "{a}"').numAtoms())
|
| 83 |
+
structure = structure.select(f'altloc "{best_altloc}"')
|
| 84 |
+
except TypeError:
|
| 85 |
+
# Strange thing that happens only once in the beginning sometimes...
|
| 86 |
+
best_altloc = max(altlocs, key=lambda a: structure.select(f'altloc "{a}"').numAtoms())
|
| 87 |
+
structure = structure.select(f'altloc "{best_altloc}"')
|
| 88 |
+
|
| 89 |
+
return prepare_protein(structure, to_exclude=[], verbose=verbose)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def prepare_protein(input_structure, to_exclude=[], verbose=False):
|
| 93 |
+
structure = input_structure.copy()
|
| 94 |
+
|
| 95 |
+
# Remove residues with bad atoms
|
| 96 |
+
if verbose and len(to_exclude) > 0:
|
| 97 |
+
print(f'Removing {len(to_exclude)} residues...')
|
| 98 |
+
for resnum, chain_id, icode in to_exclude:
|
| 99 |
+
exclude_selection = f'not (chain {chain_id} and resnum {resnum})'
|
| 100 |
+
structure = structure.select(exclude_selection)
|
| 101 |
+
|
| 102 |
+
# Write new PDB content to the stream
|
| 103 |
+
stream = StringIO()
|
| 104 |
+
prody.writePDBStream(stream, structure)
|
| 105 |
+
stream.seek(0)
|
| 106 |
+
|
| 107 |
+
# Sanitize
|
| 108 |
+
rdprot = Chem.MolFromPDBBlock(stream.read(), sanitize=False, removeHs=False)
|
| 109 |
+
try:
|
| 110 |
+
Chem.SanitizeMol(rdprot)
|
| 111 |
+
plfprot = plf.Molecule(rdprot)
|
| 112 |
+
return plfprot
|
| 113 |
+
|
| 114 |
+
except Chem.AtomValenceException as e:
|
| 115 |
+
atom_num = int(e.args[0].replace('Explicit valence for atom # ', '').split()[0])
|
| 116 |
+
info = rdprot.GetAtomWithIdx(atom_num).GetPDBResidueInfo()
|
| 117 |
+
resnum = info.GetResidueNumber()
|
| 118 |
+
chain_id = info.GetChainId()
|
| 119 |
+
icode = f'"{info.GetInsertionCode()}"'
|
| 120 |
+
|
| 121 |
+
to_exclude_next = to_exclude + [(resnum, chain_id, icode)]
|
| 122 |
+
if verbose:
|
| 123 |
+
print(f'[{len(to_exclude_next)}] Removing broken residue with atom={atom_num}, resnum={resnum}, chain_id={chain_id}, icode={icode}')
|
| 124 |
+
return prepare_protein(input_structure, to_exclude=to_exclude_next)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def prepare_ligand(mol):
|
| 128 |
+
Chem.SanitizeMol(mol)
|
| 129 |
+
mol = Chem.AddHs(mol, addCoords=True)
|
| 130 |
+
ligand_plf = plf.Molecule.from_rdkit(mol)
|
| 131 |
+
return ligand_plf
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def sdf_reader(sdf_path, proress_bar=False):
|
| 135 |
+
supp = Chem.SDMolSupplier(sdf_path, removeHs=True, sanitize=False)
|
| 136 |
+
for mol in tqdm(supp) if progress_bar else supp:
|
| 137 |
+
yield prepare_ligand(mol)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def profile_detailed(
|
| 141 |
+
ligand_plf, protein_plf, interaction_list=INTERACTION_LIST, ligand_name='ligand', protein_name='protein'
|
| 142 |
+
):
|
| 143 |
+
|
| 144 |
+
fp = Fingerprint(interactions=interaction_list)
|
| 145 |
+
fp.run_from_iterable(lig_iterable=[ligand_plf], prot_mol=protein_plf, progress=False)
|
| 146 |
+
|
| 147 |
+
profile = []
|
| 148 |
+
|
| 149 |
+
for ligand_residue in ligand_plf.residues:
|
| 150 |
+
for protein_residue in protein_plf.residues:
|
| 151 |
+
metadata = fp.metadata(ligand_plf[ligand_residue], protein_plf[protein_residue])
|
| 152 |
+
for int_name, int_metadata in metadata.items():
|
| 153 |
+
for int_instance in int_metadata:
|
| 154 |
+
profile.append({
|
| 155 |
+
'ligand': ligand_name,
|
| 156 |
+
'protein': protein_name,
|
| 157 |
+
'ligand_residue': str(ligand_residue),
|
| 158 |
+
'protein_residue': str(protein_residue),
|
| 159 |
+
'interaction': int_name,
|
| 160 |
+
'alias': INTERACTION_ALIASES[int_name],
|
| 161 |
+
'ligand_atoms': ','.join(map(str, int_instance['indices']['ligand'])),
|
| 162 |
+
'protein_atoms': ','.join(map(str, int_instance['indices']['protein'])),
|
| 163 |
+
'ligand_orig_atoms': ','.join(map(str, int_instance['parent_indices']['ligand'])),
|
| 164 |
+
'protein_orig_atoms': ','.join(map(str, int_instance['parent_indices']['protein'])),
|
| 165 |
+
'distance': int_instance['distance'],
|
| 166 |
+
'plane_angle': int_instance.get('plane_angle', None),
|
| 167 |
+
'normal_to_centroid_angle': int_instance.get('normal_to_centroid_angle', None),
|
| 168 |
+
'intersect_distance': int_instance.get('intersect_distance', None),
|
| 169 |
+
'intersect_radius': int_instance.get('intersect_radius', None),
|
| 170 |
+
'pi_ring': int_instance.get('pi_ring', None),
|
| 171 |
+
})
|
| 172 |
+
|
| 173 |
+
return pd.DataFrame(profile)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def map_orig_atoms_to_new(atoms, mol):
|
| 177 |
+
orig2new = dict()
|
| 178 |
+
for atom in mol.GetAtoms():
|
| 179 |
+
orig2new[atom.GetUnsignedProp("mapindex")] = atom.GetIdx()
|
| 180 |
+
|
| 181 |
+
atoms = list(map(int, atoms.split(',')))
|
| 182 |
+
new_atoms = ','.join(map(str, [orig2new[atom] for atom in atoms]))
|
| 183 |
+
return new_atoms
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def visualize(profile, ligand_plf, protein_plf):
|
| 187 |
+
metadata = dict()
|
| 188 |
+
|
| 189 |
+
for _, row in profile.iterrows():
|
| 190 |
+
if 'ligand_atoms' not in row:
|
| 191 |
+
row['ligand_atoms'] = map_orig_atoms_to_new(row['ligand_orig_atoms'], ligand_plf)
|
| 192 |
+
if 'protein_atoms' not in row:
|
| 193 |
+
row['protein_atoms'] = map_orig_atoms_to_new(row['protein_orig_atoms'], protein_plf[row['residue']])
|
| 194 |
+
|
| 195 |
+
namenum, chain = row['residue'].split('.')
|
| 196 |
+
name = namenum[:3]
|
| 197 |
+
num = int(namenum[3:])
|
| 198 |
+
protres = ResidueId(name=name, number=num, chain=chain)
|
| 199 |
+
key = (ResidueId(name='UNL', number=1, chain=None), protres)
|
| 200 |
+
|
| 201 |
+
metadata.setdefault(key, dict())
|
| 202 |
+
interaction = {
|
| 203 |
+
'indices': {
|
| 204 |
+
'ligand': tuple(map(int, row['ligand_atoms'].split(','))),
|
| 205 |
+
'protein': tuple(map(int, row['protein_atoms'].split(','))),
|
| 206 |
+
},
|
| 207 |
+
'parent_indices': {
|
| 208 |
+
'ligand': tuple(map(int, row['ligand_atoms'].split(','))),
|
| 209 |
+
'protein': tuple(map(int, row['protein_atoms'].split(','))),
|
| 210 |
+
},
|
| 211 |
+
'distance': row['distance'],
|
| 212 |
+
}
|
| 213 |
+
# if row['plane_angle'] is not None:
|
| 214 |
+
# interaction['plane_angle'] = row['plane_angle']
|
| 215 |
+
# if row['normal_to_centroid_angle'] is not None:
|
| 216 |
+
# interaction['normal_to_centroid_angle'] = row['normal_to_centroid_angle']
|
| 217 |
+
# if row['intersect_distance'] is not None:
|
| 218 |
+
# interaction['intersect_distance'] = row['intersect_distance']
|
| 219 |
+
# if row['intersect_radius'] is not None:
|
| 220 |
+
# interaction['intersect_radius'] = row['intersect_radius']
|
| 221 |
+
# if row['pi_ring'] is not None:
|
| 222 |
+
# interaction['pi_ring'] = row['pi_ring']
|
| 223 |
+
|
| 224 |
+
metadata[key].setdefault(row['alias'], list()).append(interaction)
|
| 225 |
+
|
| 226 |
+
ifp = IFP(metadata)
|
| 227 |
+
fp = Fingerprint(interactions=INTERACTION_LIST, vicinity_cutoff=8.0)
|
| 228 |
+
fp.ifp = {0: ifp}
|
| 229 |
+
Complex3D.COLORS.update(INTERACTION_COLORS)
|
| 230 |
+
v = fp.plot_3d(ligand_mol=ligand_plf, protein_mol=protein_plf, frame=0)
|
| 231 |
+
return v
|
src/sbdd_metrics/metrics.py
ADDED
|
@@ -0,0 +1,929 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import multiprocessing
|
| 2 |
+
import subprocess
|
| 3 |
+
import tempfile
|
| 4 |
+
from abc import abstractmethod
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Union, Dict, Collection, Set, Optional
|
| 8 |
+
import signal
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from unittest.mock import patch
|
| 12 |
+
from scipy.spatial.distance import jensenshannon
|
| 13 |
+
from fcd import get_fcd
|
| 14 |
+
from posebusters import PoseBusters
|
| 15 |
+
from posebusters.modules.distance_geometry import _get_bond_atom_indices, _get_angle_atom_indices
|
| 16 |
+
from rdkit import Chem, RDLogger
|
| 17 |
+
from rdkit.Chem import Descriptors, Crippen, Lipinski, QED, KekulizeException, AtomKekulizeException
|
| 18 |
+
from rdkit.Chem.rdForceFieldHelpers import UFFGetMoleculeForceField
|
| 19 |
+
from scipy.spatial.distance import jensenshannon
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
from useful_rdkit_utils import REOS, RingSystemLookup, get_min_ring_frequency, RingSystemFinder
|
| 22 |
+
|
| 23 |
+
from .interactions import INTERACTION_LIST, prepare_ligand, read_protein, profile_detailed
|
| 24 |
+
from .sascorer import calculateScore
|
| 25 |
+
|
| 26 |
+
def timeout_handler(signum, frame):
|
| 27 |
+
raise TimeoutError('Timeout')
|
| 28 |
+
|
| 29 |
+
BOND_SYMBOLS = {
|
| 30 |
+
Chem.rdchem.BondType.SINGLE: '-',
|
| 31 |
+
Chem.rdchem.BondType.DOUBLE: '=',
|
| 32 |
+
Chem.rdchem.BondType.TRIPLE: '#',
|
| 33 |
+
Chem.rdchem.BondType.AROMATIC: ':',
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def is_nan(value):
|
| 38 |
+
return value is None or pd.isna(value) or np.isnan(value)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def safe_run(func, timeout, **kwargs):
|
| 42 |
+
def _run(f, q, **kwargs):
|
| 43 |
+
r = f(**kwargs)
|
| 44 |
+
q.put(r)
|
| 45 |
+
|
| 46 |
+
queue = multiprocessing.Queue()
|
| 47 |
+
process = multiprocessing.Process(target=_run, kwargs={'f': func, 'q': queue, **kwargs})
|
| 48 |
+
process.start()
|
| 49 |
+
process.join(timeout)
|
| 50 |
+
if process.is_alive():
|
| 51 |
+
print(f"Function {func} didn't finish in {timeout} seconds. Terminating it.")
|
| 52 |
+
process.terminate()
|
| 53 |
+
process.join()
|
| 54 |
+
return None
|
| 55 |
+
elif not queue.empty():
|
| 56 |
+
return queue.get()
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class AbstractEvaluator:
|
| 61 |
+
ID = None
|
| 62 |
+
def __call__(self, molecule: Union[str, Path, Chem.Mol], protein: Union[str, Path] = None,
|
| 63 |
+
timeout=350):
|
| 64 |
+
"""
|
| 65 |
+
Args:
|
| 66 |
+
molecule (Union[str, Path, Chem.Mol]): input molecule
|
| 67 |
+
protein (str): target protein
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
metrics (dict): dictionary of metrics
|
| 71 |
+
"""
|
| 72 |
+
RDLogger.DisableLog('rdApp.*')
|
| 73 |
+
self.check_format(molecule, protein)
|
| 74 |
+
|
| 75 |
+
# timeout handler
|
| 76 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
| 77 |
+
try:
|
| 78 |
+
signal.alarm(timeout)
|
| 79 |
+
results = self.evaluate(molecule, protein)
|
| 80 |
+
except TimeoutError:
|
| 81 |
+
print(f'Error when evaluating [{self.ID}]: Timeout after {timeout} seconds')
|
| 82 |
+
signal.alarm(0)
|
| 83 |
+
return {}
|
| 84 |
+
except Exception as e:
|
| 85 |
+
print(f'Error when evaluating [{self.ID}]: {e}')
|
| 86 |
+
signal.alarm(0)
|
| 87 |
+
return {}
|
| 88 |
+
finally:
|
| 89 |
+
signal.alarm(0)
|
| 90 |
+
return self.add_id(results)
|
| 91 |
+
|
| 92 |
+
def add_id(self, results):
|
| 93 |
+
if self.ID is not None:
|
| 94 |
+
return {f'{self.ID}.{key}': value for key, value in results.items()}
|
| 95 |
+
else:
|
| 96 |
+
return results
|
| 97 |
+
|
| 98 |
+
@abstractmethod
|
| 99 |
+
def evaluate(self, molecule: Union[str, Path, Chem.Mol], protein: Union[str, Path]) -> Dict[str, Union[int, float, str]]:
|
| 100 |
+
raise NotImplementedError
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
def check_format(molecule, protein):
|
| 104 |
+
assert isinstance(molecule, (str, Path, Chem.Mol)), 'Supported molecule types: str, Path, Chem.Mol'
|
| 105 |
+
assert protein is None or isinstance(protein, (str, Path)), 'Supported protein types: str'
|
| 106 |
+
if isinstance(molecule, (str, Path)):
|
| 107 |
+
supp = Chem.SDMolSupplier(str(molecule), sanitize=False)
|
| 108 |
+
assert len(supp) == 1, 'Only one molecule per file is supported'
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def load_molecule(molecule):
|
| 112 |
+
if isinstance(molecule, (str, Path)):
|
| 113 |
+
return Chem.SDMolSupplier(str(molecule), sanitize=False)[0]
|
| 114 |
+
return Chem.Mol(molecule) # create copy to avoid overriding properties of the input molecule
|
| 115 |
+
|
| 116 |
+
@staticmethod
|
| 117 |
+
def save_molecule(molecule, sdf_path):
|
| 118 |
+
if isinstance(molecule, (str, Path)):
|
| 119 |
+
return molecule
|
| 120 |
+
|
| 121 |
+
with Chem.SDWriter(str(sdf_path)) as w:
|
| 122 |
+
try:
|
| 123 |
+
w.write(molecule)
|
| 124 |
+
except (RuntimeError, ValueError) as e:
|
| 125 |
+
if isinstance(e, (KekulizeException, AtomKekulizeException)):
|
| 126 |
+
w.SetKekulize(False)
|
| 127 |
+
w.write(molecule)
|
| 128 |
+
w.SetKekulize(True)
|
| 129 |
+
else:
|
| 130 |
+
w.write(Chem.Mol())
|
| 131 |
+
print('[AbstractEvaluator] Error when saving the molecule')
|
| 132 |
+
|
| 133 |
+
return sdf_path
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
def dtypes(self):
|
| 137 |
+
return self.add_id(self._dtypes)
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
@abstractmethod
|
| 141 |
+
def _dtypes(self):
|
| 142 |
+
raise NotImplementedError
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class RepresentationEvaluator(AbstractEvaluator):
|
| 146 |
+
ID = 'representation'
|
| 147 |
+
|
| 148 |
+
def evaluate(self, molecule, protein=None):
|
| 149 |
+
molecule = self.load_molecule(molecule)
|
| 150 |
+
try:
|
| 151 |
+
smiles = Chem.MolToSmiles(molecule)
|
| 152 |
+
except:
|
| 153 |
+
smiles = None
|
| 154 |
+
|
| 155 |
+
return {'smiles': smiles}
|
| 156 |
+
|
| 157 |
+
@property
|
| 158 |
+
def _dtypes(self):
|
| 159 |
+
return {'smiles': str}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class MolPropertyEvaluator(AbstractEvaluator):
|
| 163 |
+
ID = 'mol_props'
|
| 164 |
+
|
| 165 |
+
def evaluate(self, molecule, protein=None):
|
| 166 |
+
molecule = self.load_molecule(molecule)
|
| 167 |
+
return {k: v for k, v in molecule.GetPropsAsDict().items() if isinstance(v, float)}
|
| 168 |
+
|
| 169 |
+
@property
|
| 170 |
+
def _dtypes(self):
|
| 171 |
+
return {'*': float}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class PoseBustersEvaluator(AbstractEvaluator):
|
| 175 |
+
ID = 'posebusters'
|
| 176 |
+
def __init__(self, pb_conf: str = 'dock'):
|
| 177 |
+
self.posebusters = PoseBusters(config=pb_conf)
|
| 178 |
+
|
| 179 |
+
@patch('rdkit.RDLogger.EnableLog', lambda x: None)
|
| 180 |
+
@patch('rdkit.RDLogger.DisableLog', lambda x: None)
|
| 181 |
+
def evaluate(self, molecule, protein=None):
|
| 182 |
+
result = safe_run(self.posebusters.bust, timeout=20, mol_pred=molecule, mol_cond=protein)
|
| 183 |
+
if result is None:
|
| 184 |
+
return dict()
|
| 185 |
+
|
| 186 |
+
with pd.option_context("future.no_silent_downcasting", True):
|
| 187 |
+
result = dict(result.fillna(False).iloc[0])
|
| 188 |
+
result['all'] = all([bool(value) if not is_nan(value) else False for value in result.values()])
|
| 189 |
+
return result
|
| 190 |
+
|
| 191 |
+
@property
|
| 192 |
+
def _dtypes(self):
|
| 193 |
+
return {'*': bool}
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class GeometryEvaluator(AbstractEvaluator):
|
| 197 |
+
ID = 'geometry'
|
| 198 |
+
|
| 199 |
+
def evaluate(self, molecule, protein=None):
|
| 200 |
+
mol = self.load_molecule(molecule)
|
| 201 |
+
data = self.get_distances_and_angles(mol)
|
| 202 |
+
return data
|
| 203 |
+
|
| 204 |
+
@staticmethod
|
| 205 |
+
def angle_repr(mol, triplet):
|
| 206 |
+
i = mol.GetAtomWithIdx(triplet[0]).GetSymbol()
|
| 207 |
+
j = mol.GetAtomWithIdx(triplet[1]).GetSymbol()
|
| 208 |
+
k = mol.GetAtomWithIdx(triplet[2]).GetSymbol()
|
| 209 |
+
ij = BOND_SYMBOLS[mol.GetBondBetweenAtoms(triplet[0], triplet[1]).GetBondType()]
|
| 210 |
+
jk = BOND_SYMBOLS[mol.GetBondBetweenAtoms(triplet[1], triplet[2]).GetBondType()]
|
| 211 |
+
|
| 212 |
+
# Unified (sorted) representation
|
| 213 |
+
if i < k:
|
| 214 |
+
return f'{i}{ij}{j}{jk}{k}'
|
| 215 |
+
elif i > j:
|
| 216 |
+
return f'{k}{jk}{j}{ij}{i}'
|
| 217 |
+
elif ij <= jk:
|
| 218 |
+
return f'{i}{ij}{j}{jk}{k}'
|
| 219 |
+
else:
|
| 220 |
+
return f'{k}{jk}{j}{ij}{i}'
|
| 221 |
+
|
| 222 |
+
@staticmethod
|
| 223 |
+
def bond_repr(mol, pair):
|
| 224 |
+
i = mol.GetAtomWithIdx(pair[0]).GetSymbol()
|
| 225 |
+
j = mol.GetAtomWithIdx(pair[1]).GetSymbol()
|
| 226 |
+
ij = BOND_SYMBOLS[mol.GetBondBetweenAtoms(pair[0], pair[1]).GetBondType()]
|
| 227 |
+
# Unified (sorted) representation
|
| 228 |
+
return f'{i}{ij}{j}' if i <= j else f'{j}{ij}{i}'
|
| 229 |
+
|
| 230 |
+
@staticmethod
|
| 231 |
+
def get_bond_distances(mol, bonds):
|
| 232 |
+
i, j = np.array(bonds).T
|
| 233 |
+
x = mol.GetConformer().GetPositions()
|
| 234 |
+
xi = x[i]
|
| 235 |
+
xj = x[j]
|
| 236 |
+
bond_distances = np.linalg.norm(xi - xj, axis=1)
|
| 237 |
+
return bond_distances
|
| 238 |
+
|
| 239 |
+
@staticmethod
|
| 240 |
+
def get_angle_values(mol, triplets):
|
| 241 |
+
i, j, k = np.array(triplets).T
|
| 242 |
+
x = mol.GetConformer().GetPositions()
|
| 243 |
+
xi = x[i]
|
| 244 |
+
xj = x[j]
|
| 245 |
+
xk = x[k]
|
| 246 |
+
vji = xi - xj
|
| 247 |
+
vjk = xk - xj
|
| 248 |
+
angles = np.arccos((vji * vjk).sum(axis=1) / (np.linalg.norm(vji, axis=1) * np.linalg.norm(vjk, axis=1)))
|
| 249 |
+
return np.degrees(angles)
|
| 250 |
+
|
| 251 |
+
@staticmethod
|
| 252 |
+
def get_distances_and_angles(mol):
|
| 253 |
+
data = defaultdict(list)
|
| 254 |
+
bonds = _get_bond_atom_indices(mol)
|
| 255 |
+
distances = GeometryEvaluator.get_bond_distances(mol, bonds)
|
| 256 |
+
for b, d in zip(bonds, distances):
|
| 257 |
+
data[GeometryEvaluator.bond_repr(mol, b)].append(d)
|
| 258 |
+
|
| 259 |
+
triplets = _get_angle_atom_indices(bonds)
|
| 260 |
+
angles = GeometryEvaluator.get_angle_values(mol, triplets)
|
| 261 |
+
for t, a in zip(triplets, angles):
|
| 262 |
+
data[GeometryEvaluator.angle_repr(mol, t)].append(a)
|
| 263 |
+
|
| 264 |
+
return data
|
| 265 |
+
|
| 266 |
+
@property
|
| 267 |
+
def _dtypes(self):
|
| 268 |
+
return {'*': list}
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class EnergyEvaluator(AbstractEvaluator):
|
| 272 |
+
ID = 'energy'
|
| 273 |
+
|
| 274 |
+
def evaluate(self, molecule, protein=None):
|
| 275 |
+
molecule = self.load_molecule(molecule)
|
| 276 |
+
try:
|
| 277 |
+
energy = self.get_energy(molecule)
|
| 278 |
+
except:
|
| 279 |
+
energy = None
|
| 280 |
+
return {'energy': energy}
|
| 281 |
+
|
| 282 |
+
@staticmethod
|
| 283 |
+
def get_energy(mol, conf_id=-1):
|
| 284 |
+
mol = Chem.AddHs(mol, addCoords=True)
|
| 285 |
+
uff = UFFGetMoleculeForceField(mol, confId=conf_id)
|
| 286 |
+
e_uff = uff.CalcEnergy()
|
| 287 |
+
return e_uff
|
| 288 |
+
|
| 289 |
+
@property
|
| 290 |
+
def _dtypes(self):
|
| 291 |
+
return {'energy': float}
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class InteractionsEvaluator(AbstractEvaluator):
|
| 295 |
+
ID = 'interactions'
|
| 296 |
+
|
| 297 |
+
def __init__(self, reduce='./reduce'):
|
| 298 |
+
self.reduce = reduce
|
| 299 |
+
|
| 300 |
+
@property
|
| 301 |
+
def default_profile(self):
|
| 302 |
+
return {i: 0 for i in INTERACTION_LIST}
|
| 303 |
+
|
| 304 |
+
def evaluate(self, molecule, protein=None):
|
| 305 |
+
molecule = self.load_molecule(molecule)
|
| 306 |
+
profile = self.default_profile
|
| 307 |
+
try:
|
| 308 |
+
ligand_plf = prepare_ligand(molecule)
|
| 309 |
+
protein_plf = read_protein(str(protein), reduce_exec=self.reduce)
|
| 310 |
+
interactions = profile_detailed(ligand_plf, protein_plf)
|
| 311 |
+
if not interactions.empty:
|
| 312 |
+
profile.update(dict(interactions.interaction.value_counts()))
|
| 313 |
+
except Exception:
|
| 314 |
+
pass
|
| 315 |
+
return profile
|
| 316 |
+
|
| 317 |
+
@property
|
| 318 |
+
def _dtypes(self):
|
| 319 |
+
return {'*': int}
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class GninaEvalulator(AbstractEvaluator):
|
| 323 |
+
ID = 'gnina'
|
| 324 |
+
def __init__(self, gnina):
|
| 325 |
+
self.gnina = gnina
|
| 326 |
+
|
| 327 |
+
def evaluate(self, molecule, protein=None):
|
| 328 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 329 |
+
molecule = self.save_molecule(molecule, sdf_path=Path(tmpdir, 'molecule.sdf'))
|
| 330 |
+
gnina_cmd = f'{self.gnina} -r {str(protein)} -l {str(molecule)} --minimize --seed 42 --no_gpu'
|
| 331 |
+
gnina_result = subprocess.run(gnina_cmd, shell=True, capture_output=True, text=True)
|
| 332 |
+
n_atoms = self.load_molecule(molecule).GetNumAtoms()
|
| 333 |
+
|
| 334 |
+
gnina_scores = self.read_gnina_results(gnina_result)
|
| 335 |
+
|
| 336 |
+
# Additionally computing ligand efficiency
|
| 337 |
+
gnina_scores['vina_efficiency'] = gnina_scores['vina_score'] / n_atoms if n_atoms > 0 else None
|
| 338 |
+
gnina_scores['gnina_efficiency'] = gnina_scores['gnina_score'] / n_atoms if n_atoms > 0 else None
|
| 339 |
+
return gnina_scores
|
| 340 |
+
|
| 341 |
+
@staticmethod
|
| 342 |
+
def read_gnina_results(gnina_result):
|
| 343 |
+
res = {
|
| 344 |
+
'vina_score': None,
|
| 345 |
+
'gnina_score': None,
|
| 346 |
+
'minimisation_rmsd': None,
|
| 347 |
+
'cnn_score': None,
|
| 348 |
+
}
|
| 349 |
+
if gnina_result.returncode != 0:
|
| 350 |
+
print(gnina_result.stderr)
|
| 351 |
+
return res
|
| 352 |
+
|
| 353 |
+
for line in gnina_result.stdout.split('\n'):
|
| 354 |
+
if line.startswith('Affinity'):
|
| 355 |
+
res['vina_score'] = float(line.split(' ')[1].strip())
|
| 356 |
+
if line.startswith('CNNaffinity'):
|
| 357 |
+
res['gnina_score'] = float(line.split(' ')[1].strip())
|
| 358 |
+
if line.startswith('CNNscore'):
|
| 359 |
+
res['cnn_score'] = float(line.split(' ')[1].strip())
|
| 360 |
+
if line.startswith('RMSD'):
|
| 361 |
+
res['minimisation_rmsd'] = float(line.split(' ')[1].strip())
|
| 362 |
+
|
| 363 |
+
return res
|
| 364 |
+
|
| 365 |
+
@property
|
| 366 |
+
def _dtypes(self):
|
| 367 |
+
return {'*': float}
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class MedChemEvaluator(AbstractEvaluator):
|
| 371 |
+
ID = 'medchem'
|
| 372 |
+
def __init__(self, connectivity_threshold=1.0):
|
| 373 |
+
self.connectivity_threshold = connectivity_threshold
|
| 374 |
+
|
| 375 |
+
def evaluate(self, molecule, protein=None):
|
| 376 |
+
molecule = self.load_molecule(molecule)
|
| 377 |
+
valid = self.is_valid(molecule)
|
| 378 |
+
|
| 379 |
+
if valid:
|
| 380 |
+
Chem.SanitizeMol(molecule)
|
| 381 |
+
|
| 382 |
+
connected = None if not valid else self.is_connected(molecule)
|
| 383 |
+
qed = None if not valid else self.calculate_qed(molecule)
|
| 384 |
+
sa = None if not valid else self.calculate_sa(molecule)
|
| 385 |
+
logp = None if not valid else self.calculate_logp(molecule)
|
| 386 |
+
lipinski = None if not valid else self.calculate_lipinski(molecule)
|
| 387 |
+
n_rotatable_bonds = None if not valid else self.calculate_rotatable_bonds(molecule)
|
| 388 |
+
size = self.calculate_molecule_size(molecule)
|
| 389 |
+
|
| 390 |
+
return {
|
| 391 |
+
'valid': valid,
|
| 392 |
+
'connected': connected,
|
| 393 |
+
'qed': qed,
|
| 394 |
+
'sa': sa,
|
| 395 |
+
'logp': logp,
|
| 396 |
+
'lipinski': lipinski,
|
| 397 |
+
'size': size,
|
| 398 |
+
'n_rotatable_bonds': n_rotatable_bonds,
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
@staticmethod
|
| 402 |
+
def is_valid(rdmol):
|
| 403 |
+
if rdmol.GetNumAtoms() < 1:
|
| 404 |
+
return False
|
| 405 |
+
|
| 406 |
+
_mol = Chem.Mol(rdmol)
|
| 407 |
+
try:
|
| 408 |
+
Chem.SanitizeMol(_mol)
|
| 409 |
+
except ValueError:
|
| 410 |
+
return False
|
| 411 |
+
|
| 412 |
+
return True
|
| 413 |
+
|
| 414 |
+
def is_connected(self, rdmol):
|
| 415 |
+
if rdmol.GetNumAtoms() < 1:
|
| 416 |
+
return False
|
| 417 |
+
|
| 418 |
+
try:
|
| 419 |
+
mol_frags = Chem.rdmolops.GetMolFrags(rdmol, asMols=True)
|
| 420 |
+
largest_frag = max(mol_frags, default=rdmol, key=lambda m: m.GetNumAtoms())
|
| 421 |
+
return largest_frag.GetNumAtoms() / rdmol.GetNumAtoms() >= self.connectivity_threshold
|
| 422 |
+
except:
|
| 423 |
+
return False
|
| 424 |
+
|
| 425 |
+
@staticmethod
|
| 426 |
+
def calculate_qed(rdmol):
|
| 427 |
+
try:
|
| 428 |
+
return QED.qed(rdmol)
|
| 429 |
+
except:
|
| 430 |
+
return None
|
| 431 |
+
|
| 432 |
+
@staticmethod
|
| 433 |
+
def calculate_sa(rdmol):
|
| 434 |
+
try:
|
| 435 |
+
sa = calculateScore(rdmol)
|
| 436 |
+
return sa
|
| 437 |
+
except:
|
| 438 |
+
return None
|
| 439 |
+
|
| 440 |
+
@staticmethod
|
| 441 |
+
def calculate_logp(rdmol):
|
| 442 |
+
try:
|
| 443 |
+
return Crippen.MolLogP(rdmol)
|
| 444 |
+
except:
|
| 445 |
+
return None
|
| 446 |
+
|
| 447 |
+
@staticmethod
|
| 448 |
+
def calculate_lipinski(rdmol):
|
| 449 |
+
try:
|
| 450 |
+
rule_1 = Descriptors.ExactMolWt(rdmol) < 500
|
| 451 |
+
rule_2 = Lipinski.NumHDonors(rdmol) <= 5
|
| 452 |
+
rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10
|
| 453 |
+
rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5)
|
| 454 |
+
rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10
|
| 455 |
+
return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])
|
| 456 |
+
except:
|
| 457 |
+
return None
|
| 458 |
+
|
| 459 |
+
@staticmethod
|
| 460 |
+
def calculate_molecule_size(rdmol):
|
| 461 |
+
try:
|
| 462 |
+
return rdmol.GetNumAtoms()
|
| 463 |
+
except:
|
| 464 |
+
return None
|
| 465 |
+
|
| 466 |
+
@staticmethod
|
| 467 |
+
def calculate_rotatable_bonds(rdmol):
|
| 468 |
+
try:
|
| 469 |
+
return Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol)
|
| 470 |
+
except:
|
| 471 |
+
return None
|
| 472 |
+
|
| 473 |
+
@property
|
| 474 |
+
def _dtypes(self):
|
| 475 |
+
return {
|
| 476 |
+
'valid': bool,
|
| 477 |
+
'connected': bool,
|
| 478 |
+
'qed': float,
|
| 479 |
+
'sa': float,
|
| 480 |
+
'logp': float,
|
| 481 |
+
'lipinski': int,
|
| 482 |
+
'size': int,
|
| 483 |
+
'n_rotatable_bonds': int,
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
class ClashEvaluator(AbstractEvaluator):
|
| 488 |
+
ID = 'clashes'
|
| 489 |
+
def __init__(self, margin=0.75, ignore={'H'}):
|
| 490 |
+
self.margin = margin
|
| 491 |
+
self.ignore = ignore
|
| 492 |
+
|
| 493 |
+
def evaluate(self, molecule=None, protein=None):
|
| 494 |
+
result = {
|
| 495 |
+
'passed_clash_score_ligands': None,
|
| 496 |
+
'passed_clash_score_pockets': None,
|
| 497 |
+
'passed_clash_score_between': None,
|
| 498 |
+
}
|
| 499 |
+
if molecule is not None:
|
| 500 |
+
molecule = self.load_molecule(molecule)
|
| 501 |
+
clash_score = self.clash_score(molecule)
|
| 502 |
+
result['clash_score_ligands'] = clash_score
|
| 503 |
+
result['passed_clash_score_ligands'] = (clash_score == 0)
|
| 504 |
+
|
| 505 |
+
if protein is not None:
|
| 506 |
+
protein = Chem.MolFromPDBFile(str(protein), sanitize=False)
|
| 507 |
+
clash_score = self.clash_score(protein)
|
| 508 |
+
result['clash_score_pockets'] = clash_score
|
| 509 |
+
result['passed_clash_score_pockets'] = (clash_score == 0)
|
| 510 |
+
|
| 511 |
+
if molecule is not None and protein is not None:
|
| 512 |
+
clash_score = self.clash_score(molecule, protein)
|
| 513 |
+
result['clash_score_between'] = clash_score
|
| 514 |
+
result['passed_clash_score_between'] = (clash_score == 0)
|
| 515 |
+
|
| 516 |
+
return result
|
| 517 |
+
|
| 518 |
+
def clash_score(self, rdmol1, rdmol2=None):
|
| 519 |
+
"""
|
| 520 |
+
Computes a clash score as the number of atoms that have at least one
|
| 521 |
+
clash divided by the number of atoms in the molecule.
|
| 522 |
+
|
| 523 |
+
INTERMOLECULAR CLASH SCORE
|
| 524 |
+
If rdmol2 is provided, the score is the percentage of atoms in rdmol1
|
| 525 |
+
that have at least one clash with rdmol2.
|
| 526 |
+
We define a clash if two atoms are closer than "margin times the sum of
|
| 527 |
+
their van der Waals radii".
|
| 528 |
+
|
| 529 |
+
INTRAMOLECULAR CLASH SCORE
|
| 530 |
+
If rdmol2 is not provided, the score is the percentage of atoms in rdmol1
|
| 531 |
+
that have at least one clash with other atoms in rdmol1.
|
| 532 |
+
In this case, a clash is defined by margin times the atoms' smallest
|
| 533 |
+
covalent radii (among single, double and triple bond radii). This is done
|
| 534 |
+
so that this function is applicable even if no connectivity information is
|
| 535 |
+
available.
|
| 536 |
+
"""
|
| 537 |
+
|
| 538 |
+
intramolecular = rdmol2 is None
|
| 539 |
+
if intramolecular:
|
| 540 |
+
rdmol2 = rdmol1
|
| 541 |
+
|
| 542 |
+
coord1, radii1 = self.coord_and_radii(rdmol1, intramolecular=intramolecular)
|
| 543 |
+
coord2, radii2 = self.coord_and_radii(rdmol2, intramolecular=intramolecular)
|
| 544 |
+
|
| 545 |
+
dist = np.sqrt(np.sum((coord1[:, None, :] - coord2[None, :, :]) ** 2, axis=-1))
|
| 546 |
+
if intramolecular:
|
| 547 |
+
np.fill_diagonal(dist, np.inf)
|
| 548 |
+
|
| 549 |
+
clashes = dist < self.margin * (radii1[:, None] + radii2[None, :])
|
| 550 |
+
clashes = np.any(clashes, axis=1)
|
| 551 |
+
return np.mean(clashes)
|
| 552 |
+
|
| 553 |
+
def coord_and_radii(self, rdmol, intramolecular):
|
| 554 |
+
_periodic_table = Chem.GetPeriodicTable()
|
| 555 |
+
_get_radius = _periodic_table.GetRcovalent if intramolecular else _periodic_table.GetRvdw
|
| 556 |
+
|
| 557 |
+
coord = rdmol.GetConformer().GetPositions()
|
| 558 |
+
radii = np.array([_get_radius(a.GetSymbol()) for a in rdmol.GetAtoms()])
|
| 559 |
+
|
| 560 |
+
mask = np.array([a.GetSymbol() not in self.ignore for a in rdmol.GetAtoms()])
|
| 561 |
+
coord = coord[mask]
|
| 562 |
+
radii = radii[mask]
|
| 563 |
+
|
| 564 |
+
assert coord.shape[0] == radii.shape[0]
|
| 565 |
+
return coord, radii
|
| 566 |
+
|
| 567 |
+
@property
|
| 568 |
+
def _dtypes(self):
|
| 569 |
+
return {
|
| 570 |
+
'clash_score_ligands': float,
|
| 571 |
+
'clash_score_pockets': float,
|
| 572 |
+
'clash_score_between': float,
|
| 573 |
+
'passed_clash_score_ligands': bool,
|
| 574 |
+
'passed_clash_score_pockets': bool,
|
| 575 |
+
'passed_clash_score_between': bool,
|
| 576 |
+
}
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
class RingCountEvaluator(AbstractEvaluator):
|
| 580 |
+
ID = 'ring_count'
|
| 581 |
+
|
| 582 |
+
def evaluate(self, molecule, protein=None):
|
| 583 |
+
_mol = self.load_molecule(molecule)
|
| 584 |
+
|
| 585 |
+
# compute ring info if not yet available
|
| 586 |
+
try:
|
| 587 |
+
_mol.UpdatePropertyCache()
|
| 588 |
+
except ValueError:
|
| 589 |
+
return {}
|
| 590 |
+
Chem.GetSymmSSSR(_mol)
|
| 591 |
+
|
| 592 |
+
rings = _mol.GetRingInfo().AtomRings()
|
| 593 |
+
ring_sizes = [len(r) for r in rings]
|
| 594 |
+
|
| 595 |
+
ring_counts = defaultdict(int)
|
| 596 |
+
for k in ring_sizes:
|
| 597 |
+
ring_counts[f"num_{k}_rings"] += 1
|
| 598 |
+
|
| 599 |
+
return ring_counts
|
| 600 |
+
|
| 601 |
+
@property
|
| 602 |
+
def _dtypes(self):
|
| 603 |
+
return {'*': int}
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
class ChemblRingEvaluator(AbstractEvaluator):
|
| 607 |
+
ID = 'chembl_ring_systems'
|
| 608 |
+
|
| 609 |
+
def __init__(self):
|
| 610 |
+
self.ring_system_lookup = RingSystemLookup.default() # ChEMBL
|
| 611 |
+
|
| 612 |
+
def evaluate(self, molecule, protein=None):
|
| 613 |
+
|
| 614 |
+
results = {
|
| 615 |
+
'min_ring_smi': None,
|
| 616 |
+
'min_ring_freq_gt0_': None,
|
| 617 |
+
'min_ring_freq_gt10_': None,
|
| 618 |
+
'min_ring_freq_gt100_': None,
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
molecule = self.load_molecule(molecule)
|
| 622 |
+
|
| 623 |
+
try:
|
| 624 |
+
Chem.SanitizeMol(molecule)
|
| 625 |
+
freq_list = self.ring_system_lookup.process_mol(molecule)
|
| 626 |
+
freq_list = self.ring_system_lookup.process_mol(molecule)
|
| 627 |
+
except ValueError:
|
| 628 |
+
return results
|
| 629 |
+
|
| 630 |
+
min_ring, min_freq = get_min_ring_frequency(freq_list)
|
| 631 |
+
|
| 632 |
+
return {
|
| 633 |
+
'min_ring_smi': min_ring,
|
| 634 |
+
'min_ring_freq_gt0_': min_freq > 0,
|
| 635 |
+
'min_ring_freq_gt10_': min_freq > 10,
|
| 636 |
+
'min_ring_freq_gt100_': min_freq > 100,
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
@property
|
| 640 |
+
def _dtypes(self):
|
| 641 |
+
return {
|
| 642 |
+
'min_ring_smi': str,
|
| 643 |
+
'min_ring_freq_gt0_': bool,
|
| 644 |
+
'min_ring_freq_gt10_': bool,
|
| 645 |
+
'min_ring_freq_gt100_': bool,
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
class REOSEvaluator(AbstractEvaluator):
|
| 650 |
+
# Based on https://practicalcheminformatics.blogspot.com/2024/05/generative-molecular-design-isnt-as.html
|
| 651 |
+
ID = 'reos'
|
| 652 |
+
|
| 653 |
+
def __init__(self):
|
| 654 |
+
self.reos = REOS()
|
| 655 |
+
|
| 656 |
+
def evaluate(self, molecule, protein=None):
|
| 657 |
+
|
| 658 |
+
molecule = self.load_molecule(molecule)
|
| 659 |
+
try:
|
| 660 |
+
Chem.SanitizeMol(molecule)
|
| 661 |
+
except ValueError:
|
| 662 |
+
return {rule_set: False for rule_set in self.reos.get_available_rule_sets()}
|
| 663 |
+
|
| 664 |
+
results = {}
|
| 665 |
+
for rule_set in self.reos.get_available_rule_sets():
|
| 666 |
+
self.reos.set_active_rule_sets([rule_set])
|
| 667 |
+
if rule_set == 'PW':
|
| 668 |
+
self.reos.drop_rule('furans')
|
| 669 |
+
|
| 670 |
+
reos_res = self.reos.process_mol(molecule)
|
| 671 |
+
results[rule_set] = reos_res[0] == 'ok'
|
| 672 |
+
|
| 673 |
+
results['all'] = all([bool(value) if not is_nan(value) else False for value in results.values()])
|
| 674 |
+
return results
|
| 675 |
+
|
| 676 |
+
@property
|
| 677 |
+
def _dtypes(self):
|
| 678 |
+
return {'*': bool}
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
class FullEvaluator(AbstractEvaluator):
|
| 682 |
+
def __init__(
|
| 683 |
+
self,
|
| 684 |
+
pb_conf: str = 'dock',
|
| 685 |
+
gnina: Optional[Union[Path, str]] = None,
|
| 686 |
+
reduce: Optional[Union[Path, str]] = None,
|
| 687 |
+
connectivity_threshold: float = 1.0,
|
| 688 |
+
margin: float = 0.75,
|
| 689 |
+
ignore: Set[str] = {'H'},
|
| 690 |
+
exclude_evaluators: Collection[str] = [],
|
| 691 |
+
):
|
| 692 |
+
all_evaluators = [
|
| 693 |
+
RepresentationEvaluator(),
|
| 694 |
+
MolPropertyEvaluator(),
|
| 695 |
+
PoseBustersEvaluator(pb_conf=pb_conf),
|
| 696 |
+
MedChemEvaluator(connectivity_threshold=connectivity_threshold),
|
| 697 |
+
ClashEvaluator(margin=margin, ignore=ignore),
|
| 698 |
+
GeometryEvaluator(),
|
| 699 |
+
RingCountEvaluator(),
|
| 700 |
+
EnergyEvaluator(),
|
| 701 |
+
ChemblRingEvaluator(),
|
| 702 |
+
REOSEvaluator()
|
| 703 |
+
]
|
| 704 |
+
if gnina is not None:
|
| 705 |
+
all_evaluators.append(GninaEvalulator(gnina=gnina))
|
| 706 |
+
else:
|
| 707 |
+
print(f'Evaluator [{GninaEvalulator.ID}] is not included')
|
| 708 |
+
if reduce is not None:
|
| 709 |
+
all_evaluators.append(InteractionsEvaluator(reduce=reduce))
|
| 710 |
+
else:
|
| 711 |
+
print(f'Evaluator [{InteractionsEvaluator.ID}] is not included')
|
| 712 |
+
|
| 713 |
+
self.evaluators = []
|
| 714 |
+
for e in all_evaluators:
|
| 715 |
+
if e.ID in exclude_evaluators:
|
| 716 |
+
print(f'Excluded Evaluator [{e.ID}]')
|
| 717 |
+
else:
|
| 718 |
+
self.evaluators.append(e)
|
| 719 |
+
|
| 720 |
+
print('Will use the following evaluators:')
|
| 721 |
+
for e in self.evaluators:
|
| 722 |
+
print(f'- [{e.ID}]')
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
def evaluate(self, molecule, protein):
|
| 726 |
+
results = {}
|
| 727 |
+
for evaluator in self.evaluators:
|
| 728 |
+
results.update(evaluator(molecule, protein))
|
| 729 |
+
return results
|
| 730 |
+
|
| 731 |
+
@property
|
| 732 |
+
def _dtypes(self):
|
| 733 |
+
all_dtypes = {}
|
| 734 |
+
for evaluator in self.evaluators:
|
| 735 |
+
all_dtypes.update(evaluator.dtypes)
|
| 736 |
+
return all_dtypes
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
########################################################################################
|
| 740 |
+
################################# Collection Metrics ###################################
|
| 741 |
+
########################################################################################
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
class AbstractCollectionEvaluator:
|
| 745 |
+
ID = None
|
| 746 |
+
def __call__(self, smiles: Collection[str], timeout=300):
|
| 747 |
+
"""
|
| 748 |
+
Args:
|
| 749 |
+
smiles (Collection[smiles]): input list of SMILES
|
| 750 |
+
|
| 751 |
+
Returns:
|
| 752 |
+
metrics (dict): dictionary of metrics
|
| 753 |
+
"""
|
| 754 |
+
if self.ID is not None:
|
| 755 |
+
print(f'Running CollectionEvaluator [{self.ID}]')
|
| 756 |
+
|
| 757 |
+
RDLogger.DisableLog('rdApp.*')
|
| 758 |
+
self.check_format(smiles)
|
| 759 |
+
# timeout handler
|
| 760 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
| 761 |
+
try:
|
| 762 |
+
signal.alarm(timeout)
|
| 763 |
+
results = self.evaluate(smiles)
|
| 764 |
+
except TimeoutError:
|
| 765 |
+
print(f'Error when evaluating [{self.ID}]: Timeout after {timeout} seconds')
|
| 766 |
+
signal.alarm(0)
|
| 767 |
+
return {}
|
| 768 |
+
except Exception as e:
|
| 769 |
+
print(f'Error when evaluating [{self.ID}]: {e}')
|
| 770 |
+
signal.alarm(0)
|
| 771 |
+
return {}
|
| 772 |
+
finally:
|
| 773 |
+
print(f'Finished CollectionEvaluator [{self.ID}]')
|
| 774 |
+
signal.alarm(0)
|
| 775 |
+
return results
|
| 776 |
+
|
| 777 |
+
@staticmethod
|
| 778 |
+
def check_format(smiles):
|
| 779 |
+
assert len(smiles) > 0, 'List of input SMILES cannot be empty'
|
| 780 |
+
assert isinstance(smiles, Collection), 'Only list of SMILES supported'
|
| 781 |
+
assert isinstance(smiles[0], str), 'Only list of SMILES supported'
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
class UniquenessEvaluator(AbstractCollectionEvaluator):
|
| 785 |
+
ID = 'uniqueness'
|
| 786 |
+
def evaluate(self, smiles: Collection[str]):
|
| 787 |
+
uniqueness = len(set(smiles)) / len(smiles)
|
| 788 |
+
return {'uniqueness': uniqueness}
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
class NoveltyEvaluator(AbstractCollectionEvaluator):
|
| 792 |
+
ID = 'novelty'
|
| 793 |
+
def __init__(self, reference_smiles: Collection[str]):
|
| 794 |
+
self.reference_smiles = set(list(reference_smiles))
|
| 795 |
+
assert len(self.reference_smiles) > 0, 'List of refernce SMILES cannot be empty'
|
| 796 |
+
|
| 797 |
+
def evaluate(self, smiles: Collection[str]):
|
| 798 |
+
smiles = set(smiles)
|
| 799 |
+
novel = [smi for smi in smiles if smi not in self.reference_smiles]
|
| 800 |
+
novelty = len(novel) / len(smiles)
|
| 801 |
+
return {'novelty': novelty}
|
| 802 |
+
|
| 803 |
+
def canonical_smiles(smiles):
|
| 804 |
+
for smi in smiles:
|
| 805 |
+
try:
|
| 806 |
+
mol = Chem.MolFromSmiles(smi)
|
| 807 |
+
if mol is not None:
|
| 808 |
+
yield Chem.MolToSmiles(mol)
|
| 809 |
+
except:
|
| 810 |
+
yield None
|
| 811 |
+
|
| 812 |
+
class FCDEvaluator(AbstractCollectionEvaluator):
|
| 813 |
+
ID = 'fcd'
|
| 814 |
+
def __init__(self, reference_smiles: Collection[str]):
|
| 815 |
+
self.reference_smiles = list(reference_smiles)
|
| 816 |
+
assert len(self.reference_smiles) > 0, 'List of refernce SMILES cannot be empty'
|
| 817 |
+
|
| 818 |
+
def evaluate(self, smiles: Collection[str]):
|
| 819 |
+
if len(smiles) > len(self.reference_smiles):
|
| 820 |
+
print('Number of reference molecules should be greater than number of input molecules')
|
| 821 |
+
return {'fcd': None}
|
| 822 |
+
|
| 823 |
+
np.random.seed(42)
|
| 824 |
+
reference_smiles = np.random.choice(self.reference_smiles, len(smiles), replace=False).tolist()
|
| 825 |
+
reference_smiles_canonical = [w for w in canonical_smiles(reference_smiles) if w is not None]
|
| 826 |
+
smiles_canonical = [w for w in canonical_smiles(smiles) if w is not None]
|
| 827 |
+
fcd = get_fcd(reference_smiles_canonical, smiles_canonical)
|
| 828 |
+
return {'fcd': fcd}
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
class RingDistributionEvaluator(AbstractCollectionEvaluator):
|
| 832 |
+
ID = 'ring_system_distribution'
|
| 833 |
+
|
| 834 |
+
def __init__(self, reference_smiles: Collection[str], jsd_on_k_most_freq: Collection[int] = ()):
|
| 835 |
+
self.ring_system_finder = RingSystemFinder()
|
| 836 |
+
self.ref_ring_dict = self.compute_ring_dict(reference_smiles)
|
| 837 |
+
self.jsd_on_k_most_freq = jsd_on_k_most_freq
|
| 838 |
+
|
| 839 |
+
def compute_ring_dict(self, molecules):
|
| 840 |
+
|
| 841 |
+
ring_system_dict = defaultdict(int)
|
| 842 |
+
|
| 843 |
+
for mol in tqdm(molecules, desc="Computing ring systems"):
|
| 844 |
+
|
| 845 |
+
if isinstance(mol, str):
|
| 846 |
+
mol = Chem.MolFromSmiles(mol)
|
| 847 |
+
|
| 848 |
+
try:
|
| 849 |
+
ring_system_list = self.ring_system_finder.find_ring_systems(mol, as_mols=True)
|
| 850 |
+
except ValueError:
|
| 851 |
+
print(f"WARNING[{type(self).__name__}]: error while computing ring systems; skipping molecule.")
|
| 852 |
+
continue
|
| 853 |
+
|
| 854 |
+
for ring in ring_system_list:
|
| 855 |
+
inchi_key = Chem.MolToInchiKey(ring)
|
| 856 |
+
ring_system_dict[inchi_key] += 1
|
| 857 |
+
|
| 858 |
+
return ring_system_dict
|
| 859 |
+
|
| 860 |
+
def precision(self, query_ring_dict):
|
| 861 |
+
query_ring_systems = set(query_ring_dict.keys())
|
| 862 |
+
ref_ring_systems = set(self.ref_ring_dict.keys())
|
| 863 |
+
intersection = ref_ring_systems & query_ring_systems
|
| 864 |
+
return len(intersection) / len(query_ring_systems) if len(query_ring_systems) > 0 else 0
|
| 865 |
+
|
| 866 |
+
def recall(self, query_ring_dict):
|
| 867 |
+
query_ring_systems = set(query_ring_dict.keys())
|
| 868 |
+
ref_ring_systems = set(self.ref_ring_dict.keys())
|
| 869 |
+
intersection = ref_ring_systems & query_ring_systems
|
| 870 |
+
return len(intersection) / len(ref_ring_systems) if len(ref_ring_systems) > 0 else 0
|
| 871 |
+
|
| 872 |
+
def jsd(self, query_ring_dict, k_most_freq=None):
|
| 873 |
+
|
| 874 |
+
if k_most_freq is None:
|
| 875 |
+
# example on the union of all ring systems
|
| 876 |
+
sample_space = set(self.ref_ring_dict.keys()) | set(query_ring_dict.keys())
|
| 877 |
+
else:
|
| 878 |
+
# evaluate only on the k most common rings from the reference set
|
| 879 |
+
sorted_rings = [k for k, v in sorted(self.ref_ring_dict.items(), key=lambda item: item[1], reverse=True)]
|
| 880 |
+
sample_space = sorted_rings[:k_most_freq]
|
| 881 |
+
|
| 882 |
+
p = np.zeros(len(sample_space))
|
| 883 |
+
q = np.zeros(len(sample_space))
|
| 884 |
+
|
| 885 |
+
for i, inchi_key in enumerate(sample_space):
|
| 886 |
+
p[i] = self.ref_ring_dict.get(inchi_key, 0)
|
| 887 |
+
q[i] = query_ring_dict.get(inchi_key, 0)
|
| 888 |
+
|
| 889 |
+
# normalize
|
| 890 |
+
p = p / np.sum(p)
|
| 891 |
+
q = q / np.sum(q)
|
| 892 |
+
|
| 893 |
+
return jensenshannon(p, q)
|
| 894 |
+
|
| 895 |
+
def evaluate(self, smiles: Collection[str]):
|
| 896 |
+
|
| 897 |
+
query_ring_dict = self.compute_ring_dict(smiles)
|
| 898 |
+
|
| 899 |
+
out = {
|
| 900 |
+
"precision": self.precision(query_ring_dict),
|
| 901 |
+
"recall": self.recall(query_ring_dict),
|
| 902 |
+
"jsd": self.jsd(query_ring_dict),
|
| 903 |
+
}
|
| 904 |
+
|
| 905 |
+
out.update(
|
| 906 |
+
{f"jsd_{k}_most_freq": self.jsd(query_ring_dict, k_most_freq=k) for k in self.jsd_on_k_most_freq}
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
return out
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
class FullCollectionEvaluator(AbstractCollectionEvaluator):
|
| 913 |
+
def __init__(self, reference_smiles: Collection[str], exclude_evaluators: Collection[str] = []):
|
| 914 |
+
self.evaluators = [
|
| 915 |
+
UniquenessEvaluator(),
|
| 916 |
+
NoveltyEvaluator(reference_smiles=reference_smiles),
|
| 917 |
+
FCDEvaluator(reference_smiles=reference_smiles),
|
| 918 |
+
RingDistributionEvaluator(reference_smiles, jsd_on_k_most_freq=[10, 100, 1000, 10000]),
|
| 919 |
+
]
|
| 920 |
+
for e in self.evaluators:
|
| 921 |
+
if e.ID in exclude_evaluators:
|
| 922 |
+
print(f'Excluding CollectionEvaluator [{e.ID}]')
|
| 923 |
+
self.evaluators.remove(e)
|
| 924 |
+
|
| 925 |
+
def evaluate(self, smiles):
|
| 926 |
+
results = {}
|
| 927 |
+
for evaluator in self.evaluators:
|
| 928 |
+
results.update(evaluator(smiles))
|
| 929 |
+
return results
|
src/sbdd_metrics/sascorer.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# calculation of synthetic accessibility score as described in:
|
| 3 |
+
#
|
| 4 |
+
# Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
|
| 5 |
+
# Peter Ertl and Ansgar Schuffenhauer
|
| 6 |
+
# Journal of Cheminformatics 1:8 (2009)
|
| 7 |
+
# http://www.jcheminf.com/content/1/1/8
|
| 8 |
+
#
|
| 9 |
+
# several small modifications to the original paper are included
|
| 10 |
+
# particularly slightly different formula for marocyclic penalty
|
| 11 |
+
# and taking into account also molecule symmetry (fingerprint density)
|
| 12 |
+
#
|
| 13 |
+
# for a set of 10k diverse molecules the agreement between the original method
|
| 14 |
+
# as implemented in PipelinePilot and this implementation is r2 = 0.97
|
| 15 |
+
#
|
| 16 |
+
# peter ertl & greg landrum, september 2013
|
| 17 |
+
#
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
from rdkit import Chem
|
| 21 |
+
from rdkit.Chem import rdMolDescriptors
|
| 22 |
+
import pickle
|
| 23 |
+
|
| 24 |
+
import math
|
| 25 |
+
from collections import defaultdict
|
| 26 |
+
|
| 27 |
+
import os.path as op
|
| 28 |
+
|
| 29 |
+
_fscores = None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def readFragmentScores(name='fpscores'):
|
| 33 |
+
import gzip
|
| 34 |
+
global _fscores
|
| 35 |
+
# generate the full path filename:
|
| 36 |
+
if name == "fpscores":
|
| 37 |
+
name = op.join(op.dirname(__file__), name)
|
| 38 |
+
data = pickle.load(gzip.open('%s.pkl.gz' % name))
|
| 39 |
+
outDict = {}
|
| 40 |
+
for i in data:
|
| 41 |
+
for j in range(1, len(i)):
|
| 42 |
+
outDict[i[j]] = float(i[0])
|
| 43 |
+
_fscores = outDict
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def numBridgeheadsAndSpiro(mol, ri=None):
|
| 47 |
+
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
|
| 48 |
+
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
|
| 49 |
+
return nBridgehead, nSpiro
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def calculateScore(m):
|
| 53 |
+
if _fscores is None:
|
| 54 |
+
readFragmentScores()
|
| 55 |
+
|
| 56 |
+
# fragment score
|
| 57 |
+
fp = rdMolDescriptors.GetMorganFingerprint(m,
|
| 58 |
+
2) # <- 2 is the *radius* of the circular fingerprint
|
| 59 |
+
fps = fp.GetNonzeroElements()
|
| 60 |
+
score1 = 0.
|
| 61 |
+
nf = 0
|
| 62 |
+
for bitId, v in fps.items():
|
| 63 |
+
nf += v
|
| 64 |
+
sfp = bitId
|
| 65 |
+
score1 += _fscores.get(sfp, -4) * v
|
| 66 |
+
score1 /= nf
|
| 67 |
+
|
| 68 |
+
# features score
|
| 69 |
+
nAtoms = m.GetNumAtoms()
|
| 70 |
+
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
|
| 71 |
+
ri = m.GetRingInfo()
|
| 72 |
+
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
|
| 73 |
+
nMacrocycles = 0
|
| 74 |
+
for x in ri.AtomRings():
|
| 75 |
+
if len(x) > 8:
|
| 76 |
+
nMacrocycles += 1
|
| 77 |
+
|
| 78 |
+
sizePenalty = nAtoms**1.005 - nAtoms
|
| 79 |
+
stereoPenalty = math.log10(nChiralCenters + 1)
|
| 80 |
+
spiroPenalty = math.log10(nSpiro + 1)
|
| 81 |
+
bridgePenalty = math.log10(nBridgeheads + 1)
|
| 82 |
+
macrocyclePenalty = 0.
|
| 83 |
+
# ---------------------------------------
|
| 84 |
+
# This differs from the paper, which defines:
|
| 85 |
+
# macrocyclePenalty = math.log10(nMacrocycles+1)
|
| 86 |
+
# This form generates better results when 2 or more macrocycles are present
|
| 87 |
+
if nMacrocycles > 0:
|
| 88 |
+
macrocyclePenalty = math.log10(2)
|
| 89 |
+
|
| 90 |
+
score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
|
| 91 |
+
|
| 92 |
+
# correction for the fingerprint density
|
| 93 |
+
# not in the original publication, added in version 1.1
|
| 94 |
+
# to make highly symmetrical molecules easier to synthetise
|
| 95 |
+
score3 = 0.
|
| 96 |
+
if nAtoms > len(fps):
|
| 97 |
+
score3 = math.log(float(nAtoms) / len(fps)) * .5
|
| 98 |
+
|
| 99 |
+
sascore = score1 + score2 + score3
|
| 100 |
+
|
| 101 |
+
# need to transform "raw" value into scale between 1 and 10
|
| 102 |
+
min = -4.0
|
| 103 |
+
max = 2.5
|
| 104 |
+
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
|
| 105 |
+
# smooth the 10-end
|
| 106 |
+
if sascore > 8.:
|
| 107 |
+
sascore = 8. + math.log(sascore + 1. - 9.)
|
| 108 |
+
if sascore > 10.:
|
| 109 |
+
sascore = 10.0
|
| 110 |
+
elif sascore < 1.:
|
| 111 |
+
sascore = 1.0
|
| 112 |
+
|
| 113 |
+
return sascore
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def processMols(mols):
|
| 117 |
+
print('smiles\tName\tsa_score')
|
| 118 |
+
for i, m in enumerate(mols):
|
| 119 |
+
if m is None:
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
s = calculateScore(m)
|
| 123 |
+
|
| 124 |
+
smiles = Chem.MolToSmiles(m)
|
| 125 |
+
print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == '__main__':
|
| 129 |
+
import sys
|
| 130 |
+
import time
|
| 131 |
+
|
| 132 |
+
t1 = time.time()
|
| 133 |
+
readFragmentScores("fpscores")
|
| 134 |
+
t2 = time.time()
|
| 135 |
+
|
| 136 |
+
suppl = Chem.SmilesMolSupplier(sys.argv[1])
|
| 137 |
+
t3 = time.time()
|
| 138 |
+
processMols(suppl)
|
| 139 |
+
t4 = time.time()
|
| 140 |
+
|
| 141 |
+
print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
|
| 142 |
+
file=sys.stderr)
|
| 143 |
+
|
| 144 |
+
#
|
| 145 |
+
# Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
|
| 146 |
+
# All rights reserved.
|
| 147 |
+
#
|
| 148 |
+
# Redistribution and use in source and binary forms, with or without
|
| 149 |
+
# modification, are permitted provided that the following conditions are
|
| 150 |
+
# met:
|
| 151 |
+
#
|
| 152 |
+
# * Redistributions of source code must retain the above copyright
|
| 153 |
+
# notice, this list of conditions and the following disclaimer.
|
| 154 |
+
# * Redistributions in binary form must reproduce the above
|
| 155 |
+
# copyright notice, this list of conditions and the following
|
| 156 |
+
# disclaimer in the documentation and/or other materials provided
|
| 157 |
+
# with the distribution.
|
| 158 |
+
# * Neither the name of Novartis Institutes for BioMedical Research Inc.
|
| 159 |
+
# nor the names of its contributors may be used to endorse or promote
|
| 160 |
+
# products derived from this software without specific prior written permission.
|
| 161 |
+
#
|
| 162 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
| 163 |
+
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
| 164 |
+
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
| 165 |
+
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
| 166 |
+
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
| 167 |
+
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
| 168 |
+
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
| 169 |
+
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
| 170 |
+
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 171 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 172 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 173 |
+
#
|