mority commited on
Commit
6e7d4ba
·
verified ·
1 Parent(s): efff3ad

Upload 53 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. LICENSE +9 -0
  3. configs/sampling/sample_and_maybe_evaluate.yml +25 -0
  4. configs/sampling/sample_train_split.yml +25 -0
  5. configs/training/drugflow.yml +82 -0
  6. configs/training/drugflow_no_virtual_nodes.yml +82 -0
  7. configs/training/drugflow_ood.yml +83 -0
  8. configs/training/flexflow.yml +90 -0
  9. configs/training/preference_alignment.yml +93 -0
  10. docs/drugflow.jpg +3 -0
  11. environment.yaml +30 -0
  12. examples/kras.pdb +0 -0
  13. examples/kras_ref_ligand.sdf +74 -0
  14. scripts/python/evaluate_baselines.py +53 -0
  15. scripts/python/postprocess_metrics.py +271 -0
  16. src/analysis/SA_Score/README.md +1 -0
  17. src/analysis/SA_Score/fpscores.pkl.gz +3 -0
  18. src/analysis/SA_Score/sascorer.py +173 -0
  19. src/analysis/metrics.py +544 -0
  20. src/analysis/visualization_utils.py +192 -0
  21. src/constants.py +256 -0
  22. src/data/data_utils.py +901 -0
  23. src/data/dataset.py +208 -0
  24. src/data/misc.py +19 -0
  25. src/data/molecule_builder.py +107 -0
  26. src/data/nerf.py +250 -0
  27. src/data/normal_modes.py +69 -0
  28. src/data/postprocessing.py +93 -0
  29. src/data/process_crossdocked.py +176 -0
  30. src/data/process_dpo_dataset.py +406 -0
  31. src/data/sanifix.py +159 -0
  32. src/data/so3_utils.py +450 -0
  33. src/default/size_distribution.npy +3 -0
  34. src/generate.py +204 -0
  35. src/model/diffusion_utils.py +206 -0
  36. src/model/dpo.py +252 -0
  37. src/model/dynamics.py +791 -0
  38. src/model/dynamics_hetero.py +1008 -0
  39. src/model/flows.py +448 -0
  40. src/model/gvp.py +650 -0
  41. src/model/gvp_transformer.py +471 -0
  42. src/model/lightning.py +1426 -0
  43. src/model/loss_utils.py +79 -0
  44. src/model/markov_bridge.py +163 -0
  45. src/sample_and_evaluate.py +164 -0
  46. src/sbdd_metrics/evaluation.py +239 -0
  47. src/sbdd_metrics/fpscores.pkl.gz +3 -0
  48. src/sbdd_metrics/interactions.py +231 -0
  49. src/sbdd_metrics/metrics.py +929 -0
  50. src/sbdd_metrics/sascorer.py +173 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ docs/drugflow.jpg filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Arne Schneuing, Ilia Igashov, Adrian Dobbelstein
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6
+
7
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8
+
9
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
configs/sampling/sample_and_maybe_evaluate.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoint: <TODO>
2
+ set: test
3
+ sample_outdir: ./samples
4
+ n_samples: 100
5
+ sample_with_ground_truth_size: False
6
+ device: cuda
7
+ seed: 42
8
+ sample: True
9
+ postprocess: False
10
+ evaluate: False
11
+ reduce: reduce
12
+
13
+ # Override training config parameters if necessary
14
+ model_args:
15
+
16
+ virtual_nodes: [0, 5]
17
+
18
+ train_params:
19
+ datadir: ./processed_crossdocked
20
+ gnina: gnina
21
+
22
+ eval_params:
23
+ n_sampling_steps: 500
24
+ eval_batch_size: 1
25
+
configs/sampling/sample_train_split.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoint: <TODO>
2
+ set: train
3
+ sample_outdir: ./samples
4
+ n_samples: 50
5
+ sample_with_ground_truth_size: False
6
+ device: cuda
7
+ seed: 42
8
+ sample: True
9
+ postprocess: False
10
+ evaluate: False
11
+ reduce: reduce
12
+
13
+ # Override training config parameters if necessary
14
+ model_args:
15
+
16
+ virtual_nodes: [0, 10]
17
+
18
+ train_params:
19
+ datadir: ./processed_crossdocked
20
+ gnina: gnina
21
+ batch_size: 2
22
+
23
+ eval_params:
24
+ n_sampling_steps: 100
25
+
configs/training/drugflow.yml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_name: drugflow # iclr_drugflow_T5000
2
+ pocket_representation: CA+
3
+ virtual_nodes: [0, 10]
4
+ flexible: False
5
+ flexible_bb: False
6
+
7
+ train_params:
8
+ logdir: ./runs # symlink to any location you like
9
+ datadir: ./processed_crossdocked # symlink to the dataset location
10
+ enable_progress_bar: True
11
+ num_sanity_val_steps: 0
12
+ batch_size: 64
13
+ accumulate_grad_batches: 2
14
+ lr: 5.0e-4
15
+ n_epochs: 1000
16
+ num_workers: 0
17
+ gpus: 1
18
+ clip_grad: True
19
+ gnina: gnina
20
+ sample_from_clusters: False
21
+ sharded_dataset: False
22
+
23
+ wandb_params:
24
+ mode: online # disabled, offline, online
25
+ entity:
26
+ group: crossdocked
27
+
28
+ loss_params:
29
+ discrete_loss: VLB # VLB or CE
30
+ lambda_x: 1.0
31
+ lambda_h: 50.0
32
+ lambda_e: 50.0
33
+ lambda_chi: null
34
+ lambda_trans: null
35
+ lambda_rot: null
36
+ lambda_clash: null
37
+ timestep_weights: null
38
+
39
+ simulation_params:
40
+ n_steps: 5000
41
+ prior_h: marginal # uniform, marginal
42
+ prior_e: uniform # uniform, marginal
43
+ predict_final: False
44
+ predict_confidence: False
45
+
46
+ eval_params:
47
+ eval_epochs: 100
48
+ n_eval_samples: 4
49
+ n_sampling_steps: 500
50
+ eval_batch_size: 16
51
+ visualize_sample_epoch: 1
52
+ n_visualize_samples: 100
53
+ visualize_chain_epoch: 1
54
+ keep_frames: 100
55
+ sample_with_ground_truth_size: True
56
+
57
+ predictor_params:
58
+ heterogeneous_graph: True
59
+ backbone: gvp
60
+ num_rbf_time: 16
61
+ edge_cutoff_ligand: null
62
+ edge_cutoff_pocket: 10.0
63
+ edge_cutoff_interaction: 10.0
64
+ cycle_counts: True
65
+ spectral_feat: False
66
+ reflection_equivariant: False
67
+ num_rbf: 16
68
+ d_max: 15.0
69
+ self_conditioning: True
70
+ augment_residue_sc: False
71
+ augment_ligand_sc: False
72
+ normal_modes: False
73
+ add_chi_as_feature: False
74
+ angle_act_fn: null
75
+ add_all_atom_diff: False
76
+
77
+ gvp_params:
78
+ n_layers: 5
79
+ node_h_dim: [ 128, 32 ] # (s, V)
80
+ edge_h_dim: [ 128, 32 ]
81
+ dropout: 0.0
82
+ vector_gate: True
configs/training/drugflow_no_virtual_nodes.yml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_name: drugflow_no_virtual_nodes # iclr_drugflow_T5000_no_virtual_nodes
2
+ pocket_representation: CA+
3
+ virtual_nodes: null
4
+ flexible: False
5
+ flexible_bb: False
6
+
7
+ train_params:
8
+ logdir: ./runs # symlink to any location you like
9
+ datadir: ./processed_crossdocked # symlink to the dataset location
10
+ enable_progress_bar: True
11
+ num_sanity_val_steps: 0
12
+ batch_size: 64
13
+ accumulate_grad_batches: 2
14
+ lr: 5.0e-4
15
+ n_epochs: 1000
16
+ num_workers: 0
17
+ gpus: 1
18
+ clip_grad: True
19
+ gnina: gnina
20
+ sample_from_clusters: False
21
+ sharded_dataset: False
22
+
23
+ wandb_params:
24
+ mode: online # disabled, offline, online
25
+ entity: lpdi
26
+ group: crossdocked
27
+
28
+ loss_params:
29
+ discrete_loss: VLB # VLB or CE
30
+ lambda_x: 1.0
31
+ lambda_h: 50.0
32
+ lambda_e: 50.0
33
+ lambda_chi: null
34
+ lambda_trans: null
35
+ lambda_rot: null
36
+ lambda_clash: null
37
+ timestep_weights: null
38
+
39
+ simulation_params:
40
+ n_steps: 5000
41
+ prior_h: marginal # uniform, marginal
42
+ prior_e: uniform # uniform, marginal
43
+ predict_final: False
44
+ predict_confidence: False
45
+
46
+ eval_params:
47
+ eval_epochs: 100
48
+ n_eval_samples: 4
49
+ n_sampling_steps: 500
50
+ eval_batch_size: 16
51
+ visualize_sample_epoch: 1
52
+ n_visualize_samples: 100
53
+ visualize_chain_epoch: 1
54
+ keep_frames: 100
55
+ sample_with_ground_truth_size: True
56
+
57
+ predictor_params:
58
+ heterogeneous_graph: True
59
+ backbone: gvp
60
+ num_rbf_time: 16
61
+ edge_cutoff_ligand: null
62
+ edge_cutoff_pocket: 10.0
63
+ edge_cutoff_interaction: 10.0
64
+ cycle_counts: True
65
+ spectral_feat: False
66
+ reflection_equivariant: False
67
+ num_rbf: 16
68
+ d_max: 15.0
69
+ self_conditioning: True
70
+ augment_residue_sc: False
71
+ augment_ligand_sc: False
72
+ normal_modes: False
73
+ add_chi_as_feature: False
74
+ angle_act_fn: null
75
+ add_all_atom_diff: False
76
+
77
+ gvp_params:
78
+ n_layers: 5
79
+ node_h_dim: [ 128, 32 ] # (s, V)
80
+ edge_h_dim: [ 128, 32 ]
81
+ dropout: 0.0
82
+ vector_gate: True
configs/training/drugflow_ood.yml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_name: drugflow_ood # iclr_drugflow_T5000_confidence_ru10
2
+ pocket_representation: CA+
3
+ virtual_nodes: [0, 10]
4
+ flexible: False
5
+ flexible_bb: False
6
+
7
+ train_params:
8
+ logdir: ./runs # symlink to any location you like
9
+ datadir: ./processed_crossdocked # symlink to the dataset location
10
+ enable_progress_bar: True
11
+ num_sanity_val_steps: 0
12
+ batch_size: 64
13
+ accumulate_grad_batches: 2
14
+ lr: 5.0e-4
15
+ n_epochs: 1000
16
+ num_workers: 0
17
+ gpus: 1
18
+ clip_grad: True
19
+ gnina: gnina
20
+ sample_from_clusters: False
21
+ sharded_dataset: False
22
+
23
+ wandb_params:
24
+ mode: online # disabled, offline, online
25
+ entity: lpdi
26
+ group: crossdocked
27
+
28
+ loss_params:
29
+ discrete_loss: VLB # VLB or CE
30
+ lambda_x: 1.0
31
+ lambda_h: 50.0
32
+ lambda_e: 50.0
33
+ lambda_chi: null
34
+ lambda_trans: null
35
+ lambda_rot: null
36
+ lambda_clash: null
37
+ timestep_weights: null
38
+ regularize_uncertainty: 10.0
39
+
40
+ simulation_params:
41
+ n_steps: 5000
42
+ prior_h: marginal # uniform, marginal
43
+ prior_e: uniform # uniform, marginal
44
+ predict_final: False
45
+ predict_confidence: True
46
+
47
+ eval_params:
48
+ eval_epochs: 100
49
+ n_eval_samples: 4
50
+ n_sampling_steps: 500
51
+ eval_batch_size: 16
52
+ visualize_sample_epoch: 1
53
+ n_visualize_samples: 100
54
+ visualize_chain_epoch: 1
55
+ keep_frames: 100
56
+ sample_with_ground_truth_size: True
57
+
58
+ predictor_params:
59
+ heterogeneous_graph: True
60
+ backbone: gvp
61
+ num_rbf_time: 16
62
+ edge_cutoff_ligand: null
63
+ edge_cutoff_pocket: 10.0
64
+ edge_cutoff_interaction: 10.0
65
+ cycle_counts: True
66
+ spectral_feat: False
67
+ reflection_equivariant: False
68
+ num_rbf: 16
69
+ d_max: 15.0
70
+ self_conditioning: True
71
+ augment_residue_sc: False
72
+ augment_ligand_sc: False
73
+ normal_modes: False
74
+ add_chi_as_feature: False
75
+ angle_act_fn: null
76
+ add_all_atom_diff: False
77
+
78
+ gvp_params:
79
+ n_layers: 5
80
+ node_h_dim: [ 128, 32 ] # (s, V)
81
+ edge_h_dim: [ 128, 32 ]
82
+ dropout: 0.0
83
+ vector_gate: True
configs/training/flexflow.yml ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_name: flexflow
2
+ pocket_representation: CA+
3
+ virtual_nodes: [0, 10]
4
+ flexible: True
5
+ flexible_bb: False
6
+
7
+ train_params:
8
+ logdir: ./runs # symlink to any location you like
9
+ datadir: ./processed_crossdocked # symlink to the dataset location
10
+ enable_progress_bar: False
11
+ num_sanity_val_steps: 0
12
+ batch_size: 64
13
+ accumulate_grad_batches: 2
14
+ lr: 5.0e-4
15
+ lr_step_size: null
16
+ lr_gamma: null
17
+ n_epochs: 700
18
+ num_workers: 4
19
+ gpus: 1
20
+ clip_grad: True
21
+ gnina: gnina # add Gnina location to path
22
+ sample_from_clusters: False
23
+ sharded_dataset: False
24
+
25
+ wandb_params:
26
+ mode: online # disabled, offline, online
27
+ entity:
28
+ group: crossdocked
29
+
30
+ loss_params:
31
+ discrete_loss: VLB # VLB or CE
32
+ reduce: sum # 'mean' or 'sum'
33
+ lambda_x: 0.015
34
+ lambda_h: 2.5
35
+ lambda_e: 0.25
36
+ lambda_chi: 0.002
37
+ lambda_trans: null
38
+ lambda_rot: null
39
+ lambda_clash: null
40
+ regularize_uncertainty: null
41
+ timestep_weights: null
42
+
43
+ simulation_params:
44
+ n_steps: 5000
45
+ prior_h: marginal # uniform, marginal
46
+ prior_e: uniform # uniform, marginal
47
+ predict_final: False
48
+ predict_confidence: False
49
+ scheduler_chi:
50
+ type: polynomial
51
+ k: 3 # constant for exponential scheduler kappa(t)=(1-t)^k
52
+
53
+ eval_params:
54
+ eval_epochs: 100
55
+ n_loss_per_sample: 100
56
+ n_eval_samples: 4
57
+ n_sampling_steps: 500
58
+ eval_batch_size: 16
59
+ visualize_sample_epoch: 1
60
+ n_visualize_samples: 100
61
+ visualize_chain_epoch: 1
62
+ keep_frames: 100
63
+ sample_with_ground_truth_size: True
64
+
65
+ predictor_params:
66
+ heterogeneous_graph: True
67
+ backbone: gvp
68
+ num_rbf_time: 16
69
+ edge_cutoff_ligand: null
70
+ edge_cutoff_pocket: 10.0
71
+ edge_cutoff_interaction: 10.0
72
+ cycle_counts: True
73
+ spectral_feat: False
74
+ reflection_equivariant: False
75
+ num_rbf: 16
76
+ d_max: 15.0
77
+ self_conditioning: True
78
+ augment_residue_sc: False
79
+ augment_ligand_sc: False
80
+ normal_modes: False
81
+ add_chi_as_feature: False
82
+ angle_act_fn: null
83
+ add_all_atom_diff: True
84
+
85
+ gvp_params:
86
+ n_layers: 5
87
+ node_h_dim: [ 128, 32 ] # (s, V)
88
+ edge_h_dim: [ 128, 32 ]
89
+ dropout: 0.0
90
+ vector_gate: True
configs/training/preference_alignment.yml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_name: drugflow_preference_alignment
2
+
3
+ checkpoint: ./reference.ckpt # TODO: specify reference checkpoint
4
+ dpo_mode: single_dpo_comp_v3
5
+
6
+ pocket_representation: CA+
7
+ virtual_nodes: [0, 10]
8
+ flexible: False
9
+ flexible_bb: False
10
+
11
+ train_params:
12
+ logdir: ./runs # symlink to any location you like
13
+ datadir: ./processed_crossdocked # symlink to the dataset location
14
+ enable_progress_bar: True
15
+ num_sanity_val_steps: 0
16
+ batch_size: 64
17
+ accumulate_grad_batches: 2
18
+ lr: 5.0e-5
19
+ n_epochs: 500
20
+ num_workers: 0
21
+ gpus: 1
22
+ clip_grad: True
23
+ gnina: gnina # path to gnina binary
24
+ sample_from_clusters: False
25
+ sharded_dataset: False
26
+
27
+ wandb_params:
28
+ mode: online # disabled, offline, online
29
+ entity:
30
+ group: crossdocked
31
+
32
+ loss_params:
33
+ discrete_loss: VLB # VLB or CE
34
+ lambda_x: 1.0
35
+ lambda_h: 500
36
+ dpo_lambda_h: 2500
37
+ lambda_e: 500
38
+ dpo_lambda_e: 2500
39
+ lambda_chi: 0.5 # only effective if flexible=True
40
+ lambda_trans: 1.0 # only effective if flexible_bb=True
41
+ lambda_rot: 0.1 # only effective if flexible_bb=True
42
+ lambda_clash: null
43
+ timestep_weights: null # sigmoid_a=1_b=10 # null, sigmoid_a=?_b=?
44
+ dpo_beta: 100.0
45
+ dpo_beta_schedule: 't'
46
+ dpo_lambda_w: 1.0
47
+ dpo_lambda_l: 0.2
48
+ clamp_dpo: False
49
+
50
+ simulation_params:
51
+ n_steps: 5000
52
+ prior_h: marginal # uniform, marginal
53
+ prior_e: uniform # uniform, marginal
54
+ predict_final: False
55
+ predict_confidence: False
56
+
57
+ eval_params:
58
+ eval_epochs: 4
59
+ n_eval_samples: 1
60
+ n_sampling_steps: 500
61
+ eval_batch_size: 16
62
+ visualize_sample_epoch: 1
63
+ n_visualize_samples: 10
64
+ visualize_chain_epoch: 1
65
+ keep_frames: 100
66
+ sample_with_ground_truth_size: True
67
+
68
+ predictor_params:
69
+ heterogeneous_graph: True
70
+ backbone: gvp
71
+ num_rbf_time: 16
72
+ edge_cutoff_ligand: null
73
+ edge_cutoff_pocket: 10.0
74
+ edge_cutoff_interaction: 10.0
75
+ cycle_counts: True
76
+ spectral_feat: False
77
+ reflection_equivariant: False
78
+ num_rbf: 16
79
+ d_max: 15.0
80
+ self_conditioning: True
81
+ augment_residue_sc: False
82
+ augment_ligand_sc: False
83
+ normal_modes: False
84
+ add_chi_as_feature: False
85
+ angle_act_fn: null
86
+ add_all_atom_diff: False
87
+
88
+ gvp_params:
89
+ n_layers: 5
90
+ node_h_dim: [ 128, 32 ] # (s, V)
91
+ edge_h_dim: [ 128, 32 ]
92
+ dropout: 0.0
93
+ vector_gate: True
docs/drugflow.jpg ADDED

Git LFS Details

  • SHA256: c16f816eafa8e13658526b74f06f8ee5fb3258f51172c711ec1d1d539b48a4ef
  • Pointer size: 131 Bytes
  • Size of remote file: 762 kB
environment.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sbdd
2
+
3
+ channels:
4
+ - pytorch
5
+ - conda-forge
6
+ - anaconda
7
+ - pyg
8
+ - nvidia
9
+
10
+ dependencies:
11
+ - python=3.11.8
12
+ - pytorch=2.2.1=*cuda12.1*
13
+ - pytorch-cuda=12.1
14
+ - pytorch-lightning=2.2.1
15
+ - rdkit=2023.09.6
16
+ - openbabel=3.1.1
17
+ - biopython=1.83
18
+ - scipy=1.12.0
19
+ - pyg=2.5.1
20
+ - pytorch-scatter=2.1.2
21
+ - ProDy=2.4.0
22
+ - wandb=0.16.3
23
+ - pandas=2.2.2
24
+ - pip=24.0
25
+ - pip:
26
+ - posebusters==0.3.1
27
+ - useful_rdkit_utils==0.65
28
+ - fcd==1.2.2
29
+ - webdataset==0.2.86
30
+ - prolif==2.0.3
examples/kras.pdb ADDED
The diff for this file is too large to render. See raw diff
 
examples/kras_ref_ligand.sdf ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 8AZR
2
+ PyMOL2.5 3D 0
3
+
4
+ 32 36 0 0 0 0 0 0 0 0999 V2000
5
+ 15.7084 1.6569 4.9428 C 0 0 0 0 0 0 0 0 0 0 0 0
6
+ 16.2939 1.9182 6.3219 C 0 0 0 0 0 0 0 0 0 0 0 0
7
+ 17.7757 1.5677 6.3468 C 0 0 0 0 0 0 0 0 0 0 0 0
8
+ 18.0388 0.0580 6.1328 C 0 0 0 0 0 0 0 0 0 0 0 0
9
+ 16.1458 0.3026 4.4709 C 0 0 0 0 0 0 0 0 0 0 0 0
10
+ 17.1748 -0.4207 4.9854 C 0 0 0 0 0 0 0 0 0 0 0 0
11
+ 17.2894 -1.6945 4.3617 C 0 0 0 0 0 0 0 0 0 0 0 0
12
+ 16.3332 -1.9132 3.3763 C 0 0 0 0 0 0 0 0 0 0 0 0
13
+ 15.2948 -0.5437 3.2188 S 0 0 0 0 0 0 0 0 0 0 0 0
14
+ 17.6856 -0.7371 7.4005 C 0 0 0 0 0 0 0 0 0 0 0 0
15
+ 19.5008 -0.1084 5.7694 C 0 0 0 0 0 0 0 0 0 0 0 0
16
+ 19.9420 0.4778 4.6523 O 0 0 0 0 0 0 0 0 0 0 0 0
17
+ 21.3366 0.1893 4.6052 N 0 0 0 0 0 0 0 0 0 0 0 0
18
+ 21.5306 -0.5212 5.6843 C 0 0 0 0 0 0 0 0 0 0 0 0
19
+ 20.3929 -0.7319 6.4483 N 0 0 0 0 0 0 0 0 0 0 0 0
20
+ 16.1651 -3.0052 2.6033 N 0 0 0 0 0 0 0 0 0 0 0 0
21
+ 22.8349 -1.0932 6.0768 C 0 0 0 0 0 0 0 0 0 0 0 0
22
+ 23.9207 -0.6312 5.4365 N 0 0 0 0 0 0 0 0 0 0 0 0
23
+ 25.1129 -1.1528 5.7755 C 0 0 0 0 0 0 0 0 0 0 0 0
24
+ 25.2639 -2.1387 6.7500 C 0 0 0 0 0 0 0 0 0 0 0 0
25
+ 24.1280 -2.5941 7.3940 C 0 0 0 0 0 0 0 0 0 0 0 0
26
+ 22.8945 -2.0709 7.0591 C 0 0 0 0 0 0 0 0 0 0 0 0
27
+ 18.2816 -2.6789 4.6625 C 0 0 0 0 0 0 0 0 0 0 0 0
28
+ 19.0589 -3.4973 4.8688 N 0 0 0 0 0 0 0 0 0 0 0 0
29
+ 26.1982 -0.6750 5.0820 N 0 0 0 0 0 0 0 0 0 0 0 0
30
+ 26.0358 0.4071 4.0954 C 0 0 0 0 0 0 0 0 0 0 0 0
31
+ 26.8978 0.1468 2.8491 C 0 0 0 0 0 0 0 0 0 0 0 0
32
+ 28.2989 -0.1678 3.2648 N 0 0 0 0 0 0 0 0 0 0 0 0
33
+ 28.3171 -1.4142 4.0851 C 0 0 0 0 0 0 0 0 0 0 0 0
34
+ 27.5312 -1.2091 5.3777 C 0 0 0 0 0 0 0 0 0 0 0 0
35
+ 29.1988 -0.2741 2.0804 C 0 0 0 0 0 0 0 0 0 0 0 0
36
+ 26.3415 1.7618 4.7132 C 0 0 0 0 0 0 0 0 0 0 0 0
37
+ 1 2 1 0 0 0 0
38
+ 2 3 1 0 0 0 0
39
+ 3 4 1 0 0 0 0
40
+ 4 6 1 0 0 0 0
41
+ 4 10 1 0 0 0 0
42
+ 4 11 1 0 0 0 0
43
+ 1 5 1 0 0 0 0
44
+ 5 6 4 0 0 0 0
45
+ 5 9 4 0 0 0 0
46
+ 6 7 4 0 0 0 0
47
+ 7 8 4 0 0 0 0
48
+ 7 23 1 0 0 0 0
49
+ 8 9 4 0 0 0 0
50
+ 8 16 1 0 0 0 0
51
+ 11 12 4 0 0 0 0
52
+ 11 15 4 0 0 0 0
53
+ 12 13 4 0 0 0 0
54
+ 13 14 4 0 0 0 0
55
+ 14 15 4 0 0 0 0
56
+ 14 17 1 0 0 0 0
57
+ 17 18 4 0 0 0 0
58
+ 17 22 4 0 0 0 0
59
+ 18 19 4 0 0 0 0
60
+ 19 25 1 0 0 0 0
61
+ 19 20 4 0 0 0 0
62
+ 20 21 4 0 0 0 0
63
+ 21 22 4 0 0 0 0
64
+ 23 24 3 0 0 0 0
65
+ 25 26 1 0 0 0 0
66
+ 26 27 1 0 0 0 0
67
+ 26 32 1 0 0 0 0
68
+ 27 28 1 0 0 0 0
69
+ 28 29 1 0 0 0 0
70
+ 29 30 1 0 0 0 0
71
+ 25 30 1 0 0 0 0
72
+ 28 31 1 0 0 0 0
73
+ M END
74
+ $$$$
scripts/python/evaluate_baselines.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pickle
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ basedir = Path(__file__).resolve().parent.parent.parent
7
+ sys.path.append(str(basedir))
8
+
9
+ from src.sbdd_metrics.evaluation import compute_all_metrics_drugflow
10
+
11
+ if __name__ == '__main__':
12
+ p = argparse.ArgumentParser()
13
+ p.add_argument('--in_dir', type=Path, required=True, help='Directory with samples')
14
+ p.add_argument('--out_dir', type=str, required=True, help='Output directory')
15
+ p.add_argument('--reference_smiles', type=str, default=None, help='Path to the .npy file with reference SMILES (optional)')
16
+ p.add_argument('--gnina', type=str, default=None, help='Path to the gnina binary file (optional)')
17
+ p.add_argument('--reduce', type=str, default=None, help='Path to the reduce binary file (optional)')
18
+ p.add_argument('--n_samples', type=int, default=None, help='Top-N sampels to evaluate (optional)')
19
+ p.add_argument('--exclude', type=str, nargs='+', default=[], help='Evaluator IDs to exclude')
20
+ p.add_argument('--job_id', type=int, default=0, help='Job ID')
21
+ p.add_argument('--n_jobs', type=int, default=1, help='Number of jobs')
22
+ args = p.parse_args()
23
+
24
+ Path(args.out_dir).mkdir(exist_ok=True, parents=True)
25
+ if args.job_id == 0 and args.n_jobs == 1:
26
+ out_detailed_table = Path(args.out_dir, 'metrics_detailed.csv')
27
+ out_aggregated_table = Path(args.out_dir, 'metrics_aggregated.csv')
28
+ out_distributions_file = Path(args.out_dir, 'metrics_data.pkl')
29
+ else:
30
+ out_detailed_table = Path(args.out_dir, f'metrics_detailed_{args.job_id}.csv')
31
+ out_aggregated_table = Path(args.out_dir, f'metrics_aggregated_{args.job_id}.csv')
32
+ out_distributions_file = Path(args.out_dir, f'metrics_data_{args.job_id}.pkl')
33
+
34
+ if out_detailed_table.exists() and out_aggregated_table.exists() and out_distributions_file.exists():
35
+ print(f'Data already exist. Terminating')
36
+ sys.exit(0)
37
+
38
+ print(f'Evaluating: {args.in_dir}')
39
+ data, detailed, aggregated = compute_all_metrics_drugflow(
40
+ in_dir=args.in_dir,
41
+ gnina_path=args.gnina,
42
+ reduce_path=args.reduce,
43
+ reference_smiles_path=args.reference_smiles,
44
+ n_samples=args.n_samples,
45
+ exclude_evaluators=args.exclude,
46
+ job_id=args.job_id,
47
+ n_jobs=args.n_jobs,
48
+ )
49
+
50
+ detailed.to_csv(out_detailed_table, index=False)
51
+ aggregated.to_csv(out_aggregated_table, index=False)
52
+ with open(Path(out_distributions_file), 'wb') as f:
53
+ pickle.dump(data, f)
scripts/python/postprocess_metrics.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pickle
4
+ import sys
5
+ from collections import Counter, defaultdict
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ from rdkit import Chem
11
+ from scipy.stats import wasserstein_distance
12
+ from scipy.spatial.distance import jensenshannon
13
+ from tqdm import tqdm
14
+
15
+ basedir = Path(__file__).resolve().parent.parent.parent
16
+ sys.path.append(str(basedir))
17
+
18
+ from src.data.data_utils import atom_encoder, bond_encoder, encode_atom
19
+ from src.sbdd_metrics.evaluation import VALIDITY_METRIC_NAME, aggregated_metrics, collection_metrics, get_data_type
20
+ from src.sbdd_metrics.metrics import FullEvaluator
21
+
22
+
23
+ DATA_TYPES = data_types = FullEvaluator().dtypes
24
+
25
+ MEDCHEM_PROPS = [
26
+ 'medchem.qed',
27
+ 'medchem.sa',
28
+ 'medchem.logp',
29
+ 'medchem.lipinski',
30
+ 'medchem.size',
31
+ 'medchem.n_rotatable_bonds',
32
+ 'energy.energy',
33
+ ]
34
+
35
+ DOCKING_PROPS = [
36
+ 'gnina.vina_score',
37
+ 'gnina.gnina_score',
38
+ 'gnina.vina_efficiency',
39
+ 'gnina.gnina_efficiency',
40
+ ]
41
+
42
+ RELEVANT_INTERACTIONS = [
43
+ 'interactions.HBAcceptor',
44
+ 'interactions.HBDonor',
45
+ 'interactions.HB',
46
+ 'interactions.PiStacking',
47
+ 'interactions.Hydrophobic',
48
+ #
49
+ 'interactions.HBAcceptor.normalized',
50
+ 'interactions.HBDonor.normalized',
51
+ 'interactions.HB.normalized',
52
+ 'interactions.PiStacking.normalized',
53
+ 'interactions.Hydrophobic.normalized'
54
+ ]
55
+
56
+
57
+ def compute_discrete_distributions(smiles, name):
58
+ atom_counter = Counter()
59
+ bond_counter = Counter()
60
+
61
+ for smi in tqdm(smiles, desc=name):
62
+ mol = Chem.MolFromSmiles(smi)
63
+ mol = Chem.RemoveAllHs(mol, sanitize=False)
64
+ for atom in mol.GetAtoms():
65
+ try:
66
+ encoded_atom = encode_atom(atom, atom_encoder=atom_encoder)
67
+ except KeyError:
68
+ continue
69
+ atom_counter[encoded_atom] += 1
70
+ for bond in mol.GetBonds():
71
+ bond_counter[bond_encoder[str(bond.GetBondType())]] += 1
72
+
73
+ atom_distribution = np.zeros(len(atom_encoder))
74
+ bond_distribution = np.zeros(len(bond_encoder))
75
+
76
+ for k, v in atom_counter.items():
77
+ atom_distribution[k] = v
78
+ for k, v in bond_counter.items():
79
+ bond_distribution[k] = v
80
+
81
+ atom_distribution = atom_distribution / atom_distribution.sum()
82
+ bond_distribution = bond_distribution / bond_distribution.sum()
83
+
84
+ return atom_distribution, bond_distribution
85
+
86
+
87
+ def flatten_distribution(data, name, table):
88
+ aux = ['sample', 'sdf_file', 'pdb_file']
89
+ method_distributions = defaultdict(list)
90
+
91
+ sdf2sample2size = defaultdict(dict)
92
+ for _, row in table.iterrows():
93
+ sdf2sample2size[row['sdf_file']][int(row['sample'])] = row['medchem.size']
94
+
95
+ for item in tqdm(data, desc=name):
96
+ if item['medchem.valid'] is not True:
97
+ continue
98
+
99
+ if 'interactions.HBAcceptor' in item and 'interactions.HBDonor' in item:
100
+ item['interactions.HB'] = item['interactions.HBAcceptor'] + item['interactions.HBDonor']
101
+
102
+ new_entries = {}
103
+ for key, value in item.items():
104
+ if key.startswith('interactions'):
105
+ size = sdf2sample2size.get(item['sdf_file'], dict()).get(int(item['sample']))
106
+ if size is not None:
107
+ new_entries[key + '.normalized'] = value / size
108
+ item.update(new_entries)
109
+
110
+ for key, value in item.items():
111
+ if value is None:
112
+ continue
113
+ if key in aux:
114
+ continue
115
+ if key == 'energy.energy' and abs(value) > 1000:
116
+ continue
117
+
118
+ if get_data_type(key, DATA_TYPES, default=type(value)) == list:
119
+ method_distributions[key] += value
120
+ else:
121
+ method_distributions[key].append(value)
122
+
123
+ return method_distributions
124
+
125
+
126
+ def prepare_baseline_data(root_path, baseline_name):
127
+ metrics_detailed = pd.read_csv(f'{root_path}/metrics_detailed.csv')
128
+ metrics_detailed = metrics_detailed[metrics_detailed['medchem.valid']]
129
+ distributions = pickle.load(open(f'{root_path}/metrics_data.pkl', 'rb'))
130
+ distributions = flatten_distribution(distributions, name=baseline_name, table=metrics_detailed)
131
+ distributions['energy.energy'] = [v for v in distributions['energy.energy'] if -1000 <= v <= 1000]
132
+ for prop in MEDCHEM_PROPS + DOCKING_PROPS:
133
+ distributions[prop] = metrics_detailed[prop].dropna().values.tolist()
134
+
135
+ smiles = metrics_detailed['representation.smiles']
136
+ atom_distribution, bond_distribution = compute_discrete_distributions(smiles, name=baseline_name)
137
+ discrete_distributions = {
138
+ 'atom_types': atom_distribution,
139
+ 'bond_types': bond_distribution,
140
+ }
141
+
142
+ return distributions, discrete_distributions
143
+
144
+
145
+ if __name__ == '__main__':
146
+ p = argparse.ArgumentParser()
147
+ p.add_argument('--in_dir', type=Path, required=True, help='Directory with samples')
148
+ p.add_argument('--out_dir', type=str, required=True, help='Output directory')
149
+ p.add_argument('--n_samples', type=int, required=False, default=None, help='N samples per target')
150
+ p.add_argument('--reference_smiles', type=str, default=None, help='Path to the .npy file with reference SMILES (optional)')
151
+ p.add_argument('--crossdocked_dir', type=str, required=False, default=None, help='Crossdocked data dir for computing distances between distributions')
152
+ args = p.parse_args()
153
+
154
+ Path(args.out_dir).mkdir(parents=True, exist_ok=True)
155
+
156
+ print('Combining data')
157
+ data = []
158
+ for file_path in tqdm(Path(args.in_dir).glob('metrics_data_*.pkl')):
159
+ with open(file_path, 'rb') as f:
160
+ d = pickle.load(f)
161
+ if args.n_samples is not None:
162
+ d = d[:args.n_samples]
163
+ data += d
164
+ with open(Path(args.out_dir, 'metrics_data.pkl'), 'wb') as f:
165
+ pickle.dump(data, f)
166
+
167
+ print('Combining detailed metrics')
168
+ tables = []
169
+ for file_path in tqdm(Path(args.in_dir).glob('metrics_detailed_*.csv')):
170
+ table = pd.read_csv(file_path)
171
+ if args.n_samples is not None:
172
+ table = table.head(args.n_samples)
173
+ tables.append(table)
174
+
175
+ table_detailed = pd.concat(tables)
176
+ table_detailed.to_csv(Path(args.out_dir, 'metrics_detailed.csv'), index=False)
177
+
178
+ print('Computing aggregated metrics')
179
+ evaluator = FullEvaluator(gnina='gnina', reduce='reduce')
180
+ table_aggregated = aggregated_metrics(
181
+ table_detailed,
182
+ data_types=evaluator.dtypes,
183
+ validity_metric_name=VALIDITY_METRIC_NAME
184
+ )
185
+
186
+ if args.reference_smiles is not None:
187
+ reference_smiles = np.load(args.reference_smiles)
188
+ col_metrics = collection_metrics(
189
+ table=table_detailed,
190
+ reference_smiles=reference_smiles,
191
+ validity_metric_name=VALIDITY_METRIC_NAME,
192
+ exclude_evaluators=[],
193
+ )
194
+ table_aggregated = pd.concat([table_aggregated, col_metrics])
195
+
196
+ table_aggregated.to_csv(Path(args.out_dir, 'metrics_aggregated.csv'), index=False)
197
+
198
+ # Computing distributions
199
+ if args.crossdocked_dir is not None:
200
+
201
+ # Loading training data distributions
202
+ crossdocked_distributions = None
203
+ crossdocked_discrete_distributions = None
204
+ precomputed_distr_path = f'{args.crossdocked_dir}/crossdocked_distributions.pkl'
205
+ precomputed_discrete_distr_path = f'{args.crossdocked_dir}/crossdocked_discrete_distributions.pkl'
206
+ if os.path.exists(precomputed_distr_path) and os.path.exists(precomputed_discrete_distr_path):
207
+ # Use precomputed distributions in case they exist
208
+ with open(precomputed_distr_path, 'rb') as f:
209
+ crossdocked_distributions = pickle.load(f)
210
+ with open(precomputed_discrete_distr_path, 'rb') as f:
211
+ crossdocked_discrete_distributions = pickle.load(f)
212
+ else:
213
+ assert os.path.exists(f'{args.crossdocked_dir}/metrics_detailed.csv')
214
+ assert os.path.exists(f'{args.crossdocked_dir}/metrics_data.pkl')
215
+ crossdocked_distributions, crossdocked_discrete_distributions = prepare_baseline_data(
216
+ root_path=args.crossdocked_dir,
217
+ baseline_name='crossdocked'
218
+ )
219
+ # Save precomputed distributions for faster next runs
220
+ with open(precomputed_distr_path, 'wb') as f:
221
+ pickle.dump(crossdocked_distributions, f)
222
+ with open(precomputed_discrete_distr_path, 'wb') as f:
223
+ pickle.dump(crossdocked_discrete_distributions, f)
224
+
225
+ # Selecting top-5 most frequent atom types, bond types, angles and torsions
226
+ bonds = sorted([
227
+ (k, len(v)) for k, v in crossdocked_distributions.items()
228
+ if k.startswith('geometry.') and sum(s.isalpha() for s in k.split('.')[1]) == 2
229
+ ], key=lambda t: t[1], reverse=True)[:5]
230
+ top_5_bonds = [t[0] for t in bonds]
231
+
232
+ angles = sorted([
233
+ (k, len(v)) for k, v in crossdocked_distributions.items()
234
+ if k.startswith('geometry.') and sum(s.isalpha() for s in k.split('.')[1]) == 3
235
+ ], key=lambda t: t[1], reverse=True)[:5]
236
+ top_5_angles = [t[0] for t in angles]
237
+
238
+ # Loading distributions of samples
239
+ distributions, discrete_distributions = prepare_baseline_data(args.out_dir, 'samples')
240
+
241
+ # Computing distances between distributions
242
+ distances = {'method': 'method',}
243
+ relevant_columns = MEDCHEM_PROPS + DOCKING_PROPS + RELEVANT_INTERACTIONS + top_5_bonds + top_5_angles
244
+ for metric in distributions.keys():
245
+ if metric not in relevant_columns:
246
+ continue
247
+
248
+ ref = crossdocked_distributions.get(metric)
249
+ # cur = distributions.get(metric)
250
+ cur = [x for x in distributions.get(metric) if not pd.isna(x)]
251
+
252
+ if ref is not None and cur is not None and len(cur) > 0:
253
+ try:
254
+ distance = wasserstein_distance(ref, cur)
255
+ except:
256
+ from pdb import set_trace; set_trace()
257
+ num_ref = len(ref)
258
+ num_cur = len(cur)
259
+ distances[f'WD.{metric}'] = distance
260
+
261
+ for metric in crossdocked_discrete_distributions.keys():
262
+ ref = crossdocked_discrete_distributions.get(metric)
263
+ cur = discrete_distributions.get(metric)
264
+ if ref is not None and cur is not None:
265
+ distance = jensenshannon(p=ref, q=cur)
266
+ num_ref = len(ref)
267
+ num_cur = len(cur)
268
+ distances[f'JS.{metric}'] = distance
269
+
270
+ dist_table = pd.DataFrame([distances])
271
+ dist_table.to_csv(Path(args.out_dir, 'metrics_distances.csv'), index=False)
src/analysis/SA_Score/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Files taken from: https://github.com/rdkit/rdkit/tree/master/Contrib/SA_Score
src/analysis/SA_Score/fpscores.pkl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10dcef9340c873e7b987924461b0af5365eb8dd96be607203debe8ddf80c1e73
3
+ size 3848394
src/analysis/SA_Score/sascorer.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # calculation of synthetic accessibility score as described in:
3
+ #
4
+ # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
5
+ # Peter Ertl and Ansgar Schuffenhauer
6
+ # Journal of Cheminformatics 1:8 (2009)
7
+ # http://www.jcheminf.com/content/1/1/8
8
+ #
9
+ # several small modifications to the original paper are included
10
+ # particularly slightly different formula for marocyclic penalty
11
+ # and taking into account also molecule symmetry (fingerprint density)
12
+ #
13
+ # for a set of 10k diverse molecules the agreement between the original method
14
+ # as implemented in PipelinePilot and this implementation is r2 = 0.97
15
+ #
16
+ # peter ertl & greg landrum, september 2013
17
+ #
18
+
19
+
20
+ from rdkit import Chem
21
+ from rdkit.Chem import rdMolDescriptors
22
+ import pickle
23
+
24
+ import math
25
+ from collections import defaultdict
26
+
27
+ import os.path as op
28
+
29
+ _fscores = None
30
+
31
+
32
+ def readFragmentScores(name='fpscores'):
33
+ import gzip
34
+ global _fscores
35
+ # generate the full path filename:
36
+ if name == "fpscores":
37
+ name = op.join(op.dirname(__file__), name)
38
+ data = pickle.load(gzip.open('%s.pkl.gz' % name))
39
+ outDict = {}
40
+ for i in data:
41
+ for j in range(1, len(i)):
42
+ outDict[i[j]] = float(i[0])
43
+ _fscores = outDict
44
+
45
+
46
+ def numBridgeheadsAndSpiro(mol, ri=None):
47
+ nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
48
+ nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
49
+ return nBridgehead, nSpiro
50
+
51
+
52
+ def calculateScore(m):
53
+ if _fscores is None:
54
+ readFragmentScores()
55
+
56
+ # fragment score
57
+ fp = rdMolDescriptors.GetMorganFingerprint(m,
58
+ 2) # <- 2 is the *radius* of the circular fingerprint
59
+ fps = fp.GetNonzeroElements()
60
+ score1 = 0.
61
+ nf = 0
62
+ for bitId, v in fps.items():
63
+ nf += v
64
+ sfp = bitId
65
+ score1 += _fscores.get(sfp, -4) * v
66
+ score1 /= nf
67
+
68
+ # features score
69
+ nAtoms = m.GetNumAtoms()
70
+ nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
71
+ ri = m.GetRingInfo()
72
+ nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
73
+ nMacrocycles = 0
74
+ for x in ri.AtomRings():
75
+ if len(x) > 8:
76
+ nMacrocycles += 1
77
+
78
+ sizePenalty = nAtoms**1.005 - nAtoms
79
+ stereoPenalty = math.log10(nChiralCenters + 1)
80
+ spiroPenalty = math.log10(nSpiro + 1)
81
+ bridgePenalty = math.log10(nBridgeheads + 1)
82
+ macrocyclePenalty = 0.
83
+ # ---------------------------------------
84
+ # This differs from the paper, which defines:
85
+ # macrocyclePenalty = math.log10(nMacrocycles+1)
86
+ # This form generates better results when 2 or more macrocycles are present
87
+ if nMacrocycles > 0:
88
+ macrocyclePenalty = math.log10(2)
89
+
90
+ score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
91
+
92
+ # correction for the fingerprint density
93
+ # not in the original publication, added in version 1.1
94
+ # to make highly symmetrical molecules easier to synthetise
95
+ score3 = 0.
96
+ if nAtoms > len(fps):
97
+ score3 = math.log(float(nAtoms) / len(fps)) * .5
98
+
99
+ sascore = score1 + score2 + score3
100
+
101
+ # need to transform "raw" value into scale between 1 and 10
102
+ min = -4.0
103
+ max = 2.5
104
+ sascore = 11. - (sascore - min + 1) / (max - min) * 9.
105
+ # smooth the 10-end
106
+ if sascore > 8.:
107
+ sascore = 8. + math.log(sascore + 1. - 9.)
108
+ if sascore > 10.:
109
+ sascore = 10.0
110
+ elif sascore < 1.:
111
+ sascore = 1.0
112
+
113
+ return sascore
114
+
115
+
116
+ def processMols(mols):
117
+ print('smiles\tName\tsa_score')
118
+ for i, m in enumerate(mols):
119
+ if m is None:
120
+ continue
121
+
122
+ s = calculateScore(m)
123
+
124
+ smiles = Chem.MolToSmiles(m)
125
+ print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
126
+
127
+
128
+ if __name__ == '__main__':
129
+ import sys
130
+ import time
131
+
132
+ t1 = time.time()
133
+ readFragmentScores("fpscores")
134
+ t2 = time.time()
135
+
136
+ suppl = Chem.SmilesMolSupplier(sys.argv[1])
137
+ t3 = time.time()
138
+ processMols(suppl)
139
+ t4 = time.time()
140
+
141
+ print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
142
+ file=sys.stderr)
143
+
144
+ #
145
+ # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
146
+ # All rights reserved.
147
+ #
148
+ # Redistribution and use in source and binary forms, with or without
149
+ # modification, are permitted provided that the following conditions are
150
+ # met:
151
+ #
152
+ # * Redistributions of source code must retain the above copyright
153
+ # notice, this list of conditions and the following disclaimer.
154
+ # * Redistributions in binary form must reproduce the above
155
+ # copyright notice, this list of conditions and the following
156
+ # disclaimer in the documentation and/or other materials provided
157
+ # with the distribution.
158
+ # * Neither the name of Novartis Institutes for BioMedical Research Inc.
159
+ # nor the names of its contributors may be used to endorse or promote
160
+ # products derived from this software without specific prior written permission.
161
+ #
162
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
163
+ # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
164
+ # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
165
+ # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
166
+ # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
167
+ # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
168
+ # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
169
+ # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
170
+ # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
171
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
172
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
173
+ #
src/analysis/metrics.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+
3
+ import numpy as np
4
+ import tempfile
5
+ from pathlib import Path
6
+ from tqdm import tqdm
7
+ from rdkit import Chem, DataStructs
8
+ from rdkit.Chem import AllChem
9
+ from rdkit.Chem import Descriptors, Crippen, Lipinski, QED
10
+ from rdkit.Chem import AtomKekulizeException, AtomValenceException, \
11
+ KekulizeException, MolSanitizeException
12
+ from src.analysis.SA_Score.sascorer import calculateScore
13
+ from src.utils import write_sdf_file
14
+
15
+ from copy import deepcopy
16
+
17
+ from pdb import set_trace
18
+
19
+
20
+ class CategoricalDistribution:
21
+ EPS = 1e-10
22
+
23
+ def __init__(self, histogram_dict, mapping):
24
+ histogram = np.zeros(len(mapping))
25
+ for k, v in histogram_dict.items():
26
+ histogram[mapping[k]] = v
27
+
28
+ # Normalize histogram
29
+ self.p = histogram / histogram.sum()
30
+ self.mapping = deepcopy(mapping)
31
+
32
+ def kl_divergence(self, other_sample):
33
+ sample_histogram = np.zeros(len(self.mapping))
34
+ for x in other_sample:
35
+ # sample_histogram[self.mapping[x]] += 1
36
+ sample_histogram[x] += 1
37
+
38
+ # Normalize
39
+ q = sample_histogram / sample_histogram.sum()
40
+
41
+ return -np.sum(self.p * np.log(q / (self.p + self.EPS) + self.EPS))
42
+
43
+
44
+ def check_mol(rdmol):
45
+ """
46
+ See also: https://www.rdkit.org/docs/RDKit_Book.html#molecular-sanitization
47
+ """
48
+ if rdmol is None:
49
+ return 'is_none'
50
+
51
+ _rdmol = Chem.Mol(rdmol)
52
+ try:
53
+ Chem.SanitizeMol(_rdmol)
54
+ return 'valid'
55
+ except ValueError as e:
56
+ assert isinstance(e, MolSanitizeException)
57
+ return type(e).__name__
58
+
59
+
60
+ def validity_analysis(rdmol_list):
61
+ """
62
+ For explanations, see: https://www.rdkit.org/docs/RDKit_Book.html#molecular-sanitization
63
+ """
64
+
65
+ result = {
66
+ 'AtomValenceException': 0, # atoms in higher-than-allowed valence states
67
+ 'AtomKekulizeException': 0,
68
+ 'KekulizeException': 0, # ring cannot be kekulized or aromatic bonds found outside of rings
69
+ 'other': 0,
70
+ 'valid': 0
71
+ }
72
+
73
+ for rdmol in rdmol_list:
74
+ flag = check_mol(rdmol)
75
+
76
+ try:
77
+ result[flag] += 1
78
+ except KeyError:
79
+ result['other'] += 1
80
+
81
+ assert sum(result.values()) == len(rdmol_list)
82
+
83
+ return result
84
+
85
+
86
+ class MoleculeValidity:
87
+ def __init__(self, connectivity_thresh=1.0):
88
+ self.connectivity_thresh = connectivity_thresh
89
+
90
+ def compute_validity(self, generated):
91
+ """ generated: list of RDKit molecules. """
92
+ if len(generated) < 1:
93
+ return [], 0.0
94
+
95
+ # Return copies of the valid molecules
96
+ valid = [Chem.Mol(mol) for mol in generated if check_mol(mol) == 'valid']
97
+ return valid, len(valid) / len(generated)
98
+
99
+ def compute_connectivity(self, valid):
100
+ """
101
+ Consider molecule connected if its largest fragment contains at
102
+ least <self.connectivity_thresh * 100>% of all atoms.
103
+ :param valid: list of valid RDKit molecules
104
+ """
105
+ if len(valid) < 1:
106
+ return [], 0.0
107
+
108
+ for mol in valid:
109
+ Chem.SanitizeMol(mol) # all molecules should be valid
110
+
111
+ connected = []
112
+ for mol in valid:
113
+
114
+ if mol.GetNumAtoms() < 1:
115
+ continue
116
+
117
+ try:
118
+ mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True)
119
+ except MolSanitizeException as e:
120
+ print('Error while computing connectivity:', e)
121
+ continue
122
+
123
+ largest_frag = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
124
+ if largest_frag.GetNumAtoms() / mol.GetNumAtoms() >= self.connectivity_thresh:
125
+ connected.append(largest_frag)
126
+
127
+ return connected, len(connected) / len(valid)
128
+
129
+ def __call__(self, rdmols, verbose=False):
130
+ """
131
+ :param rdmols: list of RDKit molecules
132
+ """
133
+
134
+ results = {}
135
+ results['n_total'] = len(rdmols)
136
+
137
+ valid, validity = self.compute_validity(rdmols)
138
+ results['n_valid'] = len(valid)
139
+ results['validity'] = validity
140
+
141
+ connected, connectivity = self.compute_connectivity(valid)
142
+ results['n_connected'] = len(connected)
143
+ results['connectivity'] = connectivity
144
+ results['valid_and_connected'] = results['n_connected'] / results['n_total']
145
+
146
+ if verbose:
147
+ print(f"Validity over {results['n_total']} molecules: {validity * 100 :.2f}%")
148
+ print(f"Connectivity over {results['n_valid']} valid molecules: {connectivity * 100 :.2f}%")
149
+
150
+ return results
151
+
152
+
153
+ class MolecularMetrics:
154
+ def __init__(self, connectivity_thresh=1.0):
155
+ self.connectivity_thresh = connectivity_thresh
156
+
157
+ @staticmethod
158
+ def is_valid(rdmol):
159
+ if rdmol.GetNumAtoms() < 1:
160
+ return False
161
+
162
+ _mol = Chem.Mol(rdmol)
163
+ try:
164
+ Chem.SanitizeMol(_mol)
165
+ except ValueError:
166
+ return False
167
+
168
+ return True
169
+
170
+ def is_connected(self, rdmol):
171
+
172
+ if rdmol.GetNumAtoms() < 1:
173
+ return False
174
+
175
+ mol_frags = Chem.rdmolops.GetMolFrags(rdmol, asMols=True)
176
+
177
+ largest_frag = max(mol_frags, default=rdmol, key=lambda m: m.GetNumAtoms())
178
+ if largest_frag.GetNumAtoms() / rdmol.GetNumAtoms() >= self.connectivity_thresh:
179
+ return True
180
+ else:
181
+ return False
182
+
183
+ @staticmethod
184
+ def calculate_qed(rdmol):
185
+ return QED.qed(rdmol)
186
+
187
+ @staticmethod
188
+ def calculate_sa(rdmol):
189
+ sa = calculateScore(rdmol)
190
+ return sa
191
+
192
+ @staticmethod
193
+ def calculate_logp(rdmol):
194
+ return Crippen.MolLogP(rdmol)
195
+
196
+ @staticmethod
197
+ def calculate_lipinski(rdmol):
198
+ rule_1 = Descriptors.ExactMolWt(rdmol) < 500
199
+ rule_2 = Lipinski.NumHDonors(rdmol) <= 5
200
+ rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10
201
+ rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5)
202
+ rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10
203
+ return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])
204
+
205
+ def __call__(self, rdmol):
206
+ valid = self.is_valid(rdmol)
207
+
208
+ if valid:
209
+ Chem.SanitizeMol(rdmol)
210
+
211
+ connected = None if not valid else self.is_connected(rdmol)
212
+ qed = None if not valid else self.calculate_qed(rdmol)
213
+ sa = None if not valid else self.calculate_sa(rdmol)
214
+ logp = None if not valid else self.calculate_logp(rdmol)
215
+ lipinski = None if not valid else self.calculate_lipinski(rdmol)
216
+
217
+ return {
218
+ 'valid': valid,
219
+ 'connected': connected,
220
+ 'qed': qed,
221
+ 'sa': sa,
222
+ 'logp': logp,
223
+ 'lipinski': lipinski
224
+ }
225
+
226
+
227
+ class Diversity:
228
+ @staticmethod
229
+ def similarity(fp1, fp2):
230
+ return DataStructs.TanimotoSimilarity(fp1, fp2)
231
+
232
+ def get_fingerprint(self, mol):
233
+ # fp = AllChem.GetMorganFingerprintAsBitVect(
234
+ # mol, 2, nBits=2048, useChirality=False)
235
+ fp = Chem.RDKFingerprint(mol)
236
+ return fp
237
+
238
+ def __call__(self, pocket_mols):
239
+
240
+ if len(pocket_mols) < 2:
241
+ return 0.0
242
+
243
+ pocket_fps = [self.get_fingerprint(m) for m in pocket_mols]
244
+
245
+ div = 0
246
+ total = 0
247
+ for i in range(len(pocket_fps)):
248
+ for j in range(i + 1, len(pocket_fps)):
249
+ div += 1 - self.similarity(pocket_fps[i], pocket_fps[j])
250
+ total += 1
251
+
252
+ return div / total
253
+
254
+
255
+ class MoleculeUniqueness:
256
+ def __call__(self, smiles_list):
257
+ """ smiles_list: list of SMILES strings. """
258
+ if len(smiles_list) < 1:
259
+ return 0.0
260
+
261
+ return len(set(smiles_list)) / len(smiles_list)
262
+
263
+
264
+ class MoleculeNovelty:
265
+ def __init__(self, reference_smiles):
266
+ """
267
+ :param reference_smiles: list of SMILES strings
268
+ """
269
+ self.reference_smiles = set(reference_smiles)
270
+
271
+ def __call__(self, smiles_list):
272
+ if len(smiles_list) < 1:
273
+ return 0.0
274
+
275
+ novel = [smi for smi in smiles_list if smi not in self.reference_smiles]
276
+ return len(novel) / len(smiles_list)
277
+
278
+
279
+ class MolecularProperties:
280
+
281
+ @staticmethod
282
+ def calculate_qed(rdmol):
283
+ return QED.qed(rdmol)
284
+
285
+ @staticmethod
286
+ def calculate_sa(rdmol):
287
+ sa = calculateScore(rdmol)
288
+ # return round((10 - sa) / 9, 2) # from pocket2mol
289
+ return sa
290
+
291
+ @staticmethod
292
+ def calculate_logp(rdmol):
293
+ return Crippen.MolLogP(rdmol)
294
+
295
+ @staticmethod
296
+ def calculate_lipinski(rdmol):
297
+ rule_1 = Descriptors.ExactMolWt(rdmol) < 500
298
+ rule_2 = Lipinski.NumHDonors(rdmol) <= 5
299
+ rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10
300
+ rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5)
301
+ rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10
302
+ return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])
303
+
304
+ @classmethod
305
+ def calculate_diversity(cls, pocket_mols):
306
+ if len(pocket_mols) < 2:
307
+ return 0.0
308
+
309
+ div = 0
310
+ total = 0
311
+ for i in range(len(pocket_mols)):
312
+ for j in range(i + 1, len(pocket_mols)):
313
+ div += 1 - cls.similarity(pocket_mols[i], pocket_mols[j])
314
+ total += 1
315
+ return div / total
316
+
317
+ @staticmethod
318
+ def similarity(mol_a, mol_b):
319
+ # fp1 = AllChem.GetMorganFingerprintAsBitVect(
320
+ # mol_a, 2, nBits=2048, useChirality=False)
321
+ # fp2 = AllChem.GetMorganFingerprintAsBitVect(
322
+ # mol_b, 2, nBits=2048, useChirality=False)
323
+ fp1 = Chem.RDKFingerprint(mol_a)
324
+ fp2 = Chem.RDKFingerprint(mol_b)
325
+ return DataStructs.TanimotoSimilarity(fp1, fp2)
326
+
327
+ def evaluate_pockets(self, pocket_rdmols, verbose=False):
328
+ """
329
+ Run full evaluation
330
+ Args:
331
+ pocket_rdmols: list of lists, the inner list contains all RDKit
332
+ molecules generated for a pocket
333
+ Returns:
334
+ QED, SA, LogP, Lipinski (per molecule), and Diversity (per pocket)
335
+ """
336
+
337
+ for pocket in pocket_rdmols:
338
+ for mol in pocket:
339
+ Chem.SanitizeMol(mol) # only evaluate valid molecules
340
+
341
+ all_qed = []
342
+ all_sa = []
343
+ all_logp = []
344
+ all_lipinski = []
345
+ per_pocket_diversity = []
346
+ for pocket in tqdm(pocket_rdmols):
347
+ all_qed.append([self.calculate_qed(mol) for mol in pocket])
348
+ all_sa.append([self.calculate_sa(mol) for mol in pocket])
349
+ all_logp.append([self.calculate_logp(mol) for mol in pocket])
350
+ all_lipinski.append([self.calculate_lipinski(mol) for mol in pocket])
351
+ per_pocket_diversity.append(self.calculate_diversity(pocket))
352
+
353
+ qed_flattened = [x for px in all_qed for x in px]
354
+ sa_flattened = [x for px in all_sa for x in px]
355
+ logp_flattened = [x for px in all_logp for x in px]
356
+ lipinski_flattened = [x for px in all_lipinski for x in px]
357
+
358
+ if verbose:
359
+ print(f"{sum([len(p) for p in pocket_rdmols])} molecules from "
360
+ f"{len(pocket_rdmols)} pockets evaluated.")
361
+ print(f"QED: {np.mean(qed_flattened):.3f} \pm {np.std(qed_flattened):.2f}")
362
+ print(f"SA: {np.mean(sa_flattened):.3f} \pm {np.std(sa_flattened):.2f}")
363
+ print(f"LogP: {np.mean(logp_flattened):.3f} \pm {np.std(logp_flattened):.2f}")
364
+ print(f"Lipinski: {np.mean(lipinski_flattened):.3f} \pm {np.std(lipinski_flattened):.2f}")
365
+ print(f"Diversity: {np.mean(per_pocket_diversity):.3f} \pm {np.std(per_pocket_diversity):.2f}")
366
+
367
+ return all_qed, all_sa, all_logp, all_lipinski, per_pocket_diversity
368
+
369
+ def __call__(self, rdmols):
370
+ """
371
+ Run full evaluation and return mean of each property
372
+ Args:
373
+ rdmols: list of RDKit molecules
374
+ Returns:
375
+ Dictionary with mean QED, SA, LogP, Lipinski, and Diversity values
376
+ """
377
+
378
+ if len(rdmols) < 1:
379
+ return {'QED': 0.0, 'SA': 0.0, 'LogP': 0.0, 'Lipinski': 0.0,
380
+ 'Diversity': 0.0}
381
+
382
+ _rdmols = []
383
+ for mol in rdmols:
384
+ try:
385
+ Chem.SanitizeMol(mol) # only evaluate valid molecules
386
+ _rdmols.append(mol)
387
+ except ValueError as e:
388
+ print("Tried to analyze invalid molecule")
389
+ rdmols = _rdmols
390
+
391
+ qed = np.mean([self.calculate_qed(mol) for mol in rdmols])
392
+ sa = np.mean([self.calculate_sa(mol) for mol in rdmols])
393
+ logp = np.mean([self.calculate_logp(mol) for mol in rdmols])
394
+ lipinski = np.mean([self.calculate_lipinski(mol) for mol in rdmols])
395
+ diversity = self.calculate_diversity(rdmols)
396
+
397
+ return {'QED': qed, 'SA': sa, 'LogP': logp, 'Lipinski': lipinski,
398
+ 'Diversity': diversity}
399
+
400
+
401
+ def compute_gnina_scores(ligands, receptors, gnina):
402
+ metrics = ['minimizedAffinity', 'minimizedRMSD', 'CNNscore', 'CNNaffinity', 'CNN_VS', 'CNNaffinity_variance']
403
+ out = {m: [] for m in metrics}
404
+ with tempfile.TemporaryDirectory() as tmpdir:
405
+ for ligand, receptor in zip(tqdm(ligands, desc='Docking'), receptors):
406
+ in_ligand_path = Path(tmpdir, 'in_ligand.sdf')
407
+ out_ligand_path = Path(tmpdir, 'out_ligand.sdf')
408
+ receptor_path = Path(tmpdir, 'receptor.pdb')
409
+ write_sdf_file(in_ligand_path, [ligand], catch_errors=True)
410
+ Chem.MolToPDBFile(receptor, str(receptor_path))
411
+ if (
412
+ (not in_ligand_path.exists()) or
413
+ (not receptor_path.exists()) or
414
+ in_ligand_path.read_text() == '' or
415
+ receptor_path.read_text() == ''
416
+ ):
417
+ continue
418
+
419
+ cmd = (
420
+ f'{gnina} -r {receptor_path} -l {in_ligand_path} '
421
+ f'--minimize --seed 42 -o {out_ligand_path} --no_gpu 1> /dev/null'
422
+ )
423
+ subprocess.run(cmd, shell=True)
424
+ if not out_ligand_path.exists() or out_ligand_path.read_text() == '':
425
+ continue
426
+
427
+ mol = Chem.SDMolSupplier(str(out_ligand_path), sanitize=False)[0]
428
+ for metric in metrics:
429
+ out[metric].append(float(mol.GetProp(metric)))
430
+
431
+ for metric in metrics:
432
+ out[metric] = sum(out[metric]) / len(out[metric]) if len(out[metric]) > 0 else 0
433
+
434
+ return out
435
+
436
+
437
+ def legacy_clash_score(rdmol1, rdmol2=None, margin=0.75):
438
+ """
439
+ Computes a clash score as the number of atoms that have at least one
440
+ clash divided by the number of atoms in the molecule.
441
+
442
+ INTERMOLECULAR CLASH SCORE
443
+ If rdmol2 is provided, the score is the percentage of atoms in rdmol1
444
+ that have at least one clash with rdmol2.
445
+ We define a clash if two atoms are closer than "margin times the sum of
446
+ their van der Waals radii".
447
+
448
+ INTRAMOLECULAR CLASH SCORE
449
+ If rdmol2 is not provided, the score is the percentage of atoms in rdmol1
450
+ that have at least one clash with other atoms in rdmol1.
451
+ In this case, a clash is defined by margin times the atoms' smallest
452
+ covalent radii (among single, double and triple bond radii). This is done
453
+ so that this function is applicable even if no connectivity information is
454
+ available.
455
+ """
456
+ # source: https://en.wikipedia.org/wiki/Van_der_Waals_radius
457
+ vdw_radii = {'N': 1.55, 'O': 1.52, 'C': 1.70, 'H': 1.10, 'S': 1.80, 'P': 1.80,
458
+ 'Se': 1.90, 'K': 2.75, 'Na': 2.27, 'Mg': 1.73, 'Zn': 1.39, 'B': 1.92,
459
+ 'Br': 1.85, 'Cl': 1.75, 'I': 1.98, 'F': 1.47}
460
+
461
+ # https://en.wikipedia.org/wiki/Covalent_radius#Radii_for_multiple_bonds
462
+ covalent_radii = {'H': 0.32, 'C': 0.60, 'N': 0.54, 'O': 0.53, 'F': 0.53, 'B': 0.73,
463
+ 'Al': 1.11, 'Si': 1.02, 'P': 0.94, 'S': 0.94, 'Cl': 0.93, 'As': 1.06,
464
+ 'Br': 1.09, 'I': 1.25, 'Hg': 1.33, 'Bi': 1.35}
465
+
466
+ coord1 = rdmol1.GetConformer().GetPositions()
467
+
468
+ if rdmol2 is None:
469
+ radii1 = np.array([covalent_radii[a.GetSymbol()] for a in rdmol1.GetAtoms()])
470
+ assert coord1.shape[0] == radii1.shape[0]
471
+
472
+ dist = np.sqrt(np.sum((coord1[:, None, :] - coord1[None, :, :]) ** 2, axis=-1))
473
+ np.fill_diagonal(dist, np.inf)
474
+ clashes = dist < margin * (radii1[:, None] + radii1[None, :])
475
+
476
+ else:
477
+ coord2 = rdmol2.GetConformer().GetPositions()
478
+
479
+ radii1 = np.array([vdw_radii[a.GetSymbol()] for a in rdmol1.GetAtoms()])
480
+ assert coord1.shape[0] == radii1.shape[0]
481
+ radii2 = np.array([vdw_radii[a.GetSymbol()] for a in rdmol2.GetAtoms()])
482
+ assert coord2.shape[0] == radii2.shape[0]
483
+
484
+ dist = np.sqrt(np.sum((coord1[:, None, :] - coord2[None, :, :]) ** 2, axis=-1))
485
+ clashes = dist < margin * (radii1[:, None] + radii2[None, :])
486
+
487
+ clashes = np.any(clashes, axis=1)
488
+ return np.mean(clashes)
489
+
490
+
491
+ def clash_score(rdmol1, rdmol2=None, margin=0.75, ignore={'H'}):
492
+ """
493
+ Computes a clash score as the number of atoms that have at least one
494
+ clash divided by the number of atoms in the molecule.
495
+
496
+ INTERMOLECULAR CLASH SCORE
497
+ If rdmol2 is provided, the score is the percentage of atoms in rdmol1
498
+ that have at least one clash with rdmol2.
499
+ We define a clash if two atoms are closer than "margin times the sum of
500
+ their van der Waals radii".
501
+
502
+ INTRAMOLECULAR CLASH SCORE
503
+ If rdmol2 is not provided, the score is the percentage of atoms in rdmol1
504
+ that have at least one clash with other atoms in rdmol1.
505
+ In this case, a clash is defined by margin times the atoms' smallest
506
+ covalent radii (among single, double and triple bond radii). This is done
507
+ so that this function is applicable even if no connectivity information is
508
+ available.
509
+ """
510
+
511
+ intramolecular = rdmol2 is None
512
+
513
+ _periodic_table = AllChem.GetPeriodicTable()
514
+
515
+ def _coord_and_radii(rdmol):
516
+ coord = rdmol.GetConformer().GetPositions()
517
+ radii = np.array([_get_radius(a.GetSymbol()) for a in rdmol.GetAtoms()])
518
+
519
+ mask = np.array([a.GetSymbol() not in ignore for a in rdmol.GetAtoms()])
520
+ coord = coord[mask]
521
+ radii = radii[mask]
522
+
523
+ assert coord.shape[0] == radii.shape[0]
524
+ return coord, radii
525
+
526
+ # INTRAMOLECULAR CLASH SCORE
527
+ if intramolecular:
528
+ rdmol2 = rdmol1
529
+ _get_radius = _periodic_table.GetRcovalent # covalent radii
530
+
531
+ # INTERMOLECULAR CLASH SCORE
532
+ else:
533
+ _get_radius = _periodic_table.GetRvdw # vdW radii
534
+
535
+ coord1, radii1 = _coord_and_radii(rdmol1)
536
+ coord2, radii2 = _coord_and_radii(rdmol2)
537
+
538
+ dist = np.sqrt(np.sum((coord1[:, None, :] - coord2[None, :, :]) ** 2, axis=-1))
539
+ if intramolecular:
540
+ np.fill_diagonal(dist, np.inf)
541
+
542
+ clashes = dist < margin * (radii1[:, None] + radii2[None, :])
543
+ clashes = np.any(clashes, axis=1)
544
+ return np.mean(clashes)
src/analysis/visualization_utils.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+ from rdkit import Chem
5
+ from rdkit.Chem import Draw, AllChem
6
+ from rdkit.Chem import SanitizeFlags
7
+ from src.analysis.metrics import check_mol
8
+ from src import utils
9
+ from src.data.molecule_builder import build_molecule
10
+ from src.data.misc import protein_letters_1to3
11
+
12
+
13
+ # def pocket_to_rdkit(pocket, pocket_representation, atom_encoder=None,
14
+ # atom_decoder=None, aa_decoder=None, residue_decoder=None,
15
+ # aa_atom_index=None):
16
+ #
17
+ # rdpockets = []
18
+ # for i in torch.unique(pocket['mask']):
19
+ #
20
+ # node_coord = pocket['x'][pocket['mask'] == i]
21
+ # h = pocket['one_hot'][pocket['mask'] == i]
22
+ #
23
+ # if pocket_representation == 'side_chain_bead':
24
+ # coord = node_coord
25
+ #
26
+ # node_types = [residue_decoder[b] for b in h[:, -len(residue_decoder):].argmax(-1)]
27
+ # atom_types = ['C' if r == 'CA' else 'F' for r in node_types]
28
+ #
29
+ # elif pocket_representation == 'CA+':
30
+ # aa_types = [aa_decoder[b] for b in h.argmax(-1)]
31
+ # side_chain_vec = pocket['v'][pocket['mask'] == i]
32
+ #
33
+ # coord = []
34
+ # atom_types = []
35
+ # for xyz, aa, vec in zip(node_coord, aa_types, side_chain_vec):
36
+ # # C_alpha
37
+ # coord.append(xyz)
38
+ # atom_types.append('C')
39
+ #
40
+ # # all other atoms
41
+ # for atom_name, idx in aa_atom_index[aa].items():
42
+ # coord.append(xyz + vec[idx])
43
+ # atom_types.append(atom_name[0])
44
+ #
45
+ # coord = torch.stack(coord, dim=0)
46
+ #
47
+ # else:
48
+ # raise NotImplementedError(f"{pocket_representation} residue representation not supported")
49
+ #
50
+ # atom_types = torch.tensor([atom_encoder[a] for a in atom_types])
51
+ # rdpockets.append(build_molecule(coord, atom_types, atom_decoder=atom_decoder))
52
+ #
53
+ # return rdpockets
54
+ def pocket_to_rdkit(pocket, pocket_representation, atom_encoder=None,
55
+ atom_decoder=None, aa_decoder=None, residue_decoder=None,
56
+ aa_atom_index=None):
57
+
58
+ rdpockets = []
59
+ for i in torch.unique(pocket['mask']):
60
+
61
+ node_coord = pocket['x'][pocket['mask'] == i]
62
+ h = pocket['one_hot'][pocket['mask'] == i]
63
+ atom_mask = pocket['atom_mask'][pocket['mask'] == i]
64
+
65
+ pdb_infos = []
66
+
67
+ if pocket_representation == 'side_chain_bead':
68
+ coord = node_coord
69
+
70
+ node_types = [residue_decoder[b] for b in h[:, -len(residue_decoder):].argmax(-1)]
71
+ atom_types = ['C' if r == 'CA' else 'F' for r in node_types]
72
+
73
+ elif pocket_representation == 'CA+':
74
+ aa_types = [aa_decoder[b] for b in h.argmax(-1)]
75
+ side_chain_vec = pocket['v'][pocket['mask'] == i]
76
+
77
+ coord = []
78
+ atom_types = []
79
+ for resi, (xyz, aa, vec, am) in enumerate(zip(node_coord, aa_types, side_chain_vec, atom_mask)):
80
+
81
+ # CA not treated differently with updated atom dictionary
82
+ for atom_name, idx in aa_atom_index[aa].items():
83
+
84
+ if ~am[idx]:
85
+ warnings.warn(f"Missing atom {atom_name} in {aa}:{resi}")
86
+ continue
87
+
88
+ coord.append(xyz + vec[idx])
89
+ atom_types.append(atom_name[0])
90
+
91
+ info = Chem.AtomPDBResidueInfo()
92
+ # info.SetChainId('A')
93
+ info.SetResidueName(protein_letters_1to3[aa])
94
+ info.SetResidueNumber(resi + 1)
95
+ info.SetOccupancy(1.0)
96
+ info.SetTempFactor(0.0)
97
+ info.SetName(f' {atom_name:<3}')
98
+ pdb_infos.append(info)
99
+
100
+ coord = torch.stack(coord, dim=0)
101
+
102
+ else:
103
+ raise NotImplementedError(f"{pocket_representation} residue representation not supported")
104
+
105
+ atom_types = torch.tensor([atom_encoder[a] for a in atom_types])
106
+ rdmol = build_molecule(coord, atom_types, atom_decoder=atom_decoder)
107
+
108
+ if len(pdb_infos) == len(rdmol.GetAtoms()):
109
+ for a, info in zip(rdmol.GetAtoms(), pdb_infos):
110
+ a.SetPDBResidueInfo(info)
111
+
112
+ rdpockets.append(rdmol)
113
+
114
+ return rdpockets
115
+
116
+
117
+ def mols_to_pdbfile(rdmols, filename, flavor=0):
118
+ pdb_str = ""
119
+ for i, mol in enumerate(rdmols):
120
+ pdb_str += f"MODEL{i + 1:>9}\n"
121
+ block = Chem.MolToPDBBlock(mol, flavor=flavor)
122
+ block = "\n".join(block.split("\n")[:-2]) # remove END
123
+ pdb_str += block + "\n"
124
+ pdb_str += f"ENDMDL\n"
125
+ pdb_str += f"END\n"
126
+
127
+ with open(filename, 'w') as f:
128
+ f.write(pdb_str)
129
+
130
+ return pdb_str
131
+
132
+
133
+ def mol_as_pdb(rdmol, filename=None, bfactor=None):
134
+
135
+ _rdmol = Chem.Mol(rdmol) # copy
136
+ for a in _rdmol.GetAtoms():
137
+ a.SetIsAromatic(False)
138
+ for b in _rdmol.GetBonds():
139
+ b.SetIsAromatic(False)
140
+
141
+ if bfactor is not None:
142
+ for a in _rdmol.GetAtoms():
143
+ val = a.GetPropsAsDict()[bfactor]
144
+
145
+ info = Chem.AtomPDBResidueInfo()
146
+ info.SetResidueName('UNL')
147
+ info.SetResidueNumber(1)
148
+ info.SetName(f' {a.GetSymbol():<3}')
149
+ info.SetIsHeteroAtom(True)
150
+ info.SetOccupancy(1.0)
151
+ info.SetTempFactor(val)
152
+ a.SetPDBResidueInfo(info)
153
+
154
+ pdb_str = Chem.MolToPDBBlock(_rdmol)
155
+
156
+ if filename is not None:
157
+ with open(filename, 'w') as f:
158
+ f.write(pdb_str)
159
+
160
+ return pdb_str
161
+
162
+
163
+ def draw_grid(molecules, mols_per_row=5, fig_size=(200, 200),
164
+ label=check_mol,
165
+ highlight_atom=lambda atom: False,
166
+ highlight_bond=lambda bond: False):
167
+
168
+ draw_mols = []
169
+ marked_atoms = []
170
+ marked_bonds = []
171
+ for mol in molecules:
172
+ draw_mol = Chem.Mol(mol) # copy
173
+ Chem.SanitizeMol(draw_mol, sanitizeOps=SanitizeFlags.SANITIZE_NONE)
174
+ AllChem.Compute2DCoords(draw_mol)
175
+ draw_mol = Draw.rdMolDraw2D.PrepareMolForDrawing(draw_mol,
176
+ kekulize=False)
177
+ draw_mols.append(draw_mol)
178
+ marked_atoms.append([a.GetIdx() for a in draw_mol.GetAtoms() if highlight_atom(a)])
179
+ marked_bonds.append([b.GetIdx() for b in draw_mol.GetBonds() if highlight_bond(b)])
180
+
181
+ drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
182
+ drawOptions.prepareMolsBeforeDrawing = False
183
+ drawOptions.highlightBondWidthMultiplier = 20
184
+
185
+ return Draw.MolsToGridImage(draw_mols,
186
+ molsPerRow=mols_per_row,
187
+ subImgSize=fig_size,
188
+ drawOptions=drawOptions,
189
+ highlightAtomLists=marked_atoms,
190
+ highlightBondLists=marked_bonds,
191
+ legends=[f'[{i}] {label(mol)}' for
192
+ i, mol in enumerate(draw_mols)])
src/constants.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from rdkit import Chem
3
+ import torch
4
+ import numpy as np
5
+
6
+ # ------------------------------------------------------------------------------
7
+ # Computational
8
+ # ------------------------------------------------------------------------------
9
+ FLOAT_TYPE = torch.float32
10
+ INT_TYPE = torch.int64
11
+
12
+
13
+ # ------------------------------------------------------------------------------
14
+ # Type encoding/decoding
15
+ # ------------------------------------------------------------------------------
16
+
17
+ atom_dict = os.environ.get('ATOM_DICT')
18
+ if atom_dict == 'simple':
19
+ atom_encoder = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9, 'NOATOM': 10}
20
+ atom_decoder = ['C', 'N', 'O', 'S', 'B', 'Br', 'Cl', 'P', 'I', 'F', 'NOATOM']
21
+
22
+ else:
23
+ atom_encoder = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9, 'NH': 10, 'N+': 11, 'O-': 12, 'NOATOM': 13}
24
+ atom_decoder = ['C', 'N', 'O', 'S', 'B', 'Br', 'Cl', 'P', 'I', 'F', 'NH', 'N+', 'O-', 'NOATOM']
25
+
26
+ bond_encoder = {"NOBOND": 0, "SINGLE": 1, "DOUBLE": 2, "TRIPLE": 3, 'AROMATIC': 4}
27
+ bond_decoder = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
28
+
29
+ aa_encoder = {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19}
30
+ aa_decoder = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
31
+
32
+ residue_encoder = {'CA': 0, 'SS': 1}
33
+ residue_decoder = ['CA', 'SS']
34
+
35
+ residue_bond_encoder = {'CA-CA': 0, 'CA-SS': 1, 'NOBOND': 2}
36
+ residue_bond_decoder = ['CA-CA', 'CA-SS', None]
37
+
38
+ # aa_atom_index = {
39
+ # 'A': {'N': 0, 'C': 1, 'O': 2, 'CB': 3},
40
+ # 'C': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'SG': 4},
41
+ # 'D': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'OD1': 5, 'OD2': 6},
42
+ # 'E': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD': 5, 'OE1': 6, 'OE2': 7},
43
+ # 'F': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD1': 5, 'CD2': 6, 'CE1': 7, 'CE2': 8, 'CZ': 9},
44
+ # 'G': {'N': 0, 'C': 1, 'O': 2},
45
+ # 'H': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'ND1': 5, 'CD2': 6, 'CE1': 7, 'NE2': 8},
46
+ # 'I': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG1': 4, 'CG2': 5, 'CD1': 6},
47
+ # 'K': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD': 5, 'CE': 6, 'NZ': 7},
48
+ # 'L': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD1': 5, 'CD2': 6},
49
+ # 'M': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'SD': 5, 'CE': 6},
50
+ # 'N': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'OD1': 5, 'ND2': 6},
51
+ # 'P': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD': 5},
52
+ # 'Q': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD': 5, 'OE1': 6, 'NE2': 7},
53
+ # 'R': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD': 5, 'NE': 6, 'CZ': 7, 'NH1': 8, 'NH2': 9},
54
+ # 'S': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'OG': 4},
55
+ # 'T': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'OG1': 4, 'CG2': 5},
56
+ # 'V': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG1': 4, 'CG2': 5},
57
+ # 'W': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD1': 5, 'CD2': 6, 'NE1': 7, 'CE2': 8, 'CE3': 9, 'CZ2': 10, 'CZ3': 11, 'CH2': 12},
58
+ # 'Y': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD1': 5, 'CD2': 6, 'CE1': 7, 'CE2': 8, 'CZ': 9, 'OH': 10},
59
+ # }
60
+ aa_atom_index = {
61
+ 'A': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4},
62
+ 'C': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'SG': 5},
63
+ 'D': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'OD1': 6, 'OD2': 7},
64
+ 'E': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD': 6, 'OE1': 7, 'OE2': 8},
65
+ 'F': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD1': 6, 'CD2': 7, 'CE1': 8, 'CE2': 9, 'CZ': 10},
66
+ 'G': {'N': 0, 'CA': 1, 'C': 2, 'O': 3},
67
+ 'H': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'ND1': 6, 'CD2': 7, 'CE1': 8, 'NE2': 9},
68
+ 'I': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG1': 5, 'CG2': 6, 'CD1': 7},
69
+ 'K': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD': 6, 'CE': 7, 'NZ': 8},
70
+ 'L': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD1': 6, 'CD2': 7},
71
+ 'M': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'SD': 6, 'CE': 7},
72
+ 'N': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'OD1': 6, 'ND2': 7},
73
+ 'P': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD': 6},
74
+ 'Q': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD': 6, 'OE1': 7, 'NE2': 8},
75
+ 'R': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD': 6, 'NE': 7, 'CZ': 8, 'NH1': 9, 'NH2': 10},
76
+ 'S': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'OG': 5},
77
+ 'T': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'OG1': 5, 'CG2': 6},
78
+ 'V': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG1': 5, 'CG2': 6},
79
+ 'W': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD1': 6, 'CD2': 7, 'NE1': 8, 'CE2': 9, 'CE3': 10, 'CZ2': 11, 'CZ3': 12, 'CH2': 13},
80
+ 'Y': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD1': 6, 'CD2': 7, 'CE1': 8, 'CE2': 9, 'CZ': 10, 'OH': 11},
81
+ }
82
+
83
+ # ------------------------------------------------------------------------------
84
+ # NERF
85
+ # ------------------------------------------------------------------------------
86
+
87
+ # indicates whether atom exists
88
+ aa_atom_mask = {
89
+ 'A': [True, True, True, True, True, False, False, False, False, False, False, False, False, False],
90
+ 'C': [True, True, True, True, True, True, False, False, False, False, False, False, False, False],
91
+ 'D': [True, True, True, True, True, True, True, True, False, False, False, False, False, False],
92
+ 'E': [True, True, True, True, True, True, True, True, True, False, False, False, False, False],
93
+ 'F': [True, True, True, True, True, True, True, True, True, True, True, False, False, False],
94
+ 'G': [True, True, True, True, False, False, False, False, False, False, False, False, False, False],
95
+ 'H': [True, True, True, True, True, True, True, True, True, True, False, False, False, False],
96
+ 'I': [True, True, True, True, True, True, True, True, False, False, False, False, False, False],
97
+ 'K': [True, True, True, True, True, True, True, True, True, False, False, False, False, False],
98
+ 'L': [True, True, True, True, True, True, True, True, False, False, False, False, False, False],
99
+ 'M': [True, True, True, True, True, True, True, True, False, False, False, False, False, False],
100
+ 'N': [True, True, True, True, True, True, True, True, False, False, False, False, False, False],
101
+ 'P': [True, True, True, True, True, True, True, False, False, False, False, False, False, False],
102
+ 'Q': [True, True, True, True, True, True, True, True, True, False, False, False, False, False],
103
+ 'R': [True, True, True, True, True, True, True, True, True, True, True, False, False, False],
104
+ 'S': [True, True, True, True, True, True, False, False, False, False, False, False, False, False],
105
+ 'T': [True, True, True, True, True, True, True, False, False, False, False, False, False, False],
106
+ 'V': [True, True, True, True, True, True, True, False, False, False, False, False, False, False],
107
+ 'W': [True, True, True, True, True, True, True, True, True, True, True, True, True, True],
108
+ 'Y': [True, True, True, True, True, True, True, True, True, True, True, True, False, False],
109
+ }
110
+
111
+ # (14, 3) index tensor with atom indices of atoms a, b and c for NERF reconstruction
112
+ # in principle, columns 1 and 2 can be inferred from column one (immediate predecessor) alone
113
+ aa_nerf_indices = {
114
+ 'A': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
115
+ 'C': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
116
+ 'D': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
117
+ 'E': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [6, 5, 4], [6, 5, 4], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
118
+ 'F': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [6, 5, 4], [7, 5, 4], [8, 6, 5], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
119
+ 'G': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
120
+ 'H': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [6, 5, 4], [7, 5, 4], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
121
+ 'I': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [4, 1, 0], [5, 4, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
122
+ 'K': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [6, 5, 4], [7, 6, 5], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
123
+ 'L': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
124
+ 'M': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [6, 5, 4], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
125
+ 'N': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
126
+ 'P': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
127
+ 'Q': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [6, 5, 4], [6, 5, 4], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
128
+ 'R': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [6, 5, 4], [7, 6, 5], [8, 7, 6], [8, 7, 6], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
129
+ 'S': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
130
+ 'T': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [4, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
131
+ 'V': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [4, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
132
+ 'W': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [6, 5, 4], [7, 5, 4], [7, 5, 4], [9, 7, 5], [10, 7, 5], [11, 9, 7]],
133
+ 'Y': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [6, 5, 4], [7, 5, 4], [8, 6, 5], [10, 8, 6], [0, 0, 0], [0, 0, 0]],
134
+ }
135
+
136
+ # unique id for each rotatable bond (0=chi1, 1=chi, ...)
137
+ aa_bond_to_chi = {
138
+ 'A': {},
139
+ 'C': {('CA', 'CB'): 0},
140
+ 'D': {('CA', 'CB'): 0, ('CB', 'CG'): 1},
141
+ 'E': {('CA', 'CB'): 0, ('CB', 'CG'): 1, ('CG', 'CD'): 2},
142
+ 'F': {('CA', 'CB'): 0, ('CB', 'CG'): 1},
143
+ 'G': {},
144
+ 'H': {('CA', 'CB'): 0, ('CB', 'CG'): 1},
145
+ 'I': {('CA', 'CB'): 0, ('CB', 'CG2'): 1},
146
+ 'K': {('CA', 'CB'): 0, ('CB', 'CG'): 1, ('CG', 'CD'): 2, ('CD', 'CE'): 3},
147
+ 'L': {('CA', 'CB'): 0, ('CB', 'CG'): 1},
148
+ 'M': {('CA', 'CB'): 0, ('CB', 'CG'): 1, ('CG', 'SD'): 2},
149
+ 'N': {('CA', 'CB'): 0, ('CB', 'CG'): 1},
150
+ 'P': {},
151
+ 'Q': {('CA', 'CB'): 0, ('CB', 'CG'): 1, ('CG', 'CD'): 2},
152
+ 'R': {('CA', 'CB'): 0, ('CB', 'CG'): 1, ('CG', 'CD'): 2, ('CD', 'NE'): 3, ('NE', 'CZ'): 4},
153
+ 'S': {('CA', 'CB'): 0},
154
+ 'T': {('CA', 'CB'): 0},
155
+ 'V': {('CA', 'CB'): 0},
156
+ 'W': {('CA', 'CB'): 0, ('CB', 'CG'): 1},
157
+ 'Y': {('CA', 'CB'): 0, ('CB', 'CG'): 1},
158
+ }
159
+
160
+ # index between 0 and 4 to retrieve chi angles, -1 means not a rotatable bond
161
+ aa_chi_indices = {
162
+ 'A': [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
163
+ 'C': [-1, -1, -1, -1, -1, 0, -1, -1, -1, -1, -1, -1, -1, -1],
164
+ 'D': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1],
165
+ 'E': [-1, -1, -1, -1, -1, 0, 1, 2, 2, -1, -1, -1, -1, -1],
166
+ 'F': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1],
167
+ 'G': [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
168
+ 'H': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1],
169
+ 'I': [-1, -1, -1, -1, -1, 0, 0, 1, -1, -1, -1, -1, -1, -1],
170
+ 'K': [-1, -1, -1, -1, -1, 0, 1, 2, 3, -1, -1, -1, -1, -1],
171
+ 'L': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1],
172
+ 'M': [-1, -1, -1, -1, -1, 0, 1, 2, -1, -1, -1, -1, -1, -1],
173
+ 'N': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1],
174
+ 'P': [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
175
+ 'Q': [-1, -1, -1, -1, -1, 0, 1, 2, 2, -1, -1, -1, -1, -1],
176
+ 'R': [-1, -1, -1, -1, -1, 0, 1, 2, 3, 4, 4, -1, -1, -1],
177
+ 'S': [-1, -1, -1, -1, -1, 0, -1, -1, -1, -1, -1, -1, -1, -1],
178
+ 'T': [-1, -1, -1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1],
179
+ 'V': [-1, -1, -1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1],
180
+ 'W': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1],
181
+ 'Y': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1],
182
+ }
183
+
184
+ # key: chi index (0=chi1, 1=chi, ...); value: index of atom that defines the chi angle (together with its three predecessors)
185
+ aa_chi_anchor_atom = {
186
+ 'A': {},
187
+ 'C': {0: 5},
188
+ 'D': {0: 5, 1: 6},
189
+ 'E': {0: 5, 1: 6, 2: 7},
190
+ 'F': {0: 5, 1: 6},
191
+ 'G': {},
192
+ 'H': {0: 5, 1: 6},
193
+ 'I': {0: 5, 1: 7},
194
+ 'K': {0: 5, 1: 6, 2: 7, 3: 8},
195
+ 'L': {0: 5, 1: 6},
196
+ 'M': {0: 5, 1: 6, 2: 7},
197
+ 'N': {0: 5, 1: 6},
198
+ 'P': {},
199
+ 'Q': {0: 5, 1: 6, 2: 7},
200
+ 'R': {0: 5, 1: 6, 2: 7, 3: 8, 4: 9},
201
+ 'S': {0: 5},
202
+ 'T': {0: 5},
203
+ 'V': {0: 5},
204
+ 'W': {0: 5, 1: 6},
205
+ 'Y': {0: 5, 1: 6},
206
+ }
207
+
208
+ # ------------------------------------------------------------------------------
209
+ # Visualization
210
+ # ------------------------------------------------------------------------------
211
+ # PyMOL colors, see: https://pymolwiki.org/index.php/Color_Values#Chemical_element_colours
212
+ colors_dic = ['#33ff33', '#3333ff', '#ff4d4d', '#e6c540', '#ffb5b5', '#A62929', '#1FF01F', '#ff8000', '#940094', '#B3FFFF', '#b3e3f5']
213
+ radius_dic = [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]
214
+
215
+
216
+ # ------------------------------------------------------------------------------
217
+ # Backbone geometry
218
+ # Taken from: Bhagavan, N. V., and C. E. Ha.
219
+ # "Chapter 4-Three-dimensional structure of proteins and disorders of protein misfolding."
220
+ # Essentials of Medical Biochemistry (2015): 31-51.
221
+ # https://www.sciencedirect.com/science/article/pii/B978012416687500004X
222
+ # ------------------------------------------------------------------------------
223
+ N_CA_DIST = 1.47
224
+ CA_C_DIST = 1.53
225
+ N_CA_C_ANGLE = 110 * np.pi / 180
226
+
227
+ # ------------------------------------------------------------------------------
228
+ # Atom radii
229
+ # ------------------------------------------------------------------------------
230
+ # # https://en.wikipedia.org/wiki/Covalent_radius#Radii_for_multiple_bonds
231
+ # # (2023/04/14)
232
+ # covalent_radii = {'H': [32, None, None],
233
+ # 'C': [75, 67, 60],
234
+ # 'N': [71, 60, 54],
235
+ # 'O': [63, 57, 53],
236
+ # 'F': [64, 59, 53],
237
+ # 'B': [85, 78, 73],
238
+ # 'Al': [126, 113, 111],
239
+ # 'Si': [116, 107, 102],
240
+ # 'P': [111, 102, 94],
241
+ # 'S': [103, 94, 95],
242
+ # 'Cl': [99, 95, 93],
243
+ # 'As': [121, 114, 106],
244
+ # 'Br': [114, 109, 110],
245
+ # 'I': [133, 129, 125],
246
+ # 'Hg': [133, 142, None],
247
+ # 'Bi': [151, 141, 135]}
248
+
249
+ # source: https://en.wikipedia.org/wiki/Van_der_Waals_radius
250
+ vdw_radii = {'N': 1.55, 'O': 1.52, 'C': 1.70, 'H': 1.10, 'S': 1.80, 'P': 1.80,
251
+ 'Se': 1.90, 'K': 2.75, 'Na': 2.27, 'Mg': 1.73, 'Zn': 1.39, 'B': 1.92,
252
+ 'Br': 1.85, 'Cl': 1.75, 'I': 1.98, 'F': 1.47}
253
+
254
+
255
+ WEBDATASET_SHARD_SIZE = 50000
256
+ WEBDATASET_VAL_SIZE = 100
src/data/data_utils.py ADDED
@@ -0,0 +1,901 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from itertools import accumulate, chain
3
+ from copy import deepcopy
4
+ import random
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from rdkit import Chem
9
+ from torch_scatter import scatter_mean
10
+ from Bio.PDB import StructureBuilder, Chain, Model, Structure
11
+ from Bio.PDB.PICIO import read_PIC, write_PIC
12
+ from scipy.ndimage import gaussian_filter
13
+ from pdb import set_trace
14
+
15
+ from src.constants import FLOAT_TYPE, INT_TYPE
16
+ from src.constants import atom_encoder, bond_encoder, aa_encoder, residue_encoder, residue_bond_encoder, aa_atom_index
17
+ from src import utils
18
+ from src.data.misc import protein_letters_3to1, is_aa
19
+ from src.data.normal_modes import pdb_to_normal_modes
20
+ from src.data.nerf import get_nerf_params, ic_to_coords
21
+ import src.data.so3_utils as so3
22
+
23
+
24
+ class TensorDict(dict):
25
+ def __init__(self, **kwargs):
26
+ super(TensorDict, self).__init__(**kwargs)
27
+
28
+ def _apply(self, func: str, *args, **kwargs):
29
+ """ Apply function to all tensors. """
30
+ for k, v in self.items():
31
+ if torch.is_tensor(v):
32
+ self[k] = getattr(v, func)(*args, **kwargs)
33
+ return self
34
+
35
+ # def to(self, device):
36
+ # for k, v in self.items():
37
+ # if torch.is_tensor(v):
38
+ # self[k] = v.to(device)
39
+ # return self
40
+
41
+ def cuda(self):
42
+ return self.to('cuda')
43
+
44
+ def cpu(self):
45
+ return self.to('cpu')
46
+
47
+ def to(self, device):
48
+ return self._apply("to", device)
49
+
50
+ def detach(self):
51
+ return self._apply("detach")
52
+
53
+ def __repr__(self):
54
+ def val_to_str(val):
55
+ if isinstance(val, torch.Tensor):
56
+ # if val.isnan().any():
57
+ # return "(!nan)"
58
+ return "%r" % list(val.size())
59
+ if isinstance(val, list):
60
+ return "[%r,]" % len(val)
61
+ else:
62
+ return "?"
63
+
64
+ return f"{type(self).__name__}({', '.join(f'{k}={val_to_str(v)}' for k, v in self.items())})"
65
+
66
+
67
+ def collate_entity(batch):
68
+
69
+ out = {}
70
+ for prop in batch[0].keys():
71
+
72
+ if prop == 'name':
73
+ out[prop] = [x[prop] for x in batch]
74
+
75
+ elif prop == 'size' or prop == 'n_bonds':
76
+ out[prop] = torch.tensor([x[prop] for x in batch])
77
+
78
+ elif prop == 'bonds':
79
+ # index offset
80
+ offset = list(accumulate([x['size'] for x in batch], initial=0))
81
+ out[prop] = torch.cat([x[prop] + offset[i] for i, x in enumerate(batch)], dim=1)
82
+
83
+ elif prop == 'residues':
84
+ out[prop] = list(chain.from_iterable(x[prop] for x in batch))
85
+
86
+ elif prop in {'mask', 'bond_mask'}:
87
+ pass # batch masks will be written later
88
+
89
+ else:
90
+ out[prop] = torch.cat([x[prop] for x in batch], dim=0)
91
+
92
+ # Create batch masks
93
+ # make sure indices in batch start at zero (needed for torch_scatter)
94
+ if prop == 'x':
95
+ out['mask'] = torch.cat([i * torch.ones(len(x[prop]), dtype=torch.int64, device=x[prop].device)
96
+ for i, x in enumerate(batch)], dim=0)
97
+ if prop == 'bond_one_hot':
98
+ # TODO: this is not necessary as it can be computed on-the-fly as bond_mask = mask[bonds[0]] or bond_mask = mask[bonds[1]]
99
+ out['bond_mask'] = torch.cat([i * torch.ones(len(x[prop]), dtype=torch.int64, device=x[prop].device)
100
+ for i, x in enumerate(batch)], dim=0)
101
+
102
+ return out
103
+
104
+
105
+ def split_entity(
106
+ batch,
107
+ *,
108
+ index_types={'bonds'},
109
+ edge_types={'bond_one_hot', 'bond_mask'},
110
+ no_split={'name', 'size', 'n_bonds'},
111
+ skip={'fragments'},
112
+ batch_mask=None,
113
+ edge_mask=None
114
+ ):
115
+ """ Splits a batch into items and returns a list. """
116
+
117
+ batch_mask = batch["mask"] if batch_mask is None else batch_mask
118
+ edge_mask = batch["bond_mask"] if edge_mask is None else edge_mask
119
+ sizes = batch['size'] if 'size' in batch else torch.unique(batch_mask, return_counts=True)[1].tolist()
120
+
121
+ batch_size = len(torch.unique(batch['mask']))
122
+ out = {}
123
+ for prop in batch.keys():
124
+ if prop in skip:
125
+ continue
126
+ if prop in no_split:
127
+ out[prop] = batch[prop] # already a list
128
+
129
+ elif prop in index_types:
130
+ offsets = list(accumulate(sizes[:-1], initial=0))
131
+ out[prop] = utils.batch_to_list_for_indices(batch[prop], edge_mask, offsets)
132
+
133
+ elif prop in edge_types:
134
+ out[prop] = utils.batch_to_list(batch[prop], edge_mask)
135
+
136
+ else:
137
+ out[prop] = utils.batch_to_list(batch[prop], batch_mask)
138
+
139
+ out = [{k: v[i] for k, v in out.items()} for i in range(batch_size)]
140
+ return out
141
+
142
+
143
+ def repeat_items(batch, repeats):
144
+ batch_list = split_entity(batch)
145
+ out = collate_entity([x for _ in range(repeats) for x in batch_list])
146
+ return type(batch)(**out)
147
+
148
+
149
+ def get_side_chain_bead_coord(biopython_residue):
150
+ """
151
+ Places side chain bead at the location of the farthest side chain atom.
152
+ """
153
+ if biopython_residue.get_resname() == 'GLY':
154
+ return None
155
+ if biopython_residue.get_resname() == 'ALA':
156
+ return biopython_residue['CB'].get_coord()
157
+
158
+ ca_coord = biopython_residue['CA'].get_coord()
159
+ side_chain_atoms = [a for a in biopython_residue.get_atoms() if
160
+ a.id not in {'N', 'CA', 'C', 'O'} and a.element != 'H']
161
+ side_chain_coords = np.stack([a.get_coord() for a in side_chain_atoms])
162
+
163
+ atom_idx = np.argmax(np.sum((side_chain_coords - ca_coord[None, :]) ** 2, axis=-1))
164
+
165
+ return side_chain_coords[atom_idx, :]
166
+
167
+
168
+ def get_side_chain_vectors(res, index_dict, size=None):
169
+ if size is None:
170
+ size = max([x for aa in index_dict.values() for x in aa.values()]) + 1
171
+
172
+ resname = protein_letters_3to1[res.get_resname()]
173
+
174
+ out = np.zeros((size, 3))
175
+ for atom in res.get_atoms():
176
+ if atom.get_name() in index_dict[resname]:
177
+ idx = index_dict[resname][atom.get_name()]
178
+ out[idx] = atom.get_coord() - res['CA'].get_coord()
179
+ # else:
180
+ # if atom.get_name() != 'CA' and not atom.get_name().startswith('H'):
181
+ # print(resname, atom.get_name())
182
+
183
+ return out
184
+
185
+
186
+ def get_normal_modes(res, normal_mode_dict):
187
+ nm = normal_mode_dict[(res.get_parent().id, res.id[1], 'CA')] # (n_modes, 3)
188
+ return nm
189
+
190
+
191
+ def get_torsion_angles(res, device=None):
192
+ """
193
+ Return the five chi angles. Missing angles are filled with zeros.
194
+ """
195
+ ANGLES = ['chi1', 'chi2', 'chi3', 'chi4', 'chi5']
196
+
197
+ ic_res = res.internal_coord
198
+ chi_angles = [ic_res.get_angle(chi) for chi in ANGLES]
199
+ chi_angles = [chi if chi is not None else float('nan') for chi in chi_angles]
200
+
201
+ return torch.tensor(chi_angles, device=device) * np.pi / 180
202
+
203
+
204
+ def apply_torsion_angles(res, chi_angles):
205
+ """
206
+ Set side chain torsion angles of a biopython residue object with
207
+ internal coordinates.
208
+ """
209
+ ANGLES = ['chi1', 'chi2', 'chi3', 'chi4', 'chi5']
210
+
211
+ chi_angles = chi_angles * 180 / np.pi
212
+
213
+ # res.parent.internal_coord.build_atomArray() # rebuild atom pointers
214
+
215
+ ic_res = res.internal_coord
216
+ for chi, angle in zip(ANGLES, chi_angles):
217
+ if ic_res.pick_angle(chi) is None:
218
+ continue
219
+ ic_res.bond_set(chi, angle)
220
+
221
+ res.parent.internal_to_atom_coordinates(verbose=False)
222
+ # res.parent.internal_coord.init_atom_coords()
223
+ # res.internal_coord.assemble()
224
+
225
+ return res
226
+
227
+
228
+ def prepare_internal_coord(res):
229
+
230
+ # Make new structure with a single residue
231
+ new_struct = Structure.Structure('X')
232
+ new_struct.header = {}
233
+ new_model = Model.Model(0)
234
+ new_struct.add(new_model)
235
+ new_chain = Chain.Chain('X')
236
+ new_model.add(new_chain)
237
+ new_chain.add(res)
238
+ res.set_parent(new_chain) # update pointer
239
+
240
+ # Compute internal coordinates
241
+ new_chain.atom_to_internal_coordinates()
242
+
243
+ pic_io = io.StringIO()
244
+ write_PIC(new_struct, pic_io)
245
+ return pic_io.getvalue()
246
+
247
+
248
+ def residue_from_internal_coord(ic_string):
249
+ pic_io = io.StringIO(ic_string)
250
+ struct = read_PIC(pic_io, quick=True)
251
+ res = struct.child_list[0].child_list[0].child_list[0]
252
+ res.parent.internal_to_atom_coordinates(verbose=False)
253
+ return res
254
+
255
+
256
+ def prepare_pocket(biopython_residues, amino_acid_encoder, residue_encoder,
257
+ residue_bond_encoder, pocket_representation='side_chain_bead',
258
+ compute_nerf_params=False, compute_bb_frames=False,
259
+ nma_input=None):
260
+
261
+ assert nma_input is None or pocket_representation == 'CA+', \
262
+ "vector features are only supported for CA+ pockets"
263
+
264
+ # sort residues
265
+ biopython_residues = sorted(biopython_residues, key=lambda x: (x.parent.id, x.id[1]))
266
+
267
+ if nma_input is not None:
268
+ # preprocessed normal mode eigenvectors
269
+ if isinstance(nma_input, dict):
270
+ nma_dict = nma_input
271
+
272
+ # PDB file
273
+ else:
274
+ nma_dict = pdb_to_normal_modes(str(nma_input))
275
+
276
+ if pocket_representation == 'side_chain_bead':
277
+ ca_coords = np.zeros((len(biopython_residues), 3))
278
+ ca_types = np.zeros(len(biopython_residues), dtype='int64')
279
+ side_chain_coords = []
280
+ side_chain_aa_types = []
281
+ edges = [] # CA-CA and CA-side_chain
282
+ edge_types = []
283
+ last_res_id = None
284
+ for i, res in enumerate(biopython_residues):
285
+ aa = amino_acid_encoder[protein_letters_3to1[res.get_resname()]]
286
+ ca_coords[i, :] = res['CA'].get_coord()
287
+ ca_types[i] = aa
288
+ side_chain_coord = get_side_chain_bead_coord(res)
289
+ if side_chain_coord is not None:
290
+ side_chain_coords.append(side_chain_coord)
291
+ side_chain_aa_types.append(aa)
292
+ edges.append((i, len(ca_coords) + len(side_chain_coords) - 1))
293
+ edge_types.append(residue_bond_encoder['CA-SS'])
294
+
295
+ # add edges between contiguous CA atoms
296
+ if i > 0 and res.id[1] == last_res_id + 1:
297
+ edges.append((i - 1, i))
298
+ edge_types.append(residue_bond_encoder['CA-CA'])
299
+
300
+ last_res_id = res.id[1]
301
+
302
+ # Coordinates
303
+ side_chain_coords = np.stack(side_chain_coords)
304
+ pocket_coords = np.concatenate([ca_coords, side_chain_coords], axis=0)
305
+ pocket_coords = torch.from_numpy(pocket_coords)
306
+
307
+ # Features
308
+ amino_acid_onehot = F.one_hot(
309
+ torch.cat([torch.from_numpy(ca_types), torch.tensor(side_chain_aa_types, dtype=torch.int64)], dim=0),
310
+ num_classes=len(amino_acid_encoder)
311
+ )
312
+ side_chain_onehot = np.concatenate([
313
+ np.tile(np.eye(1, len(residue_encoder), residue_encoder['CA']),
314
+ [len(ca_coords), 1]),
315
+ np.tile(np.eye(1, len(residue_encoder), residue_encoder['SS']),
316
+ [len(side_chain_coords), 1])
317
+ ], axis=0)
318
+ side_chain_onehot = torch.from_numpy(side_chain_onehot)
319
+ pocket_onehot = torch.cat([amino_acid_onehot, side_chain_onehot], dim=1)
320
+
321
+ vector_features = None
322
+ nma_features = None
323
+
324
+ # Bonds
325
+ edges = torch.tensor(edges).T
326
+ edge_types = F.one_hot(torch.tensor(edge_types), num_classes=len(residue_bond_encoder))
327
+
328
+ elif pocket_representation == 'CA+':
329
+ ca_coords = np.zeros((len(biopython_residues), 3))
330
+ ca_types = np.zeros(len(biopython_residues), dtype='int64')
331
+
332
+ v_dim = max([x for aa in aa_atom_index.values() for x in aa.values()]) + 1
333
+ vec_feats = np.zeros((len(biopython_residues), v_dim, 3), dtype='float32')
334
+ nf_nma = 5
335
+ nma_feats = np.zeros((len(biopython_residues), nf_nma, 3), dtype='float32')
336
+
337
+ edges = [] # CA-CA and CA-side_chain
338
+ edge_types = []
339
+ last_res_id = None
340
+ for i, res in enumerate(biopython_residues):
341
+ aa = amino_acid_encoder[protein_letters_3to1[res.get_resname()]]
342
+ ca_coords[i, :] = res['CA'].get_coord()
343
+ ca_types[i] = aa
344
+
345
+ vec_feats[i] = get_side_chain_vectors(res, aa_atom_index, v_dim)
346
+ if nma_input is not None:
347
+ nma_feats[i] = get_normal_modes(res, nma_dict)
348
+
349
+ # add edges between contiguous CA atoms
350
+ if i > 0 and res.id[1] == last_res_id + 1:
351
+ edges.append((i - 1, i))
352
+ edge_types.append(residue_bond_encoder['CA-CA'])
353
+
354
+ last_res_id = res.id[1]
355
+
356
+ # Coordinates
357
+ pocket_coords = torch.from_numpy(ca_coords)
358
+
359
+ # Features
360
+ pocket_onehot = F.one_hot(torch.from_numpy(ca_types),
361
+ num_classes=len(amino_acid_encoder))
362
+
363
+ vector_features = torch.from_numpy(vec_feats)
364
+ nma_features = torch.from_numpy(nma_feats)
365
+
366
+ # Bonds
367
+ if len(edges) < 1:
368
+ edges = torch.empty(2, 0)
369
+ edge_types = torch.empty(0, len(residue_bond_encoder))
370
+ else:
371
+ edges = torch.tensor(edges).T
372
+ edge_types = F.one_hot(torch.tensor(edge_types),
373
+ num_classes=len(residue_bond_encoder))
374
+
375
+ else:
376
+ raise NotImplementedError(
377
+ f"Pocket representation '{pocket_representation}' not implemented")
378
+
379
+ # pocket_ids = [f'{res.parent.id}:{res.id[1]}' for res in biopython_residues]
380
+
381
+ pocket = {
382
+ 'x': pocket_coords.to(dtype=FLOAT_TYPE),
383
+ 'one_hot': pocket_onehot.to(dtype=FLOAT_TYPE),
384
+ # 'ids': pocket_ids,
385
+ 'size': torch.tensor([len(pocket_coords)], dtype=INT_TYPE),
386
+ 'mask': torch.zeros(len(pocket_coords), dtype=INT_TYPE),
387
+ 'bonds': edges.to(INT_TYPE),
388
+ 'bond_one_hot': edge_types.to(FLOAT_TYPE),
389
+ 'bond_mask': torch.zeros(edges.size(1), dtype=INT_TYPE),
390
+ 'n_bonds': torch.tensor([len(edge_types)], dtype=INT_TYPE),
391
+ }
392
+
393
+ if vector_features is not None:
394
+ pocket['v'] = vector_features.to(dtype=FLOAT_TYPE)
395
+
396
+ if nma_input is not None:
397
+ pocket['nma_vec'] = nma_features.to(dtype=FLOAT_TYPE)
398
+
399
+ if compute_nerf_params:
400
+ nerf_params = [get_nerf_params(r) for r in biopython_residues]
401
+ nerf_params = {k: torch.stack([x[k] for x in nerf_params], dim=0)
402
+ for k in nerf_params[0].keys()}
403
+ pocket.update(nerf_params)
404
+
405
+ if compute_bb_frames:
406
+ n_xyz = torch.from_numpy(np.stack([r['N'].get_coord() for r in biopython_residues]))
407
+ ca_xyz = torch.from_numpy(np.stack([r['CA'].get_coord() for r in biopython_residues]))
408
+ c_xyz = torch.from_numpy(np.stack([r['C'].get_coord() for r in biopython_residues]))
409
+ pocket['axis_angle'], _ = get_bb_transform(n_xyz, ca_xyz, c_xyz)
410
+
411
+ return pocket, biopython_residues
412
+
413
+
414
+ def encode_atom(rd_atom, atom_encoder):
415
+ element = rd_atom.GetSymbol().capitalize()
416
+
417
+ explicitHs = rd_atom.GetNumExplicitHs()
418
+ if explicitHs == 1 and f'{element}H' in atom_encoder:
419
+ return atom_encoder[f'{element}H']
420
+
421
+ charge = rd_atom.GetFormalCharge()
422
+ if charge == 1 and f'{element}+' in atom_encoder:
423
+ return atom_encoder[f'{element}+']
424
+ if charge == -1 and f'{element}-' in atom_encoder:
425
+ return atom_encoder[f'{element}-']
426
+
427
+ return atom_encoder[element]
428
+
429
+
430
+ def prepare_ligand(rdmol, atom_encoder, bond_encoder):
431
+
432
+ # remove H atoms if not in atom_encoder
433
+ if 'H' not in atom_encoder:
434
+ rdmol = Chem.RemoveAllHs(rdmol, sanitize=False)
435
+
436
+ # Coordinates
437
+ ligand_coord = rdmol.GetConformer().GetPositions()
438
+ ligand_coord = torch.from_numpy(ligand_coord)
439
+
440
+ # Features
441
+ ligand_onehot = F.one_hot(
442
+ torch.tensor([encode_atom(a, atom_encoder) for a in rdmol.GetAtoms()]),
443
+ num_classes=len(atom_encoder)
444
+ )
445
+
446
+ # Bonds
447
+ adj = np.ones((rdmol.GetNumAtoms(), rdmol.GetNumAtoms())) * bond_encoder['NOBOND']
448
+ for b in rdmol.GetBonds():
449
+ i = b.GetBeginAtomIdx()
450
+ j = b.GetEndAtomIdx()
451
+ adj[i, j] = bond_encoder[str(b.GetBondType())]
452
+ adj[j, i] = adj[i, j] # undirected graph
453
+
454
+ # molecular graph is undirected -> don't save redundant information
455
+ bonds = np.stack(np.triu_indices(len(ligand_coord), k=1), axis=0)
456
+ # bonds = np.stack(np.ones_like(adj).nonzero(), axis=0)
457
+ bond_types = adj[bonds[0], bonds[1]].astype('int64')
458
+ bonds = torch.from_numpy(bonds)
459
+ bond_types = F.one_hot(torch.from_numpy(bond_types), num_classes=len(bond_encoder))
460
+
461
+ ligand = {
462
+ 'x': ligand_coord.to(dtype=FLOAT_TYPE),
463
+ 'one_hot': ligand_onehot.to(dtype=FLOAT_TYPE),
464
+ 'mask': torch.zeros(len(ligand_coord), dtype=INT_TYPE),
465
+ 'bonds': bonds.to(INT_TYPE),
466
+ 'bond_one_hot': bond_types.to(FLOAT_TYPE),
467
+ 'bond_mask': torch.zeros(bonds.size(1), dtype=INT_TYPE),
468
+ 'size': torch.tensor([len(ligand_coord)], dtype=INT_TYPE),
469
+ 'n_bonds': torch.tensor([len(bond_types)], dtype=INT_TYPE),
470
+ }
471
+
472
+ return ligand
473
+
474
+
475
+ def process_raw_molecule_with_empty_pocket(rdmol):
476
+ ligand = prepare_ligand(rdmol, atom_encoder, bond_encoder)
477
+ pocket = {
478
+ 'x': torch.tensor([], dtype=FLOAT_TYPE),
479
+ 'one_hot': torch.tensor([], dtype=FLOAT_TYPE),
480
+ 'size': torch.tensor([], dtype=INT_TYPE),
481
+ 'mask': torch.tensor([], dtype=INT_TYPE),
482
+ 'bonds': torch.tensor([], dtype=INT_TYPE),
483
+ 'bond_one_hot': torch.tensor([], dtype=FLOAT_TYPE),
484
+ 'bond_mask': torch.tensor([], dtype=INT_TYPE),
485
+ 'n_bonds': torch.tensor([], dtype=INT_TYPE),
486
+ }
487
+ return ligand, pocket
488
+
489
+
490
+ def process_raw_pair(biopython_model, rdmol, dist_cutoff=None,
491
+ pocket_representation='side_chain_bead',
492
+ compute_nerf_params=False, compute_bb_frames=False,
493
+ nma_input=None, return_pocket_pdb=False):
494
+
495
+ # Process ligand
496
+ ligand = prepare_ligand(rdmol, atom_encoder, bond_encoder)
497
+
498
+ # Find interacting pocket residues based on distance cutoff
499
+ pocket_residues = []
500
+ for residue in biopython_model.get_residues():
501
+
502
+ # Remove non-standard amino acids and HETATMs
503
+ if not is_aa(residue.get_resname(), standard=True):
504
+ continue
505
+
506
+ res_coords = torch.from_numpy(np.array([a.get_coord() for a in residue.get_atoms()]))
507
+ if dist_cutoff is None or (((res_coords[:, None, :] - ligand['x'][None, :, :]) ** 2).sum(-1) ** 0.5).min() < dist_cutoff:
508
+ pocket_residues.append(residue)
509
+
510
+ pocket, pocket_residues = prepare_pocket(
511
+ pocket_residues, aa_encoder, residue_encoder, residue_bond_encoder,
512
+ pocket_representation, compute_nerf_params, compute_bb_frames, nma_input
513
+ )
514
+
515
+ if return_pocket_pdb:
516
+ builder = StructureBuilder.StructureBuilder()
517
+ builder.init_structure("")
518
+ builder.init_model(0)
519
+ pocket_struct = builder.get_structure()
520
+ for residue in pocket_residues:
521
+ chain = residue.get_parent().get_id()
522
+
523
+ # init chain if necessary
524
+ if not pocket_struct[0].has_id(chain):
525
+ builder.init_chain(chain)
526
+
527
+ # add residue
528
+ pocket_struct[0][chain].add(residue)
529
+
530
+ pocket['pocket_pdb'] = pocket_struct
531
+ # if return_pocket_pdb:
532
+ # pocket['residues'] = [prepare_internal_coord(res) for res in pocket_residues]
533
+
534
+ return ligand, pocket
535
+
536
+
537
+ class AppendVirtualNodes:
538
+ def __init__(self, atom_encoder, bond_encoder, max_ligand_size, scale=1.0):
539
+ self.max_size = max_ligand_size
540
+ self.atom_encoder = atom_encoder
541
+ self.bond_encoder = bond_encoder
542
+ self.vidx = atom_encoder['NOATOM']
543
+ self.bidx = bond_encoder['NOBOND']
544
+ self.scale = scale
545
+
546
+ def __call__(self, ligand, max_size=None, eps=1e-6):
547
+ if max_size is None:
548
+ max_size = self.max_size
549
+
550
+ n_virt = max_size - ligand['size']
551
+
552
+ C = torch.cov(ligand['x'].T)
553
+ L = torch.linalg.cholesky(C + torch.eye(3) * eps)
554
+ mu = ligand['x'].mean(0, keepdim=True)
555
+ virt_coords = mu + torch.randn(n_virt, 3) @ L.T * self.scale
556
+
557
+ # insert virtual atom column
558
+ virt_one_hot = F.one_hot(torch.ones(n_virt, dtype=torch.int64) * self.vidx, num_classes=len(self.atom_encoder))
559
+ virt_mask = torch.cat([torch.zeros(ligand['size'], dtype=bool), torch.ones(n_virt, dtype=bool)])
560
+
561
+ ligand['x'] = torch.cat([ligand['x'], virt_coords])
562
+ ligand['one_hot'] = torch.cat(([ligand['one_hot'], virt_one_hot]))
563
+ ligand['virtual_mask'] = virt_mask
564
+ ligand['size'] = max_size
565
+
566
+ # Bonds
567
+ new_bonds = torch.triu_indices(max_size, max_size, offset=1)
568
+
569
+ bond_types = torch.ones(max_size, max_size, dtype=INT_TYPE) * self.bidx
570
+ row, col = ligand['bonds']
571
+ bond_types[row, col] = ligand['bond_one_hot'].argmax(dim=1)
572
+ new_row, new_col = new_bonds
573
+ bond_types = bond_types[new_row, new_col]
574
+
575
+ ligand['bonds'] = new_bonds
576
+ ligand['bond_one_hot'] = F.one_hot(bond_types, num_classes=len(self.bond_encoder)).to(ligand['bond_one_hot'].dtype)
577
+ ligand['n_bonds'] = len(ligand['bond_one_hot'])
578
+
579
+ return ligand
580
+
581
+
582
+ class AppendVirtualNodesInCoM:
583
+ def __init__(self, atom_encoder, bond_encoder, add_min=0, add_max=10):
584
+ self.atom_encoder = atom_encoder
585
+ self.bond_encoder = bond_encoder
586
+ self.vidx = atom_encoder['NOATOM']
587
+ self.bidx = bond_encoder['NOBOND']
588
+ self.add_min = add_min
589
+ self.add_max = add_max
590
+
591
+ def __call__(self, ligand):
592
+
593
+ n_virt = random.randint(self.add_min, self.add_max)
594
+
595
+ # all virtual coordinates in the CoM
596
+ virt_coords = ligand['x'].mean(0, keepdim=True).repeat(n_virt, 1)
597
+
598
+ # insert virtual atom column
599
+ virt_one_hot = F.one_hot(torch.ones(n_virt, dtype=torch.int64) * self.vidx, num_classes=len(self.atom_encoder))
600
+ virt_mask = torch.cat([torch.zeros(ligand['size'], dtype=bool), torch.ones(n_virt, dtype=bool)])
601
+
602
+ ligand['x'] = torch.cat([ligand['x'], virt_coords])
603
+ ligand['one_hot'] = torch.cat(([ligand['one_hot'], virt_one_hot]))
604
+ ligand['virtual_mask'] = virt_mask
605
+ ligand['size'] = len(ligand['x'])
606
+
607
+ # Bonds
608
+ new_bonds = torch.triu_indices(ligand['size'], ligand['size'], offset=1)
609
+
610
+ bond_types = torch.ones(ligand['size'], ligand['size'], dtype=INT_TYPE) * self.bidx
611
+ row, col = ligand['bonds']
612
+ bond_types[row, col] = ligand['bond_one_hot'].argmax(dim=1)
613
+ new_row, new_col = new_bonds
614
+ bond_types = bond_types[new_row, new_col]
615
+
616
+ ligand['bonds'] = new_bonds
617
+ ligand['bond_one_hot'] = F.one_hot(bond_types, num_classes=len(self.bond_encoder)).to(ligand['bond_one_hot'].dtype)
618
+ ligand['n_bonds'] = len(ligand['bond_one_hot'])
619
+
620
+ return ligand
621
+
622
+
623
+ def rdmol_to_smiles(rdmol):
624
+ mol = Chem.Mol(rdmol)
625
+ Chem.RemoveStereochemistry(mol)
626
+ mol = Chem.RemoveHs(mol)
627
+ return Chem.MolToSmiles(mol)
628
+
629
+
630
+ def get_n_nodes(lig_positions, pocket_positions, smooth_sigma=None):
631
+ # Joint distribution of ligand's and pocket's number of nodes
632
+ n_nodes_lig = [len(x) for x in lig_positions]
633
+ n_nodes_pocket = [len(x) for x in pocket_positions]
634
+
635
+ joint_histogram = np.zeros((np.max(n_nodes_lig) + 1,
636
+ np.max(n_nodes_pocket) + 1))
637
+
638
+ for nlig, npocket in zip(n_nodes_lig, n_nodes_pocket):
639
+ joint_histogram[nlig, npocket] += 1
640
+
641
+ print(f'Original histogram: {np.count_nonzero(joint_histogram)}/'
642
+ f'{joint_histogram.shape[0] * joint_histogram.shape[1]} bins filled')
643
+
644
+ # Smooth the histogram
645
+ if smooth_sigma is not None:
646
+ filtered_histogram = gaussian_filter(
647
+ joint_histogram, sigma=smooth_sigma, order=0, mode='constant',
648
+ cval=0.0, truncate=4.0)
649
+
650
+ print(f'Smoothed histogram: {np.count_nonzero(filtered_histogram)}/'
651
+ f'{filtered_histogram.shape[0] * filtered_histogram.shape[1]} bins filled')
652
+
653
+ joint_histogram = filtered_histogram
654
+
655
+ return joint_histogram
656
+
657
+
658
+ # def get_type_histograms(lig_one_hot, pocket_one_hot, lig_encoder, pocket_encoder):
659
+ #
660
+ # lig_one_hot = np.concatenate(lig_one_hot, axis=0)
661
+ # pocket_one_hot = np.concatenate(pocket_one_hot, axis=0)
662
+ #
663
+ # atom_decoder = list(lig_encoder.keys())
664
+ # lig_counts = {k: 0 for k in lig_encoder.keys()}
665
+ # for a in [atom_decoder[x] for x in lig_one_hot.argmax(1)]:
666
+ # lig_counts[a] += 1
667
+ #
668
+ # aa_decoder = list(pocket_encoder.keys())
669
+ # pocket_counts = {k: 0 for k in pocket_encoder.keys()}
670
+ # for r in [aa_decoder[x] for x in pocket_one_hot.argmax(1)]:
671
+ # pocket_counts[r] += 1
672
+ #
673
+ # return lig_counts, pocket_counts
674
+
675
+
676
+ def get_type_histogram(one_hot, type_encoder):
677
+
678
+ one_hot = np.concatenate(one_hot, axis=0)
679
+
680
+ decoder = list(type_encoder.keys())
681
+ counts = {k: 0 for k in type_encoder.keys()}
682
+ for a in [decoder[x] for x in one_hot.argmax(1)]:
683
+ counts[a] += 1
684
+
685
+ return counts
686
+
687
+
688
+ def get_residue_with_resi(pdb_chain, resi):
689
+ res = [x for x in pdb_chain.get_residues() if x.id[1] == resi]
690
+ assert len(res) == 1
691
+ return res[0]
692
+
693
+
694
+ def get_pocket_from_ligand(pdb_model, ligand, dist_cutoff=8.0):
695
+
696
+ if ligand.endswith(".sdf"):
697
+ # ligand as sdf file
698
+ rdmol = Chem.SDMolSupplier(str(ligand))[0]
699
+ ligand_coords = torch.from_numpy(rdmol.GetConformer().GetPositions()).float()
700
+ resi = None
701
+ else:
702
+ # ligand contained in PDB; given in <chain>:<resi> format
703
+ chain, resi = ligand.split(':')
704
+ ligand = get_residue_with_resi(pdb_model[chain], int(resi))
705
+ ligand_coords = torch.from_numpy(
706
+ np.array([a.get_coord() for a in ligand.get_atoms()]))
707
+
708
+ pocket_residues = []
709
+ for residue in pdb_model.get_residues():
710
+ if residue.id[1] == resi:
711
+ continue # skip ligand itself
712
+
713
+ res_coords = torch.from_numpy(
714
+ np.array([a.get_coord() for a in residue.get_atoms()]))
715
+ if is_aa(residue.get_resname(), standard=True) \
716
+ and torch.cdist(res_coords, ligand_coords).min() < dist_cutoff:
717
+ pocket_residues.append(residue)
718
+
719
+ return pocket_residues
720
+
721
+
722
+ def encode_residues(biopython_residues, type_encoder, level='atom',
723
+ remove_H=True):
724
+ assert level in {'atom', 'residue'}
725
+
726
+ if level == 'atom':
727
+ entities = [a for res in biopython_residues for a in res.get_atoms()
728
+ if (a.element != 'H' or not remove_H)]
729
+ types = [a.element.capitalize() for a in entities]
730
+ else:
731
+ entities = [res['CA'] for res in biopython_residues]
732
+ types = [protein_letters_3to1[res.get_resname()] for res in biopython_residues]
733
+
734
+ coord = torch.tensor(np.stack([e.get_coord() for e in entities]))
735
+ one_hot = F.one_hot(torch.tensor([type_encoder[t] for t in types]),
736
+ num_classes=len(type_encoder))
737
+
738
+ return coord, one_hot
739
+
740
+
741
+ def center_data(ligand, pocket):
742
+ if pocket['x'].numel() > 0:
743
+ pocket_com = pocket.center()
744
+ else:
745
+ pocket_com = scatter_mean(ligand['x'], ligand['mask'], dim=0)
746
+
747
+ ligand['x'] = ligand['x'] - pocket_com[ligand['mask']]
748
+ return ligand, pocket
749
+
750
+
751
+ def get_bb_transform(n_xyz, ca_xyz, c_xyz):
752
+ """
753
+ Compute translation and rotation of the canoncical backbone frame (triangle N-Ca-C) from a position with
754
+ Ca at the origin, N on the x-axis and C in the xy-plane to the global position of the backbone frame
755
+
756
+ Args:
757
+ n_xyz: (n, 3)
758
+ ca_xyz: (n, 3)
759
+ c_xyz: (n, 3)
760
+
761
+ Returns:
762
+ axis-angle representation of the rotation, shape (n, 3) # rotation matrix of shape (n, 3, 3)
763
+ translation vector of shape (n, 3)
764
+ """
765
+
766
+ def rotation_matrix(angle, axis):
767
+ axis_mapping = {'x': 0, 'y': 1, 'z': 2}
768
+ axis = axis_mapping[axis]
769
+ vector = torch.zeros(len(angle), 3)
770
+ vector[:, axis] = 1
771
+ # return axis_angle_to_matrix(angle * vector)
772
+ return so3.matrix_from_rotation_vector(angle.view(-1, 1) * vector)
773
+
774
+ translation = ca_xyz
775
+ n_xyz = n_xyz - translation
776
+ c_xyz = c_xyz - translation
777
+
778
+ # Find rotation matrix that aligns the coordinate systems
779
+
780
+ # rotate around y-axis to move N into the xy-plane
781
+ theta_y = torch.arctan2(n_xyz[:, 2], -n_xyz[:, 0])
782
+ Ry = rotation_matrix(theta_y, 'y')
783
+ Ry = Ry.transpose(2, 1)
784
+ n_xyz = torch.einsum('noi,ni->no', Ry, n_xyz)
785
+
786
+ # rotate around z-axis to move N onto the x-axis
787
+ theta_z = torch.arctan2(n_xyz[:, 1], n_xyz[:, 0])
788
+ Rz = rotation_matrix(theta_z, 'z')
789
+ Rz = Rz.transpose(2, 1)
790
+ # print(torch.einsum('noi,ni->no', Rz, n_xyz))
791
+
792
+ # n_xyz = torch.einsum('noi,ni->no', Rz.transpose(0, 2, 1), n_xyz)
793
+
794
+ # rotate around x-axis to move C into the xy-plane
795
+ c_xyz = torch.einsum('noj,nji,ni->no', Rz, Ry, c_xyz)
796
+ theta_x = torch.arctan2(c_xyz[:, 2], c_xyz[:, 1])
797
+ Rx = rotation_matrix(theta_x, 'x')
798
+ Rx = Rx.transpose(2, 1)
799
+ # print(torch.einsum('noi,ni->no', Rx, c_xyz))
800
+
801
+ # Final rotation matrix
802
+ Ry = Ry.transpose(2, 1)
803
+ Rz = Rz.transpose(2, 1)
804
+ Rx = Rx.transpose(2, 1)
805
+ R = torch.einsum('nok,nkj,nji->noi', Ry, Rz, Rx)
806
+
807
+ # return R, translation
808
+ # return matrix_to_axis_angle(R), translation
809
+ return so3.rotation_vector_from_matrix(R), translation
810
+
811
+
812
+ class Residues(TensorDict):
813
+ """
814
+ Dictionary-like container for residues that supports some basic transformations.
815
+ """
816
+
817
+ # all keys
818
+ KEYS = {'x', 'one_hot', 'bonds', 'bond_one_hot', 'v', 'nma_vec', 'fixed_coord',
819
+ 'atom_mask', 'nerf_indices', 'length', 'theta', 'chi', 'ddihedral',
820
+ 'chi_indices', 'axis_angle', 'mask', 'bond_mask'}
821
+
822
+ # coordinate-type values, shape (..., 3)
823
+ COORD_KEYS = {'x', 'fixed_coord'}
824
+
825
+ # vector-type values, shape (n_residues, n_feat, 3)
826
+ VECTOR_KEYS = {'v', 'nma_vec'}
827
+
828
+ # properties that change if the side chains and/or backbones are updated
829
+ MUTABLE_PROPS_SS_AND_BB = {'v'}
830
+
831
+ # properties that only change if the side chains are updated
832
+ MUTABLE_PROPS_SS = {'chi'}
833
+
834
+ # properties that only change if the backbones are updated
835
+ MUTABLE_PROPS_BB = {'x', 'fixed_coord', 'axis_angle', 'nma_vec'}
836
+
837
+ # properties that remain fixed in all cases
838
+ IMMUTABLE_PROPS = {'mask', 'one_hot', 'bonds', 'bond_one_hot', 'bond_mask',
839
+ 'atom_mask', 'nerf_indices', 'length', 'theta',
840
+ 'ddihedral', 'chi_indices', 'name', 'size', 'n_bonds'}
841
+
842
+ def copy(self):
843
+ data = super().copy()
844
+ return Residues(**data)
845
+
846
+ def deepcopy(self):
847
+ data = {k: v.clone() if torch.is_tensor(v) else deepcopy(v)
848
+ for k, v in self.items()}
849
+ return Residues(**data)
850
+
851
+ def center(self):
852
+ com = scatter_mean(self['x'], self['mask'], dim=0)
853
+ self['x'] = self['x'] - com[self['mask']]
854
+ self['fixed_coord'] = self['fixed_coord'] - com[self['mask']].unsqueeze(1)
855
+ return com
856
+
857
+ def set_empty_v(self):
858
+ self['v'] = torch.tensor([], device=self['x'].device)
859
+
860
+ @torch.no_grad()
861
+ def set_chi(self, chi_angles):
862
+ self['chi'][:, :5] = chi_angles
863
+ nerf_params = {k: self[k] for k in ['fixed_coord', 'atom_mask',
864
+ 'nerf_indices', 'length', 'theta',
865
+ 'chi', 'ddihedral', 'chi_indices']}
866
+ self['v'] = ic_to_coords(**nerf_params) - self['x'].unsqueeze(1)
867
+
868
+ @torch.no_grad()
869
+ def set_frame(self, new_ca_coord, new_axis_angle):
870
+ bb_coord = self['fixed_coord']
871
+ bb_coord = bb_coord - self['x'].unsqueeze(1)
872
+ rotmat_before = so3.matrix_from_rotation_vector(self['axis_angle'])
873
+ rotmat_after = so3.matrix_from_rotation_vector(new_axis_angle)
874
+ rotmat_diff = rotmat_after @ rotmat_before.transpose(-1, -2)
875
+ bb_coord = torch.einsum('boi,bai->bao', rotmat_diff, bb_coord)
876
+ bb_coord = bb_coord + new_ca_coord.unsqueeze(1)
877
+
878
+ self['x'] = new_ca_coord
879
+ self['axis_angle'] = new_axis_angle
880
+ self['fixed_coord'] = bb_coord
881
+ self['v'] = torch.einsum('boi,bai->bao', rotmat_diff, self['v'])
882
+
883
+ @staticmethod
884
+ def empty(device):
885
+ return Residues(
886
+ x=torch.zeros(1, 3, device=device).float(),
887
+ mask=torch.zeros(1, 1, device=device).long(),
888
+ size=torch.zeros(1, device=device).long(),
889
+ )
890
+
891
+
892
+ def randomize_tensors(tensor_dict, exclude_keys=None):
893
+ """Replace tensors with random tensors with the same shape."""
894
+ exclude_keys = set() if exclude_keys is None else set(exclude_keys)
895
+ for k, v in tensor_dict.items():
896
+ if isinstance(v, torch.Tensor) and k not in exclude_keys:
897
+ if torch.is_floating_point(v):
898
+ tensor_dict[k] = torch.randn_like(v)
899
+ else:
900
+ tensor_dict[k] = torch.randint_like(v, low=-42, high=42)
901
+ return tensor_dict
src/data/dataset.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import random
3
+ import warnings
4
+ import torch
5
+ import webdataset as wds
6
+
7
+ from pathlib import Path
8
+ from torch.utils.data import Dataset
9
+
10
+ from src.data.data_utils import TensorDict, collate_entity
11
+ from src.constants import WEBDATASET_SHARD_SIZE, WEBDATASET_VAL_SIZE
12
+
13
+
14
+ class ProcessedLigandPocketDataset(Dataset):
15
+ def __init__(self, pt_path, ligand_transform=None, pocket_transform=None,
16
+ catch_errors=False):
17
+
18
+ self.ligand_transform = ligand_transform
19
+ self.pocket_transform = pocket_transform
20
+ self.catch_errors = catch_errors
21
+ self.pt_path = pt_path
22
+
23
+ self.data = torch.load(pt_path)
24
+
25
+ # add number of nodes for convenience
26
+ for entity in ['ligands', 'pockets']:
27
+ self.data[entity]['size'] = torch.tensor([len(x) for x in self.data[entity]['x']])
28
+ self.data[entity]['n_bonds'] = torch.tensor([len(x) for x in self.data[entity]['bond_one_hot']])
29
+
30
+ def __len__(self):
31
+ return len(self.data['ligands']['name'])
32
+
33
+ def __getitem__(self, idx):
34
+ data = {}
35
+ data['ligand'] = {key: val[idx] for key, val in self.data['ligands'].items()}
36
+ data['pocket'] = {key: val[idx] for key, val in self.data['pockets'].items()}
37
+ try:
38
+ if self.ligand_transform is not None:
39
+ data['ligand'] = self.ligand_transform(data['ligand'])
40
+ if self.pocket_transform is not None:
41
+ data['pocket'] = self.pocket_transform(data['pocket'])
42
+ except (RuntimeError, ValueError) as e:
43
+ if self.catch_errors:
44
+ warnings.warn(f"{type(e).__name__}('{e}') in data transform. "
45
+ f"Returning random item instead")
46
+ # replace bad item with a random one
47
+ rand_idx = random.randint(0, len(self) - 1)
48
+ return self[rand_idx]
49
+ else:
50
+ raise e
51
+ return data
52
+
53
+ @staticmethod
54
+ def collate_fn(batch_pairs, ligand_transform=None):
55
+
56
+ out = {}
57
+ for entity in ['ligand', 'pocket']:
58
+ batch = [x[entity] for x in batch_pairs]
59
+
60
+ if entity == 'ligand' and ligand_transform is not None:
61
+ max_size = max(x['size'].item() for x in batch)
62
+ # TODO: might have to remove elements from batch if processing fails, warn user in that case
63
+ batch = [ligand_transform(x, max_size=max_size) for x in batch]
64
+
65
+ out[entity] = TensorDict(**collate_entity(batch))
66
+
67
+ return out
68
+
69
+
70
+ class ClusteredDataset(ProcessedLigandPocketDataset):
71
+ def __init__(self, pt_path, ligand_transform=None, pocket_transform=None,
72
+ catch_errors=False):
73
+ super().__init__(pt_path, ligand_transform, pocket_transform, catch_errors)
74
+ self.clusters = list(self.data['clusters'].values())
75
+
76
+ def __len__(self):
77
+ return len(self.clusters)
78
+
79
+ def __getitem__(self, cidx):
80
+ cluster_inds = self.clusters[cidx]
81
+ # idx = cluster_inds[random.randint(0, len(cluster_inds) - 1)]
82
+ idx = random.choice(cluster_inds)
83
+ return super().__getitem__(idx)
84
+
85
+ class DPODataset(ProcessedLigandPocketDataset):
86
+ def __init__(self, pt_path, ligand_transform=None, pocket_transform=None,
87
+ catch_errors=False):
88
+ self.ligand_transform = ligand_transform
89
+ self.pocket_transform = pocket_transform
90
+ self.catch_errors = catch_errors
91
+ self.pt_path = pt_path
92
+
93
+ self.data = torch.load(pt_path)
94
+
95
+ if not 'pockets' in self.data:
96
+ self.data['pockets'] = self.data['pockets_w']
97
+ if not 'ligands' in self.data:
98
+ self.data['ligands'] = self.data['ligands_w']
99
+
100
+ if (
101
+ len(self.data["ligands"]["name"])
102
+ != len(self.data["ligands_l"]["name"])
103
+ != len(self.data["pockets"]["name"])
104
+ ):
105
+ raise ValueError(
106
+ "Error while importing DPO Dataset: Number of ligands winning, ligands losing and pockets must be the same"
107
+ )
108
+
109
+ # add number of nodes for convenience
110
+ for entity in ['ligands', 'ligands_l', 'pockets']:
111
+ self.data[entity]['size'] = torch.tensor([len(x) for x in self.data[entity]['x']])
112
+ self.data[entity]['n_bonds'] = torch.tensor([len(x) for x in self.data[entity]['bond_one_hot']])
113
+
114
+ def __len__(self):
115
+ return len(self.data["ligands"]["name"])
116
+
117
+ def __getitem__(self, idx):
118
+ data = {}
119
+ data['ligand'] = {key: val[idx] for key, val in self.data['ligands'].items()}
120
+ data['ligand_l'] = {key: val[idx] for key, val in self.data['ligands_l'].items()}
121
+ data['pocket'] = {key: val[idx] for key, val in self.data['pockets'].items()}
122
+ try:
123
+ if self.ligand_transform is not None:
124
+ data['ligand'] = self.ligand_transform(data['ligand'])
125
+ data['ligand_l'] = self.ligand_transform(data['ligand_l'])
126
+ if self.pocket_transform is not None:
127
+ data['pocket'] = self.pocket_transform(data['pocket'])
128
+ except (RuntimeError, ValueError) as e:
129
+ if self.catch_errors:
130
+ warnings.warn(f"{type(e).__name__}('{e}') in data transform. "
131
+ f"Returning random item instead")
132
+ # replace bad item with a random one
133
+ rand_idx = random.randint(0, len(self) - 1)
134
+ return self[rand_idx]
135
+ else:
136
+ raise e
137
+ return data
138
+
139
+ @staticmethod
140
+ def collate_fn(batch_pairs, ligand_transform=None):
141
+
142
+ out = {}
143
+ for entity in ['ligand', 'ligand_l', 'pocket']:
144
+ batch = [x[entity] for x in batch_pairs]
145
+
146
+ if entity in ['ligand', 'ligand_l'] and ligand_transform is not None:
147
+ max_size = max(x['size'].item() for x in batch)
148
+ batch = [ligand_transform(x, max_size=max_size) for x in batch]
149
+
150
+ out[entity] = TensorDict(**collate_entity(batch))
151
+
152
+ return out
153
+
154
+ ##########################################
155
+ ############### WebDatasets ##############
156
+ ##########################################
157
+
158
+ class ProteinLigandWebDataset(wds.WebDataset):
159
+ @staticmethod
160
+ def collate_fn(batch_pairs, ligand_transform=None):
161
+ return ProcessedLigandPocketDataset.collate_fn(batch_pairs, ligand_transform)
162
+
163
+
164
+ def wds_decoder(key, value):
165
+ return torch.load(io.BytesIO(value))
166
+
167
+
168
+ def preprocess_wds_item(data):
169
+ out = {}
170
+ for entity in ['ligand', 'pocket']:
171
+ out[entity] = data['pt'][entity]
172
+ for attr in ['size', 'n_bonds']:
173
+ if torch.is_tensor(out[entity][attr]):
174
+ assert len(out[entity][attr]) == 0
175
+ out[entity][attr] = 0
176
+
177
+ return out
178
+
179
+
180
+ def get_wds(data_path, stage, ligand_transform=None, pocket_transform=None):
181
+ current_data_dir = Path(data_path, stage)
182
+ shards = sorted(current_data_dir.glob('shard-?????.tar'), key=lambda s: int(s.name.split('-')[-1].split('.')[0]))
183
+ min_shard = min(shards).name.split('-')[-1].split('.')[0]
184
+ max_shard = max(shards).name.split('-')[-1].split('.')[0]
185
+ total_size = (int(max_shard) - int(min_shard) + 1) * WEBDATASET_SHARD_SIZE if stage == 'train' else WEBDATASET_VAL_SIZE
186
+
187
+ url = f'{data_path}/{stage}/shard-{{{min_shard}..{max_shard}}}.tar'
188
+ ligand_transform_wrapper = lambda _data: _data
189
+ pocket_transform_wrapper = lambda _data: _data
190
+
191
+ if ligand_transform is not None:
192
+ def ligand_transform_wrapper(_data):
193
+ _data['pt']['ligand'] = ligand_transform(_data['pt']['ligand'])
194
+ return _data
195
+
196
+ if pocket_transform is not None:
197
+ def pocket_transform_wrapper(_data):
198
+ _data['pt']['pocket'] = pocket_transform(_data['pt']['pocket'])
199
+ return _data
200
+
201
+ return (
202
+ ProteinLigandWebDataset(url, nodesplitter=wds.split_by_node)
203
+ .decode(wds_decoder)
204
+ .map(ligand_transform_wrapper)
205
+ .map(pocket_transform_wrapper)
206
+ .map(preprocess_wds_item)
207
+ .with_length(total_size)
208
+ )
src/data/misc.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From: https://github.com/biopython/biopython/blob/master/Bio/PDB/Polypeptide.py#L128
2
+
3
+ protein_letters_1to3 = {'A': 'ALA', 'C': 'CYS', 'D': 'ASP', 'E': 'GLU', 'F': 'PHE', 'G': 'GLY', 'H': 'HIS', 'I': 'ILE', 'K': 'LYS', 'L': 'LEU', 'M': 'MET', 'N': 'ASN', 'P': 'PRO', 'Q': 'GLN', 'R': 'ARG', 'S': 'SER', 'T': 'THR', 'V': 'VAL', 'W': 'TRP', 'Y': 'TYR'}
4
+
5
+
6
+ protein_letters_3to1 = {'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'}
7
+
8
+
9
+ protein_letters_3to1_extended = {'A5N': 'N', 'A8E': 'V', 'A9D': 'S', 'AA3': 'A', 'AA4': 'A', 'AAR': 'R', 'ABA': 'A', 'ACL': 'R', 'AEA': 'C', 'AEI': 'D', 'AFA': 'N', 'AGM': 'R', 'AGQ': 'Y', 'AGT': 'C', 'AHB': 'N', 'AHL': 'R', 'AHO': 'A', 'AHP': 'A', 'AIB': 'A', 'AKL': 'D', 'AKZ': 'D', 'ALA': 'A', 'ALC': 'A', 'ALM': 'A', 'ALN': 'A', 'ALO': 'T', 'ALS': 'A', 'ALT': 'A', 'ALV': 'A', 'ALY': 'K', 'AME': 'M', 'AN6': 'L', 'AN8': 'A', 'API': 'K', 'APK': 'K', 'AR2': 'R', 'AR4': 'E', 'AR7': 'R', 'ARG': 'R', 'ARM': 'R', 'ARO': 'R', 'AS7': 'N', 'ASA': 'D', 'ASB': 'D', 'ASI': 'D', 'ASK': 'D', 'ASL': 'D', 'ASN': 'N', 'ASP': 'D', 'ASQ': 'D', 'AYA': 'A', 'AZH': 'A', 'AZK': 'K', 'AZS': 'S', 'AZY': 'Y', 'AVJ': 'H', 'A30': 'Y', 'A3U': 'F', 'ECC': 'Q', 'ECX': 'C', 'EFC': 'C', 'EHP': 'F', 'ELY': 'K', 'EME': 'E', 'EPM': 'M', 'EPQ': 'Q', 'ESB': 'Y', 'ESC': 'M', 'EXY': 'L', 'EXA': 'K', 'E0Y': 'P', 'E9V': 'H', 'E9M': 'W', 'EJA': 'C', 'EUP': 'T', 'EZY': 'G', 'E9C': 'Y', 'EW6': 'S', 'EXL': 'W', 'I2M': 'I', 'I4G': 'G', 'I58': 'K', 'IAM': 'A', 'IAR': 'R', 'ICY': 'C', 'IEL': 'K', 'IGL': 'G', 'IIL': 'I', 'ILE': 'I', 'ILG': 'E', 'ILM': 'I', 'ILX': 'I', 'ILY': 'K', 'IML': 'I', 'IOR': 'R', 'IPG': 'G', 'IT1': 'K', 'IYR': 'Y', 'IZO': 'M', 'IC0': 'G', 'M0H': 'C', 'M2L': 'K', 'M2S': 'M', 'M30': 'G', 'M3L': 'K', 'M3R': 'K', 'MA ': 'A', 'MAA': 'A', 'MAI': 'R', 'MBQ': 'Y', 'MC1': 'S', 'MCL': 'K', 'MCS': 'C', 'MD3': 'C', 'MD5': 'C', 'MD6': 'G', 'MDF': 'Y', 'ME0': 'M', 'MEA': 'F', 'MEG': 'E', 'MEN': 'N', 'MEQ': 'Q', 'MET': 'M', 'MEU': 'G', 'MFN': 'E', 'MGG': 'R', 'MGN': 'Q', 'MGY': 'G', 'MH1': 'H', 'MH6': 'S', 'MHL': 'L', 'MHO': 'M', 'MHS': 'H', 'MHU': 'F', 'MIR': 'S', 'MIS': 'S', 'MK8': 'L', 'ML3': 'K', 'MLE': 'L', 'MLL': 'L', 'MLY': 'K', 'MLZ': 'K', 'MME': 'M', 'MMO': 'R', 'MNL': 'L', 'MNV': 'V', 'MP8': 'P', 'MPQ': 'G', 'MSA': 'G', 'MSE': 'M', 'MSL': 'M', 'MSO': 'M', 'MT2': 'M', 'MTY': 'Y', 'MVA': 'V', 'MYK': 'K', 'MYN': 'R', 'QCS': 'C', 'QIL': 'I', 'QMM': 'Q', 'QPA': 'C', 'QPH': 'F', 'Q3P': 'K', 'QVA': 'C', 'QX7': 'A', 'Q2E': 'W', 'Q75': 'M', 'Q78': 'F', 'QM8': 'L', 'QMB': 'A', 'QNQ': 'C', 'QNT': 'C', 'QNW': 'C', 'QO2': 'C', 'QO5': 'C', 'QO8': 'C', 'QQ8': 'Q', 'U2X': 'Y', 'U3X': 'F', 'UF0': 'S', 'UGY': 'G', 'UM1': 'A', 'UM2': 'A', 'UMA': 'A', 'UQK': 'A', 'UX8': 'W', 'UXQ': 'F', 'YCM': 'C', 'YOF': 'Y', 'YPR': 'P', 'YPZ': 'Y', 'YTH': 'T', 'Y1V': 'L', 'Y57': 'K', 'YHA': 'K', '200': 'F', '23F': 'F', '23P': 'A', '26B': 'T', '28X': 'T', '2AG': 'A', '2CO': 'C', '2FM': 'M', '2GX': 'F', '2HF': 'H', '2JG': 'S', '2KK': 'K', '2KP': 'K', '2LT': 'Y', '2LU': 'L', '2ML': 'L', '2MR': 'R', '2MT': 'P', '2OR': 'R', '2P0': 'P', '2QZ': 'T', '2R3': 'Y', '2RA': 'A', '2RX': 'S', '2SO': 'H', '2TY': 'Y', '2VA': 'V', '2XA': 'C', '2ZC': 'S', '6CL': 'K', '6CW': 'W', '6GL': 'A', '6HN': 'K', '60F': 'C', '66D': 'I', '6CV': 'A', '6M6': 'C', '6V1': 'C', '6WK': 'C', '6Y9': 'P', '6DN': 'K', 'DA2': 'R', 'DAB': 'A', 'DAH': 'F', 'DBS': 'S', 'DBU': 'T', 'DBY': 'Y', 'DBZ': 'A', 'DC2': 'C', 'DDE': 'H', 'DDZ': 'A', 'DI7': 'Y', 'DHA': 'S', 'DHN': 'V', 'DIR': 'R', 'DLS': 'K', 'DM0': 'K', 'DMH': 'N', 'DMK': 'D', 'DNL': 'K', 'DNP': 'A', 'DNS': 'K', 'DNW': 'A', 'DOH': 'D', 'DON': 'L', 'DP1': 'R', 'DPL': 'P', 'DPP': 'A', 'DPQ': 'Y', 'DYS': 'C', 'D2T': 'D', 'DYA': 'D', 'DJD': 'F', 'DYJ': 'P', 'DV9': 'E', 'H14': 'F', 'H1D': 'M', 'H5M': 'P', 'HAC': 'A', 'HAR': 'R', 'HBN': 'H', 'HCM': 'C', 'HGY': 'G', 'HHI': 'H', 'HIA': 'H', 'HIC': 'H', 'HIP': 'H', 'HIQ': 'H', 'HIS': 'H', 'HL2': 'L', 'HLU': 'L', 'HMR': 'R', 'HNC': 'C', 'HOX': 'F', 'HPC': 'F', 'HPE': 'F', 'HPH': 'F', 'HPQ': 'F', 'HQA': 'A', 'HR7': 'R', 'HRG': 'R', 'HRP': 'W', 'HS8': 'H', 'HS9': 'H', 'HSE': 'S', 'HSK': 'H', 'HSL': 'S', 'HSO': 'H', 'HT7': 'W', 'HTI': 'C', 'HTR': 'W', 'HV5': 'A', 'HVA': 'V', 'HY3': 'P', 'HYI': 'M', 'HYP': 'P', 'HZP': 'P', 'HIX': 'A', 'HSV': 'H', 'HLY': 'K', 'HOO': 'H', 'H7V': 'A', 'L5P': 'K', 'LRK': 'K', 'L3O': 'L', 'LA2': 'K', 'LAA': 'D', 'LAL': 'A', 'LBY': 'K', 'LCK': 'K', 'LCX': 'K', 'LDH': 'K', 'LE1': 'V', 'LED': 'L', 'LEF': 'L', 'LEH': 'L', 'LEM': 'L', 'LEN': 'L', 'LET': 'K', 'LEU': 'L', 'LEX': 'L', 'LGY': 'K', 'LLO': 'K', 'LLP': 'K', 'LLY': 'K', 'LLZ': 'K', 'LME': 'E', 'LMF': 'K', 'LMQ': 'Q', 'LNE': 'L', 'LNM': 'L', 'LP6': 'K', 'LPD': 'P', 'LPG': 'G', 'LPS': 'S', 'LSO': 'K', 'LTR': 'W', 'LVG': 'G', 'LVN': 'V', 'LWY': 'P', 'LYF': 'K', 'LYK': 'K', 'LYM': 'K', 'LYN': 'K', 'LYO': 'K', 'LYP': 'K', 'LYR': 'K', 'LYS': 'K', 'LYU': 'K', 'LYX': 'K', 'LYZ': 'K', 'LAY': 'L', 'LWI': 'F', 'LBZ': 'K', 'P1L': 'C', 'P2Q': 'Y', 'P2Y': 'P', 'P3Q': 'Y', 'PAQ': 'Y', 'PAS': 'D', 'PAT': 'W', 'PBB': 'C', 'PBF': 'F', 'PCA': 'Q', 'PCC': 'P', 'PCS': 'F', 'PE1': 'K', 'PEC': 'C', 'PF5': 'F', 'PFF': 'F', 'PG1': 'S', 'PGY': 'G', 'PHA': 'F', 'PHD': 'D', 'PHE': 'F', 'PHI': 'F', 'PHL': 'F', 'PHM': 'F', 'PKR': 'P', 'PLJ': 'P', 'PM3': 'F', 'POM': 'P', 'PPN': 'F', 'PR3': 'C', 'PR4': 'P', 'PR7': 'P', 'PR9': 'P', 'PRJ': 'P', 'PRK': 'K', 'PRO': 'P', 'PRS': 'P', 'PRV': 'G', 'PSA': 'F', 'PSH': 'H', 'PTH': 'Y', 'PTM': 'Y', 'PTR': 'Y', 'PVH': 'H', 'PXU': 'P', 'PYA': 'A', 'PYH': 'K', 'PYX': 'C', 'PH6': 'P', 'P9S': 'C', 'P5U': 'S', 'POK': 'R', 'T0I': 'Y', 'T11': 'F', 'TAV': 'D', 'TBG': 'V', 'TBM': 'T', 'TCQ': 'Y', 'TCR': 'W', 'TEF': 'F', 'TFQ': 'F', 'TH5': 'T', 'TH6': 'T', 'THC': 'T', 'THR': 'T', 'THZ': 'R', 'TIH': 'A', 'TIS': 'S', 'TLY': 'K', 'TMB': 'T', 'TMD': 'T', 'TNB': 'C', 'TNR': 'S', 'TNY': 'T', 'TOQ': 'W', 'TOX': 'W', 'TPJ': 'P', 'TPK': 'P', 'TPL': 'W', 'TPO': 'T', 'TPQ': 'Y', 'TQI': 'W', 'TQQ': 'W', 'TQZ': 'C', 'TRF': 'W', 'TRG': 'K', 'TRN': 'W', 'TRO': 'W', 'TRP': 'W', 'TRQ': 'W', 'TRW': 'W', 'TRX': 'W', 'TRY': 'W', 'TS9': 'I', 'TSY': 'C', 'TTQ': 'W', 'TTS': 'Y', 'TXY': 'Y', 'TY1': 'Y', 'TY2': 'Y', 'TY3': 'Y', 'TY5': 'Y', 'TY8': 'Y', 'TY9': 'Y', 'TYB': 'Y', 'TYC': 'Y', 'TYE': 'Y', 'TYI': 'Y', 'TYJ': 'Y', 'TYN': 'Y', 'TYO': 'Y', 'TYQ': 'Y', 'TYR': 'Y', 'TYS': 'Y', 'TYT': 'Y', 'TYW': 'Y', 'TYY': 'Y', 'T8L': 'T', 'T9E': 'T', 'TNQ': 'W', 'TSQ': 'F', 'TGH': 'W', 'X2W': 'E', 'XCN': 'C', 'XPR': 'P', 'XSN': 'N', 'XW1': 'A', 'XX1': 'K', 'XYC': 'A', 'XA6': 'F', '11Q': 'P', '11W': 'E', '12L': 'P', '12X': 'P', '12Y': 'P', '143': 'C', '1AC': 'A', '1L1': 'A', '1OP': 'Y', '1PA': 'F', '1PI': 'A', '1TQ': 'W', '1TY': 'Y', '1X6': 'S', '56A': 'H', '5AB': 'A', '5CS': 'C', '5CW': 'W', '5HP': 'E', '5OH': 'A', '5PG': 'G', '51T': 'Y', '54C': 'W', '5CR': 'F', '5CT': 'K', '5FQ': 'A', '5GM': 'I', '5JP': 'S', '5T3': 'K', '5MW': 'K', '5OW': 'K', '5R5': 'S', '5VV': 'N', '5XU': 'A', '55I': 'F', '999': 'D', '9DN': 'N', '9NE': 'E', '9NF': 'F', '9NR': 'R', '9NV': 'V', '9E7': 'K', '9KP': 'K', '9WV': 'A', '9TR': 'K', '9TU': 'K', '9TX': 'K', '9U0': 'K', '9IJ': 'F', 'B1F': 'F', 'B27': 'T', 'B2A': 'A', 'B2F': 'F', 'B2I': 'I', 'B2V': 'V', 'B3A': 'A', 'B3D': 'D', 'B3E': 'E', 'B3K': 'K', 'B3U': 'H', 'B3X': 'N', 'B3Y': 'Y', 'BB6': 'C', 'BB7': 'C', 'BB8': 'F', 'BB9': 'C', 'BBC': 'C', 'BCS': 'C', 'BCX': 'C', 'BFD': 'D', 'BG1': 'S', 'BH2': 'D', 'BHD': 'D', 'BIF': 'F', 'BIU': 'I', 'BL2': 'L', 'BLE': 'L', 'BLY': 'K', 'BMT': 'T', 'BNN': 'F', 'BOR': 'R', 'BP5': 'A', 'BPE': 'C', 'BSE': 'S', 'BTA': 'L', 'BTC': 'C', 'BTK': 'K', 'BTR': 'W', 'BUC': 'C', 'BUG': 'V', 'BYR': 'Y', 'BWV': 'R', 'BWB': 'S', 'BXT': 'S', 'F2F': 'F', 'F2Y': 'Y', 'FAK': 'K', 'FB5': 'A', 'FB6': 'A', 'FC0': 'F', 'FCL': 'F', 'FDL': 'K', 'FFM': 'C', 'FGL': 'G', 'FGP': 'S', 'FH7': 'K', 'FHL': 'K', 'FHO': 'K', 'FIO': 'R', 'FLA': 'A', 'FLE': 'L', 'FLT': 'Y', 'FME': 'M', 'FOE': 'C', 'FP9': 'P', 'FPK': 'P', 'FT6': 'W', 'FTR': 'W', 'FTY': 'Y', 'FVA': 'V', 'FZN': 'K', 'FY3': 'Y', 'F7W': 'W', 'FY2': 'Y', 'FQA': 'K', 'F7Q': 'Y', 'FF9': 'K', 'FL6': 'D', 'JJJ': 'C', 'JJK': 'C', 'JJL': 'C', 'JLP': 'K', 'J3D': 'C', 'J9Y': 'R', 'J8W': 'S', 'JKH': 'P', 'N10': 'S', 'N7P': 'P', 'NA8': 'A', 'NAL': 'A', 'NAM': 'A', 'NBQ': 'Y', 'NC1': 'S', 'NCB': 'A', 'NEM': 'H', 'NEP': 'H', 'NFA': 'F', 'NIY': 'Y', 'NLB': 'L', 'NLE': 'L', 'NLN': 'L', 'NLO': 'L', 'NLP': 'L', 'NLQ': 'Q', 'NLY': 'G', 'NMC': 'G', 'NMM': 'R', 'NNH': 'R', 'NOT': 'L', 'NPH': 'C', 'NPI': 'A', 'NTR': 'Y', 'NTY': 'Y', 'NVA': 'V', 'NWD': 'A', 'NYB': 'C', 'NYS': 'C', 'NZH': 'H', 'N80': 'P', 'NZC': 'T', 'NLW': 'L', 'N0A': 'F', 'N9P': 'A', 'N65': 'K', 'R1A': 'C', 'R4K': 'W', 'RE0': 'W', 'RE3': 'W', 'RGL': 'R', 'RGP': 'E', 'RT0': 'P', 'RVX': 'S', 'RZ4': 'S', 'RPI': 'R', 'RVJ': 'A', 'VAD': 'V', 'VAF': 'V', 'VAH': 'V', 'VAI': 'V', 'VAL': 'V', 'VB1': 'K', 'VH0': 'P', 'VR0': 'R', 'V44': 'C', 'V61': 'F', 'VPV': 'K', 'V5N': 'H', 'V7T': 'K', 'Z01': 'A', 'Z3E': 'T', 'Z70': 'H', 'ZBZ': 'C', 'ZCL': 'F', 'ZU0': 'T', 'ZYJ': 'P', 'ZYK': 'P', 'ZZD': 'C', 'ZZJ': 'A', 'ZIQ': 'W', 'ZPO': 'P', 'ZDJ': 'Y', 'ZT1': 'K', '30V': 'C', '31Q': 'C', '33S': 'F', '33W': 'A', '34E': 'V', '3AH': 'H', '3BY': 'P', '3CF': 'F', '3CT': 'Y', '3GA': 'A', '3GL': 'E', '3MD': 'D', '3MY': 'Y', '3NF': 'Y', '3O3': 'E', '3PX': 'P', '3QN': 'K', '3TT': 'P', '3XH': 'G', '3YM': 'Y', '3WS': 'A', '3WX': 'P', '3X9': 'C', '3ZH': 'H', '7JA': 'I', '73C': 'S', '73N': 'R', '73O': 'Y', '73P': 'K', '74P': 'K', '7N8': 'F', '7O5': 'A', '7XC': 'F', '7ID': 'D', '7OZ': 'A', 'C1S': 'C', 'C1T': 'C', 'C1X': 'K', 'C22': 'A', 'C3Y': 'C', 'C4R': 'C', 'C5C': 'C', 'C6C': 'C', 'CAF': 'C', 'CAS': 'C', 'CAY': 'C', 'CCS': 'C', 'CEA': 'C', 'CGA': 'E', 'CGU': 'E', 'CGV': 'C', 'CHP': 'G', 'CIR': 'R', 'CLE': 'L', 'CLG': 'K', 'CLH': 'K', 'CME': 'C', 'CMH': 'C', 'CML': 'C', 'CMT': 'C', 'CR5': 'G', 'CS0': 'C', 'CS1': 'C', 'CS3': 'C', 'CS4': 'C', 'CSA': 'C', 'CSB': 'C', 'CSD': 'C', 'CSE': 'C', 'CSJ': 'C', 'CSO': 'C', 'CSP': 'C', 'CSR': 'C', 'CSS': 'C', 'CSU': 'C', 'CSW': 'C', 'CSX': 'C', 'CSZ': 'C', 'CTE': 'W', 'CTH': 'T', 'CWD': 'A', 'CWR': 'S', 'CXM': 'M', 'CY0': 'C', 'CY1': 'C', 'CY3': 'C', 'CY4': 'C', 'CYA': 'C', 'CYD': 'C', 'CYF': 'C', 'CYG': 'C', 'CYJ': 'K', 'CYM': 'C', 'CYQ': 'C', 'CYR': 'C', 'CYS': 'C', 'CYW': 'C', 'CZ2': 'C', 'CZZ': 'C', 'CG6': 'C', 'C1J': 'R', 'C4G': 'R', 'C67': 'R', 'C6D': 'R', 'CE7': 'N', 'CZS': 'A', 'G01': 'E', 'G8M': 'E', 'GAU': 'E', 'GEE': 'G', 'GFT': 'S', 'GHC': 'E', 'GHG': 'Q', 'GHW': 'E', 'GL3': 'G', 'GLH': 'Q', 'GLJ': 'E', 'GLK': 'E', 'GLN': 'Q', 'GLQ': 'E', 'GLU': 'E', 'GLY': 'G', 'GLZ': 'G', 'GMA': 'E', 'GME': 'E', 'GNC': 'Q', 'GPL': 'K', 'GSC': 'G', 'GSU': 'E', 'GT9': 'C', 'GVL': 'S', 'G3M': 'R', 'G5G': 'L', 'G1X': 'Y', 'G8X': 'P', 'K1R': 'C', 'KBE': 'K', 'KCX': 'K', 'KFP': 'K', 'KGC': 'K', 'KNB': 'A', 'KOR': 'M', 'KPI': 'K', 'KPY': 'K', 'KST': 'K', 'KYN': 'W', 'KYQ': 'K', 'KCR': 'K', 'KPF': 'K', 'K5L': 'S', 'KEO': 'K', 'KHB': 'K', 'KKD': 'D', 'K5H': 'C', 'K7K': 'S', 'OAR': 'R', 'OAS': 'S', 'OBS': 'K', 'OCS': 'C', 'OCY': 'C', 'OHI': 'H', 'OHS': 'D', 'OLD': 'H', 'OLT': 'T', 'OLZ': 'S', 'OMH': 'S', 'OMT': 'M', 'OMX': 'Y', 'OMY': 'Y', 'ONH': 'A', 'ORN': 'A', 'ORQ': 'R', 'OSE': 'S', 'OTH': 'T', 'OXX': 'D', 'OYL': 'H', 'O7A': 'T', 'O7D': 'W', 'O7G': 'V', 'O2E': 'S', 'O6H': 'W', 'OZW': 'F', 'S12': 'S', 'S1H': 'S', 'S2C': 'C', 'S2P': 'A', 'SAC': 'S', 'SAH': 'C', 'SAR': 'G', 'SBG': 'S', 'SBL': 'S', 'SCH': 'C', 'SCS': 'C', 'SCY': 'C', 'SD4': 'N', 'SDB': 'S', 'SDP': 'S', 'SEB': 'S', 'SEE': 'S', 'SEG': 'A', 'SEL': 'S', 'SEM': 'S', 'SEN': 'S', 'SEP': 'S', 'SER': 'S', 'SET': 'S', 'SGB': 'S', 'SHC': 'C', 'SHP': 'G', 'SHR': 'K', 'SIB': 'C', 'SLL': 'K', 'SLZ': 'K', 'SMC': 'C', 'SME': 'M', 'SMF': 'F', 'SNC': 'C', 'SNN': 'N', 'SOY': 'S', 'SRZ': 'S', 'STY': 'Y', 'SUN': 'S', 'SVA': 'S', 'SVV': 'S', 'SVW': 'S', 'SVX': 'S', 'SVY': 'S', 'SVZ': 'S', 'SXE': 'S', 'SKH': 'K', 'SNM': 'S', 'SNK': 'H', 'SWW': 'S', 'WFP': 'F', 'WLU': 'L', 'WPA': 'F', 'WRP': 'W', 'WVL': 'V', '02K': 'A', '02L': 'N', '02O': 'A', '02Y': 'A', '033': 'V', '037': 'P', '03Y': 'C', '04U': 'P', '04V': 'P', '05N': 'P', '07O': 'C', '0A0': 'D', '0A1': 'Y', '0A2': 'K', '0A8': 'C', '0A9': 'F', '0AA': 'V', '0AB': 'V', '0AC': 'G', '0AF': 'W', '0AG': 'L', '0AH': 'S', '0AK': 'D', '0AR': 'R', '0BN': 'F', '0CS': 'A', '0E5': 'T', '0EA': 'Y', '0FL': 'A', '0LF': 'P', '0NC': 'A', '0PR': 'Y', '0QL': 'C', '0TD': 'D', '0UO': 'W', '0WZ': 'Y', '0X9': 'R', '0Y8': 'P', '4AF': 'F', '4AR': 'R', '4AW': 'W', '4BF': 'F', '4CF': 'F', '4CY': 'M', '4DP': 'W', '4FB': 'P', '4FW': 'W', '4HL': 'Y', '4HT': 'W', '4IN': 'W', '4MM': 'M', '4PH': 'F', '4U7': 'A', '41H': 'F', '41Q': 'N', '42Y': 'S', '432': 'S', '45F': 'P', '4AK': 'K', '4D4': 'R', '4GJ': 'C', '4KY': 'P', '4L0': 'P', '4LZ': 'Y', '4N7': 'P', '4N8': 'P', '4N9': 'P', '4OG': 'W', '4OU': 'F', '4OV': 'S', '4OZ': 'S', '4PQ': 'W', '4SJ': 'F', '4WQ': 'A', '4HH': 'S', '4HJ': 'S', '4J4': 'C', '4J5': 'R', '4II': 'F', '4VI': 'R', '823': 'N', '8SP': 'S', '8AY': 'A'}
10
+
11
+
12
+ def is_aa(residue, standard=False):
13
+ if not isinstance(residue, str):
14
+ residue = f"{residue.get_resname():<3s}"
15
+ residue = residue.upper()
16
+ if standard:
17
+ return residue in protein_letters_3to1
18
+ else:
19
+ return residue in protein_letters_3to1_extended
src/data/molecule_builder.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit import Chem
2
+
3
+ from src import constants
4
+
5
+
6
+ def remove_dummy_atoms(rdmol, sanitize=False):
7
+ # find exit atoms to be removed
8
+ dummy_inds = []
9
+ for a in rdmol.GetAtoms():
10
+ if a.GetSymbol() == '*':
11
+ dummy_inds.append(a.GetIdx())
12
+
13
+ dummy_inds = sorted(dummy_inds, reverse=True)
14
+ new_mol = Chem.EditableMol(rdmol)
15
+ for idx in dummy_inds:
16
+ new_mol.RemoveAtom(idx)
17
+ new_mol = new_mol.GetMol()
18
+ if sanitize:
19
+ Chem.SanitizeMol(new_mol)
20
+ return new_mol
21
+
22
+
23
+ def build_molecule(coords, atom_types, bonds=None, bond_types=None,
24
+ atom_props=None, atom_decoder=None, bond_decoder=None):
25
+ """
26
+ Build RDKit molecule with given bonds
27
+ :param coords: N x 3
28
+ :param atom_types: N
29
+ :param bonds: 2 x N_bonds
30
+ :param bond_types: N_bonds
31
+ :param atom_props: Dict, key: property name, value: list of float values (N,)
32
+ :param atom_decoder: list
33
+ :param bond_decoder: list
34
+ :return: RDKit molecule
35
+ """
36
+ if atom_decoder is None:
37
+ atom_decoder = constants.atom_decoder
38
+ if bond_decoder is None:
39
+ bond_decoder = constants.bond_decoder
40
+ assert len(coords) == len(atom_types)
41
+ assert bonds is None or bonds.size(1) == len(bond_types)
42
+
43
+ mol = Chem.RWMol()
44
+ for i, atom in enumerate(atom_types):
45
+ element = atom_decoder[atom.item()]
46
+ charge = None
47
+ explicitHs = None
48
+
49
+ if len(element) > 1 and element.endswith('H'):
50
+ explicitHs = 1
51
+ element = element[:-1]
52
+ elif element.endswith('+'):
53
+ charge = 1
54
+ element = element[:-1]
55
+ elif element.endswith('-'):
56
+ charge = -1
57
+ element = element[:-1]
58
+
59
+ if element == 'NOATOM':
60
+ # element = 'Xe' # debug
61
+ element = '*'
62
+
63
+ a = Chem.Atom(element)
64
+
65
+ if explicitHs is not None:
66
+ a.SetNumExplicitHs(explicitHs)
67
+ if charge is not None:
68
+ a.SetFormalCharge(charge)
69
+
70
+ if atom_props is not None:
71
+ for k, vals in atom_props.items():
72
+ a.SetDoubleProp(k, vals[i].item())
73
+
74
+ mol.AddAtom(a)
75
+
76
+ # add coordinates
77
+ conf = Chem.Conformer(mol.GetNumAtoms())
78
+ for i in range(mol.GetNumAtoms()):
79
+ conf.SetAtomPosition(i, (coords[i, 0].item(),
80
+ coords[i, 1].item(),
81
+ coords[i, 2].item()))
82
+ mol.AddConformer(conf)
83
+
84
+ # add bonds
85
+ if bonds is not None:
86
+ for bond, bond_type in zip(bonds.T, bond_types):
87
+ bond_type = bond_decoder[bond_type]
88
+ src = bond[0].item()
89
+ dst = bond[1].item()
90
+
91
+ # try:
92
+ if bond_type == 'NOBOND' or mol.GetAtomWithIdx(src).GetSymbol() == '*' or mol.GetAtomWithIdx(dst).GetSymbol() == '*':
93
+ continue
94
+ # except RuntimeError:
95
+ # from pdb import set_trace; set_trace()
96
+
97
+ if mol.GetBondBetweenAtoms(src, dst) is not None:
98
+ assert mol.GetBondBetweenAtoms(src, dst).GetBondType() == bond_type, \
99
+ "Trying to assign two different types to the same bond."
100
+ continue
101
+
102
+ if bond_type is None or src == dst:
103
+ continue
104
+ mol.AddBond(src, dst, bond_type)
105
+
106
+ mol = remove_dummy_atoms(mol, sanitize=False)
107
+ return mol
src/data/nerf.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Natural Extension Reference Frame (NERF)
3
+
4
+ Inspiration for parallel reconstruction:
5
+ https://github.com/EleutherAI/mp_nerf and references therein
6
+
7
+ For atom names, see also:
8
+ https://www.ccpn.ac.uk/manual/v3/NEFAtomNames.html
9
+
10
+ References:
11
+ - https://onlinelibrary.wiley.com/doi/10.1002/jcc.20237 (NERF)
12
+ - https://onlinelibrary.wiley.com/doi/10.1002/jcc.26768 (for code)
13
+ """
14
+
15
+ import warnings
16
+ import torch
17
+ import numpy as np
18
+
19
+ from src.data.misc import protein_letters_3to1
20
+ from src.constants import aa_atom_index, aa_atom_mask, aa_nerf_indices, aa_chi_indices, aa_chi_anchor_atom
21
+
22
+
23
+ # https://github.com/EleutherAI/mp_nerf/blob/master/mp_nerf/utils.py
24
+ def get_dihedral(c1, c2, c3, c4):
25
+ """ Returns the dihedral angle in radians.
26
+ Will use atan2 formula from:
27
+ https://en.wikipedia.org/wiki/Dihedral_angle#In_polymer_physics
28
+ Inputs:
29
+ * c1: (batch, 3) or (3,)
30
+ * c2: (batch, 3) or (3,)
31
+ * c3: (batch, 3) or (3,)
32
+ * c4: (batch, 3) or (3,)
33
+ """
34
+ u1 = c2 - c1
35
+ u2 = c3 - c2
36
+ u3 = c4 - c3
37
+
38
+ return torch.atan2( ( (torch.norm(u2, dim=-1, keepdim=True) * u1) * torch.cross(u2,u3, dim=-1) ).sum(dim=-1) ,
39
+ ( torch.cross(u1,u2, dim=-1) * torch.cross(u2, u3, dim=-1) ).sum(dim=-1) )
40
+
41
+
42
+ # https://github.com/EleutherAI/mp_nerf/blob/master/mp_nerf/utils.py
43
+ def get_angle(c1, c2, c3):
44
+ """ Returns the angle in radians.
45
+ Inputs:
46
+ * c1: (batch, 3) or (3,)
47
+ * c2: (batch, 3) or (3,)
48
+ * c3: (batch, 3) or (3,)
49
+ """
50
+ u1 = c2 - c1
51
+ u2 = c3 - c2
52
+
53
+ # dont use acos since norms involved.
54
+ # better use atan2 formula: atan2(cross, dot) from here:
55
+ # https://johnblackburne.blogspot.com/2012/05/angle-between-two-3d-vectors.html
56
+
57
+ # add a minus since we want the angle in reversed order - sidechainnet issues
58
+ return torch.atan2( torch.norm(torch.cross(u1,u2, dim=-1), dim=-1),
59
+ -(u1*u2).sum(dim=-1) )
60
+
61
+
62
+ def get_nerf_params(biopython_residue):
63
+ aa = protein_letters_3to1[biopython_residue.get_resname()]
64
+
65
+ # Basic mask and index tensors
66
+ atom_mask = torch.tensor(aa_atom_mask[aa], dtype=bool)
67
+ nerf_indices = torch.tensor(aa_nerf_indices[aa], dtype=int)
68
+ chi_indices = torch.tensor(aa_chi_indices[aa], dtype=int)
69
+
70
+ fixed_coord = torch.zeros((5, 3))
71
+ residue_coords = torch.zeros((14, 3)) # only required to compute internal coordinates during pre-processing
72
+ atom_found = torch.zeros_like(atom_mask)
73
+ for atom in biopython_residue.get_atoms():
74
+ try:
75
+ idx = aa_atom_index[aa][atom.get_name()]
76
+ atom_found[idx] = True
77
+ except KeyError:
78
+ warnings.warn(f"{atom.get_name()} not found")
79
+ continue
80
+
81
+ residue_coords[idx, :] = torch.from_numpy(atom.get_coord())
82
+
83
+ if atom.get_name() in ['N', 'CA', 'C', 'O', 'CB']:
84
+ fixed_coord[idx, :] = torch.from_numpy(atom.get_coord())
85
+
86
+ # Determine chi angles
87
+ chi = torch.zeros(6) # the last chi angle is a dummy and should always be zero
88
+ for chi_idx, anchor in aa_chi_anchor_atom[aa].items():
89
+ idx_a = nerf_indices[anchor, 2]
90
+ idx_b = nerf_indices[anchor, 1]
91
+ idx_c = nerf_indices[anchor, 0]
92
+
93
+ coords_a = residue_coords[idx_a, :]
94
+ coords_b = residue_coords[idx_b, :]
95
+ coords_c = residue_coords[idx_c, :]
96
+ coords_d = residue_coords[anchor, :]
97
+
98
+ chi[chi_idx] = get_dihedral(coords_a, coords_b, coords_c, coords_d)
99
+
100
+ # Compute remaining internal coordinates
101
+ # (parallel version)
102
+ idx_a = nerf_indices[:, 2]
103
+ idx_b = nerf_indices[:, 1]
104
+ idx_c = nerf_indices[:, 0]
105
+
106
+ # update atom mask
107
+ # remove atoms for which one or several parameters are missing/incorrect
108
+ _atom_mask = atom_mask & atom_found & atom_found[idx_a] & atom_found[idx_b] & atom_found[idx_c]
109
+ if not torch.all(_atom_mask == atom_mask):
110
+ warnings.warn("Some atoms are missing for NERF reconstruction")
111
+ atom_mask = _atom_mask
112
+
113
+ coords_a = residue_coords[idx_a]
114
+ coords_b = residue_coords[idx_b]
115
+ coords_c = residue_coords[idx_c]
116
+ coords_d = residue_coords
117
+
118
+ length = torch.norm(coords_d - coords_c, dim=-1)
119
+ theta = get_angle(coords_b, coords_c, coords_d)
120
+ ddihedral = get_dihedral(coords_a, coords_b, coords_c, coords_d)
121
+
122
+ # subtract chi angles from dihedrals
123
+ ddihedral = ddihedral - chi[chi_indices]
124
+
125
+ # # (serial version)
126
+ # length = torch.zeros(14)
127
+ # theta = torch.zeros(14)
128
+ # ddihedral = torch.zeros(14)
129
+ # for i in range(5, 14):
130
+ # if not atom_mask[i]: # atom doesn't exist
131
+ # continue
132
+
133
+ # idx_a = nerf_indices[i, 2]
134
+ # idx_b = nerf_indices[i, 1]
135
+ # idx_c = nerf_indices[i, 0]
136
+
137
+ # coords_a = residue_coords[idx_a]
138
+ # coords_b = residue_coords[idx_b]
139
+ # coords_c = residue_coords[idx_c]
140
+ # coords_d = residue_coords[i]
141
+
142
+ # length[i] = torch.norm(coords_d - coords_c, dim=-1)
143
+ # theta[i] = get_angle(coords_b, coords_c, coords_d)
144
+ # ddihedral[i] = get_dihedral(coords_a, coords_b, coords_c, coords_d)
145
+
146
+ # # subtract chi angles from dihedrals
147
+ # ddihedral[i] = ddihedral[i] - chi[chi_indices[i]]
148
+
149
+ return {
150
+ 'fixed_coord': fixed_coord,
151
+ 'atom_mask': atom_mask,
152
+ 'nerf_indices': nerf_indices,
153
+ 'length': length,
154
+ 'theta': theta,
155
+ 'chi': chi,
156
+ 'ddihedral': ddihedral,
157
+ 'chi_indices': chi_indices,
158
+ }
159
+
160
+
161
+ # https://github.com/EleutherAI/mp_nerf/blob/master/mp_nerf/massive_pnerf.py#L38C1-L65C67
162
+ def mp_nerf_torch(a, b, c, l, theta, chi):
163
+ """ Custom Natural extension of Reference Frame.
164
+ Inputs:
165
+ * a: (batch, 3) or (3,). point(s) of the plane, not connected to d
166
+ * b: (batch, 3) or (3,). point(s) of the plane, not connected to d
167
+ * c: (batch, 3) or (3,). point(s) of the plane, connected to d
168
+ * theta: (batch,) or (float). angle(s) between b-c-d
169
+ * chi: (batch,) or float. dihedral angle(s) between the a-b-c and b-c-d planes
170
+ Outputs: d (batch, 3) or (float). the next point in the sequence, linked to c
171
+ """
172
+ # safety check
173
+ if not ( (-np.pi <= theta) * (theta <= np.pi) ).all().item():
174
+ raise ValueError(f"theta(s) must be in radians and in [-pi, pi]. theta(s) = {theta}")
175
+ # calc vecs
176
+ ba = b-a
177
+ cb = c-b
178
+ # calc rotation matrix. based on plane normals and normalized
179
+ n_plane = torch.cross(ba, cb, dim=-1)
180
+ n_plane_ = torch.cross(n_plane, cb, dim=-1)
181
+ rotate = torch.stack([cb, n_plane_, n_plane], dim=-1)
182
+ rotate /= torch.norm(rotate, dim=-2, keepdim=True)
183
+ # calc proto point, rotate. add (-1 for sidechainnet convention)
184
+ # https://github.com/jonathanking/sidechainnet/issues/14
185
+ d = torch.stack([-torch.cos(theta),
186
+ torch.sin(theta) * torch.cos(chi),
187
+ torch.sin(theta) * torch.sin(chi)], dim=-1).unsqueeze(-1)
188
+ # extend base point, set length
189
+ return c + l.unsqueeze(-1) * torch.matmul(rotate, d).squeeze()
190
+
191
+
192
+ # inspired by: https://github.com/EleutherAI/mp_nerf/blob/master/mp_nerf/proteins.py#L323C5-L344C65
193
+ def ic_to_coords(fixed_coord, atom_mask, nerf_indices, length, theta, chi, ddihedral, chi_indices):
194
+ """
195
+ Run NERF in parallel for all residues.
196
+
197
+ :param fixed_coord: (L, 5, 3) coordinates of (N, CA, C, O, CB) atoms, they don't depend on chi angles
198
+ :param atom_mask: (L, 14) indicates whether atom exists in this residue
199
+ :param nerf_indices: (L, 14, 3) indices of the three previous atoms ({c, b, a} for the NERF algorithm)
200
+ :param length: (L, 14) bond length between this and previous atom
201
+ :param theta: (L, 14) angle between this and previous two atoms
202
+ :param chi: (L, 6) values of the 5 rotatable bonds, plus zero in last column
203
+ :param ddihedral: (L, 14) angle offset to which chi is added
204
+ :param chi_indices: (L, 14) indexes into the chi array
205
+ :returns: (L, 14, 3) tensor with all coordinates, non-existing atoms are assigned CA coords
206
+ """
207
+
208
+ if not torch.all(chi[:, 5] == 0):
209
+ chi[:, 5] = 0.0
210
+ warnings.warn("Last column of 'chi' tensor should be zero. Overriding values.")
211
+ assert torch.all(chi[:, 5] == 0)
212
+
213
+ L, device = fixed_coord.size(0), fixed_coord.device
214
+ coords = torch.zeros((L, 14, 3), device=device)
215
+ coords[:, :5, :] = fixed_coord
216
+
217
+ for i in range(5, 14):
218
+ level_mask = atom_mask[:, i]
219
+ # level_mask = torch.ones(len(atom_mask), dtype=bool)
220
+
221
+ length_i = length[level_mask, i]
222
+ theta_i = theta[level_mask, i]
223
+
224
+ # dihedral_i = dihedral[level_mask, i]
225
+ dihedral_i = chi[level_mask, chi_indices[level_mask, i]] + ddihedral[level_mask, i]
226
+
227
+ idx_a = nerf_indices[level_mask, i, 2]
228
+ idx_b = nerf_indices[level_mask, i, 1]
229
+ idx_c = nerf_indices[level_mask, i, 0]
230
+
231
+ coords[level_mask, i] = mp_nerf_torch(coords[level_mask, idx_a],
232
+ coords[level_mask, idx_b],
233
+ coords[level_mask, idx_c],
234
+ length_i,
235
+ theta_i,
236
+ dihedral_i)
237
+
238
+ if coords.isnan().any():
239
+ warnings.warn("Side chain reconstruction error. Removing affected atoms...")
240
+
241
+ # mask out affected atoms
242
+ m, n, _ = torch.where(coords.isnan())
243
+ atom_mask[m, n] = False
244
+ coords[m, n, :] = 0.0
245
+
246
+ # replace non-existing atom coords with CA coords (TODO: don't hard-code CA index)
247
+ coords = atom_mask.unsqueeze(-1) * coords + \
248
+ (~atom_mask.unsqueeze(2)) * coords[:, 1, :].unsqueeze(1)
249
+
250
+ return coords
src/data/normal_modes.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import numpy as np
3
+ import prody
4
+ prody.confProDy(verbosity='none')
5
+ from prody import parsePDB, ANM
6
+
7
+
8
+ def pdb_to_normal_modes(pdb_file, num_modes=5, nmax=5000):
9
+ """
10
+ Compute normal modes for a PDB file using an Anisotropic Network Model (ANM)
11
+ http://prody.csb.pitt.edu/tutorials/enm_analysis/anm.html (accessed 01/11/2023)
12
+ """
13
+ protein = parsePDB(pdb_file, model=1).select('calpha')
14
+
15
+ if len(protein) > nmax:
16
+ warnings.warn("Protein is too big. Returning zeros...")
17
+ eig_vecs = np.zeros((len(protein), 3, num_modes))
18
+
19
+ else:
20
+ # build Hessian
21
+ anm = ANM('ANM analysis')
22
+ anm.buildHessian(protein, cutoff=15.0, gamma=1.0)
23
+
24
+ # calculate normal modes
25
+ anm.calcModes(num_modes, zeros=False)
26
+
27
+ # only use slowest modes
28
+ eig_vecs = anm.getEigvecs() # shape: (num_atoms * 3, num_modes)
29
+ eig_vecs = eig_vecs.reshape(len(protein), 3, num_modes)
30
+ # eig_vals = anm.getEigvals() # shape: (num_modes,)
31
+
32
+ nm_dict = {}
33
+ for atom, nm_vec in zip(protein, eig_vecs):
34
+ chain = atom.getChid()
35
+ resi = atom.getResnum()
36
+ name = atom.getName()
37
+ nm_dict[(chain, resi, name)] = nm_vec.T
38
+
39
+ return nm_dict
40
+
41
+
42
+ if __name__ == "__main__":
43
+ import argparse
44
+ from pathlib import Path
45
+ import torch
46
+ from tqdm import tqdm
47
+
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument('basedir', type=Path)
50
+ parser.add_argument('--outfile', type=Path, default=None)
51
+ args = parser.parse_args()
52
+
53
+ # Read data split
54
+ split_path = Path(args.basedir, 'split_by_name.pt')
55
+ data_split = torch.load(split_path)
56
+
57
+ pockets = [x[0] for split in data_split.values() for x in split]
58
+
59
+ all_normal_modes = {}
60
+ for p in tqdm(pockets):
61
+ pdb_file = Path(args.basedir, 'crossdocked_pocket10', p)
62
+
63
+ try:
64
+ nm_dict = pdb_to_normal_modes(str(pdb_file))
65
+ all_normal_modes[p] = nm_dict
66
+ except AttributeError as e:
67
+ warnings.warn(str(e))
68
+
69
+ np.save(args.outfile, all_normal_modes)
src/data/postprocessing.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ from rdkit import Chem
4
+ from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule, UFFHasAllMoleculeParams
5
+
6
+ from src.data import sanifix
7
+
8
+
9
+ def uff_relax(mol, max_iter=200):
10
+ """
11
+ Uses RDKit's universal force field (UFF) implementation to optimize a
12
+ molecule.
13
+ """
14
+ if not UFFHasAllMoleculeParams(mol):
15
+ warnings.warn('UFF parameters not available for all atoms. '
16
+ 'Returning None.')
17
+ return None
18
+
19
+ try:
20
+ more_iterations_required = UFFOptimizeMolecule(mol, maxIters=max_iter)
21
+ if more_iterations_required:
22
+ warnings.warn(f'Maximum number of FF iterations reached. '
23
+ f'Returning molecule after {max_iter} relaxation steps.')
24
+
25
+ except RuntimeError:
26
+ return None
27
+
28
+ return mol
29
+
30
+
31
+ def add_hydrogens(rdmol):
32
+ return Chem.AddHs(rdmol, addCoords=(len(rdmol.GetConformers()) > 0))
33
+
34
+
35
+ def get_largest_fragment(rdmol):
36
+ mol_frags = Chem.GetMolFrags(rdmol, asMols=True, sanitizeFrags=False)
37
+ largest_frag = max(mol_frags, default=rdmol, key=lambda m: m.GetNumAtoms())
38
+
39
+ # try:
40
+ # Chem.SanitizeMol(largest_frag)
41
+ # except ValueError:
42
+ # return None
43
+
44
+ return largest_frag
45
+
46
+
47
+ def process_all(rdmol, largest_frag=True, adjust_aromatic_Ns=True, relax_iter=0):
48
+ """
49
+ Apply all filters and post-processing steps. Returns a new molecule.
50
+
51
+ Returns:
52
+ RDKit molecule or None if it does not pass the filters or processing
53
+ fails
54
+ """
55
+
56
+ # Only consider non-trivial molecules
57
+ if rdmol.GetNumAtoms() < 1:
58
+ return None
59
+
60
+ # Create a copy
61
+ mol = Chem.Mol(rdmol)
62
+
63
+ # try:
64
+ # Chem.SanitizeMol(mol)
65
+ # except ValueError:
66
+ # warnings.warn('Sanitization failed. Returning None.')
67
+ # return None
68
+
69
+ if largest_frag:
70
+ mol = get_largest_fragment(mol)
71
+ # if mol is None:
72
+ # return None
73
+
74
+ if adjust_aromatic_Ns:
75
+ mol = sanifix.fix_mol(mol)
76
+ if mol is None:
77
+ return None
78
+
79
+ # if add_hydrogens:
80
+ # mol = add_hydrogens(mol)
81
+
82
+ if relax_iter > 0:
83
+ mol = uff_relax(mol, relax_iter)
84
+ if mol is None:
85
+ return None
86
+
87
+ try:
88
+ Chem.SanitizeMol(mol)
89
+ except ValueError:
90
+ warnings.warn('Sanitization failed. Returning None.')
91
+ return None
92
+
93
+ return mol
src/data/process_crossdocked.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from time import time
3
+ import argparse
4
+ import shutil
5
+ import random
6
+ import yaml
7
+ from collections import defaultdict
8
+
9
+ import torch
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ from Bio.PDB import PDBParser
13
+ from rdkit import Chem
14
+
15
+ import sys
16
+ basedir = Path(__file__).resolve().parent.parent.parent
17
+ sys.path.append(str(basedir))
18
+
19
+ from src.data.data_utils import process_raw_pair, get_n_nodes, get_type_histogram
20
+ from src.data.data_utils import rdmol_to_smiles
21
+ from src.constants import atom_encoder, bond_encoder
22
+
23
+
24
+ if __name__ == '__main__':
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument('basedir', type=Path)
27
+ parser.add_argument('--outdir', type=Path, default=None)
28
+ parser.add_argument('--split_path', type=Path, default=None)
29
+ parser.add_argument('--pocket', type=str, default='CA+',
30
+ choices=['side_chain_bead', 'CA+'])
31
+ parser.add_argument('--random_seed', type=int, default=42)
32
+ parser.add_argument('--val_size', type=int, default=100)
33
+ parser.add_argument('--normal_modes', action='store_true')
34
+ parser.add_argument('--flex', action='store_true')
35
+ parser.add_argument('--toy', action='store_true')
36
+ args = parser.parse_args()
37
+
38
+ random.seed(args.random_seed)
39
+
40
+ datadir = args.basedir / 'crossdocked_pocket10/'
41
+
42
+ # Make output directory
43
+ dirname = f"processed_crossdocked_{args.pocket}"
44
+ if args.flex:
45
+ dirname += '_flex'
46
+ if args.normal_modes:
47
+ dirname += '_nma'
48
+ if args.toy:
49
+ dirname += '_toy'
50
+ processed_dir = Path(args.basedir, dirname) if args.outdir is None else args.outdir
51
+ processed_dir.mkdir(parents=True)
52
+
53
+ # Read data split
54
+ split_path = Path(args.basedir, 'split_by_name.pt') if args.split_path is None else args.split_path
55
+ data_split = torch.load(split_path)
56
+
57
+ # If there is no validation set, copy training examples (the validation set
58
+ # is not very important in this application)
59
+ if 'val' not in data_split:
60
+ random.shuffle(data_split['train'])
61
+ data_split['val'] = data_split['train'][-args.val_size:]
62
+ data_split['train'] = data_split['train'][:-args.val_size]
63
+
64
+ if args.toy:
65
+ data_split['train'] = random.sample(data_split['train'], 100)
66
+
67
+ failed = {}
68
+ train_smiles = []
69
+
70
+ n_samples_after = {}
71
+ for split in data_split.keys():
72
+
73
+ print(f"Processing {split} dataset...")
74
+
75
+ ligands = defaultdict(list)
76
+ pockets = defaultdict(list)
77
+
78
+ tic = time()
79
+ pbar = tqdm(data_split[split])
80
+ for pocket_fn, ligand_fn in pbar:
81
+
82
+ pbar.set_description(f'#failed: {len(failed)}')
83
+
84
+ sdffile = datadir / f'{ligand_fn}'
85
+ pdbfile = datadir / f'{pocket_fn}'
86
+
87
+ try:
88
+ pdb_model = PDBParser(QUIET=True).get_structure('', pdbfile)[0]
89
+
90
+ rdmol = Chem.SDMolSupplier(str(sdffile))[0]
91
+
92
+ ligand, pocket = process_raw_pair(
93
+ pdb_model, rdmol, pocket_representation=args.pocket,
94
+ compute_nerf_params=args.flex, compute_bb_frames=args.flex,
95
+ nma_input=pdbfile if args.normal_modes else None)
96
+
97
+ except (KeyError, AssertionError, FileNotFoundError, IndexError,
98
+ ValueError, AttributeError) as e:
99
+ failed[(split, sdffile, pdbfile)] = (type(e).__name__, str(e))
100
+ continue
101
+
102
+ nerf_keys = ['fixed_coord', 'atom_mask', 'nerf_indices', 'length', 'theta', 'chi', 'ddihedral', 'chi_indices']
103
+ for k in ['x', 'one_hot', 'bonds', 'bond_one_hot', 'v', 'nma_vec'] + nerf_keys + ['axis_angle']:
104
+ if k in ligand:
105
+ ligands[k].append(ligand[k])
106
+ if k in pocket:
107
+ pockets[k].append(pocket[k])
108
+
109
+ pocket_file = pdbfile.name.replace('_', '-')
110
+ ligand_file = Path(pocket_file).stem + '_' + Path(sdffile).name.replace('_', '-')
111
+ ligands['name'].append(ligand_file)
112
+ pockets['name'].append(pocket_file)
113
+ train_smiles.append(rdmol_to_smiles(rdmol))
114
+
115
+ if split in {'val', 'test'}:
116
+ pdb_sdf_dir = processed_dir / split
117
+ pdb_sdf_dir.mkdir(exist_ok=True)
118
+
119
+ # Copy PDB file
120
+ pdb_file_out = Path(pdb_sdf_dir, pocket_file)
121
+ shutil.copy(pdbfile, pdb_file_out)
122
+
123
+ # Copy SDF file
124
+ sdf_file_out = Path(pdb_sdf_dir, ligand_file)
125
+ shutil.copy(sdffile, sdf_file_out)
126
+
127
+ data = {'ligands': ligands, 'pockets': pockets}
128
+ torch.save(data, Path(processed_dir, f'{split}.pt'))
129
+
130
+ if split == 'train':
131
+ np.save(Path(processed_dir, 'train_smiles.npy'), train_smiles)
132
+
133
+ print(f"Processing {split} set took {(time() - tic) / 60.0:.2f} minutes")
134
+
135
+
136
+ # --------------------------------------------------------------------------
137
+ # Compute statistics & additional information
138
+ # --------------------------------------------------------------------------
139
+ train_data = torch.load(Path(processed_dir, f'train.pt'))
140
+
141
+ # Maximum molecule size
142
+ max_ligand_size = max([len(x) for x in train_data['ligands']['x']])
143
+
144
+ # Joint histogram of number of ligand and pocket nodes
145
+ pocket_coords = train_data['pockets']['x']
146
+ ligand_coords = train_data['ligands']['x']
147
+ n_nodes = get_n_nodes(ligand_coords, pocket_coords)
148
+ np.save(Path(processed_dir, 'size_distribution.npy'), n_nodes)
149
+
150
+ # Get histograms of ligand node types
151
+ lig_one_hot = [x.numpy() for x in train_data['ligands']['one_hot']]
152
+ ligand_hist = get_type_histogram(lig_one_hot, atom_encoder)
153
+ np.save(Path(processed_dir, 'ligand_type_histogram.npy'), ligand_hist)
154
+
155
+ # Get histograms of ligand edge types
156
+ lig_bond_one_hot = [x.numpy() for x in train_data['ligands']['bond_one_hot']]
157
+ ligand_bond_hist = get_type_histogram(lig_bond_one_hot, bond_encoder)
158
+ np.save(Path(processed_dir, 'ligand_bond_type_histogram.npy'), ligand_bond_hist)
159
+
160
+ # Write error report
161
+ error_str = ""
162
+ for k, v in failed.items():
163
+ error_str += f"{'Split':<15}: {k[0]}\n"
164
+ error_str += f"{'Ligand':<15}: {k[1]}\n"
165
+ error_str += f"{'Pocket':<15}: {k[2]}\n"
166
+ error_str += f"{'Error type':<15}: {v[0]}\n"
167
+ error_str += f"{'Error msg':<15}: {v[1]}\n\n"
168
+
169
+ with open(Path(processed_dir, 'errors.txt'), 'w') as f:
170
+ f.write(error_str)
171
+
172
+ metadata = {
173
+ 'max_ligand_size': max_ligand_size
174
+ }
175
+ with open(Path(processed_dir, 'metadata.yml'), 'w') as f:
176
+ yaml.dump(metadata, f, default_flow_style=False)
src/data/process_dpo_dataset.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import numpy as np
4
+ import random
5
+ import shutil
6
+ from time import time
7
+ from collections import defaultdict
8
+ from Bio.PDB import PDBParser
9
+ from rdkit import Chem
10
+ import torch
11
+ from tqdm import tqdm
12
+ import pandas as pd
13
+ from itertools import combinations
14
+
15
+ import sys
16
+ basedir = Path(__file__).resolve().parent.parent.parent
17
+ sys.path.append(str(basedir))
18
+
19
+ from src.sbdd_metrics.metrics import REOSEvaluator, MedChemEvaluator, PoseBustersEvaluator, GninaEvalulator
20
+ from src.data.data_utils import process_raw_pair, rdmol_to_smiles
21
+
22
+ def parse_args():
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument('--smplsdir', type=Path, required=True)
25
+ parser.add_argument('--metrics-detailed', type=Path, required=False)
26
+ parser.add_argument('--ignore-missing-scores', action='store_true')
27
+ parser.add_argument('--datadir', type=Path, required=True)
28
+ parser.add_argument('--dpo-criterion', type=str, default='reos.all',
29
+ choices=['reos.all', 'medchem.sa', 'medchem.qed', 'gnina.vina_efficiency','combined'])
30
+ parser.add_argument('--basedir', type=Path, default=None)
31
+ parser.add_argument('--pocket', type=str, default='CA+',
32
+ choices=['side_chain_bead', 'CA+'])
33
+ parser.add_argument('--gnina', type=Path, default='gnina')
34
+ parser.add_argument('--random_seed', type=int, default=42)
35
+ parser.add_argument('--normal_modes', action='store_true')
36
+ parser.add_argument('--flex', action='store_true')
37
+ parser.add_argument('--toy', action='store_true')
38
+ parser.add_argument('--toy_size', type=int, default=100)
39
+ parser.add_argument('--n_pairs', type=int, default=5)
40
+ args = parser.parse_args()
41
+ return args
42
+
43
+ def scan_smpl_dir(samples_dir):
44
+ samples_dir = Path(samples_dir)
45
+ subdirs = []
46
+ for subdir in tqdm(samples_dir.iterdir(), desc='Scanning samples'):
47
+ if not subdir.is_dir():
48
+ continue
49
+ if not sample_dir_valid(subdir):
50
+ continue
51
+ subdirs.append(subdir)
52
+ return subdirs
53
+
54
+ def sample_dir_valid(samples_dir):
55
+ pocket = samples_dir / '0_pocket.pdb'
56
+ if not pocket.exists():
57
+ return False
58
+ ligands = list(samples_dir.glob('*_ligand.sdf'))
59
+ if len(ligands) < 2:
60
+ return False
61
+ for ligand in ligands:
62
+ if ligand.stat().st_size == 0:
63
+ return False
64
+ return True
65
+
66
+ def return_winning_losing_smpl(score_1, score_2, criterion):
67
+ if criterion == 'reos.all':
68
+ if score_1 == score_2:
69
+ return None
70
+ return score_1 > score_2
71
+ elif criterion == 'medchem.sa':
72
+ if np.abs(score_1 - score_2) < 0.5:
73
+ return None
74
+ return score_1 < score_2
75
+ elif criterion == 'medchem.qed':
76
+ if np.abs(score_1 - score_2) < 0.1:
77
+ return None
78
+ return score_1 > score_2
79
+ elif criterion == 'gnina.vina_efficiency':
80
+ if np.abs(score_1 - score_2) < 0.1:
81
+ return None
82
+ return score_1 < score_2
83
+ elif criterion == 'combined':
84
+ score_reos_1, score_reos_2 = score_1['reos.all'], score_2['reos.all']
85
+ score_sa_1, score_sa_2 = score_1['medchem.sa'], score_2['medchem.sa']
86
+ score_qed_1, score_qed_2 = score_1['medchem.qed'], score_2['medchem.qed']
87
+ score_vina_1, score_vina_2 = score_1['gnina.vina_efficiency'], score_2['gnina.vina_efficiency']
88
+ if score_reos_1 == score_reos_2: return None
89
+ # checking consistency
90
+ reos_sign = score_reos_1 > score_reos_2
91
+ sa_sign = score_sa_1 < score_sa_2
92
+ qed_sign = score_qed_1 > score_qed_2
93
+ vina_sign = score_vina_1 < score_vina_2
94
+ signs = [reos_sign, sa_sign, qed_sign, vina_sign]
95
+ if all(signs) or not any(signs): return signs[0]
96
+ return None
97
+
98
+ def compute_scores(sample_dirs, evaluator, criterion, n_pairs=5, toy=False, toy_size=100,
99
+ precomp_scores=None, ignore_missing_scores=False):
100
+ samples = []
101
+ pose_evaluator = PoseBustersEvaluator()
102
+ pbar = tqdm(sample_dirs, desc='Computing scores for samples')
103
+
104
+ for dir in pbar:
105
+ pocket = dir / '0_pocket.pdb'
106
+ ligands = list(dir.glob('*_ligand.sdf'))
107
+
108
+ target_samples = []
109
+ for lig_path in ligands:
110
+ try:
111
+ mol = Chem.SDMolSupplier(str(lig_path))[0]
112
+ if mol is None:
113
+ continue
114
+ smiles = rdmol_to_smiles(mol)
115
+ except Exception as e:
116
+ print('Failed to read ligand:', lig_path)
117
+ continue
118
+
119
+ if precomp_scores is not None and str(lig_path) in precomp_scores.index:
120
+ mol_props = precomp_scores.loc[str(lig_path)].to_dict()
121
+ if criterion == 'combined':
122
+ if not 'reos.all' in mol_props or not 'medchem.sa' in mol_props or not 'medchem.qed' in mol_props or not 'gnina.vina_efficiency' in mol_props:
123
+ print(f'Missing combined scores for ligand:', lig_path)
124
+ continue
125
+ mol_props['combined'] = {
126
+ 'reos.all': mol_props['reos.all'],
127
+ 'medchem.sa': mol_props['medchem.sa'],
128
+ 'medchem.qed': mol_props['medchem.qed'],
129
+ 'gnina.vina_efficiency': mol_props['gnina.vina_efficiency'],
130
+ 'combined': mol_props['gnina.vina_efficiency']
131
+ }
132
+ else:
133
+ mol_props = {}
134
+ if criterion not in mol_props:
135
+ if ignore_missing_scores:
136
+ print(f'Missing {criterion} for ligand:', lig_path)
137
+ continue
138
+ print(f'Recomputing {criterion} for ligand:', lig_path)
139
+ try:
140
+ eval_res = evaluator.evaluate(mol)
141
+ criterion_cat = criterion.split('.')[0]
142
+ eval_res = {f'{criterion_cat}.{k}': v for k, v in eval_res.items()}
143
+ score = eval_res[criterion]
144
+ except:
145
+ continue
146
+ else:
147
+ score = mol_props[criterion]
148
+
149
+ if 'posebusters.all' not in mol_props:
150
+ if ignore_missing_scores:
151
+ print('Missing PoseBusters for ligand:', lig_path)
152
+ continue
153
+ print('Recomputing PoseBusters for ligand:', lig_path)
154
+ try:
155
+ pose_eval_res = pose_evaluator.evaluate(lig_path, pocket)
156
+ except:
157
+ continue
158
+ if 'all' not in pose_eval_res or not pose_eval_res['all']:
159
+ continue
160
+ else:
161
+ pose_eval_res = mol_props['posebusters.all']
162
+ if not pose_eval_res:
163
+ continue
164
+
165
+ target_samples.append({
166
+ 'smiles': smiles,
167
+ 'score': score,
168
+ 'ligand_path': lig_path,
169
+ 'pocket_path': pocket
170
+ })
171
+
172
+ # Deduplicate by SMILES
173
+ unique_samples = {}
174
+ for sample in target_samples:
175
+ if sample['smiles'] not in unique_samples:
176
+ unique_samples[sample['smiles']] = sample
177
+ unique_samples = list(unique_samples.values())
178
+ if len(unique_samples) < 2:
179
+ continue
180
+
181
+ # Generate all possible pairs
182
+ all_pairs = list(combinations(unique_samples, 2))
183
+
184
+ # Calculate score differences and filter valid pairs
185
+ valid_pairs = []
186
+ for s1, s2 in all_pairs:
187
+ sign = return_winning_losing_smpl(s1['score'], s2['score'], criterion)
188
+ if sign is None:
189
+ continue
190
+ score_diff = abs(s1['score'] - s2['score']) if not criterion == 'combined' else \
191
+ abs(s1['score']['combined'] - s2['score']['combined'])
192
+ if sign:
193
+ valid_pairs.append((s1, s2, score_diff))
194
+ elif sign is False:
195
+ valid_pairs.append((s2, s1, score_diff))
196
+
197
+ # Sort pairs by score difference (descending) and select top N pairs
198
+ valid_pairs.sort(key=lambda x: x[2], reverse=True)
199
+ used_ligand_paths = set()
200
+ selected_pairs = []
201
+ for winning, losing, score_diff in valid_pairs:
202
+ if winning['ligand_path'] in used_ligand_paths or losing['ligand_path'] in used_ligand_paths:
203
+ continue
204
+
205
+ selected_pairs.append((winning, losing, score_diff))
206
+ used_ligand_paths.add(winning['ligand_path'])
207
+ used_ligand_paths.add(losing['ligand_path'])
208
+
209
+ if len(selected_pairs) == n_pairs:
210
+ break
211
+ for winning, losing, _ in selected_pairs:
212
+ d = {
213
+ 'score_w': winning['score'],
214
+ 'score_l': losing['score'],
215
+ 'pocket_p': winning['pocket_path'],
216
+ 'ligand_p_w': winning['ligand_path'],
217
+ 'ligand_p_l': losing['ligand_path']
218
+ }
219
+ if isinstance(winning['score'], dict):
220
+ for k, v in winning['score'].items():
221
+ d[f'{k}_w'] = v
222
+ d['score_w'] = winning['score']['combined']
223
+ if isinstance(losing['score'], dict):
224
+ for k, v in losing['score'].items():
225
+ d[f'{k}_l'] = v
226
+ d['score_l'] = losing['score']['combined']
227
+ samples.append(d)
228
+
229
+ pbar.set_postfix({'samples': len(samples)})
230
+
231
+ if toy and len(samples) >= toy_size:
232
+ break
233
+
234
+ return samples
235
+
236
+ def main():
237
+ args = parse_args()
238
+
239
+ if 'reos' in args.dpo_criterion:
240
+ evaluator = REOSEvaluator()
241
+ elif 'medchem' in args.dpo_criterion:
242
+ evaluator = MedChemEvaluator()
243
+ elif 'gnina' in args.dpo_criterion:
244
+ evaluator = GninaEvalulator(gnina=args.gnina)
245
+ elif 'combined' in args.dpo_criterion:
246
+ evaluator = None # for combined criterion, metrics have to be computed separately
247
+ if args.metrics_detailed is None:
248
+ raise ValueError('For combined criterion, detailed metrics file has to be provided')
249
+ if not args.ignore_missing_scores:
250
+ raise ValueError('For combined criterion, --ignore-missing-scores flag has to be set')
251
+ else:
252
+ raise ValueError(f"Unknown DPO criterion: {args.dpo_criterion}")
253
+
254
+ # Make output directory
255
+ dirname = f"dpo_{args.dpo_criterion.replace('.','_')}_{args.pocket}"
256
+ if args.flex:
257
+ dirname += '_flex'
258
+ if args.normal_modes:
259
+ dirname += '_nma'
260
+ if args.toy:
261
+ dirname += '_toy'
262
+ processed_dir = Path(args.basedir, dirname)
263
+ processed_dir.mkdir(parents=True, exist_ok=True)
264
+
265
+ if (processed_dir / f'samples_{args.dpo_criterion}.csv').exists():
266
+ print(f"Samples already computed for criterion {args.dpo_criterion}, loading from file")
267
+ samples = pd.read_csv(processed_dir / f'samples_{args.dpo_criterion}.csv')
268
+ samples = [dict(row) for _, row in samples.iterrows()]
269
+ print(f"Found {len(samples)} winning/losing samples")
270
+ else:
271
+ print('Scanning sample directory...')
272
+ samples_dir = Path(args.smplsdir)
273
+ # scan dir
274
+ sample_dirs = scan_smpl_dir(samples_dir)
275
+ if args.metrics_detailed:
276
+ print(f'Loading precomputed scores from {args.metrics_detailed}')
277
+ precomp_scores = pd.read_csv(args.metrics_detailed)
278
+ precomp_scores = precomp_scores.set_index('sdf_file')
279
+ else:
280
+ precomp_scores = None
281
+ print(f'Found {len(sample_dirs)} valid sample directories')
282
+ print('Computing scores...')
283
+ samples = compute_scores(sample_dirs, evaluator, args.dpo_criterion,
284
+ n_pairs=args.n_pairs, toy=args.toy, toy_size=args.toy_size,
285
+ precomp_scores=precomp_scores,
286
+ ignore_missing_scores=args.ignore_missing_scores)
287
+ print(f'Found {len(samples)} winning/losing samples, saving to file')
288
+ pd.DataFrame(samples).to_csv(Path(processed_dir, f'samples_{args.dpo_criterion}.csv'), index=False)
289
+
290
+ data_split = {}
291
+ data_split['train'] = samples
292
+ if args.toy:
293
+ data_split['train'] = random.sample(samples, min(args.toy_size, len(data_split['train'])))
294
+
295
+ failed = {}
296
+ train_smiles = []
297
+
298
+ for split in data_split.keys():
299
+
300
+ print(f"Processing {split} dataset...")
301
+
302
+ ligands_w = defaultdict(list)
303
+ ligands_l = defaultdict(list)
304
+ pockets = defaultdict(list)
305
+
306
+ tic = time()
307
+ pbar = tqdm(data_split[split])
308
+ for entry in pbar:
309
+
310
+ pbar.set_description(f'#failed: {len(failed)}')
311
+
312
+ pdbfile = Path(entry['pocket_p'])
313
+ entry['ligand_p_w'] = Path(entry['ligand_p_w'])
314
+ entry['ligand_p_l'] = Path(entry['ligand_p_l'])
315
+ entry['ligand_w'] = Chem.SDMolSupplier(str(entry['ligand_p_w']))[0]
316
+ entry['ligand_l'] = Chem.SDMolSupplier(str(entry['ligand_p_l']))[0]
317
+
318
+ try:
319
+ pdb_model = PDBParser(QUIET=True).get_structure('', pdbfile)[0]
320
+
321
+ ligand_w, pocket = process_raw_pair(
322
+ pdb_model, entry['ligand_w'], pocket_representation=args.pocket,
323
+ compute_nerf_params=args.flex, compute_bb_frames=args.flex,
324
+ nma_input=pdbfile if args.normal_modes else None)
325
+ ligand_l, _ = process_raw_pair(
326
+ pdb_model, entry['ligand_l'], pocket_representation=args.pocket,
327
+ compute_nerf_params=args.flex, compute_bb_frames=args.flex,
328
+ nma_input=pdbfile if args.normal_modes else None)
329
+
330
+ except (KeyError, AssertionError, FileNotFoundError, IndexError,
331
+ ValueError, AttributeError) as e:
332
+ failed[(split, entry['ligand_p_w'], entry['ligand_p_l'], pdbfile)] \
333
+ = (type(e).__name__, str(e))
334
+ continue
335
+
336
+ nerf_keys = ['fixed_coord', 'atom_mask', 'nerf_indices', 'length', 'theta', 'chi', 'ddihedral', 'chi_indices']
337
+ for k in ['x', 'one_hot', 'bonds', 'bond_one_hot', 'v', 'nma_vec'] + nerf_keys + ['axis_angle']:
338
+ if k in ligand_w:
339
+ ligands_w[k].append(ligand_w[k])
340
+ ligands_l[k].append(ligand_l[k])
341
+ if k in pocket:
342
+ pockets[k].append(pocket[k])
343
+
344
+ smpl_n = pdbfile.parent.name
345
+ pocket_file = f'{smpl_n}__{pdbfile.stem}.pdb'
346
+ ligand_file_w = f'{smpl_n}__{entry["ligand_p_w"].stem}.sdf'
347
+ ligand_file_l = f'{smpl_n}__{entry["ligand_p_l"].stem}.sdf'
348
+ ligands_w['name'].append(ligand_file_w)
349
+ ligands_l['name'].append(ligand_file_l)
350
+ pockets['name'].append(pocket_file)
351
+ train_smiles.append(rdmol_to_smiles(entry['ligand_w']))
352
+ train_smiles.append(rdmol_to_smiles(entry['ligand_l']))
353
+
354
+ data = {'ligands_w': ligands_w,
355
+ 'ligands_l': ligands_l,
356
+ 'pockets': pockets}
357
+ torch.save(data, Path(processed_dir, f'{split}.pt'))
358
+
359
+ if split == 'train':
360
+ np.save(Path(processed_dir, 'train_smiles.npy'), train_smiles)
361
+
362
+ print(f"Processing {split} set took {(time() - tic) / 60.0:.2f} minutes")
363
+
364
+ # cp stats from original dataset
365
+ size_distr_p = Path(args.datadir, 'size_distribution.npy')
366
+ type_histo_p = Path(args.datadir, 'ligand_type_histogram.npy')
367
+ bond_histo_p = Path(args.datadir, 'ligand_bond_type_histogram.npy')
368
+ metadata_p = Path(args.datadir, 'metadata.yml')
369
+ shutil.copy(size_distr_p, processed_dir)
370
+ shutil.copy(type_histo_p, processed_dir)
371
+ shutil.copy(bond_histo_p, processed_dir)
372
+ shutil.copy(metadata_p, processed_dir)
373
+
374
+ # cp val and test .pt and dirs
375
+ val_dir = Path(args.datadir, 'val')
376
+ test_dir = Path(args.datadir, 'test')
377
+ val_pt = Path(args.datadir, 'val.pt')
378
+ test_pt = Path(args.datadir, 'test.pt')
379
+ assert val_dir.exists() and test_dir.exists() and val_pt.exists() and test_pt.exists()
380
+ if (processed_dir / 'val').exists():
381
+ shutil.rmtree(processed_dir / 'val')
382
+ if (processed_dir / 'test').exists():
383
+ shutil.rmtree(processed_dir / 'test')
384
+ shutil.copytree(val_dir, processed_dir / 'val')
385
+ shutil.copytree(test_dir, processed_dir / 'test')
386
+ shutil.copy(val_pt, processed_dir)
387
+ shutil.copy(test_pt, processed_dir)
388
+
389
+ # Write error report
390
+ error_str = ""
391
+ for k, v in failed.items():
392
+ error_str += f"{'Split':<15}: {k[0]}\n"
393
+ error_str += f"{'Ligand W':<15}: {k[1]}\n"
394
+ error_str += f"{'Ligand L':<15}: {k[2]}\n"
395
+ error_str += f"{'Pocket':<15}: {k[3]}\n"
396
+ error_str += f"{'Error type':<15}: {v[0]}\n"
397
+ error_str += f"{'Error msg':<15}: {v[1]}\n\n"
398
+
399
+ with open(Path(processed_dir, 'errors.txt'), 'w') as f:
400
+ f.write(error_str)
401
+
402
+ with open(Path(processed_dir, 'dataset_config.txt'), 'w') as f:
403
+ f.write(str(args))
404
+
405
+ if __name__ == '__main__':
406
+ main()
src/data/sanifix.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ sanifix4.py
2
+
3
+ Contribution from James Davidson
4
+ adapted from: https://github.com/abradle/rdkitserver/blob/master/MYSITE/src/testproject/mol_parsing/sanifix.py
5
+ """
6
+ from rdkit import Chem
7
+ from rdkit.Chem import AllChem
8
+ import warnings
9
+
10
+ def _FragIndicesToMol(oMol,indices):
11
+ em = Chem.EditableMol(Chem.Mol())
12
+
13
+ newIndices={}
14
+ for i,idx in enumerate(indices):
15
+ em.AddAtom(oMol.GetAtomWithIdx(idx))
16
+ newIndices[idx]=i
17
+
18
+ for i,idx in enumerate(indices):
19
+ at = oMol.GetAtomWithIdx(idx)
20
+ for bond in at.GetBonds():
21
+ if bond.GetBeginAtomIdx()==idx:
22
+ oidx = bond.GetEndAtomIdx()
23
+ else:
24
+ oidx = bond.GetBeginAtomIdx()
25
+ # make sure every bond only gets added once:
26
+ if oidx<idx:
27
+ continue
28
+ em.AddBond(newIndices[idx],newIndices[oidx],bond.GetBondType())
29
+ res = em.GetMol()
30
+ res.ClearComputedProps()
31
+ Chem.GetSymmSSSR(res)
32
+ res.UpdatePropertyCache(False)
33
+ res._idxMap=newIndices
34
+ return res
35
+
36
+ def _recursivelyModifyNs(mol,matches,indices=None):
37
+ if indices is None:
38
+ indices=[]
39
+ res=None
40
+ while len(matches) and res is None:
41
+ tIndices=indices[:]
42
+ nextIdx = matches.pop(0)
43
+ tIndices.append(nextIdx)
44
+ nm = Chem.Mol(mol)
45
+ nm.GetAtomWithIdx(nextIdx).SetNoImplicit(True)
46
+ nm.GetAtomWithIdx(nextIdx).SetNumExplicitHs(1)
47
+ cp = Chem.Mol(nm)
48
+ try:
49
+ Chem.SanitizeMol(cp)
50
+ except ValueError:
51
+ res,indices = _recursivelyModifyNs(nm,matches,indices=tIndices)
52
+ else:
53
+ indices=tIndices
54
+ res=cp
55
+ return res,indices
56
+
57
+ def AdjustAromaticNs(m,nitrogenPattern='[n&D2&H0;r5,r6]'):
58
+ """
59
+ default nitrogen pattern matches Ns in 5 rings and 6 rings in order to be able
60
+ to fix: O=c1ccncc1
61
+ """
62
+ Chem.GetSymmSSSR(m)
63
+ m.UpdatePropertyCache(False)
64
+
65
+ # break non-ring bonds linking rings:
66
+ em = Chem.EditableMol(m)
67
+ linkers = m.GetSubstructMatches(Chem.MolFromSmarts('[r]!@[r]'))
68
+ plsFix=set()
69
+ for a,b in linkers:
70
+ em.RemoveBond(a,b)
71
+ plsFix.add(a)
72
+ plsFix.add(b)
73
+ nm = em.GetMol()
74
+ for at in plsFix:
75
+ at=nm.GetAtomWithIdx(at)
76
+ if at.GetIsAromatic() and at.GetAtomicNum()==7:
77
+ at.SetNumExplicitHs(1)
78
+ at.SetNoImplicit(True)
79
+
80
+ # build molecules from the fragments:
81
+ fragLists = Chem.GetMolFrags(nm)
82
+ frags = [_FragIndicesToMol(nm,x) for x in fragLists]
83
+
84
+ # loop through the fragments in turn and try to aromatize them:
85
+ ok=True
86
+ for i,frag in enumerate(frags):
87
+ cp = Chem.Mol(frag)
88
+ try:
89
+ Chem.SanitizeMol(cp)
90
+ except ValueError:
91
+ matches = [x[0] for x in frag.GetSubstructMatches(Chem.MolFromSmarts(nitrogenPattern))]
92
+ lres,indices=_recursivelyModifyNs(frag,matches)
93
+ if not lres:
94
+ #print 'frag %d failed (%s)'%(i,str(fragLists[i]))
95
+ ok=False
96
+ break
97
+ else:
98
+ revMap={}
99
+ for k,v in frag._idxMap.items():
100
+ revMap[v]=k
101
+ for idx in indices:
102
+ oatom = m.GetAtomWithIdx(revMap[idx])
103
+ oatom.SetNoImplicit(True)
104
+ oatom.SetNumExplicitHs(1)
105
+ if not ok:
106
+ return None
107
+ return m
108
+
109
+
110
+
111
+ def fix_mol(m):
112
+ if m is None:
113
+ return None
114
+ try:
115
+ m.UpdatePropertyCache(False)
116
+ cp = Chem.Mol(m.ToBinary())
117
+ Chem.SanitizeMol(cp)
118
+ m = cp
119
+ # print('fine:',Chem.MolToSmiles(m))
120
+ warnings.warn(f'fine: {Chem.MolToSmiles(m)}')
121
+ return m
122
+ except ValueError:
123
+ # print('adjust')
124
+ warnings.warn('adjust')
125
+ nm=AdjustAromaticNs(m)
126
+ if nm is not None:
127
+ try:
128
+ Chem.SanitizeMol(nm)
129
+ # print('fixed:',Chem.MolToSmiles(nm))
130
+ warnings.warn(f'fixed: {Chem.MolToSmiles(nm)}')
131
+ except ValueError:
132
+ # print('still broken')
133
+ warnings.warn('still broken')
134
+ else:
135
+ # print('still broken')
136
+ warnings.warn('still broken')
137
+ return nm
138
+
139
+ if __name__=='__main__':
140
+ ms = [x for x in open("18.sdf").read().split("$$$$\n")]
141
+ for txt_m in ms:
142
+ if not txt_m:
143
+ continue
144
+ m = Chem.MolFromMolBlock(txt_m, False)
145
+ print('#---------------------')
146
+ try:
147
+ m.UpdatePropertyCache(False)
148
+ cp = Chem.Mol(m.ToBinary())
149
+ Chem.SanitizeMol(cp)
150
+ m = cp
151
+ print('fine:',Chem.MolToSmiles(m))
152
+ except ValueError:
153
+ print('adjust')
154
+ nm=AdjustAromaticNs(m)
155
+ if nm is not None:
156
+ Chem.SanitizeMol(nm)
157
+ print('fixed:',Chem.MolToSmiles(nm))
158
+ else:
159
+ print('still broken')
src/data/so3_utils.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+
5
+ def _batch_trace(m):
6
+ return torch.einsum('...ii', m)
7
+
8
+
9
+ def regularize(point, eps=1e-6):
10
+ """
11
+ Norm of the rotation vector should be between 0 and pi.
12
+ Inverts the direction of the rotation axis if the value is between pi and 2 pi.
13
+ Args:
14
+ point, (n, 3)
15
+ Returns:
16
+ regularized point, (n, 3)
17
+ """
18
+ theta = torch.linalg.norm(point, axis=-1)
19
+
20
+ # angle in [0, 2pi)
21
+ theta_wrapped = theta % (2 * math.pi)
22
+ inv_mask = theta_wrapped > math.pi
23
+
24
+ # angle in [0, pi) & invert
25
+ theta_wrapped[inv_mask] = -1 * (2 * math.pi - theta_wrapped[inv_mask])
26
+
27
+ # apply
28
+ theta = torch.clamp(theta, min=eps)
29
+ point = point * (theta_wrapped / theta).unsqueeze(-1)
30
+ assert not point.isnan().any()
31
+ return point
32
+
33
+
34
+ def random_uniform(n_samples, device=None):
35
+ """
36
+ Follow geomstats implementation:
37
+ https://geomstats.github.io/_modules/geomstats/geometry/special_orthogonal.html
38
+
39
+ Args:
40
+ n_samples: int
41
+ Returns:
42
+ rotation vectors, (n, 3)
43
+ """
44
+ random_point = (torch.rand(n_samples, 3, device=device) * 2 - 1) * math.pi
45
+ random_point = regularize(random_point)
46
+
47
+ return random_point
48
+
49
+
50
+ def hat(rot_vec):
51
+ """
52
+ Maps R^3 vector to a skew-symmetric matrix r (i.e. r \in R^{3x3} and r^T = -r).
53
+ Since we have the identity rv = rot_vec x v for all v \in R^3, this is
54
+ identical to a cross-product-matrix representation of rot_vec.
55
+ rot_vec x v = hat(rot_vec)^T v
56
+ See also:
57
+ https://en.wikipedia.org/wiki/Cross_product#Conversion_to_matrix_multiplication
58
+ https://en.wikipedia.org/wiki/Hat_notation#Cross_product
59
+ Args:
60
+ rot_vec: (n, 3)
61
+ Returns:
62
+ skew-symmetric matrices (n, 3, 3)
63
+ """
64
+ basis = torch.tensor([
65
+ [[0., 0., 0.], [0., 0., -1.], [0., 1., 0.]],
66
+ [[0., 0., 1.], [0., 0., 0.], [-1., 0., 0.]],
67
+ [[0., -1., 0.], [1., 0., 0.], [0., 0., 0.]]
68
+ ], device=rot_vec.device)
69
+ # basis = torch.tensor([
70
+ # [[0., 0., 0.], [0., 0., 1.], [0., -1., 0.]],
71
+ # [[0., 0., -1.], [0., 0., 0.], [1., 0., 0.]],
72
+ # [[0., 1., 0.], [-1., 0., 0.], [0., 0., 0.]]
73
+ # ], device=rot_vec.device)
74
+
75
+ return torch.einsum('...i,ijk->...jk', rot_vec, basis)
76
+
77
+
78
+ def inv_hat(skew_mat):
79
+ """
80
+ Inverse of hat operation
81
+ Args:
82
+ skew_mat: skew-symmetric matrices (n, 3, 3)
83
+ Returns:
84
+ rotation vectors, (n, 3)
85
+ """
86
+
87
+ assert torch.allclose(-skew_mat, skew_mat.transpose(-2, -1), atol=1e-4), \
88
+ f"Input not skew-symmetric (err={(-skew_mat - skew_mat.transpose(-2, -1)).abs().max():.4g})"
89
+
90
+ # vec = torch.stack([
91
+ # skew_mat[:, 1, 2],
92
+ # skew_mat[:, 2, 1],
93
+ # skew_mat[:, 0, 1]
94
+ # ], dim=1)
95
+
96
+ vec = torch.stack([
97
+ skew_mat[:, 2, 1],
98
+ skew_mat[:, 0, 2],
99
+ skew_mat[:, 1, 0]
100
+ ], dim=1)
101
+
102
+ return vec
103
+
104
+
105
+ def matrix_from_rotation_vector(axis_angle, eps=1e-6):
106
+ """
107
+ Args:
108
+ axis_angle: (n, 3)
109
+ Returns:
110
+ rotation matrices, (n, 3, 3)
111
+ """
112
+
113
+ axis_angle = regularize(axis_angle)
114
+ angle = axis_angle.norm(dim=-1)
115
+ _norm = torch.clamp(angle, min=eps).unsqueeze(-1)
116
+ skew_mat = hat(axis_angle / _norm)
117
+
118
+ # https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula#Matrix_notation
119
+ _id = torch.eye(3, device=axis_angle.device).unsqueeze(0)
120
+ rot_mat = _id + \
121
+ torch.sin(angle)[:, None, None] * skew_mat + \
122
+ (1 - torch.cos(angle))[:, None, None] * torch.bmm(skew_mat, skew_mat)
123
+
124
+ return rot_mat
125
+
126
+
127
+ class safe_acos(torch.autograd.Function):
128
+ """
129
+ Implementation of arccos that avoids NaN in backward pass.
130
+ https://github.com/pytorch/pytorch/issues/8069#issuecomment-2041223872
131
+ """
132
+ EPS = 1e-4
133
+ @classmethod
134
+ def d_acos_dx(cls, x):
135
+ x = torch.clamp(x, min=-1. + cls.EPS, max=1. - cls.EPS)
136
+ return -1.0 / (1 - x**2).sqrt()
137
+
138
+ @staticmethod
139
+ def forward(ctx, input):
140
+ ctx.save_for_backward(input)
141
+ return input.acos()
142
+
143
+ @staticmethod
144
+ def backward(ctx, grad_output):
145
+ input, = ctx.saved_tensors
146
+ return grad_output * safe_acos.d_acos_dx(input)
147
+
148
+
149
+ def rotation_vector_from_matrix(rot_mat, approx=1e-4):
150
+ """
151
+ Args:
152
+ rot_mat: (n, 3, 3)
153
+ approx: float, minimum angle below which an approximation will be used
154
+ for numerical stability
155
+ Returns:
156
+ rotation vector, (n, 3)
157
+ """
158
+
159
+ # https://en.wikipedia.org/wiki/Rotation_matrix#Conversion_from_rotation_matrix_to_axis%E2%80%93angle
160
+ # https://en.wikipedia.org/wiki/Axis%E2%80%93angle_representation#Log_map_from_SO(3)_to_%F0%9D%94%B0%F0%9D%94%AC(3)
161
+
162
+ # determine axis
163
+ skew_mat = rot_mat - rot_mat.transpose(-2, -1)
164
+
165
+ # determine the angle
166
+ cos_angle = 0.5 * (_batch_trace(rot_mat) - 1)
167
+ # arccos is only defined between -1 and 1
168
+ assert torch.all(cos_angle.abs() <= 1 + 1e-6)
169
+ cos_angle = torch.clamp(cos_angle, min=-1., max=1.)
170
+ # abs_angle = torch.arccos(cos_angle)
171
+ abs_angle = safe_acos.apply(cos_angle)
172
+
173
+ # avoid numerical instability; use sin(x) \approx x for small x
174
+ close_to_0 = abs_angle < approx
175
+ _fac = torch.empty_like(abs_angle)
176
+ _fac[close_to_0] = 0.5
177
+ _fac[~close_to_0] = 0.5 * abs_angle[~close_to_0] / torch.sin(abs_angle[~close_to_0])
178
+
179
+ axis_angle = inv_hat(_fac[:, None, None] * skew_mat)
180
+ return regularize(axis_angle)
181
+
182
+
183
+ def get_jacobian(point, left=True, inverse=False, eps=1e-4):
184
+
185
+ # # From Geomstats: https://geomstats.github.io/_modules/geomstats/geometry/special_orthogonal.html
186
+ # jacobian = so3_vector.jacobian_translation(point, left)
187
+ #
188
+ # if inverse:
189
+ # jacobian = torch.linalg.inv(jacobian)
190
+
191
+ # Right Jacobian defined as J_r(theta) = \partial exp([theta]_x) / \partial theta
192
+ # https://math.stackexchange.com/questions/301533/jacobian-involving-so3-exponential-map-logr-expm
193
+ # Source:
194
+ # Chirikjian, Gregory S. Stochastic models, information theory, and Lie
195
+ # groups, volume 2: Analytic methods and modern applications. Vol. 2.
196
+ # Springer Science & Business Media, 2011. (page 40)
197
+ # NOTE: the definitions of 'inverse' and 'left' in the book are the opposite
198
+ # of their meanings in Geomstats, whose functionality we're mimicking here.
199
+ # This explains the differences in the equations.
200
+ angle_squared = point.square().sum(-1)
201
+ angle = angle_squared.sqrt()
202
+ skew_mat = hat(point)
203
+
204
+ assert torch.all(angle <= math.pi)
205
+ close_to_0 = angle < eps
206
+ close_to_pi = (math.pi - angle) < eps
207
+
208
+ angle = angle[:, None, None]
209
+ angle_squared = angle_squared[:, None, None]
210
+
211
+ if inverse:
212
+ # _jacobian = torch.eye(3, device=point.device).unsqueeze(0) + \
213
+ # (1 - torch.cos(angle)) / angle_squared * skew_mat + \
214
+ # (angle - torch.sin(angle)) / angle ** 3 * (skew_mat @ skew_mat)
215
+
216
+ _term1 = torch.empty_like(angle)
217
+ _term1[close_to_0] = 0.5 # approximate with value at zero
218
+ _term1[~close_to_0] = (1 - torch.cos(angle)) / angle_squared
219
+
220
+ _term2 = torch.empty_like(angle)
221
+ _term2[close_to_0] = 1 / 6 # approximate with value at zero
222
+ _term2[~close_to_0] = (angle - torch.sin(angle)) / angle ** 3
223
+
224
+ jacobian = torch.eye(3, device=point.device).unsqueeze(0) + \
225
+ _term1 * skew_mat + _term2 * (skew_mat @ skew_mat)
226
+ # assert torch.allclose(jacobian, _jacobian, atol=1e-4)
227
+ else:
228
+ # _jacobian = torch.eye(3, device=point.device).unsqueeze(0) - 0.5 * skew_mat + \
229
+ # (1 / angle_squared - (1 + torch.cos(angle)) / (2 * angle * torch.sin(angle))) * (skew_mat @ skew_mat)
230
+
231
+ _term1 = torch.empty_like(angle)
232
+ _term1[close_to_0] = 1 / 12 # approximate with value at zero
233
+ _term1[close_to_pi] = 1 / math.pi**2 # approximate with value at pi
234
+ default = ~close_to_0 & ~close_to_pi
235
+ _term1[default] = 1 / angle_squared[default] - \
236
+ (1 + torch.cos(angle[default])) / (2 * angle[default] * torch.sin(angle[default]))
237
+
238
+ jacobian = torch.eye(3, device=point.device).unsqueeze(0) - \
239
+ 0.5 * skew_mat + _term1 * (skew_mat @ skew_mat)
240
+ # assert torch.allclose(jacobian, _jacobian, atol=1e-4)
241
+
242
+ if left:
243
+ jacobian = jacobian.transpose(-2, -1)
244
+
245
+ return jacobian
246
+
247
+
248
+ def compose_rotations(rot_vec_1, rot_vec_2):
249
+ rot_mat_1 = matrix_from_rotation_vector(rot_vec_1)
250
+ rot_mat_2 = matrix_from_rotation_vector(rot_vec_2)
251
+ rot_mat_out = torch.bmm(rot_mat_1, rot_mat_2)
252
+ return rotation_vector_from_matrix(rot_mat_out)
253
+
254
+
255
+ def exp(tangent):
256
+ """
257
+ Exponential map at identity.
258
+ Args:
259
+ tangent: vector on the tangent space, (n, 3)
260
+ Returns:
261
+ rotation vector on the manifold, (n, 3)
262
+ """
263
+ # rotations are already represented by rotation vectors
264
+ exp_from_identity = regularize(tangent)
265
+ return exp_from_identity
266
+
267
+
268
+ def exp_not_from_identity(tangent_vec, base_point):
269
+ """
270
+ Exponential map at base point.
271
+ Args:
272
+ tangent_vec: vector on the tangent plane, (n, 3)
273
+ base_point: base point on the manifold, (n, 3)
274
+ Returns:
275
+ new point on the manifold, (n, 3)
276
+ """
277
+
278
+ tangent_vec = regularize(tangent_vec)
279
+ base_point = regularize(base_point)
280
+
281
+ # Lie algebra is the tangent space at the identity element of a Lie group
282
+ # -> to identity
283
+ jacobian = get_jacobian(base_point, left=True, inverse=True)
284
+ tangent_vec_at_id = torch.einsum("...ij,...j->...i", jacobian, tangent_vec)
285
+
286
+ # exponential map from identity
287
+ exp_from_identity = exp(tangent_vec_at_id)
288
+
289
+ # -> back to base point
290
+ return compose_rotations(base_point, exp_from_identity)
291
+
292
+
293
+ def log(rot_vec, as_skew=False):
294
+ """
295
+ Logarithm map from tangent space at the identity.
296
+ Args:
297
+ rot_vec: point on the manifold, (n, 3)
298
+ Returns:
299
+ vector on the tangent space, (n, 3)
300
+ """
301
+ # rotations are already represented by rotation vectors
302
+ # log_from_id = regularize(rot_vec)
303
+ log_from_id = rot_vec
304
+ if as_skew:
305
+ log_from_id = hat(log_from_id)
306
+ return log_from_id
307
+
308
+
309
+ def log_not_from_identity(point, base_point):
310
+ """
311
+ Logarithm map of point from base point.
312
+ Args:
313
+ point: point on the manifold, (n, 3)
314
+ base_point: base point on the manifold, (n, 3)
315
+ Returns:
316
+ vector on the tangent plane, (n, 3)
317
+ """
318
+ point = regularize(point)
319
+ base_point = regularize(base_point)
320
+
321
+ inv_base_point = -1 * base_point
322
+
323
+ point_near_id = compose_rotations(inv_base_point, point)
324
+
325
+ # logarithm map from identity
326
+ log_from_id = log(point_near_id)
327
+
328
+ jacobian = get_jacobian(base_point, inverse=False)
329
+ tangent_vec_at_id = torch.einsum("...ij,...j->...i", jacobian, log_from_id)
330
+
331
+ return tangent_vec_at_id
332
+
333
+
334
+ if __name__ == "__main__":
335
+
336
+ import os
337
+ os.environ['GEOMSTATS_BACKEND'] = "pytorch"
338
+ import scipy.optimize # does not seem to be imported correctly when just loading geomstats
339
+ default_dtype = torch.get_default_dtype()
340
+ from geomstats.geometry.special_orthogonal import SpecialOrthogonal
341
+ torch.set_default_dtype(default_dtype) # Geomstats changes default type when imported
342
+
343
+ so3_vector = SpecialOrthogonal(n=3, point_type="vector")
344
+
345
+ # decorator
346
+ if torch.__version__ >= '2.0.0':
347
+ GEOMSTATS_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
348
+
349
+ def geomstats_tensor_type(func):
350
+ def inner(*args, **kwargs):
351
+ with torch.device(GEOMSTATS_DEVICE):
352
+ out = func(*args, **kwargs)
353
+ return out
354
+
355
+ return inner
356
+ else:
357
+ GEOMSTATS_TENSOR_TYPE = 'torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatTensor'
358
+
359
+ # GEOMSTATS_TENSOR_TYPE = 'torch.cuda.DoubleTensor' if torch.cuda.is_available() else 'torch.DoubleTensor'
360
+ def geomstats_tensor_type(func):
361
+ def inner(*args, **kwargs):
362
+ # tensor_type_before = TODO
363
+ torch.set_default_tensor_type(GEOMSTATS_TENSOR_TYPE)
364
+ out = func(*args, **kwargs)
365
+ # torch.set_default_tensor_type(tensor_type_before)
366
+ torch.set_default_tensor_type('torch.FloatTensor')
367
+ return out
368
+
369
+ return inner
370
+
371
+ @geomstats_tensor_type
372
+ def gs_matrix_from_rotation_vector(*args, **kwargs):
373
+ return so3_vector.matrix_from_rotation_vector(*args, **kwargs)
374
+
375
+ @geomstats_tensor_type
376
+ def gs_rotation_vector_from_matrix(*args, **kwargs):
377
+ return so3_vector.rotation_vector_from_matrix(*args, **kwargs)
378
+
379
+ @geomstats_tensor_type
380
+ def gs_exp_not_from_identity(*args, **kwargs):
381
+ return so3_vector.exp_not_from_identity(*args, **kwargs)
382
+
383
+ @geomstats_tensor_type
384
+ def gs_log_not_from_identity(*args, **kwargs):
385
+ # norm of the rotation vector will be between 0 and pi
386
+ return so3_vector.log_not_from_identity(*args, **kwargs)
387
+
388
+ @geomstats_tensor_type
389
+ def compose(*args, **kwargs):
390
+ return so3_vector.compose(*args, **kwargs)
391
+
392
+ @geomstats_tensor_type
393
+ def inverse(*args, **kwargs):
394
+ return so3_vector.inverse(*args, **kwargs)
395
+
396
+ @geomstats_tensor_type
397
+ def gs_random_uniform(*args, **kwargs):
398
+ return so3_vector.random_uniform(*args, **kwargs)
399
+
400
+
401
+ #############
402
+ # RUN TESTS #
403
+ #############
404
+
405
+ n = 16
406
+ device = 'cuda' if torch.cuda.is_available() else None
407
+
408
+ ### regularize ###
409
+
410
+ # vec = (torch.rand(n, 3) * 2 - 1) * math.pi
411
+ vec = (torch.rand(n, 3) * 4 - 2) * math.pi
412
+ axis_angle = regularize(vec)
413
+ assert torch.all(torch.cross(vec, axis_angle).norm(dim=-1) < 1e-5), "not all vectors collinear"
414
+ assert torch.all(axis_angle.norm(dim=-1) < math.pi) & torch.all(axis_angle.norm(dim=-1) >= 0), "norm not between 0 and pi"
415
+
416
+
417
+ ### matrix_from_rotation_vector ###
418
+
419
+ rot_vec = random_uniform(16, device=device)
420
+ assert torch.allclose(matrix_from_rotation_vector(rot_vec),
421
+ gs_matrix_from_rotation_vector(rot_vec), atol=1e-06)
422
+
423
+
424
+ ### rotation_vector_from_matrix ###
425
+
426
+ rot_vec = random_uniform(16, device=device)
427
+ rot_mat = matrix_from_rotation_vector(rot_vec)
428
+ assert torch.allclose(rotation_vector_from_matrix(rot_mat),
429
+ gs_rotation_vector_from_matrix(rot_mat), atol=1e-05)
430
+
431
+
432
+ ### exp_not_from_identity ###
433
+
434
+ tangent_vec = random_uniform(16, device=device)
435
+ base_pt = random_uniform(16, device=device)
436
+ my_val = exp_not_from_identity(tangent_vec, base_pt)
437
+ gs_val = gs_exp_not_from_identity(tangent_vec, base_pt)
438
+ assert torch.allclose(my_val, gs_val, atol=1e-03), (my_val - gs_val).abs().max()
439
+
440
+
441
+ ### log_not_from_identity ###
442
+
443
+ pt = random_uniform(16, device=device)
444
+ base_pt = random_uniform(16, device=device)
445
+ my_val = log_not_from_identity(pt, base_pt)
446
+ gs_val = gs_log_not_from_identity(pt, base_pt)
447
+ assert torch.allclose(my_val, gs_val, atol=1e-03), (my_val - gs_val).abs().max()
448
+
449
+
450
+ print("All tests successful!")
src/default/size_distribution.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4e677a30c4b972051499bb5577a0de773e4f92ec54c282d432f94873406ec7e
3
+ size 158488
src/generate.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ import os
4
+ import warnings
5
+ import tempfile
6
+ import pandas as pd
7
+
8
+ from Bio.PDB import PDBParser
9
+ from pathlib import Path
10
+ from rdkit import Chem
11
+ from torch.utils.data import DataLoader
12
+ from functools import partial
13
+
14
+ basedir = Path(__file__).resolve().parent.parent
15
+ sys.path.append(str(basedir))
16
+ warnings.filterwarnings("ignore")
17
+
18
+ from src import utils
19
+ from src.data.dataset import ProcessedLigandPocketDataset
20
+ from src.data.data_utils import TensorDict, process_raw_pair
21
+ from src.model.lightning import DrugFlow
22
+ from src.sbdd_metrics.metrics import FullEvaluator
23
+
24
+ from tqdm import tqdm
25
+ from pdb import set_trace
26
+
27
+
28
+ def aggregate_metrics(table):
29
+ agg_col = 'posebusters'
30
+ total = 0
31
+ table[agg_col] = 0
32
+ for column in table.columns:
33
+ if column.startswith(agg_col) and column != agg_col:
34
+ table[agg_col] += table[column].fillna(0).astype(float)
35
+ total += 1
36
+ table[agg_col] = table[agg_col] / total
37
+
38
+ agg_col = 'reos'
39
+ total = 0
40
+ table[agg_col] = 0
41
+ for column in table.columns:
42
+ if column.startswith(agg_col) and column != agg_col:
43
+ table[agg_col] += table[column].fillna(0).astype(float)
44
+ total += 1
45
+ table[agg_col] = table[agg_col] / total
46
+
47
+ agg_col = 'chembl_ring_systems'
48
+ total = 0
49
+ table[agg_col] = 0
50
+ for column in table.columns:
51
+ if column.startswith(agg_col) and column != agg_col and not column.endswith('smi'):
52
+ table[agg_col] += table[column].fillna(0).astype(float)
53
+ total += 1
54
+ table[agg_col] = table[agg_col] / total
55
+ return table
56
+
57
+
58
+ if __name__ == "__main__":
59
+ p = argparse.ArgumentParser()
60
+ p.add_argument('--protein', type=str, required=True, help="Input PDB file.")
61
+ p.add_argument('--ref_ligand', type=str, required=True, help="SDF file with reference ligand used to define the pocket.")
62
+ p.add_argument('--checkpoint', type=str, required=True, help="Model checkpoint file.")
63
+ p.add_argument('--molecule_size', type=str, required=False, default=None, help="Maximum number of atoms in the sampled molecules. Can be a single number or a range, e.g. '15,20'. If None, size will be sampled.")
64
+ p.add_argument('--output', type=str, required=False, default='samples.sdf', help="Output file.")
65
+ p.add_argument('--n_samples', type=int, required=False, default=10, help="Number of sampled molecules.")
66
+ p.add_argument('--batch_size', type=int, required=False, default=32, help="Batch size.")
67
+ p.add_argument('--pocket_distance_cutoff', type=float, required=False, default=8.0, help="Distance cutoff to define the pocket around the reference ligand.")
68
+ p.add_argument('--n_steps', type=int, required=False, default=None, help="Number of denoising steps.")
69
+ p.add_argument('--device', type=str, required=False, default='cuda:0', help="Device to use.")
70
+ p.add_argument('--datadir', type=Path, required=False, default=Path(basedir, 'src', 'default'), help="Needs to be specified to sample molecule sizes.")
71
+ p.add_argument('--seed', type=int, required=False, default=42, help="Random seed.")
72
+ p.add_argument('--filter', action='store_true', required=False, default=False, help="Apply basic filters and keep sampling until `n_samples` molecules passing these filters are found.")
73
+ p.add_argument('--metrics_output', type=str, required=False, default=None, help="If provided, metrics will be computed and saved in csv format at this location.")
74
+ p.add_argument('--gnina', type=str, required=False, default=None, help="Path to a gnina executable. Required for computing docking scores.")
75
+ p.add_argument('--reduce', type=str, required=False, default=None, help="Path to a reduce executable. Required for computing interactions.")
76
+ args = p.parse_args()
77
+
78
+ utils.set_deterministic(seed=args.seed)
79
+ utils.disable_rdkit_logging()
80
+
81
+ if args.molecule_size is None and (args.datadir is None or not args.datadir.exists()):
82
+ raise NotImplementedError(
83
+ "Please provide a path to the processed dataset (using `--datadir`) "\
84
+ "to infer the number of nodes. It contains the size distribution histogram."
85
+ )
86
+
87
+ if not args.filter:
88
+ args.batch_size = min(args.batch_size, args.n_samples)
89
+
90
+ # Loading model
91
+ chkpt_path = Path(args.checkpoint)
92
+ chkpt_name = chkpt_path.parts[-1].split('.')[0]
93
+ model = DrugFlow.load_from_checkpoint(args.checkpoint, map_location=args.device, strict=False)
94
+ if args.datadir is not None:
95
+ model.datadir = args.datadir
96
+
97
+ model.setup(stage='generation')
98
+ model.batch_size = model.eval_batch_size = args.batch_size
99
+ model.eval().to(args.device)
100
+ if args.n_steps is not None:
101
+ model.T = args.n_steps
102
+
103
+ # Loading size model
104
+ size_model = None
105
+ molecule_size = None
106
+ molecule_size_boundaries = None
107
+ if args.molecule_size is not None:
108
+ if args.molecule_size.isdigit():
109
+ molecule_size = int(args.molecule_size)
110
+ print(f'Will generate molecules of size {molecule_size}')
111
+ else:
112
+ boundaries = [x.strip() for x in args.molecule_size.split(',')]
113
+ assert len(boundaries) == 2 and boundaries[0].isdigit() and boundaries[1].isdigit()
114
+ left = int(boundaries[0])
115
+ right = int(boundaries[1])
116
+ molecule_size = f"uniform_{left}_{right}"
117
+ print(f'Will generate molecules with numbers of atoms sampled from U({left}, {right})')
118
+
119
+ # Preparing input
120
+ pdb_model = PDBParser(QUIET=True).get_structure('', args.protein)[0]
121
+ rdmol = Chem.SDMolSupplier(str(args.ref_ligand))[0]
122
+
123
+ ligand, pocket = process_raw_pair(
124
+ pdb_model, rdmol,
125
+ dist_cutoff=args.pocket_distance_cutoff,
126
+ pocket_representation=model.pocket_representation,
127
+ compute_nerf_params=True,
128
+ nma_input=args.protein if model.dynamics.add_nma_feat else None
129
+ )
130
+ ligand['name'] = 'ligand'
131
+ dataset = [{'ligand': ligand, 'pocket': pocket} for _ in range(args.batch_size)]
132
+ dataloader = DataLoader(
133
+ dataset=dataset,
134
+ batch_size=args.batch_size,
135
+ collate_fn=partial(ProcessedLigandPocketDataset.collate_fn, ligand_transform=None),
136
+ pin_memory=True
137
+ )
138
+
139
+ # Start sampling
140
+ smiles = set()
141
+ sampled_molecules = []
142
+ metrics = []
143
+ Path(args.output).parent.absolute().mkdir(parents=True, exist_ok=True)
144
+ print(f'Will generate {args.n_samples} samples')
145
+
146
+ evaluator = FullEvaluator(gnina=args.gnina, reduce=args.reduce)
147
+
148
+ with tqdm(total=args.n_samples) as pbar:
149
+ while len(sampled_molecules) < args.n_samples:
150
+ for i, data in enumerate(dataloader):
151
+ new_data = {
152
+ 'ligand': TensorDict(**data['ligand']).to(args.device),
153
+ 'pocket': TensorDict(**data['pocket']).to(args.device),
154
+ }
155
+ rdmols, rdpockets, _ = model.sample(
156
+ new_data,
157
+ n_samples=1,
158
+ timesteps=args.n_steps,
159
+ num_nodes=molecule_size,
160
+ )
161
+
162
+ if args.filter or (args.metrics_output is not None):
163
+ results = []
164
+ with tempfile.TemporaryDirectory() as tmpdir:
165
+ for mol, receptor in zip(rdmols, rdpockets):
166
+ receptor_path = Path(tmpdir, 'receptor.pdb')
167
+ Chem.MolToPDBFile(receptor, str(receptor_path))
168
+ results.append(evaluator(mol, receptor_path))
169
+
170
+ table = pd.DataFrame(results)
171
+ table['novel'] = ~table['representation.smiles'].isin(smiles)
172
+ table = aggregate_metrics(table)
173
+
174
+ added_molecules = 0
175
+ if args.filter:
176
+ table['passed_filters'] = (
177
+ (table['posebusters'] == 1) &
178
+ # (table['reos'] == 1) &
179
+ (table['chembl_ring_systems'] == 1) &
180
+ (table['novel'] == 1)
181
+ )
182
+ for i, (passed, smi) in enumerate(table[['passed_filters', 'representation.smiles']].values):
183
+ if passed:
184
+ sampled_molecules.append(rdmols[i])
185
+ smiles.add(smi)
186
+ added_molecules += 1
187
+
188
+ if args.metrics_output is not None:
189
+ metrics.append(table[table['passed_filters']])
190
+
191
+ else:
192
+ sampled_molecules.extend(rdmols)
193
+ added_molecules = len(rdmols)
194
+ if args.metrics_output is not None:
195
+ metrics.append(table)
196
+
197
+ pbar.update(added_molecules)
198
+
199
+ # Write results
200
+ utils.write_sdf_file(args.output, sampled_molecules)
201
+
202
+ if args.metrics_output is not None:
203
+ metrics = pd.concat(metrics)
204
+ metrics.to_csv(args.metrics_output, index=False)
src/model/diffusion_utils.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+
7
+ class DistributionNodes:
8
+ def __init__(self, histogram):
9
+
10
+ histogram = torch.tensor(histogram).float()
11
+ histogram = histogram + 1e-3 # for numerical stability
12
+
13
+ prob = histogram / histogram.sum()
14
+
15
+ self.idx_to_n_nodes = torch.tensor(
16
+ [[(i, j) for j in range(prob.shape[1])] for i in range(prob.shape[0])]
17
+ ).view(-1, 2)
18
+
19
+ self.n_nodes_to_idx = {tuple(x.tolist()): i
20
+ for i, x in enumerate(self.idx_to_n_nodes)}
21
+
22
+ self.prob = prob
23
+ self.m = torch.distributions.Categorical(self.prob.view(-1),
24
+ validate_args=True)
25
+
26
+ self.n1_given_n2 = \
27
+ [torch.distributions.Categorical(prob[:, j], validate_args=True)
28
+ for j in range(prob.shape[1])]
29
+ self.n2_given_n1 = \
30
+ [torch.distributions.Categorical(prob[i, :], validate_args=True)
31
+ for i in range(prob.shape[0])]
32
+
33
+ # entropy = -torch.sum(self.prob.view(-1) * torch.log(self.prob.view(-1) + 1e-30))
34
+ # entropy = self.m.entropy()
35
+ # print("Entropy of n_nodes: H[N]", entropy.item())
36
+
37
+ def sample(self, n_samples=1):
38
+ idx = self.m.sample((n_samples,))
39
+ num_nodes_lig, num_nodes_pocket = self.idx_to_n_nodes[idx].T
40
+ return num_nodes_lig, num_nodes_pocket
41
+
42
+ def sample_conditional(self, n1=None, n2=None):
43
+ assert (n1 is None) ^ (n2 is None), \
44
+ "Exactly one input argument must be None"
45
+
46
+ m = self.n1_given_n2 if n2 is not None else self.n2_given_n1
47
+ c = n2 if n2 is not None else n1
48
+
49
+ return torch.tensor([m[i].sample() for i in c], device=c.device)
50
+
51
+ def log_prob(self, batch_n_nodes_1, batch_n_nodes_2):
52
+ assert len(batch_n_nodes_1.size()) == 1
53
+ assert len(batch_n_nodes_2.size()) == 1
54
+
55
+ idx = torch.tensor(
56
+ [self.n_nodes_to_idx[(n1, n2)]
57
+ for n1, n2 in zip(batch_n_nodes_1.tolist(), batch_n_nodes_2.tolist())]
58
+ )
59
+
60
+ # log_probs = torch.log(self.prob.view(-1)[idx] + 1e-30)
61
+ log_probs = self.m.log_prob(idx)
62
+
63
+ return log_probs.to(batch_n_nodes_1.device)
64
+
65
+ def log_prob_n1_given_n2(self, n1, n2):
66
+ assert len(n1.size()) == 1
67
+ assert len(n2.size()) == 1
68
+ log_probs = torch.stack([self.n1_given_n2[c].log_prob(i.cpu())
69
+ for i, c in zip(n1, n2)])
70
+ return log_probs.to(n1.device)
71
+
72
+ def log_prob_n2_given_n1(self, n2, n1):
73
+ assert len(n2.size()) == 1
74
+ assert len(n1.size()) == 1
75
+ log_probs = torch.stack([self.n2_given_n1[c].log_prob(i.cpu())
76
+ for i, c in zip(n2, n1)])
77
+ return log_probs.to(n2.device)
78
+
79
+
80
+ def cosine_beta_schedule_midi(timesteps, s=0.008, nu=1.0, clip=False):
81
+ """
82
+ Modified cosine schedule as proposed in https://arxiv.org/abs/2302.09048.
83
+ Note: we use (t/T)^\nu not (t/T + s)^\nu as written in the MiDi paper
84
+ We also divide by alphas_cumprod[0] as the original cosine schedule from
85
+ https://arxiv.org/abs/2102.09672
86
+ """
87
+ device = nu.device if torch.is_tensor(nu) else None
88
+ x = torch.linspace(0, timesteps, timesteps + 1, device=device)
89
+ alphas_cumprod = torch.cos(0.5 * np.pi * ((x / timesteps)**nu + s) / (1 + s)) ** 2
90
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
91
+
92
+ if clip:
93
+ alphas_cumprod = torch.cat([torch.tensor([1.0], device=alphas_cumprod.device), alphas_cumprod])
94
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
95
+ betas = torch.clip(betas, min=0, max=0.999)
96
+ alphas = 1. - betas
97
+ alphas_cumprod = torch.cumprod(alphas, axis=0)
98
+ return alphas_cumprod
99
+
100
+
101
+ class CosineSchedule(torch.nn.Module):
102
+ """
103
+ nu=1.0 corresponds to the standard cosine schedule
104
+ """
105
+
106
+ def __init__(self, timesteps, nu=1.0, trainable=False, clip_alpha2_step=0.001):
107
+ super(CosineSchedule, self).__init__()
108
+ self.timesteps = timesteps
109
+ self.trainable = trainable
110
+ self.nu = nu
111
+ assert 0.0 <= clip_alpha2_step < 1.0
112
+ self.clip = clip_alpha2_step
113
+
114
+ if self.trainable:
115
+ self.nu = torch.nn.Parameter(torch.Tensor([nu]), requires_grad=True)
116
+ else:
117
+ self._alpha2 = self.alphas2
118
+ self._gamma = torch.nn.Parameter(self.gammas, requires_grad=False)
119
+
120
+ @property
121
+ def alphas2(self):
122
+ """
123
+ Cumulative alpha squared.
124
+ Called alpha_bar in: Nichol, Alexander Quinn, and Prafulla Dhariwal.
125
+ "Improved denoising diffusion probabilistic models." PMLR, 2021.
126
+ """
127
+ if hasattr(self, '_alpha2'):
128
+ return self._alpha2
129
+
130
+ assert isinstance(self.nu, float) or ~self.nu.isnan()
131
+
132
+ # our alpha is eqivalent to sqrt(alpha) from https://arxiv.org/abs/2102.09672, where the cosine schedule was introduced
133
+ alphas2 = cosine_beta_schedule_midi(self.timesteps, nu=self.nu, clip=False)
134
+
135
+ # avoid singularities near t=T
136
+ alphas2 = torch.cat([torch.tensor([1.0], device=alphas2.device), alphas2])
137
+ alphas2_step = alphas2[1:] / alphas2[:-1]
138
+ alphas2_step = torch.clip(alphas2_step, min=self.clip, max=1.0)
139
+ alphas2 = torch.cumprod(alphas2_step, dim=0)
140
+
141
+ return alphas2
142
+
143
+ @property
144
+ def alphas2_t_given_tminus1(self):
145
+ """
146
+ Alphas for a single transition
147
+ """
148
+ alphas2 = torch.cat([torch.tensor([1.0]), self.alphas2])
149
+ return alphas2[1:] / alphas2[:-1]
150
+
151
+ @property
152
+ def gammas(self):
153
+ """
154
+ Gammas as defined in appendix B of the EDM paper
155
+ gamma_t = -(log alpha_t^2 - log sigma_t^2)
156
+ """
157
+ if hasattr(self, '_gamma'):
158
+ return self._gamma
159
+
160
+ alphas2 = self.alphas2
161
+ sigmas2 = 1 - alphas2
162
+
163
+ gammas = -(torch.log(alphas2) - torch.log(sigmas2))
164
+
165
+ return gammas.float()
166
+
167
+ def forward(self, t):
168
+ t_int = torch.round(t * self.timesteps).long()
169
+ return self.gammas[t_int]
170
+
171
+ @staticmethod
172
+ def alpha(gamma):
173
+ """ Computes alpha given gamma. """
174
+ return torch.sqrt(torch.sigmoid(-gamma))
175
+
176
+ @staticmethod
177
+ def sigma(gamma):
178
+ """ Computes sigma given gamma. """
179
+ return torch.sqrt(torch.sigmoid(gamma))
180
+
181
+ @staticmethod
182
+ def SNR(gamma):
183
+ """ Computes signal to noise ratio (alpha^2/sigma^2) given gamma. """
184
+ return torch.exp(-gamma)
185
+
186
+ def sigma_and_alpha_t_given_s(self, gamma_t: torch.Tensor, gamma_s: torch.Tensor):
187
+ """
188
+ Computes sigma_t_given_s, using gamma_t and gamma_s. Used during sampling.
189
+ These are defined as:
190
+ alpha_t_given_s = alpha_t / alpha_s,
191
+ sigma_t_given_s = sqrt(1 - (alpha_t_given_s)^2 ).
192
+ """
193
+ sigma2_t_given_s = -torch.expm1(
194
+ F.softplus(gamma_s) - F.softplus(gamma_t))
195
+
196
+ # alpha_t_given_s = alpha_t / alpha_s
197
+ log_alpha2_t = F.logsigmoid(-gamma_t)
198
+ log_alpha2_s = F.logsigmoid(-gamma_s)
199
+ log_alpha2_t_given_s = log_alpha2_t - log_alpha2_s
200
+
201
+ alpha_t_given_s = torch.exp(0.5 * log_alpha2_t_given_s)
202
+ alpha_t_given_s = torch.clip(alpha_t_given_s, min=self.clip ** 0.5, max=1.0)
203
+
204
+ sigma_t_given_s = torch.sqrt(sigma2_t_given_s)
205
+
206
+ return sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s
src/model/dpo.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from pathlib import Path
3
+ from contextlib import nullcontext
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch_scatter import scatter_mean
8
+
9
+ from src.constants import atom_encoder, bond_encoder
10
+ from src.model.lightning import DrugFlow, set_default
11
+ from src.data.dataset import ProcessedLigandPocketDataset, DPODataset
12
+ from src.data.data_utils import AppendVirtualNodesInCoM, Residues, center_data
13
+
14
+ class DPO(DrugFlow):
15
+ def __init__(self, dpo_mode, ref_checkpoint_p, **kwargs):
16
+ super(DPO, self).__init__(**kwargs)
17
+ self.dpo_mode = dpo_mode
18
+ self.dpo_beta = kwargs['loss_params'].dpo_beta if 'dpo_beta' in kwargs['loss_params'] else 100.0
19
+ self.dpo_beta_schedule = kwargs['loss_params'].dpo_beta_schedule if 'dpo_beta_schedule' in kwargs['loss_params'] else 't'
20
+ self.clamp_dpo = kwargs['loss_params'].clamp_dpo if 'clamp_dpo' in kwargs['loss_params'] else True
21
+ self.dpo_lambda_dpo = kwargs['loss_params'].dpo_lambda_dpo if 'dpo_lambda_dpo' in kwargs['loss_params'] else 1
22
+ self.dpo_lambda_w = kwargs['loss_params'].dpo_lambda_w if 'dpo_lambda_w' in kwargs['loss_params'] else 1
23
+ self.dpo_lambda_l = kwargs['loss_params'].dpo_lambda_l if 'dpo_lambda_l' in kwargs['loss_params'] else 0.2
24
+ self.dpo_lambda_h = kwargs['loss_params'].dpo_lambda_h if 'dpo_lambda_h' in kwargs['loss_params'] else kwargs['loss_params'].lambda_h
25
+ self.dpo_lambda_e = kwargs['loss_params'].dpo_lambda_e if 'dpo_lambda_e' in kwargs['loss_params'] else kwargs['loss_params'].lambda_e
26
+ self.ref_dynamics = self.init_model(kwargs['predictor_params'])
27
+ state_dict = torch.load(ref_checkpoint_p)['state_dict']
28
+ self.ref_dynamics.load_state_dict({k.replace('dynamics.',''): v for k, v in state_dict.items() if k.startswith('dynamics.')})
29
+ print(f'Loaded reference model from {ref_checkpoint_p}')
30
+ # initializing model params with ref model params
31
+ self.dynamics.load_state_dict(self.ref_dynamics.state_dict())
32
+
33
+ def get_dataset(self, stage, pocket_transform=None):
34
+
35
+ # when sampling we don't append virtual nodes as we might need access to the ground truth size
36
+ if self.virtual_nodes and stage == 'train':
37
+ ligand_transform = AppendVirtualNodesInCoM(
38
+ atom_encoder, bond_encoder, add_min=self.add_virtual_min, add_max=self.add_virtual_max)
39
+ else:
40
+ ligand_transform = None
41
+
42
+ # we want to know if something goes wrong on the validation or test set
43
+ catch_errors = stage == 'train'
44
+
45
+ if self.sharded_dataset:
46
+ raise NotImplementedError('Sharded dataset not implemented for DPO')
47
+
48
+ if self.sample_from_clusters and stage == 'train': # val/test should be deterministic
49
+ raise NotImplementedError('Sampling from clusters not implemented for DPO')
50
+
51
+ if stage == 'train':
52
+ return DPODataset(
53
+ Path(self.datadir, 'train.pt'),
54
+ ligand_transform=None,
55
+ pocket_transform=pocket_transform,
56
+ catch_errors=True,
57
+ )
58
+ else:
59
+ return ProcessedLigandPocketDataset(
60
+ pt_path=Path(self.datadir, 'val.pt' if self.debug else f'{stage}.pt'),
61
+ ligand_transform=ligand_transform,
62
+ pocket_transform=pocket_transform,
63
+ catch_errors=catch_errors,
64
+ )
65
+
66
+
67
+ def training_step(self, data, *args):
68
+ ligand_w, ligand_l, pocket = data['ligand'], data['ligand_l'], data['pocket']
69
+ loss, info = self.compute_dpo_loss(pocket, ligand_w=ligand_w, ligand_l=ligand_l, return_info=True)
70
+
71
+ if torch.isnan(loss):
72
+ print(f'For ligand pair , loss is NaN at epoch {self.current_epoch}. Info: {info}')
73
+
74
+ log_dict = {k: v for k, v in info.items() if isinstance(v, float) or torch.numel(v) <= 1}
75
+ self.log_metrics({'loss': loss, **log_dict}, 'train', batch_size=len(ligand_w['size']))
76
+
77
+ out = {'loss': loss, **info}
78
+ self.training_step_outputs.append(out)
79
+ return out
80
+
81
+ def validation_step(self, data, *args):
82
+ return super().validation_step(data, *args)
83
+
84
+ def compute_dpo_loss(self, pocket, ligand_w, ligand_l, return_info=False):
85
+ t = torch.rand(ligand_w['size'].size(0), device=ligand_w['x'].device).unsqueeze(-1)
86
+
87
+ if self.dpo_beta_schedule == 't':
88
+ # from https://arxiv.org/pdf/2407.13981
89
+ beta_t = (self.dpo_beta * t).squeeze()
90
+ elif self.dpo_beta_schedule == 'const':
91
+ beta_t = self.dpo_beta
92
+ else:
93
+ raise ValueError(f'Unknown DPO beta schedule: {self.dpo_beta_schedule}')
94
+
95
+ loss_dict_w = self.compute_loss_single_pair(ligand_w, pocket, t)
96
+ loss_dict_l = self.compute_loss_single_pair(ligand_l, pocket, t)
97
+ info = {
98
+ 'loss_x_w': loss_dict_w['theta']['x'].mean().item(),
99
+ 'loss_h_w': loss_dict_w['theta']['h'].mean().item(),
100
+ 'loss_e_w': loss_dict_w['theta']['e'].mean().item(),
101
+ 'loss_x_l': loss_dict_l['theta']['x'].mean().item(),
102
+ 'loss_h_l': loss_dict_l['theta']['h'].mean().item(),
103
+ 'loss_e_l': loss_dict_l['theta']['e'].mean().item(),
104
+ }
105
+ if self.dpo_mode == 'single_dpo_comp':
106
+ loss_w_theta = (
107
+ loss_dict_w['theta']['x'] +
108
+ self.dpo_lambda_h * loss_dict_w['theta']['h'] +
109
+ self.dpo_lambda_e * loss_dict_w['theta']['e']
110
+ )
111
+ loss_w_ref = (
112
+ loss_dict_w['ref']['x'] +
113
+ self.dpo_lambda_h * loss_dict_w['ref']['h'] +
114
+ self.dpo_lambda_e * loss_dict_w['ref']['e']
115
+ )
116
+ loss_l_theta = (
117
+ loss_dict_l['theta']['x'] +
118
+ self.dpo_lambda_h * loss_dict_l['theta']['h'] +
119
+ self.dpo_lambda_e * loss_dict_l['theta']['e']
120
+ )
121
+ loss_l_ref = (
122
+ loss_dict_l['ref']['x'] +
123
+ self.dpo_lambda_h * loss_dict_l['ref']['h'] +
124
+ self.dpo_lambda_e * loss_dict_l['ref']['e']
125
+ )
126
+ diff_w = loss_w_theta - loss_w_ref
127
+ diff_l = loss_l_theta - loss_l_ref
128
+ info['diff_w'] = diff_w.mean().item()
129
+ info['diff_l'] = diff_l.mean().item()
130
+ # print(diff)
131
+ diff = -1 * beta_t * (diff_w - diff_l)
132
+ loss = -1 * F.logsigmoid(diff)
133
+ elif self.dpo_mode == 'single_dpo_comp_v3':
134
+ diff_w_x = loss_dict_w['theta']['x'] - loss_dict_w['ref']['x']
135
+ diff_w_h = loss_dict_w['theta']['h'] - loss_dict_w['ref']['h']
136
+ diff_w_e = loss_dict_w['theta']['e'] - loss_dict_w['ref']['e']
137
+ diff_l_x = loss_dict_l['theta']['x'] - loss_dict_l['ref']['x']
138
+ diff_l_h = loss_dict_l['theta']['h'] - loss_dict_l['ref']['h']
139
+ diff_l_e = loss_dict_l['theta']['e'] - loss_dict_l['ref']['e']
140
+ info['diff_w_x'] = diff_w_x.mean().item()
141
+ info['diff_w_h'] = diff_w_h.mean().item()
142
+ info['diff_w_e'] = diff_w_e.mean().item()
143
+ info['diff_l_x'] = diff_l_x.mean().item()
144
+ info['diff_l_h'] = diff_l_h.mean().item()
145
+ info['diff_l_e'] = diff_l_e.mean().item()
146
+
147
+ # not used, just for logging
148
+ _diff_w = diff_w_x + self.dpo_lambda_h * diff_w_h + self.dpo_lambda_e * diff_w_e
149
+ _diff_l = diff_l_x + self.dpo_lambda_h * diff_l_h + self.dpo_lambda_e * diff_l_e
150
+ info['diff_w'] = _diff_w.mean().item()
151
+ info['diff_l'] = _diff_l.mean().item()
152
+
153
+ diff_x = diff_w_x - diff_l_x
154
+ diff_h = diff_w_h - diff_l_h
155
+ diff_e = diff_w_e - diff_l_e
156
+ info['diff_x'] = diff_x.mean().item()
157
+ info['diff_h'] = diff_h.mean().item()
158
+ info['diff_e'] = diff_e.mean().item()
159
+
160
+ diff = -1 * beta_t * (diff_x + self.dpo_lambda_h * diff_h + self.dpo_lambda_e * diff_e)
161
+ if self.clamp_dpo:
162
+ diff = diff.clamp(-10, 10)
163
+ info['dpo_arg_min'] = diff.min().item()
164
+ info['dpo_arg_max'] = diff.max().item()
165
+ info['dpo_arg_mean'] = diff.mean().item()
166
+ dpo_loss = -1 * self.dpo_lambda_dpo * F.logsigmoid(diff)
167
+ info['dpo_loss'] = dpo_loss.mean().item()
168
+
169
+ loss_w_theta_reg = (
170
+ loss_dict_w['theta']['x'] +
171
+ self.lambda_h * loss_dict_w['theta']['h'] +
172
+ self.lambda_e * loss_dict_w['theta']['e']
173
+ )
174
+ info['loss_w_theta_reg'] = loss_w_theta_reg.mean().item()
175
+ loss_l_theta_reg = (
176
+ loss_dict_l['theta']['x'] +
177
+ self.lambda_h * loss_dict_l['theta']['h'] +
178
+ self.lambda_e * loss_dict_l['theta']['e']
179
+ )
180
+ info['loss_l_theta_reg'] = loss_l_theta_reg.mean().item()
181
+ dpo_reg = self.dpo_lambda_w * loss_w_theta_reg + \
182
+ self.dpo_lambda_l * loss_l_theta_reg
183
+ info['dpo_reg'] = dpo_reg.mean().item()
184
+ loss = dpo_loss + dpo_reg
185
+ else:
186
+ raise ValueError(f'Unknown DPO mode: {self.dpo_mode}')
187
+
188
+ if self.timestep_weights is not None:
189
+ w_t = self.timestep_weights(t).squeeze()
190
+ loss = w_t * loss
191
+
192
+ loss = loss.mean(0)
193
+
194
+ print(f'Loss is {loss}, info is {info}')
195
+
196
+ return (loss, info) if return_info else loss
197
+
198
+ def compute_loss_single_pair(self, ligand, pocket, t):
199
+ pocket = Residues(**pocket)
200
+
201
+ # Center sample
202
+ ligand, pocket = center_data(ligand, pocket)
203
+ pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0)
204
+
205
+ # Noise
206
+ z0_x = self.module_x.sample_z0(pocket_com, ligand['mask'])
207
+ z0_h = self.module_h.sample_z0(ligand['mask'])
208
+ z0_e = self.module_e.sample_z0(ligand['bond_mask'])
209
+ zt_x = self.module_x.sample_zt(z0_x, ligand['x'], t, ligand['mask'])
210
+ zt_h = self.module_h.sample_zt(z0_h, ligand['one_hot'], t, ligand['mask'])
211
+ zt_e = self.module_e.sample_zt(z0_e, ligand['bond_one_hot'], t, ligand['bond_mask'])
212
+
213
+ # Predict denoising
214
+ sc_transform = self.get_sc_transform_fn(None, zt_x, t, None, ligand['mask'], pocket)
215
+
216
+ pred_ligand, _ = self.dynamics(
217
+ zt_x, zt_h, ligand['mask'], pocket, t,
218
+ bonds_ligand=(ligand['bonds'], zt_e),
219
+ sc_transform=sc_transform
220
+ )
221
+
222
+ # Reference model
223
+ with torch.no_grad():
224
+ ref_pred_ligand, _ = self.ref_dynamics(
225
+ zt_x, zt_h, ligand['mask'], pocket, t,
226
+ bonds_ligand=(ligand['bonds'], zt_e),
227
+ sc_transform=sc_transform
228
+ )
229
+
230
+ # Compute L2 loss
231
+ loss_x = self.module_x.compute_loss(pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce=self.loss_reduce)
232
+ ref_loss_x = self.module_x.compute_loss(ref_pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce=self.loss_reduce)
233
+
234
+ t_next = torch.clamp(t + self.train_step_size, max=1.0)
235
+
236
+ loss_h = self.module_h.compute_loss(pred_ligand['logits_h'], zt_h, ligand['one_hot'], ligand['mask'], t, t_next, reduce=self.loss_reduce)
237
+ ref_loss_h = self.module_h.compute_loss(ref_pred_ligand['logits_h'], zt_h, ligand['one_hot'], ligand['mask'], t, t_next, reduce=self.loss_reduce)
238
+ loss_e = self.module_e.compute_loss(pred_ligand['logits_e'], zt_e, ligand['bond_one_hot'], ligand['bond_mask'], t, t_next, reduce=self.loss_reduce)
239
+ ref_loss_e = self.module_e.compute_loss(ref_pred_ligand['logits_e'], zt_e, ligand['bond_one_hot'], ligand['bond_mask'], t, t_next, reduce=self.loss_reduce)
240
+
241
+ return {
242
+ 'theta': {
243
+ 'x': loss_x,
244
+ 'h': loss_h,
245
+ 'e': loss_e,
246
+ },
247
+ 'ref': {
248
+ 'x': ref_loss_x,
249
+ 'h': ref_loss_h,
250
+ 'e': ref_loss_e,
251
+ }
252
+ }
src/model/dynamics.py ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Iterable
2
+ from abc import abstractmethod
3
+ import random
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+
9
+ from src.constants import INT_TYPE
10
+ from src.model.gvp import GVPModel, GVP, LayerNorm
11
+ from src.model.gvp_transformer import GVPTransformerModel
12
+ from src.constants import FLOAT_TYPE
13
+
14
+ from pdb import set_trace
15
+
16
+
17
+ def binomial_coefficient(n, k):
18
+ # source: https://discuss.pytorch.org/t/n-choose-k-function/121974
19
+ return ((n + 1).lgamma() - (k + 1).lgamma() - ((n - k) + 1).lgamma()).exp()
20
+
21
+
22
+ def cycle_counts(adj):
23
+ assert (adj.diag() == 0).all()
24
+ assert (adj == adj.T).all()
25
+
26
+ A = adj.float()
27
+ d = A.sum(dim=-1)
28
+
29
+ # Compute powers
30
+ A2 = A @ A
31
+ A3 = A2 @ A
32
+ A4 = A3 @ A
33
+ A5 = A4 @ A
34
+
35
+ x3 = A3.diag() / 2
36
+ x4 = (A4.diag() - d * (d - 1) - A @ d) / 2
37
+
38
+ """ New (different from DiGress)
39
+ case where correction is relevant:
40
+ 2 o
41
+ |
42
+ 1,3 o--o 4
43
+ | /
44
+ 0,5 o
45
+ """
46
+ # Triangle count matrix (indicates for each node i how many triangles it shares with node j)
47
+ T = adj * A2
48
+ x5 = (A5.diag() - 2 * T @ d - 4 * d * x3 - 2 * A @ x3 + 10 * x3) / 2
49
+
50
+ # # TODO
51
+ # A6 = A5 @ A
52
+ #
53
+ # # 4-cycle count matrix (indicates in how many shared 4-cycles i and j are 2 hops apart)
54
+ # Q2 = binomial_coefficient(n=A2 - d.diag(), k=torch.tensor(2))
55
+ #
56
+ # # 4-cycle count matrix (indicates in how many shared 4-cycles i and j are 1 (and 3) hop(s) apart)
57
+ # Q1 = A * (A3 - (d.view(-1, 1) + d.view(1, -1)) + 1) # "+1" because link between i and j is subtracted twice
58
+ #
59
+ # x6 = ...
60
+ # return torch.stack([x3, x4, x5, x6], dim=-1)
61
+
62
+ return torch.stack([x3, x4, x5], dim=-1)
63
+
64
+
65
+ # TODO: also consider directional aggregation as in:
66
+ # Beaini, Dominique, et al. "Directional graph networks."
67
+ # International Conference on Machine Learning. PMLR, 2021.
68
+ def eigenfeatures(A, batch_mask, k=5):
69
+ # TODO, see:
70
+ # - https://github.com/cvignac/DiGress/blob/main/src/diffusion/extra_features.py
71
+ # - https://arxiv.org/pdf/2209.14734.pdf (Appendix B.2)
72
+
73
+ # split adjacency matrix
74
+ batch = []
75
+ for i in torch.unique(batch_mask, sorted=True): # TODO: optimize (try to avoid loop)
76
+ batch_inds = torch.where(batch_mask == i)[0]
77
+ batch.append(A[torch.meshgrid(batch_inds, batch_inds, indexing='ij')])
78
+
79
+ eigenfeats = [get_nontrivial_eigenvectors(adj)[:, :k] for adj in batch]
80
+ # if there are less than k non-trivial eigenvectors
81
+ eigenfeats = [torch.cat([
82
+ x, torch.zeros(x.size(0), max(k - x.size(1), 0), device=x.device)], dim=-1)
83
+ for x in eigenfeats]
84
+ return torch.cat(eigenfeats, dim=0)
85
+
86
+
87
+ def get_nontrivial_eigenvectors(A, normalize_l=True, thresh=1e-5,
88
+ norm_eps=1e-12):
89
+ """
90
+ Compute eigenvectors of the graph Laplacian corresponding to non-zero
91
+ eigenvalues.
92
+ """
93
+ assert (A == A.T).all(), "undirected graph"
94
+
95
+ # Compute laplacian
96
+ d = A.sum(-1)
97
+ D = d.diag()
98
+ L = D - A
99
+
100
+ if normalize_l:
101
+ D_inv_sqrt = (1 / (d.sqrt() + norm_eps)).diag()
102
+ L = D_inv_sqrt @ L @ D_inv_sqrt
103
+
104
+ # Eigendecomposition
105
+ # eigenvalues are sorted in ascending order
106
+ # eigvecs matrix contains eigenvectors as its columns
107
+ eigvals, eigvecs = torch.linalg.eigh(L)
108
+
109
+ # index of first non-trivial eigenvector
110
+ try:
111
+ idx = torch.nonzero(eigvals > thresh)[0].item()
112
+ except IndexError:
113
+ # recover if no non-trivial eigenvectors are found
114
+ idx = eigvecs.size(1)
115
+
116
+ return eigvecs[:, idx:]
117
+
118
+
119
+ class DynamicsBase(nn.Module):
120
+ """
121
+ Implements self-conditioning logic and basic functions
122
+ """
123
+ def __init__(
124
+ self,
125
+ predict_angles=False,
126
+ predict_frames=False,
127
+ add_cycle_counts=False,
128
+ add_spectral_feat=False,
129
+ self_conditioning=False,
130
+ augment_residue_sc=False,
131
+ augment_ligand_sc=False
132
+ ):
133
+ super().__init__()
134
+
135
+ if not hasattr(self, 'predict_angles'):
136
+ self.predict_angles = predict_angles
137
+
138
+ if not hasattr(self, 'predict_frames'):
139
+ self.predict_frames = predict_frames
140
+
141
+ if not hasattr(self, 'add_cycle_counts'):
142
+ self.add_cycle_counts = add_cycle_counts
143
+
144
+ if not hasattr(self, 'add_spectral_feat'):
145
+ self.add_spectral_feat = add_spectral_feat
146
+
147
+ if not hasattr(self, 'self_conditioning'):
148
+ self.self_conditioning = self_conditioning
149
+
150
+ if not hasattr(self, 'augment_residue_sc'):
151
+ self.augment_residue_sc = augment_residue_sc
152
+
153
+ if not hasattr(self, 'augment_ligand_sc'):
154
+ self.augment_ligand_sc = augment_ligand_sc
155
+
156
+ if self.self_conditioning:
157
+ self.prev_ligand = None
158
+ self.prev_residues = None
159
+
160
+ @abstractmethod
161
+ def _forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None,
162
+ h_atoms_sc=None, e_atoms_sc=None, h_residues_sc=None):
163
+ """
164
+ Implement forward pass.
165
+ Returns:
166
+ - vel
167
+ - h_final_atoms
168
+ - edge_final_atoms
169
+ - residue_angles
170
+ - residue_trans
171
+ - residue_rot
172
+ """
173
+ pass
174
+
175
+ def make_sc_input(self, pred_ligand, pred_residues, sc_transform):
176
+
177
+ if self.predict_confidence:
178
+ h_atoms_sc = (torch.cat([pred_ligand['logits_h'], pred_ligand['uncertainty_vel'].unsqueeze(1)], dim=-1),
179
+ pred_ligand['vel'].unsqueeze(1))
180
+ else:
181
+ h_atoms_sc = (pred_ligand['logits_h'], pred_ligand['vel'].unsqueeze(1))
182
+ e_atoms_sc = pred_ligand['logits_e']
183
+
184
+ if self.predict_frames:
185
+ h_residues_sc = (torch.cat([pred_residues['chi'], pred_residues['rot']], dim=-1),
186
+ pred_residues['trans'].unsqueeze(1))
187
+ elif self.predict_angles:
188
+ h_residues_sc = pred_residues['chi']
189
+ else:
190
+ h_residues_sc = None
191
+
192
+ if self.augment_residue_sc and h_residues_sc is not None:
193
+ if self.predict_frames:
194
+ h_residues_sc = (h_residues_sc[0], torch.cat(
195
+ [h_residues_sc[1], sc_transform['residues'](pred_residues['chi'], pred_residues['trans'].squeeze(1), pred_residues['rot'])], dim=1))
196
+
197
+ else:
198
+ h_residues_sc = (h_residues_sc, sc_transform['residues'](pred_residues['chi']))
199
+
200
+ if self.augment_ligand_sc:
201
+ h_atoms_sc = (h_atoms_sc[0], torch.cat(
202
+ [h_atoms_sc[1], sc_transform['atoms'](pred_ligand['vel'].unsqueeze(1))], dim=1))
203
+
204
+ return h_atoms_sc, e_atoms_sc, h_residues_sc
205
+
206
+ def forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None, sc_transform=None):
207
+ """
208
+ Implements self-conditioning as in https://arxiv.org/abs/2208.04202
209
+ """
210
+
211
+ h_atoms_sc, e_atoms_sc = None, None
212
+ h_residues_sc = None
213
+
214
+ if self.self_conditioning:
215
+
216
+ # Sampling: use previous prediction in all but the first time step
217
+ if not self.training and t.min() > 0.0:
218
+ assert t.min() == t.max(), "currently only supports sampling at same time steps"
219
+ assert self.prev_ligand is not None
220
+ assert self.prev_residues is not None or not self.predict_frames
221
+
222
+ else:
223
+ # Create zero tensors
224
+ zeros_ligand = {'logits_h': torch.zeros_like(h_atoms),
225
+ 'vel': torch.zeros_like(x_atoms),
226
+ 'logits_e': torch.zeros_like(bonds_ligand[1])}
227
+ if self.predict_confidence:
228
+ zeros_ligand['uncertainty_vel'] = torch.zeros(
229
+ len(x_atoms), dtype=x_atoms.dtype, device=x_atoms.device)
230
+
231
+ zeros_residues = {}
232
+ if self.predict_angles:
233
+ zeros_residues['chi'] = torch.zeros((pocket['one_hot'].size(0), 5), device=pocket['one_hot'].device)
234
+ if self.predict_frames:
235
+ zeros_residues['trans'] = torch.zeros((pocket['one_hot'].size(0), 3), device=pocket['one_hot'].device)
236
+ zeros_residues['rot'] = torch.zeros((pocket['one_hot'].size(0), 3), device=pocket['one_hot'].device)
237
+
238
+ # Training: use 50% zeros and 50% predictions with detached gradients
239
+ if self.training and random.random() > 0.5:
240
+ with torch.no_grad():
241
+ h_atoms_sc, e_atoms_sc, h_residues_sc = self.make_sc_input(
242
+ zeros_ligand, zeros_residues, sc_transform)
243
+
244
+ self.prev_ligand, self.prev_residues = self._forward(
245
+ x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand,
246
+ h_atoms_sc, e_atoms_sc, h_residues_sc)
247
+
248
+ # use zeros for first sampling step and 50% of training
249
+ else:
250
+ self.prev_ligand = zeros_ligand
251
+ self.prev_residues = zeros_residues
252
+
253
+ h_atoms_sc, e_atoms_sc, h_residues_sc = self.make_sc_input(
254
+ self.prev_ligand, self.prev_residues, sc_transform)
255
+
256
+ pred_ligand, pred_residues = self._forward(
257
+ x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand,
258
+ h_atoms_sc, e_atoms_sc, h_residues_sc
259
+ )
260
+
261
+ if self.self_conditioning and not self.training:
262
+ self.prev_ligand = pred_ligand.copy()
263
+ self.prev_residues = pred_residues.copy()
264
+
265
+ return pred_ligand, pred_residues
266
+
267
+ def compute_extra_features(self, batch_mask, edge_indices, edge_types):
268
+
269
+ feat = torch.zeros(len(batch_mask), 0, device=batch_mask.device)
270
+
271
+ if not (self.add_cycle_counts or self.add_spectral_feat):
272
+ return feat
273
+
274
+ adj = batch_mask[:, None] == batch_mask[None, :]
275
+
276
+ E = torch.zeros_like(adj, dtype=INT_TYPE)
277
+ E[edge_indices[0], edge_indices[1]] = edge_types
278
+
279
+ A = (E > 0).float()
280
+
281
+ if self.add_cycle_counts:
282
+ cycle_features = cycle_counts(A)
283
+ cycle_features[cycle_features > 10] = 10 # avoid large values
284
+
285
+ feat = torch.cat([feat, cycle_features], dim=-1)
286
+
287
+ if self.add_spectral_feat:
288
+ feat = torch.cat([feat, eigenfeatures(A, batch_mask)], dim=-1)
289
+
290
+ return feat
291
+
292
+
293
+ class Dynamics(DynamicsBase):
294
+ def __init__(self, atom_nf, residue_nf, joint_nf, bond_dict, pocket_bond_dict,
295
+ edge_nf, hidden_nf, act_fn=torch.nn.SiLU(), condition_time=True,
296
+ model='egnn', model_params=None,
297
+ edge_cutoff_ligand=None, edge_cutoff_pocket=None,
298
+ edge_cutoff_interaction=None,
299
+ predict_angles=False, predict_frames=False,
300
+ add_cycle_counts=False, add_spectral_feat=False,
301
+ add_nma_feat=False, self_conditioning=False,
302
+ augment_residue_sc=False, augment_ligand_sc=False,
303
+ add_chi_as_feature=False, angle_act_fn=False):
304
+ super().__init__()
305
+ self.model = model
306
+ self.edge_cutoff_l = edge_cutoff_ligand
307
+ self.edge_cutoff_p = edge_cutoff_pocket
308
+ self.edge_cutoff_i = edge_cutoff_interaction
309
+ self.hidden_nf = hidden_nf
310
+ self.predict_angles = predict_angles
311
+ self.predict_frames = predict_frames
312
+ self.bond_dict = bond_dict
313
+ self.pocket_bond_dict = pocket_bond_dict
314
+ self.bond_nf = len(bond_dict)
315
+ self.pocket_bond_nf = len(pocket_bond_dict)
316
+ self.edge_nf = edge_nf
317
+ self.add_cycle_counts = add_cycle_counts
318
+ self.add_spectral_feat = add_spectral_feat
319
+ self.add_nma_feat = add_nma_feat
320
+ self.self_conditioning = self_conditioning
321
+ self.augment_residue_sc = augment_residue_sc
322
+ self.augment_ligand_sc = augment_ligand_sc
323
+ self.add_chi_as_feature = add_chi_as_feature
324
+ self.predict_confidence = False
325
+
326
+ if self.self_conditioning:
327
+ self.prev_vel = None
328
+ self.prev_h = None
329
+ self.prev_e = None
330
+ self.prev_a = None
331
+ self.prev_ca = None
332
+ self.prev_rot = None
333
+
334
+ lig_nf = atom_nf
335
+ if self.add_cycle_counts:
336
+ lig_nf = lig_nf + 3
337
+ if self.add_spectral_feat:
338
+ lig_nf = lig_nf + 5
339
+
340
+
341
+ if not isinstance(joint_nf, Iterable):
342
+ # joint_nf contains only scalars
343
+ joint_nf = (joint_nf, 0)
344
+
345
+
346
+ if isinstance(residue_nf, Iterable):
347
+ _atom_in_nf = (lig_nf, 0)
348
+ _residue_atom_dim = residue_nf[1]
349
+
350
+ if self.add_nma_feat:
351
+ residue_nf = (residue_nf[0], residue_nf[1] + 5)
352
+
353
+ if self.self_conditioning:
354
+ _atom_in_nf = (_atom_in_nf[0] + atom_nf, 1)
355
+
356
+ if self.augment_ligand_sc:
357
+ _atom_in_nf = (_atom_in_nf[0], _atom_in_nf[1] + 1)
358
+
359
+ if self.predict_angles:
360
+ residue_nf = (residue_nf[0] + 5, residue_nf[1])
361
+
362
+ if self.predict_frames:
363
+ residue_nf = (residue_nf[0], residue_nf[1] + 2)
364
+
365
+ if self.augment_residue_sc:
366
+ assert self.predict_angles
367
+ residue_nf = (residue_nf[0], residue_nf[1] + _residue_atom_dim)
368
+
369
+ if self.add_chi_as_feature:
370
+ residue_nf = (residue_nf[0] + 5, residue_nf[1])
371
+
372
+ self.atom_encoder = nn.Sequential(
373
+ GVP(_atom_in_nf, joint_nf, activations=(act_fn, torch.sigmoid)),
374
+ LayerNorm(joint_nf, learnable_vector_weight=True),
375
+ GVP(joint_nf, joint_nf, activations=(None, None)),
376
+ )
377
+
378
+ self.residue_encoder = nn.Sequential(
379
+ GVP(residue_nf, joint_nf, activations=(act_fn, torch.sigmoid)),
380
+ LayerNorm(joint_nf, learnable_vector_weight=True),
381
+ GVP(joint_nf, joint_nf, activations=(None, None)),
382
+ )
383
+
384
+ else:
385
+ # No vector-valued input features
386
+ assert joint_nf[1] == 0
387
+
388
+ # self-conditioning not yet supported
389
+ assert not self.self_conditioning
390
+
391
+ # Normal mode features are vectors
392
+ assert not self.add_nma_feat
393
+
394
+ if self.add_chi_as_feature:
395
+ residue_nf += 5
396
+
397
+ self.atom_encoder = nn.Sequential(
398
+ nn.Linear(lig_nf, 2 * atom_nf),
399
+ act_fn,
400
+ nn.Linear(2 * atom_nf, joint_nf[0])
401
+ )
402
+
403
+ self.residue_encoder = nn.Sequential(
404
+ nn.Linear(residue_nf, 2 * residue_nf),
405
+ act_fn,
406
+ nn.Linear(2 * residue_nf, joint_nf[0])
407
+ )
408
+
409
+ self.atom_decoder = nn.Sequential(
410
+ nn.Linear(joint_nf[0], 2 * atom_nf),
411
+ act_fn,
412
+ nn.Linear(2 * atom_nf, atom_nf)
413
+ )
414
+
415
+ self.edge_decoder = nn.Sequential(
416
+ nn.Linear(hidden_nf, hidden_nf),
417
+ act_fn,
418
+ nn.Linear(hidden_nf, self.bond_nf)
419
+ )
420
+
421
+ _atom_bond_nf = 2 * self.bond_nf if self.self_conditioning else self.bond_nf
422
+ self.ligand_bond_encoder = nn.Sequential(
423
+ nn.Linear(_atom_bond_nf, hidden_nf),
424
+ act_fn,
425
+ nn.Linear(hidden_nf, self.edge_nf)
426
+ )
427
+
428
+ self.pocket_bond_encoder = nn.Sequential(
429
+ nn.Linear(self.pocket_bond_nf, hidden_nf),
430
+ act_fn,
431
+ nn.Linear(hidden_nf, self.edge_nf)
432
+ )
433
+
434
+ out_nf = (joint_nf[0], 1)
435
+ res_out_nf = (0, 0)
436
+ if self.predict_angles:
437
+ res_out_nf = (res_out_nf[0] + 5, res_out_nf[1])
438
+ if self.predict_frames:
439
+ res_out_nf = (res_out_nf[0], res_out_nf[1] + 2)
440
+ self.residue_decoder = nn.Sequential(
441
+ GVP(out_nf, out_nf, activations=(act_fn, torch.sigmoid)),
442
+ LayerNorm(out_nf, learnable_vector_weight=True),
443
+ GVP(out_nf, res_out_nf, activations=(None, None)),
444
+ ) if res_out_nf != (0, 0) else None
445
+
446
+ if angle_act_fn is None:
447
+ self.angle_act_fn = None
448
+ elif angle_act_fn == 'tanh':
449
+ self.angle_act_fn = lambda x: np.pi * F.tanh(x)
450
+ else:
451
+ raise NotImplementedError(f"Angle activation {angle_act_fn} not available")
452
+
453
+ # self.ligand_nobond_emb = nn.Parameter(torch.zeros(self.edge_nf))
454
+ # self.pocket_nobond_emb = nn.Parameter(torch.zeros(self.edge_nf))
455
+ self.cross_emb = nn.Parameter(torch.zeros(self.edge_nf),
456
+ requires_grad=True)
457
+
458
+ if condition_time:
459
+ dynamics_node_nf = (joint_nf[0] + 1, joint_nf[1])
460
+ else:
461
+ print('Warning: dynamics model is NOT conditioned on time.')
462
+ dynamics_node_nf = (joint_nf[0], joint_nf[1])
463
+
464
+ if model == 'egnn':
465
+ raise NotImplementedError
466
+ # self.net = EGNN(
467
+ # in_node_nf=dynamics_node_nf[0], in_edge_nf=self.edge_nf,
468
+ # hidden_nf=hidden_nf, out_node_nf=joint_nf[0],
469
+ # device=model_params.device, act_fn=act_fn,
470
+ # n_layers=model_params.n_layers,
471
+ # attention=model_params.attention,
472
+ # tanh=model_params.tanh,
473
+ # norm_constant=model_params.norm_constant,
474
+ # inv_sublayers=model_params.inv_sublayers,
475
+ # sin_embedding=model_params.sin_embedding,
476
+ # normalization_factor=model_params.normalization_factor,
477
+ # aggregation_method=model_params.aggregation_method,
478
+ # reflection_equiv=model_params.reflection_equivariant,
479
+ # update_edge_attr=True
480
+ # )
481
+ # self.node_nf = dynamics_node_nf[0]
482
+
483
+ elif model == 'gvp':
484
+ self.net = GVPModel(
485
+ node_in_dim=dynamics_node_nf, node_h_dim=model_params.node_h_dim,
486
+ node_out_nf=joint_nf[0], edge_in_nf=self.edge_nf,
487
+ edge_h_dim=model_params.edge_h_dim, edge_out_nf=hidden_nf,
488
+ num_layers=model_params.n_layers,
489
+ drop_rate=model_params.dropout,
490
+ vector_gate=model_params.vector_gate,
491
+ reflection_equiv=model_params.reflection_equivariant,
492
+ d_max=model_params.d_max,
493
+ num_rbf=model_params.num_rbf,
494
+ update_edge_attr=True
495
+ )
496
+
497
+ elif model == 'gvp_transformer':
498
+ self.net = GVPTransformerModel(
499
+ node_in_dim=dynamics_node_nf,
500
+ node_h_dim=model_params.node_h_dim,
501
+ node_out_nf=joint_nf[0],
502
+ edge_in_nf=self.edge_nf,
503
+ edge_h_dim=model_params.edge_h_dim,
504
+ edge_out_nf=hidden_nf,
505
+ num_layers=model_params.n_layers,
506
+ dk=model_params.dk,
507
+ dv=model_params.dv,
508
+ de=model_params.de,
509
+ db=model_params.db,
510
+ dy=model_params.dy,
511
+ attn_heads=model_params.attn_heads,
512
+ n_feedforward=model_params.n_feedforward,
513
+ drop_rate=model_params.dropout,
514
+ reflection_equiv=model_params.reflection_equivariant,
515
+ d_max=model_params.d_max,
516
+ num_rbf=model_params.num_rbf,
517
+ vector_gate=model_params.vector_gate,
518
+ attention=model_params.attention,
519
+ )
520
+
521
+ elif model == 'gnn':
522
+ raise NotImplementedError
523
+ # n_dims = 3
524
+ # self.net = GNN(
525
+ # in_node_nf=dynamics_node_nf + n_dims, in_edge_nf=self.edge_emb_dim,
526
+ # hidden_nf=hidden_nf, out_node_nf=n_dims + dynamics_node_nf,
527
+ # device=model_params.device, act_fn=act_fn, n_layers=model_params.n_layers,
528
+ # attention=model_params.attention, normalization_factor=model_params.normalization_factor,
529
+ # aggregation_method=model_params.aggregation_method)
530
+
531
+ else:
532
+ raise NotImplementedError(f"{model} is not available")
533
+
534
+ # self.device = device
535
+ # self.n_dims = n_dims
536
+ self.condition_time = condition_time
537
+
538
+ def _forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None,
539
+ h_atoms_sc=None, e_atoms_sc=None, h_residues_sc=None):
540
+ """
541
+ :param x_atoms:
542
+ :param h_atoms:
543
+ :param mask_atoms:
544
+ :param pocket: must contain keys: 'x', 'one_hot', 'mask', 'bonds' and 'bond_one_hot'
545
+ :param t:
546
+ :param bonds_ligand: tuple - bond indices (2, n_bonds) & bond types (n_bonds, bond_nf)
547
+ :param h_atoms_sc: additional node feature for self-conditioning, (s, V)
548
+ :param e_atoms_sc: additional edge feature for self-conditioning, only scalar
549
+ :param h_residues_sc: additional node feature for self-conditioning, tensor or tuple
550
+ :return:
551
+ """
552
+ x_residues, h_residues, mask_residues = pocket['x'], pocket['one_hot'], pocket['mask']
553
+ if 'bonds' in pocket:
554
+ bonds_pocket = (pocket['bonds'], pocket['bond_one_hot'])
555
+ else:
556
+ bonds_pocket = None
557
+
558
+ if self.add_chi_as_feature:
559
+ h_residues = torch.cat([h_residues, pocket['chi'][:, :5]], dim=-1)
560
+
561
+ if 'v' in pocket:
562
+ v_residues = pocket['v']
563
+ if self.add_nma_feat:
564
+ v_residues = torch.cat([v_residues, pocket['nma_vec']], dim=1)
565
+ h_residues = (h_residues, v_residues)
566
+
567
+ if h_residues_sc is not None:
568
+ # if self.augment_residue_sc:
569
+ if isinstance(h_residues_sc, tuple):
570
+ h_residues = (torch.cat([h_residues[0], h_residues_sc[0]], dim=-1),
571
+ torch.cat([h_residues[1], h_residues_sc[1]], dim=1))
572
+ else:
573
+ h_residues = (torch.cat([h_residues[0], h_residues_sc], dim=-1),
574
+ h_residues[1])
575
+
576
+ # get graph edges and edge attributes
577
+ if bonds_ligand is not None:
578
+ # NOTE: 'bond' denotes one-directional edges and 'edge' means bi-directional
579
+ ligand_bond_indices = bonds_ligand[0]
580
+
581
+ # make sure messages are passed both ways
582
+ ligand_edge_indices = torch.cat(
583
+ [bonds_ligand[0], bonds_ligand[0].flip(dims=[0])], dim=1)
584
+ ligand_edge_types = torch.cat([bonds_ligand[1], bonds_ligand[1]], dim=0)
585
+ # edges_ligand = (ligand_edge_indices, ligand_edge_types)
586
+
587
+ # add auxiliary features to ligand nodes
588
+ extra_features = self.compute_extra_features(
589
+ mask_atoms, ligand_edge_indices, ligand_edge_types.argmax(-1))
590
+ h_atoms = torch.cat([h_atoms, extra_features], dim=-1)
591
+
592
+ if bonds_pocket is not None:
593
+ # make sure messages are passed both ways
594
+ pocket_edge_indices = torch.cat(
595
+ [bonds_pocket[0], bonds_pocket[0].flip(dims=[0])], dim=1)
596
+ pocket_edge_types = torch.cat([bonds_pocket[1], bonds_pocket[1]], dim=0)
597
+ # edges_pocket = (pocket_edge_indices, pocket_edge_types)
598
+
599
+ if h_atoms_sc is not None:
600
+ h_atoms = (torch.cat([h_atoms, h_atoms_sc[0]], dim=-1),
601
+ h_atoms_sc[1])
602
+
603
+ if e_atoms_sc is not None:
604
+ e_atoms_sc = torch.cat([e_atoms_sc, e_atoms_sc], dim=0)
605
+ ligand_edge_types = torch.cat([ligand_edge_types, e_atoms_sc], dim=-1)
606
+
607
+ # embed atom features and residue features in a shared space
608
+ h_atoms = self.atom_encoder(h_atoms)
609
+ e_ligand = self.ligand_bond_encoder(ligand_edge_types)
610
+
611
+ if len(h_residues) > 0:
612
+ h_residues = self.residue_encoder(h_residues)
613
+ e_pocket = self.pocket_bond_encoder(pocket_edge_types)
614
+ else:
615
+ e_pocket = pocket_edge_types
616
+ h_residues = (h_residues, h_residues)
617
+ pocket_edge_indices = torch.tensor([[], []], dtype=torch.long, device=h_residues[0].device)
618
+ pocket_edge_types = torch.tensor([[], []], dtype=torch.long, device=h_residues[0].device)
619
+
620
+ if isinstance(h_atoms, tuple):
621
+ h_atoms, v_atoms = h_atoms
622
+ h_residues, v_residues = h_residues
623
+ v = torch.cat((v_atoms, v_residues), dim=0)
624
+ else:
625
+ v = None
626
+
627
+ edges, edge_feat = self.get_edges(
628
+ mask_atoms, mask_residues, x_atoms, x_residues,
629
+ bond_inds_ligand=ligand_edge_indices, bond_inds_pocket=pocket_edge_indices,
630
+ bond_feat_ligand=e_ligand, bond_feat_pocket=e_pocket)
631
+
632
+ # combine the two node types
633
+ x = torch.cat((x_atoms, x_residues), dim=0)
634
+ h = torch.cat((h_atoms, h_residues), dim=0)
635
+ mask = torch.cat([mask_atoms, mask_residues])
636
+
637
+ if self.condition_time:
638
+ if np.prod(t.size()) == 1:
639
+ # t is the same for all elements in batch.
640
+ h_time = torch.empty_like(h[:, 0:1]).fill_(t.item())
641
+ else:
642
+ # t is different over the batch dimension.
643
+ h_time = t[mask]
644
+ h = torch.cat([h, h_time], dim=1)
645
+
646
+ assert torch.all(mask[edges[0]] == mask[edges[1]])
647
+
648
+ if self.model == 'egnn':
649
+ # Don't update pocket coordinates
650
+ update_coords_mask = torch.cat((torch.ones_like(mask_atoms),
651
+ torch.zeros_like(mask_residues))).unsqueeze(1)
652
+ h_final, vel, edge_final = self.net(
653
+ h, x, edges, batch_mask=mask, edge_attr=edge_feat,
654
+ update_coords_mask=update_coords_mask)
655
+ # vel = (x_final - x)
656
+
657
+ elif self.model == 'gvp' or self.model == 'gvp_transformer':
658
+ h_final, vel, edge_final = self.net(
659
+ h, x, edges, v=v, batch_mask=mask, edge_attr=edge_feat)
660
+
661
+ elif self.model == 'gnn':
662
+ xh = torch.cat([x, h], dim=1)
663
+ output = self.net(xh, edges, node_mask=None, edge_attr=edge_feat)
664
+ vel = output[:, :3]
665
+ h_final = output[:, 3:]
666
+
667
+ else:
668
+ raise NotImplementedError(f"Wrong model ({self.model})")
669
+
670
+ # if self.condition_time:
671
+ # # Slice off last dimension which represented time.
672
+ # h_final = h_final[:, :-1]
673
+
674
+ # decode atom and residue features
675
+ h_final_atoms = self.atom_decoder(h_final[:len(mask_atoms)])
676
+
677
+ if torch.any(torch.isnan(vel)) or torch.any(torch.isnan(h_final_atoms)):
678
+ if self.training:
679
+ vel[torch.isnan(vel)] = 0.0
680
+ h_final_atoms[torch.isnan(h_final_atoms)] = 0.0
681
+ else:
682
+ raise ValueError("NaN detected in network output")
683
+
684
+ # predict edge type
685
+ ligand_edge_mask = (edges[0] < len(mask_atoms)) & (edges[1] < len(mask_atoms))
686
+ edge_final = edge_final[ligand_edge_mask]
687
+ edges = edges[:, ligand_edge_mask]
688
+
689
+ # Symmetrize
690
+ edge_logits = torch.zeros(
691
+ (len(mask_atoms), len(mask_atoms), self.hidden_nf),
692
+ device=mask_atoms.device)
693
+ edge_logits[edges[0], edges[1]] = edge_final
694
+ edge_logits = (edge_logits + edge_logits.transpose(0, 1)) * 0.5
695
+ # edge_logits = edge_logits[lig_edge_indices[0], lig_edge_indices[1]]
696
+
697
+ # return upper triangular elements only (matching the input)
698
+ edge_logits = edge_logits[ligand_bond_indices[0], ligand_bond_indices[1]]
699
+ # assert (edge_logits == 0).sum() == 0
700
+
701
+ edge_final_atoms = self.edge_decoder(edge_logits)
702
+
703
+ # Predict torsion angles
704
+ residue_angles = None
705
+ residue_trans, residue_rot = None, None
706
+ if self.residue_decoder is not None:
707
+ h_residues = h_final[len(mask_atoms):]
708
+ vec_residues = vel[len(mask_atoms):].unsqueeze(1)
709
+ residue_angles = self.residue_decoder((h_residues, vec_residues))
710
+ if self.predict_frames:
711
+ residue_angles, residue_frames = residue_angles
712
+ residue_trans = residue_frames[:, 0, :].squeeze(1)
713
+ residue_rot = residue_frames[:, 1, :].squeeze(1)
714
+ if self.angle_act_fn is not None:
715
+ residue_angles = self.angle_act_fn(residue_angles)
716
+
717
+ # return vel[:len(mask_atoms)], h_final_atoms, edge_final_atoms, residue_angles, residue_trans, residue_rot
718
+ pred_ligand = {'vel': vel[:len(mask_atoms)], 'logits_h': h_final_atoms, 'logits_e': edge_final_atoms}
719
+ pred_residues = {'chi': residue_angles, 'trans': residue_trans, 'rot': residue_rot}
720
+ return pred_ligand, pred_residues
721
+
722
+ def get_edges(self, batch_mask_ligand, batch_mask_pocket, x_ligand,
723
+ x_pocket, bond_inds_ligand=None, bond_inds_pocket=None,
724
+ bond_feat_ligand=None, bond_feat_pocket=None, self_edges=False):
725
+
726
+ # Adjacency matrix
727
+ adj_ligand = batch_mask_ligand[:, None] == batch_mask_ligand[None, :]
728
+ adj_pocket = batch_mask_pocket[:, None] == batch_mask_pocket[None, :]
729
+ adj_cross = batch_mask_ligand[:, None] == batch_mask_pocket[None, :]
730
+
731
+ if self.edge_cutoff_l is not None:
732
+ adj_ligand = adj_ligand & (torch.cdist(x_ligand, x_ligand) <= self.edge_cutoff_l)
733
+
734
+ # Add missing bonds if they got removed
735
+ adj_ligand[bond_inds_ligand[0], bond_inds_ligand[1]] = True
736
+
737
+ if self.edge_cutoff_p is not None and len(x_pocket) > 0:
738
+ adj_pocket = adj_pocket & (torch.cdist(x_pocket, x_pocket) <= self.edge_cutoff_p)
739
+
740
+ # Add missing bonds if they got removed
741
+ adj_pocket[bond_inds_pocket[0], bond_inds_pocket[1]] = True
742
+
743
+ if self.edge_cutoff_i is not None and len(x_pocket) > 0:
744
+ adj_cross = adj_cross & (torch.cdist(x_ligand, x_pocket) <= self.edge_cutoff_i)
745
+
746
+ adj = torch.cat((torch.cat((adj_ligand, adj_cross), dim=1),
747
+ torch.cat((adj_cross.T, adj_pocket), dim=1)), dim=0)
748
+
749
+ if not self_edges:
750
+ adj = adj ^ torch.eye(*adj.size(), out=torch.empty_like(adj))
751
+
752
+ # # ensure that edge definition is consistent if bonds are provided (for loss computation)
753
+ # if bond_inds_ligand is not None:
754
+ # # remove ligand edges
755
+ # adj[:adj_ligand.size(0), :adj_ligand.size(1)] = False
756
+ # edges = torch.stack(torch.where(adj), dim=0)
757
+ # # add ligand edges back with original definition
758
+ # edges = torch.cat([bond_inds_ligand, edges], dim=-1)
759
+ # else:
760
+ # edges = torch.stack(torch.where(adj), dim=0)
761
+
762
+ # Feature matrix
763
+ ligand_nobond_onehot = F.one_hot(torch.tensor(
764
+ self.bond_dict['NOBOND'], device=bond_feat_ligand.device),
765
+ num_classes=self.ligand_bond_encoder[0].in_features)
766
+ ligand_nobond_emb = self.ligand_bond_encoder(
767
+ ligand_nobond_onehot.to(FLOAT_TYPE))
768
+ feat_ligand = ligand_nobond_emb.repeat(*adj_ligand.shape, 1)
769
+ feat_ligand[bond_inds_ligand[0], bond_inds_ligand[1]] = bond_feat_ligand
770
+
771
+ if len(adj_pocket) > 0:
772
+ pocket_nobond_onehot = F.one_hot(torch.tensor(
773
+ self.pocket_bond_dict['NOBOND'], device=bond_feat_pocket.device),
774
+ num_classes=self.pocket_bond_nf)
775
+ pocket_nobond_emb = self.pocket_bond_encoder(
776
+ pocket_nobond_onehot.to(FLOAT_TYPE))
777
+ feat_pocket = pocket_nobond_emb.repeat(*adj_pocket.shape, 1)
778
+ feat_pocket[bond_inds_pocket[0], bond_inds_pocket[1]] = bond_feat_pocket
779
+
780
+ feat_cross = self.cross_emb.repeat(*adj_cross.shape, 1)
781
+
782
+ feats = torch.cat((torch.cat((feat_ligand, feat_cross), dim=1),
783
+ torch.cat((feat_cross.transpose(0, 1), feat_pocket), dim=1)), dim=0)
784
+ else:
785
+ feats = feat_ligand
786
+
787
+ # Return results
788
+ edges = torch.stack(torch.where(adj), dim=0)
789
+ edge_feat = feats[edges[0], edges[1]]
790
+
791
+ return edges, edge_feat
src/model/dynamics_hetero.py ADDED
@@ -0,0 +1,1008 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Iterable
2
+ from collections import defaultdict
3
+ from functools import partial
4
+ import functools
5
+ import warnings
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from torch_scatter import scatter_mean
11
+ from torch_geometric.nn import MessagePassing
12
+ from torch_geometric.nn.module_dict import ModuleDict
13
+ from torch_geometric.utils.hetero import check_add_self_loops
14
+ try:
15
+ from torch_geometric.nn.conv.hgt_conv import group
16
+ except ImportError as e:
17
+ from torch_geometric.nn.conv.hetero_conv import group
18
+
19
+ from src.model.dynamics import DynamicsBase
20
+ from src.model import gvp
21
+ from src.model.gvp import GVP, _rbf, _normalize, tuple_index, tuple_sum, _split, tuple_cat, _merge
22
+
23
+
24
+ class MyModuleDict(nn.ModuleDict):
25
+ def __init__(self, modules):
26
+ # a mapping (dictionary) of (string: module) or an iterable of key-value pairs of type (string, module)
27
+ if isinstance(modules, dict):
28
+ super().__init__({str(k): v for k, v in modules.items()})
29
+ else:
30
+ raise NotImplementedError
31
+
32
+ def __getitem__(self, key):
33
+ return super().__getitem__(str(key))
34
+
35
+ def __setitem__(self, key, value):
36
+ super().__setitem__(str(key), value)
37
+
38
+ def __delitem__(self, key):
39
+ super().__delitem__(str(key))
40
+
41
+
42
+ class MyHeteroConv(nn.Module):
43
+ """
44
+ Implementation from PyG 2.2.0 with minor changes.
45
+ Override forward pass to control the final aggregation
46
+ Ref.: https://pytorch-geometric.readthedocs.io/en/2.2.0/_modules/torch_geometric/nn/conv/hetero_conv.html
47
+ """
48
+ def __init__(self, convs, aggr="sum"):
49
+ self.vo = {}
50
+ for k, module in convs.items():
51
+ dst = k[-1]
52
+ if dst not in self.vo:
53
+ self.vo[dst] = module.vo
54
+ else:
55
+ assert self.vo[dst] == module.vo
56
+
57
+ # from the original implementation in PyTorch Geometric
58
+ super().__init__()
59
+
60
+ for edge_type, module in convs.items():
61
+ check_add_self_loops(module, [edge_type])
62
+
63
+ src_node_types = set([key[0] for key in convs.keys()])
64
+ dst_node_types = set([key[-1] for key in convs.keys()])
65
+ if len(src_node_types - dst_node_types) > 0:
66
+ warnings.warn(
67
+ f"There exist node types ({src_node_types - dst_node_types}) "
68
+ f"whose representations do not get updated during message "
69
+ f"passing as they do not occur as destination type in any "
70
+ f"edge type. This may lead to unexpected behaviour.")
71
+
72
+ self.convs = ModuleDict({'__'.join(k): v for k, v in convs.items()})
73
+ self.aggr = aggr
74
+
75
+ def reset_parameters(self):
76
+ for conv in self.convs.values():
77
+ conv.reset_parameters()
78
+
79
+ def __repr__(self) -> str:
80
+ return f'{self.__class__.__name__}(num_relations={len(self.convs)})'
81
+
82
+ def forward(
83
+ self,
84
+ x_dict,
85
+ edge_index_dict,
86
+ *args_dict,
87
+ **kwargs_dict,
88
+ ):
89
+ r"""
90
+ Args:
91
+ x_dict (Dict[str, Tensor]): A dictionary holding node feature
92
+ information for each individual node type.
93
+ edge_index_dict (Dict[Tuple[str, str, str], Tensor]): A dictionary
94
+ holding graph connectivity information for each individual
95
+ edge type.
96
+ *args_dict (optional): Additional forward arguments of invididual
97
+ :class:`torch_geometric.nn.conv.MessagePassing` layers.
98
+ **kwargs_dict (optional): Additional forward arguments of
99
+ individual :class:`torch_geometric.nn.conv.MessagePassing`
100
+ layers.
101
+ For example, if a specific GNN layer at edge type
102
+ :obj:`edge_type` expects edge attributes :obj:`edge_attr` as a
103
+ forward argument, then you can pass them to
104
+ :meth:`~torch_geometric.nn.conv.HeteroConv.forward` via
105
+ :obj:`edge_attr_dict = { edge_type: edge_attr }`.
106
+ """
107
+ out_dict = defaultdict(list)
108
+ out_dict_edge = {}
109
+ for edge_type, edge_index in edge_index_dict.items():
110
+ src, rel, dst = edge_type
111
+
112
+ str_edge_type = '__'.join(edge_type)
113
+ if str_edge_type not in self.convs:
114
+ continue
115
+
116
+ args = []
117
+ for value_dict in args_dict:
118
+ if edge_type in value_dict:
119
+ args.append(value_dict[edge_type])
120
+ elif src == dst and src in value_dict:
121
+ args.append(value_dict[src])
122
+ elif src in value_dict or dst in value_dict:
123
+ args.append(
124
+ (value_dict.get(src, None), value_dict.get(dst, None)))
125
+
126
+ kwargs = {}
127
+ for arg, value_dict in kwargs_dict.items():
128
+ arg = arg[:-5] # `{*}_dict`
129
+ if edge_type in value_dict:
130
+ kwargs[arg] = value_dict[edge_type]
131
+ elif src == dst and src in value_dict:
132
+ kwargs[arg] = value_dict[src]
133
+ elif src in value_dict or dst in value_dict:
134
+ kwargs[arg] = (value_dict.get(src, None),
135
+ value_dict.get(dst, None))
136
+
137
+ conv = self.convs[str_edge_type]
138
+
139
+ if src == dst:
140
+ out = conv(x_dict[src], edge_index, *args, **kwargs)
141
+ else:
142
+ out = conv((x_dict[src], x_dict[dst]), edge_index, *args,
143
+ **kwargs)
144
+
145
+ if isinstance(out, (tuple, list)):
146
+ out, out_edge = out
147
+ out_dict_edge[edge_type] = out_edge
148
+
149
+ out_dict[dst].append(out)
150
+
151
+ for key, value in out_dict.items():
152
+ out_dict[key] = group(value, self.aggr)
153
+ out_dict[key] = _split(out_dict[key], self.vo[key])
154
+
155
+ return out_dict if len(out_dict_edge) <= 0 else out_dict, out_dict_edge
156
+
157
+
158
+ class GVPHeteroConv(MessagePassing):
159
+ '''
160
+ Graph convolution / message passing with Geometric Vector Perceptrons.
161
+ Takes in a graph with node and edge embeddings,
162
+ and returns new node embeddings.
163
+
164
+ This does NOT do residual updates and pointwise feedforward layers
165
+ ---see `GVPConvLayer`.
166
+
167
+ :param in_dims: input node embedding dimensions (n_scalar, n_vector)
168
+ :param out_dims: output node embedding dimensions (n_scalar, n_vector)
169
+ :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
170
+ :param n_layers: number of GVPs in the message function
171
+ :param module_list: preconstructed message function, overrides n_layers
172
+ :param aggr: should be "add" if some incoming edges are masked, as in
173
+ a masked autoregressive decoder architecture, otherwise "mean"
174
+ :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
175
+ :param vector_gate: whether to use vector gating.
176
+ (vector_act will be used as sigma^+ in vector gating if `True`)
177
+ :param update_edge_attr: whether to compute an updated edge representation
178
+ '''
179
+
180
+ def __init__(self, in_dims, out_dims, edge_dims, in_dims_other=None,
181
+ n_layers=3, module_list=None, aggr="mean",
182
+ activations=(F.relu, torch.sigmoid), vector_gate=False,
183
+ update_edge_attr=False):
184
+ super(GVPHeteroConv, self).__init__(aggr=aggr)
185
+
186
+ if in_dims_other is None:
187
+ in_dims_other = in_dims
188
+
189
+ self.si, self.vi = in_dims
190
+ self.si_other, self.vi_other = in_dims_other
191
+ self.so, self.vo = out_dims
192
+ self.se, self.ve = edge_dims
193
+ self.update_edge_attr = update_edge_attr
194
+
195
+ GVP_ = functools.partial(GVP,
196
+ activations=activations,
197
+ vector_gate=vector_gate)
198
+
199
+ def get_modules(module_list, out_dims):
200
+ module_list = module_list or []
201
+ if not module_list:
202
+ if n_layers == 1:
203
+ module_list.append(
204
+ GVP_((self.si + self.si_other + self.se, self.vi + self.vi_other + self.ve),
205
+ (self.so, self.vo), activations=(None, None)))
206
+ else:
207
+ module_list.append(
208
+ GVP_((self.si + self.si_other + self.se, self.vi + self.vi_other + self.ve),
209
+ out_dims)
210
+ )
211
+ for i in range(n_layers - 2):
212
+ module_list.append(GVP_(out_dims, out_dims))
213
+ module_list.append(GVP_(out_dims, out_dims,
214
+ activations=(None, None)))
215
+ return nn.Sequential(*module_list)
216
+
217
+ self.message_func = get_modules(module_list, out_dims)
218
+ self.edge_func = get_modules(module_list, edge_dims) if self.update_edge_attr else None
219
+
220
+ def forward(self, x, edge_index, edge_attr):
221
+ '''
222
+ :param x: tuple (s, V) of `torch.Tensor`
223
+ :param edge_index: array of shape [2, n_edges]
224
+ :param edge_attr: tuple (s, V) of `torch.Tensor`
225
+ '''
226
+ elem_0, elem_1 = x
227
+ if isinstance(elem_0, (tuple, list)):
228
+ assert isinstance(elem_1, (tuple, list))
229
+ x_s = (elem_0[0], elem_1[0])
230
+ x_v = (elem_0[1].reshape(elem_0[1].shape[0], 3 * elem_0[1].shape[1]),
231
+ elem_1[1].reshape(elem_1[1].shape[0], 3 * elem_1[1].shape[1]))
232
+ else:
233
+ x_s, x_v = elem_0, elem_1
234
+ x_v = x_v.reshape(x_v.shape[0], 3 * x_v.shape[1])
235
+
236
+ message = self.propagate(edge_index, s=x_s, v=x_v, edge_attr=edge_attr)
237
+
238
+ if self.update_edge_attr:
239
+ if isinstance(x_s, (tuple, list)):
240
+ s_i, s_j = x_s[1][edge_index[1]], x_s[0][edge_index[0]]
241
+ else:
242
+ s_i, s_j = x_s[edge_index[1]], x_s[edge_index[0]]
243
+
244
+ if isinstance(x_v, (tuple, list)):
245
+ v_i, v_j = x_v[1][edge_index[1]], x_v[0][edge_index[0]]
246
+ else:
247
+ v_i, v_j = x_v[edge_index[1]], x_v[edge_index[0]]
248
+
249
+ edge_out = self.edge_attr(s_i, v_i, s_j, v_j, edge_attr)
250
+ # return _split(message, self.vo), edge_out
251
+ return message, edge_out
252
+ else:
253
+ # return _split(message, self.vo)
254
+ return message
255
+
256
+ def message(self, s_i, v_i, s_j, v_j, edge_attr):
257
+ v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3)
258
+ v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3)
259
+ message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
260
+ message = self.message_func(message)
261
+ return _merge(*message)
262
+
263
+ def edge_attr(self, s_i, v_i, s_j, v_j, edge_attr):
264
+ v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3)
265
+ v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3)
266
+ message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
267
+ return self.edge_func(message)
268
+
269
+
270
+ class GVPHeteroConvLayer(nn.Module):
271
+ """
272
+ Full graph convolution / message passing layer with
273
+ Geometric Vector Perceptrons. Residually updates node embeddings with
274
+ aggregated incoming messages, applies a pointwise feedforward
275
+ network to node embeddings, and returns updated node embeddings.
276
+
277
+ To only compute the aggregated messages, see `GVPConv`.
278
+
279
+ :param conv_dims: dictionary defining (src_dim, dst_dim, edge_dim) for each edge type
280
+ """
281
+ def __init__(self, conv_dims,
282
+ n_message=3, n_feedforward=2, drop_rate=.1,
283
+ activations=(F.relu, torch.sigmoid), vector_gate=False,
284
+ update_edge_attr=False, ln_vector_weight=False):
285
+
286
+ super(GVPHeteroConvLayer, self).__init__()
287
+ self.update_edge_attr = update_edge_attr
288
+
289
+ gvp_conv = partial(GVPHeteroConv,
290
+ n_layers=n_message,
291
+ aggr="sum",
292
+ activations=activations,
293
+ vector_gate=vector_gate,
294
+ update_edge_attr=update_edge_attr)
295
+
296
+ def get_feedforward(n_dims):
297
+ GVP_ = partial(GVP, activations=activations, vector_gate=vector_gate)
298
+
299
+ ff_func = []
300
+ if n_feedforward == 1:
301
+ ff_func.append(GVP_(n_dims, n_dims, activations=(None, None)))
302
+ else:
303
+ hid_dims = 4 * n_dims[0], 2 * n_dims[1]
304
+ ff_func.append(GVP_(n_dims, hid_dims))
305
+ for i in range(n_feedforward - 2):
306
+ ff_func.append(GVP_(hid_dims, hid_dims))
307
+ ff_func.append(GVP_(hid_dims, n_dims, activations=(None, None)))
308
+ return nn.Sequential(*ff_func)
309
+
310
+ # self.conv = HeteroConv({k: gvp_conv(*dims) for k, dims in conv_dims.items()}, aggr='sum')
311
+ self.conv = MyHeteroConv({k: gvp_conv(*dims) for k, dims in conv_dims.items()}, aggr='sum')
312
+
313
+ node_dims = {k[-1]: dims[1] for k, dims in conv_dims.items()}
314
+ self.norm0 = MyModuleDict({k: gvp.LayerNorm(dims, ln_vector_weight) for k, dims in node_dims.items()})
315
+ self.dropout0 = MyModuleDict({k: gvp.Dropout(drop_rate) for k, dims in node_dims.items()})
316
+ self.ff_func = MyModuleDict({k: get_feedforward(dims) for k, dims in node_dims.items()})
317
+ self.norm1 = MyModuleDict({k: gvp.LayerNorm(dims, ln_vector_weight) for k, dims in node_dims.items()})
318
+ self.dropout1 = MyModuleDict({k: gvp.Dropout(drop_rate) for k, dims in node_dims.items()})
319
+
320
+ if self.update_edge_attr:
321
+ self.edge_norm0 = MyModuleDict({k: gvp.LayerNorm(dims[2], ln_vector_weight) for k, dims in conv_dims.items()})
322
+ self.edge_dropout0 = MyModuleDict({k: gvp.Dropout(drop_rate) for k, dims in conv_dims.items()})
323
+ self.edge_ff = MyModuleDict({k: get_feedforward(dims[2]) for k, dims in conv_dims.items()})
324
+ self.edge_norm1 = MyModuleDict({k: gvp.LayerNorm(dims[2], ln_vector_weight) for k, dims in conv_dims.items()})
325
+ self.edge_dropout1 = MyModuleDict({k: gvp.Dropout(drop_rate) for k, dims in conv_dims.items()})
326
+
327
+ def forward(self, x_dict, edge_index_dict, edge_attr_dict, node_mask_dict=None):
328
+ '''
329
+ :param x: tuple (s, V) of `torch.Tensor`
330
+ :param edge_index: array of shape [2, n_edges]
331
+ :param edge_attr: tuple (s, V) of `torch.Tensor`
332
+ :param node_mask: array of type `bool` to index into the first
333
+ dim of node embeddings (s, V). If not `None`, only
334
+ these nodes will be updated.
335
+ '''
336
+
337
+ dh_dict = self.conv(x_dict, edge_index_dict, edge_attr_dict)
338
+
339
+ if self.update_edge_attr:
340
+ dh_dict, de_dict = dh_dict
341
+
342
+ for k, edge_attr in edge_attr_dict.items():
343
+ de = de_dict[k]
344
+
345
+ edge_attr = self.edge_norm0[k](tuple_sum(edge_attr, self.edge_dropout0[k](de)))
346
+ de = self.edge_ff[k](edge_attr)
347
+ edge_attr = self.edge_norm1[k](tuple_sum(edge_attr, self.edge_dropout1[k](de)))
348
+
349
+ edge_attr_dict[k] = edge_attr
350
+
351
+ for k, x in x_dict.items():
352
+ dh = dh_dict[k]
353
+ node_mask = None if node_mask_dict is None else node_mask_dict[k]
354
+
355
+ if node_mask is not None:
356
+ x_ = x
357
+ x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask)
358
+
359
+ x = self.norm0[k](tuple_sum(x, self.dropout0[k](dh)))
360
+
361
+ dh = self.ff_func[k](x)
362
+ x = self.norm1[k](tuple_sum(x, self.dropout1[k](dh)))
363
+
364
+ if node_mask is not None:
365
+ x_[0][node_mask], x_[1][node_mask] = x[0], x[1]
366
+ x = x_
367
+
368
+ x_dict[k] = x
369
+
370
+ return (x_dict, edge_attr_dict) if self.update_edge_attr else x_dict
371
+
372
+
373
+ class GVPModel(torch.nn.Module):
374
+ """
375
+ GVP-GNN model
376
+ inspired by: https://github.com/drorlab/gvp-pytorch/blob/main/gvp/models.py
377
+ and: https://github.com/drorlab/gvp-pytorch/blob/82af6b22eaf8311c15733117b0071408d24ed877/gvp/atom3d.py#L115
378
+ """
379
+ def __init__(self,
380
+ node_in_dim_ligand, node_in_dim_pocket,
381
+ edge_in_dim_ligand, edge_in_dim_pocket, edge_in_dim_interaction,
382
+ node_h_dim_ligand, node_h_dim_pocket,
383
+ edge_h_dim_ligand, edge_h_dim_pocket, edge_h_dim_interaction,
384
+ node_out_dim_ligand=None, node_out_dim_pocket=None,
385
+ edge_out_dim_ligand=None, edge_out_dim_pocket=None, edge_out_dim_interaction=None,
386
+ num_layers=3, drop_rate=0.1, vector_gate=False, update_edge_attr=False):
387
+
388
+ super(GVPModel, self).__init__()
389
+
390
+ self.update_edge_attr = update_edge_attr
391
+
392
+ self.node_in = nn.ModuleDict({
393
+ 'ligand': GVP(node_in_dim_ligand, node_h_dim_ligand, activations=(None, None), vector_gate=vector_gate),
394
+ 'pocket': GVP(node_in_dim_pocket, node_h_dim_pocket, activations=(None, None), vector_gate=vector_gate),
395
+ })
396
+ # self.edge_in = MyModuleDict({
397
+ # ('ligand', 'ligand'): GVP(edge_in_dim_ligand, edge_h_dim_ligand, activations=(None, None), vector_gate=vector_gate),
398
+ # ('pocket', 'pocket'): GVP(edge_in_dim_pocket, edge_h_dim_pocket, activations=(None, None), vector_gate=vector_gate),
399
+ # ('ligand', 'pocket'): GVP(edge_in_dim_interaction, edge_h_dim_interaction, activations=(None, None), vector_gate=vector_gate),
400
+ # ('pocket', 'ligand'): GVP(edge_in_dim_interaction, edge_h_dim_interaction, activations=(None, None), vector_gate=vector_gate),
401
+ # })
402
+ self.edge_in = MyModuleDict({
403
+ ('ligand', '', 'ligand'): GVP(edge_in_dim_ligand, edge_h_dim_ligand, activations=(None, None), vector_gate=vector_gate),
404
+ ('pocket', '', 'pocket'): GVP(edge_in_dim_pocket, edge_h_dim_pocket, activations=(None, None), vector_gate=vector_gate),
405
+ ('ligand', '', 'pocket'): GVP(edge_in_dim_interaction, edge_h_dim_interaction, activations=(None, None), vector_gate=vector_gate),
406
+ ('pocket', '', 'ligand'): GVP(edge_in_dim_interaction, edge_h_dim_interaction, activations=(None, None), vector_gate=vector_gate),
407
+ })
408
+
409
+ # conv_dims = {
410
+ # ('ligand', 'ligand'): (node_h_dim_ligand, node_h_dim_ligand, edge_h_dim_ligand),
411
+ # ('pocket', 'pocket'): (node_h_dim_pocket, node_h_dim_pocket, edge_h_dim_pocket),
412
+ # ('ligand', 'pocket'): (node_h_dim_ligand, node_h_dim_pocket, edge_h_dim_interaction),
413
+ # ('pocket', 'ligand'): (node_h_dim_pocket, node_h_dim_ligand, edge_h_dim_interaction),
414
+ # }
415
+ conv_dims = {
416
+ ('ligand', '', 'ligand'): (node_h_dim_ligand, node_h_dim_ligand, edge_h_dim_ligand),
417
+ ('pocket', '', 'pocket'): (node_h_dim_pocket, node_h_dim_pocket, edge_h_dim_pocket),
418
+ ('ligand', '', 'pocket'): (node_h_dim_ligand, node_h_dim_pocket, edge_h_dim_interaction, node_h_dim_pocket),
419
+ ('pocket', '', 'ligand'): (node_h_dim_pocket, node_h_dim_ligand, edge_h_dim_interaction, node_h_dim_ligand),
420
+ }
421
+
422
+ self.layers = nn.ModuleList(
423
+ GVPHeteroConvLayer(conv_dims,
424
+ drop_rate=drop_rate,
425
+ update_edge_attr=self.update_edge_attr,
426
+ activations=(F.relu, None),
427
+ vector_gate=vector_gate,
428
+ ln_vector_weight=True)
429
+ for _ in range(num_layers))
430
+
431
+ self.node_out = nn.ModuleDict({
432
+ 'ligand': GVP(node_h_dim_ligand, node_out_dim_ligand, activations=(None, None), vector_gate=vector_gate),
433
+ 'pocket': GVP(node_h_dim_pocket, node_out_dim_pocket, activations=(None, None), vector_gate=vector_gate) if node_out_dim_pocket is not None else None,
434
+ })
435
+ # self.edge_out = MyModuleDict({
436
+ # ('ligand', 'ligand'): GVP(edge_h_dim_ligand, edge_out_dim_ligand, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_ligand is not None else None,
437
+ # ('pocket', 'pocket'): GVP(edge_h_dim_pocket, edge_out_dim_pocket, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_pocket is not None else None,
438
+ # ('ligand', 'pocket'): GVP(edge_h_dim_interaction, edge_out_dim_interaction, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_interaction is not None else None,
439
+ # ('pocket', 'ligand'): GVP(edge_h_dim_interaction, edge_out_dim_interaction, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_interaction is not None else None,
440
+ # })
441
+ self.edge_out = MyModuleDict({
442
+ ('ligand', '', 'ligand'): GVP(edge_h_dim_ligand, edge_out_dim_ligand, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_ligand is not None else None,
443
+ ('pocket', '', 'pocket'): GVP(edge_h_dim_pocket, edge_out_dim_pocket, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_pocket is not None else None,
444
+ ('ligand', '', 'pocket'): GVP(edge_h_dim_interaction, edge_out_dim_interaction, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_interaction is not None else None,
445
+ ('pocket', '', 'ligand'): GVP(edge_h_dim_interaction, edge_out_dim_interaction, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_interaction is not None else None,
446
+ })
447
+
448
+ def forward(self, node_attr, batch_mask, edge_index, edge_attr):
449
+
450
+ # to hidden dimension
451
+ for k in node_attr.keys():
452
+ node_attr[k] = self.node_in[k](node_attr[k])
453
+
454
+ for k in edge_attr.keys():
455
+ edge_attr[k] = self.edge_in[k](edge_attr[k])
456
+
457
+ # convolutions
458
+ for layer in self.layers:
459
+ out = layer(node_attr, edge_index, edge_attr)
460
+ if self.update_edge_attr:
461
+ node_attr, edge_attr = out
462
+ else:
463
+ node_attr = out
464
+
465
+ # to output dimension
466
+ for k in node_attr.keys():
467
+ node_attr[k] = self.node_out[k](node_attr[k]) \
468
+ if self.node_out[k] is not None else None
469
+
470
+ if self.update_edge_attr:
471
+ for k in edge_attr.keys():
472
+ if self.edge_out[k] is not None:
473
+ edge_attr[k] = self.edge_out[k](edge_attr[k])
474
+
475
+ return node_attr, edge_attr
476
+
477
+
478
+ class DynamicsHetero(DynamicsBase):
479
+ def __init__(self, atom_nf, residue_nf, bond_dict, pocket_bond_dict,
480
+ condition_time=True,
481
+ num_rbf_time=None,
482
+ model='gvp',
483
+ model_params=None,
484
+ edge_cutoff_ligand=None,
485
+ edge_cutoff_pocket=None,
486
+ edge_cutoff_interaction=None,
487
+ predict_angles=False,
488
+ predict_frames=False,
489
+ add_cycle_counts=False,
490
+ add_spectral_feat=False,
491
+ add_nma_feat=False,
492
+ reflection_equiv=False,
493
+ d_max=15.0,
494
+ num_rbf_dist=16,
495
+ self_conditioning=False,
496
+ augment_residue_sc=False,
497
+ augment_ligand_sc=False,
498
+ add_chi_as_feature=False,
499
+ angle_act_fn=False,
500
+ add_all_atom_diff=False,
501
+ predict_confidence=False):
502
+
503
+ super().__init__(
504
+ predict_angles=predict_angles,
505
+ predict_frames=predict_frames,
506
+ add_cycle_counts=add_cycle_counts,
507
+ add_spectral_feat=add_spectral_feat,
508
+ self_conditioning=self_conditioning,
509
+ augment_residue_sc=augment_residue_sc,
510
+ augment_ligand_sc=augment_ligand_sc
511
+ )
512
+
513
+ self.model = model
514
+ self.edge_cutoff_l = edge_cutoff_ligand
515
+ self.edge_cutoff_p = edge_cutoff_pocket
516
+ self.edge_cutoff_i = edge_cutoff_interaction
517
+ self.bond_dict = bond_dict
518
+ self.pocket_bond_dict = pocket_bond_dict
519
+ self.bond_nf = len(bond_dict)
520
+ self.pocket_bond_nf = len(pocket_bond_dict)
521
+ # self.edge_dim = edge_dim
522
+ self.add_nma_feat = add_nma_feat
523
+ self.add_chi_as_feature = add_chi_as_feature
524
+ self.add_all_atom_diff = add_all_atom_diff
525
+ self.condition_time = condition_time
526
+ self.predict_confidence = predict_confidence
527
+
528
+ # edge encoding params
529
+ self.reflection_equiv = reflection_equiv
530
+ self.d_max = d_max
531
+ self.num_rbf = num_rbf_dist
532
+
533
+
534
+ # Output dimensions dimensions, always tuple (scalar, vector)
535
+ _atom_out = (atom_nf[0], 1) if isinstance(atom_nf, Iterable) else (atom_nf, 1)
536
+ _residue_out = (0, 0)
537
+
538
+ if self.predict_confidence:
539
+ _atom_out = tuple_sum(_atom_out, (1, 0))
540
+
541
+ if self.predict_angles:
542
+ _residue_out = tuple_sum(_residue_out, (5, 0))
543
+
544
+ if self.predict_frames:
545
+ _residue_out = tuple_sum(_residue_out, (3, 1))
546
+
547
+
548
+ # Input dimensions dimensions, always tuple (scalar, vector)
549
+ assert isinstance(atom_nf, int), "expected: element onehot"
550
+ _atom_in = (atom_nf, 0)
551
+ assert isinstance(residue_nf, Iterable), "expected: (AA-onehot, vectors to atoms)"
552
+ _residue_in = tuple(residue_nf)
553
+ _residue_atom_dim = residue_nf[1]
554
+
555
+ if self.add_cycle_counts:
556
+ _atom_in = tuple_sum(_atom_in, (3, 0))
557
+ if self.add_spectral_feat:
558
+ _atom_in = tuple_sum(_atom_in, (5, 0))
559
+
560
+ if self.add_nma_feat:
561
+ _residue_in = tuple_sum(_residue_in, (0, 5))
562
+
563
+ if self.add_chi_as_feature:
564
+ _residue_in = tuple_sum(_residue_in, (5, 0))
565
+
566
+ if self.condition_time:
567
+ self.embed_time = num_rbf_time is not None
568
+ self.time_dim = num_rbf_time if self.embed_time else 1
569
+
570
+ _atom_in = tuple_sum(_atom_in, (self.time_dim, 0))
571
+ _residue_in = tuple_sum(_residue_in, (self.time_dim, 0))
572
+ else:
573
+ print('Warning: dynamics model is NOT conditioned on time.')
574
+
575
+ if self.self_conditioning:
576
+ _atom_in = tuple_sum(_atom_in, _atom_out)
577
+ _residue_in = tuple_sum(_residue_in, _residue_out)
578
+
579
+ if self.augment_ligand_sc:
580
+ _atom_in = tuple_sum(_atom_in, (0, 1))
581
+
582
+ if self.augment_residue_sc:
583
+ assert self.predict_angles
584
+ _residue_in = tuple_sum(_residue_in, (0, _residue_atom_dim))
585
+
586
+
587
+ # Edge output dimensions, always tuple (scalar, vector)
588
+ _edge_ligand_out = (self.bond_nf, 0)
589
+ _edge_ligand_before_symmetrization = (model_params.edge_h_dim[0], 0)
590
+
591
+
592
+ # Edge input dimensions dimensions, always tuple (scalar, vector)
593
+ _edge_ligand_in = (self.bond_nf + self.num_rbf, 1 if self.reflection_equiv else 2)
594
+ _edge_ligand_in = tuple_sum(_edge_ligand_in, _atom_in) # src node
595
+ _edge_ligand_in = tuple_sum(_edge_ligand_in, _atom_in) # dst node
596
+
597
+ if self_conditioning:
598
+ _edge_ligand_in = tuple_sum(_edge_ligand_in, _edge_ligand_out)
599
+
600
+ _n_dist_residue = _residue_atom_dim ** 2 if self.add_all_atom_diff else 1
601
+ _edge_pocket_in = (_n_dist_residue * self.num_rbf + self.pocket_bond_nf, _n_dist_residue)
602
+ _edge_pocket_in = tuple_sum(_edge_pocket_in, _residue_in) # src node
603
+ _edge_pocket_in = tuple_sum(_edge_pocket_in, _residue_in) # dst node
604
+
605
+ _n_dist_interaction = _residue_atom_dim if self.add_all_atom_diff else 1
606
+ _edge_interaction_in = (_n_dist_interaction * self.num_rbf, _n_dist_interaction)
607
+ _edge_interaction_in = tuple_sum(_edge_interaction_in, _atom_in) # atom node
608
+ _edge_interaction_in = tuple_sum(_edge_interaction_in, _residue_in) # residue node
609
+
610
+
611
+ # Embeddings for newly added edges
612
+ _ligand_nobond_nf = self.bond_nf + _edge_ligand_out[0] if self.self_conditioning else self.bond_nf
613
+ self.ligand_nobond_emb = nn.Parameter(torch.zeros(_ligand_nobond_nf), requires_grad=True)
614
+ self.pocket_nobond_emb = nn.Parameter(torch.zeros(self.pocket_bond_nf), requires_grad=True)
615
+
616
+ # for access in self-conditioning
617
+ self.atom_out_dim = _atom_out
618
+ self.residue_out_dim = _residue_out
619
+ self.edge_out_dim = _edge_ligand_out
620
+
621
+ if model == 'gvp':
622
+
623
+ self.net = GVPModel(
624
+ node_in_dim_ligand=_atom_in,
625
+ node_in_dim_pocket=_residue_in,
626
+ edge_in_dim_ligand=_edge_ligand_in,
627
+ edge_in_dim_pocket=_edge_pocket_in,
628
+ edge_in_dim_interaction=_edge_interaction_in,
629
+ node_h_dim_ligand=model_params.node_h_dim,
630
+ node_h_dim_pocket=model_params.node_h_dim,
631
+ edge_h_dim_ligand=model_params.edge_h_dim,
632
+ edge_h_dim_pocket=model_params.edge_h_dim,
633
+ edge_h_dim_interaction=model_params.edge_h_dim,
634
+ node_out_dim_ligand=_atom_out,
635
+ node_out_dim_pocket=_residue_out,
636
+ edge_out_dim_ligand=_edge_ligand_before_symmetrization,
637
+ edge_out_dim_pocket=None,
638
+ edge_out_dim_interaction=None,
639
+ num_layers=model_params.n_layers,
640
+ drop_rate=model_params.dropout,
641
+ vector_gate=model_params.vector_gate,
642
+ update_edge_attr=True
643
+ )
644
+
645
+ else:
646
+ raise NotImplementedError(f"{model} is not available")
647
+
648
+ assert _edge_ligand_out[1] == 0
649
+ assert _edge_ligand_before_symmetrization[1] == 0
650
+ self.edge_decoder = nn.Sequential(
651
+ nn.Linear(_edge_ligand_before_symmetrization[0], _edge_ligand_before_symmetrization[0]),
652
+ torch.nn.SiLU(),
653
+ nn.Linear(_edge_ligand_before_symmetrization[0], _edge_ligand_out[0])
654
+ )
655
+
656
+ if angle_act_fn is None:
657
+ self.angle_act_fn = None
658
+ elif angle_act_fn == 'tanh':
659
+ self.angle_act_fn = lambda x: np.pi * F.tanh(x)
660
+ else:
661
+ raise NotImplementedError(f"Angle activation {angle_act_fn} not available")
662
+
663
+ def _forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None,
664
+ h_atoms_sc=None, e_atoms_sc=None, h_residues_sc=None):
665
+ """
666
+ :param x_atoms:
667
+ :param h_atoms:
668
+ :param mask_atoms:
669
+ :param pocket: must contain keys: 'x', 'one_hot', 'mask', 'bonds' and 'bond_one_hot'
670
+ :param t:
671
+ :param bonds_ligand: tuple - bond indices (2, n_bonds) & bond types (n_bonds, bond_nf)
672
+ :param h_atoms_sc: additional node feature for self-conditioning, (s, V)
673
+ :param e_atoms_sc: additional edge feature for self-conditioning, only scalar
674
+ :param h_residues_sc: additional node feature for self-conditioning, tensor or tuple
675
+ :return:
676
+ """
677
+ x_residues, h_residues, mask_residues = pocket['x'], pocket['one_hot'], pocket['mask']
678
+ if 'bonds' in pocket:
679
+ bonds_pocket = (pocket['bonds'], pocket['bond_one_hot'])
680
+ else:
681
+ bonds_pocket = None
682
+
683
+ if self.add_chi_as_feature:
684
+ h_residues = torch.cat([h_residues, pocket['chi'][:, :5]], dim=-1)
685
+
686
+ if 'v' in pocket:
687
+ v_residues = pocket['v']
688
+ if self.add_nma_feat:
689
+ v_residues = torch.cat([v_residues, pocket['nma_vec']], dim=1)
690
+ h_residues = (h_residues, v_residues)
691
+
692
+ # NOTE: 'bond' denotes one-directional edges and 'edge' means bi-directional
693
+ # get graph edges and edge attributes
694
+ if bonds_ligand is not None:
695
+
696
+ ligand_bond_indices = bonds_ligand[0]
697
+
698
+ # make sure messages are passed both ways
699
+ ligand_edge_indices = torch.cat(
700
+ [bonds_ligand[0], bonds_ligand[0].flip(dims=[0])], dim=1)
701
+ ligand_edge_types = torch.cat([bonds_ligand[1], bonds_ligand[1]], dim=0)
702
+ if e_atoms_sc is not None:
703
+ e_atoms_sc = torch.cat([e_atoms_sc, e_atoms_sc], dim=0)
704
+
705
+ # add auxiliary features to ligand nodes
706
+ extra_features = self.compute_extra_features(
707
+ mask_atoms, ligand_edge_indices, ligand_edge_types.argmax(-1))
708
+ h_atoms = torch.cat([h_atoms, extra_features], dim=-1)
709
+
710
+ if bonds_pocket is not None:
711
+ # make sure messages are passed both ways
712
+ pocket_edge_indices = torch.cat(
713
+ [bonds_pocket[0], bonds_pocket[0].flip(dims=[0])], dim=1)
714
+ pocket_edge_types = torch.cat([bonds_pocket[1], bonds_pocket[1]], dim=0)
715
+
716
+
717
+ # Self-conditioning
718
+ if h_atoms_sc is not None:
719
+ h_atoms = (torch.cat([h_atoms, h_atoms_sc[0]], dim=-1), h_atoms_sc[1])
720
+
721
+ if e_atoms_sc is not None:
722
+ ligand_edge_types = torch.cat([ligand_edge_types, e_atoms_sc], dim=-1)
723
+
724
+ if h_residues_sc is not None:
725
+ # if self.augment_residue_sc:
726
+ if isinstance(h_residues_sc, tuple):
727
+ h_residues = (torch.cat([h_residues[0], h_residues_sc[0]], dim=-1),
728
+ torch.cat([h_residues[1], h_residues_sc[1]], dim=1))
729
+ else:
730
+ h_residues = (torch.cat([h_residues[0], h_residues_sc], dim=-1),
731
+ h_residues[1])
732
+
733
+ if self.condition_time:
734
+ if self.embed_time:
735
+ t = _rbf(t.squeeze(-1), D_min=0.0, D_max=1.0, D_count=self.time_dim, device=t.device)
736
+ if isinstance(h_atoms, tuple) :
737
+ h_atoms = (torch.cat([h_atoms[0], t[mask_atoms]], dim=1), h_atoms[1])
738
+ else:
739
+ h_atoms = torch.cat([h_atoms, t[mask_atoms]], dim=1)
740
+ h_residues = (torch.cat([h_residues[0], t[mask_residues]], dim=1), h_residues[1])
741
+
742
+ empty_pocket = (len(pocket['x']) == 0)
743
+
744
+ # Process edges and encode in shared feature space
745
+ edge_index_dict, edge_attr_dict = self.get_edges(
746
+ x_atoms, h_atoms, mask_atoms, ligand_edge_indices, ligand_edge_types,
747
+ x_residues, h_residues, mask_residues, pocket['v'], pocket_edge_indices, pocket_edge_types,
748
+ empty_pocket=empty_pocket
749
+ )
750
+
751
+ if not empty_pocket:
752
+ node_attr_dict = {
753
+ 'ligand': h_atoms,
754
+ 'pocket': h_residues,
755
+ }
756
+ batch_mask_dict = {
757
+ 'ligand': mask_atoms,
758
+ 'pocket': mask_residues,
759
+ }
760
+ else:
761
+ node_attr_dict = {'ligand': h_atoms}
762
+ batch_mask_dict = {'ligand': mask_atoms}
763
+
764
+ if self.model == 'gvp' or self.model == 'gvp_transformer':
765
+ out_node_attr, out_edge_attr = self.net(
766
+ node_attr_dict, batch_mask_dict, edge_index_dict, edge_attr_dict)
767
+
768
+ else:
769
+ raise NotImplementedError(f"Wrong model ({self.model})")
770
+
771
+ h_final_atoms = out_node_attr['ligand'][0]
772
+ vel = out_node_attr['ligand'][1].squeeze(-2)
773
+
774
+ if torch.any(torch.isnan(vel)) or torch.any(torch.isnan(h_final_atoms)):
775
+ if self.training:
776
+ vel[torch.isnan(vel)] = 0.0
777
+ h_final_atoms[torch.isnan(h_final_atoms)] = 0.0
778
+ else:
779
+ raise ValueError("NaN detected in network output")
780
+
781
+ # predict edge type
782
+ edge_final = out_edge_attr[('ligand', '', 'ligand')]
783
+ edges = edge_index_dict[('ligand', '', 'ligand')]
784
+
785
+ # Symmetrize
786
+ edge_logits = torch.zeros(
787
+ (len(mask_atoms), len(mask_atoms), edge_final.size(-1)),
788
+ device=mask_atoms.device)
789
+ edge_logits[edges[0], edges[1]] = edge_final
790
+ edge_logits = (edge_logits + edge_logits.transpose(0, 1)) * 0.5
791
+
792
+ # return upper triangular elements only (matching the input)
793
+ edge_logits = edge_logits[ligand_bond_indices[0], ligand_bond_indices[1]]
794
+ # assert (edge_logits == 0).sum() == 0
795
+
796
+ edge_final_atoms = self.edge_decoder(edge_logits)
797
+
798
+ pred_ligand = {'vel': vel, 'logits_e': edge_final_atoms}
799
+
800
+ if self.predict_confidence:
801
+ pred_ligand['logits_h'] = h_final_atoms[:, :-1]
802
+ pred_ligand['uncertainty_vel'] = F.softplus(h_final_atoms[:, -1])
803
+ else:
804
+ pred_ligand['logits_h'] = h_final_atoms
805
+
806
+ pred_residues = {}
807
+
808
+ # Predict torsion angles
809
+ if self.predict_angles and self.predict_frames:
810
+ residue_s, residue_v = out_node_attr['pocket']
811
+ pred_residues['chi'] = residue_s[:, :5]
812
+ pred_residues['rot'] = residue_s[:, 5:]
813
+ pred_residues['trans'] = residue_v.squeeze(1)
814
+
815
+ elif self.predict_frames:
816
+ pred_residues['rot'], pred_residues['trans'] = out_node_attr['pocket']
817
+ pred_residues['trans'] = pred_residues['trans'].squeeze(1)
818
+
819
+ elif self.predict_angles:
820
+ pred_residues['chi'] = out_node_attr['pocket']
821
+
822
+ if self.angle_act_fn is not None and 'chi' in pred_residues:
823
+ pred_residues['chi'] = self.angle_act_fn(pred_residues['chi'])
824
+
825
+ return pred_ligand, pred_residues
826
+
827
+ def get_edges(self, x_ligand, h_ligand, batch_mask_ligand, edges_ligand, edge_feat_ligand,
828
+ x_pocket, h_pocket, batch_mask_pocket, atom_vectors_pocket, edges_pocket, edge_feat_pocket,
829
+ self_edges=False, empty_pocket=False):
830
+
831
+ # Adjacency matrix
832
+ adj_ligand = batch_mask_ligand[:, None] == batch_mask_ligand[None, :]
833
+ adj_pocket = batch_mask_pocket[:, None] == batch_mask_pocket[None, :]
834
+ adj_cross = batch_mask_ligand[:, None] == batch_mask_pocket[None, :]
835
+
836
+ if self.edge_cutoff_l is not None:
837
+ adj_ligand = adj_ligand & (torch.cdist(x_ligand, x_ligand) <= self.edge_cutoff_l)
838
+
839
+ # Add missing bonds if they got removed
840
+ adj_ligand[edges_ligand[0], edges_ligand[1]] = True
841
+
842
+ if not self_edges:
843
+ adj_ligand = adj_ligand ^ torch.eye(*adj_ligand.size(), out=torch.empty_like(adj_ligand))
844
+
845
+ if self.edge_cutoff_p is not None and not empty_pocket:
846
+ adj_pocket = adj_pocket & (torch.cdist(x_pocket, x_pocket) <= self.edge_cutoff_p)
847
+
848
+ # Add missing bonds if they got removed
849
+ adj_pocket[edges_pocket[0], edges_pocket[1]] = True
850
+
851
+ if not self_edges:
852
+ adj_pocket = adj_pocket ^ torch.eye(*adj_pocket.size(), out=torch.empty_like(adj_pocket))
853
+
854
+ if self.edge_cutoff_i is not None and not empty_pocket:
855
+ adj_cross = adj_cross & (torch.cdist(x_ligand, x_pocket) <= self.edge_cutoff_i)
856
+
857
+ # ligand-ligand edge features
858
+ edges_ligand_updated = torch.stack(torch.where(adj_ligand), dim=0)
859
+ feat_ligand = self.ligand_nobond_emb.repeat(*adj_ligand.shape, 1)
860
+ feat_ligand[edges_ligand[0], edges_ligand[1]] = edge_feat_ligand
861
+ feat_ligand = feat_ligand[edges_ligand_updated[0], edges_ligand_updated[1]]
862
+ feat_ligand = self.ligand_edge_features(h_ligand, x_ligand, edges_ligand_updated, batch_mask_ligand, edge_attr=feat_ligand)
863
+
864
+ if not empty_pocket:
865
+ # residue-residue edge features
866
+ edges_pocket_updated = torch.stack(torch.where(adj_pocket), dim=0)
867
+ feat_pocket = self.pocket_nobond_emb.repeat(*adj_pocket.shape, 1)
868
+ feat_pocket[edges_pocket[0], edges_pocket[1]] = edge_feat_pocket
869
+ feat_pocket = feat_pocket[edges_pocket_updated[0], edges_pocket_updated[1]]
870
+ feat_pocket = self.pocket_edge_features(h_pocket, x_pocket, atom_vectors_pocket, edges_pocket_updated, edge_attr=feat_pocket)
871
+
872
+ # ligand-residue edge features
873
+ edges_cross = torch.stack(torch.where(adj_cross), dim=0)
874
+ feat_cross = self.cross_edge_features(h_ligand, x_ligand, h_pocket, x_pocket, atom_vectors_pocket, edges_cross)
875
+
876
+ edge_index = {
877
+ ('ligand', '', 'ligand'): edges_ligand_updated,
878
+ ('pocket', '', 'pocket'): edges_pocket_updated,
879
+ ('ligand', '', 'pocket'): edges_cross,
880
+ ('pocket', '', 'ligand'): edges_cross.flip(dims=[0]),
881
+ }
882
+
883
+ edge_attr = {
884
+ ('ligand', '', 'ligand'): feat_ligand,
885
+ ('pocket', '', 'pocket'): feat_pocket,
886
+ ('ligand', '', 'pocket'): feat_cross,
887
+ ('pocket', '', 'ligand'): feat_cross,
888
+ }
889
+ else:
890
+ edge_index = {('ligand', '', 'ligand'): edges_ligand_updated}
891
+ edge_attr = {('ligand', '', 'ligand'): feat_ligand}
892
+
893
+ return edge_index, edge_attr
894
+
895
+ def ligand_edge_features(self, h, x, edge_index, batch_mask=None, edge_attr=None):
896
+ """
897
+ :param h: (s, V)
898
+ :param x:
899
+ :param edge_index:
900
+ :param batch_mask:
901
+ :param edge_attr:
902
+ :return: scalar and vector-valued edge features
903
+ """
904
+ row, col = edge_index
905
+ coord_diff = x[row] - x[col]
906
+ dist = coord_diff.norm(dim=-1)
907
+ rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf,
908
+ device=x.device)
909
+
910
+ if isinstance(h, tuple):
911
+ edge_s = torch.cat([h[0][row], h[0][col], rbf], dim=1)
912
+ edge_v = torch.cat([h[1][row], h[1][col], _normalize(coord_diff).unsqueeze(-2)], dim=1)
913
+ else:
914
+ edge_s = torch.cat([h[row], h[col], rbf], dim=1)
915
+ edge_v = _normalize(coord_diff).unsqueeze(-2)
916
+
917
+ # edge_s = rbf
918
+ # edge_v = _normalize(coord_diff).unsqueeze(-2)
919
+
920
+ if edge_attr is not None:
921
+ edge_s = torch.cat([edge_s, edge_attr], dim=1)
922
+
923
+ # self.reflection_equiv: bool, use reflection-sensitive feature based on
924
+ # the cross product if False
925
+ if not self.reflection_equiv:
926
+ mean = scatter_mean(x, batch_mask, dim=0,
927
+ dim_size=batch_mask.max() + 1)
928
+ row, col = edge_index
929
+ cross = torch.cross(x[row] - mean[batch_mask[row]],
930
+ x[col] - mean[batch_mask[col]], dim=1)
931
+ cross = _normalize(cross).unsqueeze(-2)
932
+
933
+ edge_v = torch.cat([edge_v, cross], dim=-2)
934
+
935
+ return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v)
936
+
937
+ def pocket_edge_features(self, h, x, v, edge_index, edge_attr=None):
938
+ """
939
+ :param h: (s, V)
940
+ :param x:
941
+ :param v:
942
+ :param edge_index:
943
+ :param edge_attr:
944
+ :return: scalar and vector-valued edge features
945
+ """
946
+ row, col = edge_index
947
+
948
+ if self.add_all_atom_diff:
949
+ all_coord = v + x.unsqueeze(1) # (nR, nA, 3)
950
+ coord_diff = all_coord[row, :, None, :] - all_coord[col, None, :, :] # (nB, nA, nA, 3)
951
+ coord_diff = coord_diff.flatten(1, 2)
952
+ dist = coord_diff.norm(dim=-1) # (nB, nA^2)
953
+ rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf, device=x.device) # (nB, nA^2, rdb_dim)
954
+ rbf = rbf.flatten(1, 2)
955
+ coord_diff = _normalize(coord_diff)
956
+ else:
957
+ coord_diff = x[row] - x[col]
958
+ dist = coord_diff.norm(dim=-1)
959
+ rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf, device=x.device)
960
+ coord_diff = _normalize(coord_diff).unsqueeze(-2)
961
+
962
+ edge_s = torch.cat([h[0][row], h[0][col], rbf], dim=1)
963
+ edge_v = torch.cat([h[1][row], h[1][col], coord_diff], dim=1)
964
+ # edge_s = rbf
965
+ # edge_v = coord_diff
966
+
967
+ if edge_attr is not None:
968
+ edge_s = torch.cat([edge_s, edge_attr], dim=1)
969
+
970
+ return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v)
971
+
972
+ def cross_edge_features(self, h_ligand, x_ligand, h_pocket, x_pocket, v_pocket, edge_index):
973
+ """
974
+ :param h_ligand: (s, V)
975
+ :param x_ligand:
976
+ :param h_pocket: (s, V)
977
+ :param x_pocket:
978
+ :param v_pocket:
979
+ :param edge_index: first row indexes into the ligand tensors, second row into the pocket tensors
980
+
981
+ :return: scalar and vector-valued edge features
982
+ """
983
+ ligand_idx, pocket_idx = edge_index
984
+
985
+ if self.add_all_atom_diff:
986
+ all_coord_pocket = v_pocket + x_pocket.unsqueeze(1) # (nR, nA, 3)
987
+ coord_diff = x_ligand[ligand_idx, None, :] - all_coord_pocket[pocket_idx] # (nB, nA, 3)
988
+ dist = coord_diff.norm(dim=-1) # (nB, nA)
989
+ rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf, device=x_ligand.device) # (nB, nA, rdb_dim)
990
+ rbf = rbf.flatten(1, 2)
991
+ coord_diff = _normalize(coord_diff)
992
+ else:
993
+ coord_diff = x_ligand[ligand_idx] - x_pocket[pocket_idx]
994
+ dist = coord_diff.norm(dim=-1) # (nB, nA)
995
+ rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf, device=x_ligand.device)
996
+ coord_diff = _normalize(coord_diff).unsqueeze(-2)
997
+
998
+ if isinstance(h_ligand, tuple):
999
+ edge_s = torch.cat([h_ligand[0][ligand_idx], h_pocket[0][pocket_idx], rbf], dim=1)
1000
+ edge_v = torch.cat([h_ligand[1][ligand_idx], h_pocket[1][pocket_idx], coord_diff], dim=1)
1001
+ else:
1002
+ edge_s = torch.cat([h_ligand[ligand_idx], h_pocket[0][pocket_idx], rbf], dim=1)
1003
+ edge_v = torch.cat([h_pocket[1][pocket_idx], coord_diff], dim=1)
1004
+
1005
+ # edge_s = rbf
1006
+ # edge_v = coord_diff
1007
+
1008
+ return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v)
src/model/flows.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from abc import abstractmethod
3
+ import math
4
+ import torch
5
+ from torch_scatter import scatter_mean, scatter_add
6
+
7
+ import src.data.so3_utils as so3
8
+
9
+
10
+ class ICFM(ABC):
11
+ """
12
+ Abstract base class for all Independent-coupling CFM classes.
13
+ Defines a common interface.
14
+ Notation:
15
+ - zt is the intermediate representation at time step t \in [0, 1]
16
+ - zs is the noised representation at time step s < t
17
+
18
+ # TODO: add interpolation schedule (not necessrily linear)
19
+ """
20
+ def __init__(self, sigma):
21
+ self.sigma = sigma
22
+
23
+ @abstractmethod
24
+ def sample_zt(self, z0, z1, t, *args, **kwargs):
25
+ """ TODO. """
26
+ pass
27
+
28
+ @abstractmethod
29
+ def sample_zt_given_zs(self, *args, **kwargs):
30
+ """ Perform update, typically using an explicit Euler step. """
31
+ pass
32
+
33
+ @abstractmethod
34
+ def sample_z0(self, *args, **kwargs):
35
+ """ Prior. """
36
+ pass
37
+
38
+ @abstractmethod
39
+ def compute_loss(self, pred, z0, z1, *args, **kwargs):
40
+ """ Compute loss per sample. """
41
+ pass
42
+
43
+
44
+ class CoordICFM(ICFM):
45
+ def __init__(self, sigma):
46
+ self.dim = 3
47
+ self.scale = 2.7
48
+ super().__init__(sigma)
49
+
50
+ def sample_zt(self, z0, z1, t, batch_mask):
51
+ zt = t[batch_mask] * z1 + (1 - t)[batch_mask] * z0
52
+ # zt = self.sigma * z0 + t[batch_mask] * z1 + (1 - t)[batch_mask] * z0 # TODO: do we have to compute Psi?
53
+ return zt
54
+
55
+ def sample_zt_given_zs(self, zs, pred, s, t, batch_mask):
56
+ """ Perform an explicit Euler step. """
57
+ step_size = t - s
58
+ zt = zs + step_size[batch_mask] * self.scale * pred
59
+ return zt
60
+
61
+ def sample_z0(self, com, batch_mask):
62
+ """ Prior. """
63
+ z0 = torch.randn((len(batch_mask), self.dim), device=batch_mask.device)
64
+
65
+ # Move center of mass
66
+ z0 = z0 + com[batch_mask]
67
+
68
+ return z0
69
+
70
+ def reduce_loss(self, loss, batch_mask, reduce):
71
+ assert reduce in {'mean', 'sum', 'none'}
72
+
73
+ if reduce == 'mean':
74
+ loss = scatter_mean(loss / self.dim, batch_mask, dim=0)
75
+ elif reduce == 'sum':
76
+ loss = scatter_add(loss, batch_mask, dim=0)
77
+
78
+ return loss
79
+
80
+ def compute_loss(self, pred, z0, z1, t, batch_mask, reduce='mean'):
81
+ """ Compute loss per sample. """
82
+
83
+ loss = torch.sum((pred - (z1 - z0) / self.scale) ** 2, dim=-1)
84
+
85
+ return self.reduce_loss(loss, batch_mask, reduce)
86
+
87
+ def get_z1_given_zt_and_pred(self, zt, pred, z0, t, batch_mask):
88
+ """ Make a best guess on the final state z1 given the current state and
89
+ the network prediction. """
90
+ # z1 = z0 + pred
91
+ z1 = zt + (1 - t)[batch_mask] * pred
92
+ return z1
93
+
94
+
95
+ class TorusICFM(ICFM):
96
+ """
97
+ Following:
98
+ Chen, Ricky TQ, and Yaron Lipman.
99
+ "Riemannian flow matching on general geometries."
100
+ arXiv preprint arXiv:2302.03660 (2023).
101
+ """
102
+ def __init__(self, sigma, dim, scheduler_args=None):
103
+ super().__init__(sigma)
104
+ self.dim = dim
105
+
106
+ # Scheduler that determines the rate at which the geodesic distance decreases
107
+ scheduler_args = scheduler_args or {}
108
+ scheduler_args["type"] = scheduler_args.get("type", "linear") # default
109
+ scheduler_args["learn_scaled"] = scheduler_args.get("learn_scaled", False) # default
110
+
111
+ # linear scheduler: kappa(t) = 1-t (default)
112
+ if scheduler_args["type"] == "linear":
113
+ # equivalent to: 1 - kappa(t)
114
+ self.flow_scaling = lambda t: t
115
+
116
+ # equivalent to: -1 * d/dt kappa(t)
117
+ self.velocity_scaling = lambda t: torch.ones_like(t)
118
+
119
+ # exponential scheduler: kappa(t) = exp(-c*t)
120
+ elif scheduler_args["type"] == "exponential":
121
+
122
+ self.c = scheduler_args["c"]
123
+ assert self.c > 0
124
+
125
+ # equivalent to: 1 - kappa(t)
126
+ self.flow_scaling = lambda t: 1 - torch.exp(-self.c * t)
127
+
128
+ # equivalent to: -1 * d/dt kappa(t)
129
+ self.velocity_scaling = lambda t: self.c * torch.exp(-self.c * t)
130
+
131
+ # polynomial scheduler: kappa(t) = (1-t)^k
132
+ elif scheduler_args["type"] == "polynomial":
133
+ self.k = scheduler_args["k"]
134
+ assert self.k > 0
135
+
136
+ # equivalent to: 1 - kappa(t)
137
+ self.flow_scaling = lambda t: 1 - (1 - t)**self.k
138
+
139
+ # equivalent to: -1 * d/dt kappa(t)
140
+ self.velocity_scaling = lambda t: self.k * (1 - t)**(self.k - 1)
141
+
142
+ else:
143
+ raise NotImplementedError(f"Scheduler {scheduler_args['type']} not implemented.")
144
+
145
+ kappa_interval = self.flow_scaling(torch.tensor([0.0, 1.0]))
146
+ if kappa_interval[0] != 0.0 or kappa_interval[1] != 1.0:
147
+ print(f"Scheduler should satisfy kappa(0)=1 and kappa(1)=0. Found "
148
+ f"interval {kappa_interval.tolist()} instead.")
149
+
150
+ # determines whether the scaled vector field is learned or the scheduler
151
+ # is post-multiplied
152
+ self.learn_scaled = scheduler_args["learn_scaled"]
153
+
154
+ @staticmethod
155
+ def wrap(angle):
156
+ """ Maps angles to range [-\pi, \pi). """
157
+ return ((angle + math.pi) % (2 * math.pi)) - math.pi
158
+
159
+ def exponential_map(self, x, u):
160
+ """
161
+ :param x: point on the manifold
162
+ :param u: point on the tangent space
163
+ """
164
+ return self.wrap(x + u)
165
+
166
+ @staticmethod
167
+ def logarithm_map(x, y):
168
+ """
169
+ :param x, y: points on the manifold
170
+ """
171
+ return torch.atan2(torch.sin(y - x), torch.cos(y - x))
172
+
173
+ def sample_zt(self, z0, z1, t, batch_mask):
174
+ """ expressed in terms of exponential and logarithm maps """
175
+
176
+ # apply logarithm map
177
+ # zt_tangent = t[batch_mask] * self.logarithm_map(z0, z1)
178
+ zt_tangent = self.flow_scaling(t)[batch_mask] * self.logarithm_map(z0, z1)
179
+
180
+ # apply exponential map
181
+ return self.exponential_map(z0, zt_tangent)
182
+
183
+ def get_z1_given_zt_and_pred(self, zt, pred, z0, t, batch_mask):
184
+ """ Make a best guess on the final state z1 given the current state and
185
+ the network prediction. """
186
+
187
+ # estimate z1_tangent based on zt and pred only
188
+ if self.learn_scaled:
189
+ pred = pred / torch.clamp(self.velocity_scaling(t), min=1e-3)[batch_mask]
190
+
191
+ z1_tangent = (1 - t)[batch_mask] * pred
192
+
193
+ # exponential map
194
+ return self.exponential_map(zt, z1_tangent)
195
+
196
+ def sample_zt_given_zs(self, zs, pred, s, t, batch_mask):
197
+ """ Perform update, typically using an explicit Euler step. """
198
+
199
+ step_size = t - s
200
+ zt_tangent = step_size[batch_mask] * pred
201
+
202
+ if not self.learn_scaled:
203
+ zt_tangent = self.velocity_scaling(t)[batch_mask] * zt_tangent
204
+
205
+ # exponential map
206
+ return self.exponential_map(zs, zt_tangent)
207
+
208
+ def sample_z0(self, batch_mask):
209
+ """ Prior. """
210
+
211
+ # Uniform distribution
212
+ z0 = torch.rand((len(batch_mask), self.dim), device=batch_mask.device)
213
+
214
+ return 2 * math.pi * z0 - math.pi
215
+
216
+ def compute_loss(self, pred, z0, z1, zt, t, batch_mask, reduce='mean'):
217
+ """ Compute loss per sample. """
218
+ assert reduce in {'mean', 'sum', 'none'}
219
+ mask = ~torch.isnan(z1)
220
+ z1 = torch.nan_to_num(z1, nan=0.0)
221
+
222
+ zt_dot = self.logarithm_map(z0, z1)
223
+ if self.learn_scaled:
224
+ # NOTE: potentially requires output magnitude to vary substantially
225
+ zt_dot = self.velocity_scaling(t)[batch_mask] * zt_dot
226
+ loss = mask * (pred - zt_dot) ** 2
227
+ loss = torch.sum(loss, dim=-1)
228
+
229
+ if reduce == 'mean':
230
+ denom = mask.sum(dim=-1) + 1e-6
231
+ loss = scatter_mean(loss / denom, batch_mask, dim=0)
232
+ elif reduce == 'sum':
233
+ loss = scatter_add(loss, batch_mask, dim=0)
234
+ return loss
235
+
236
+
237
+ class SO3ICFM(ICFM):
238
+ """
239
+ All rotations are assumed to be in axis-angle format.
240
+ Mostly following descriptions from the FoldFlow paper:
241
+ https://openreview.net/forum?id=kJFIH23hXb
242
+
243
+ See also:
244
+ https://geomstats.github.io/_modules/geomstats/geometry/special_orthogonal.html#SpecialOrthogonal
245
+ https://geomstats.github.io/_modules/geomstats/geometry/lie_group.html#LieGroup
246
+ """
247
+ def __init__(self, sigma):
248
+ super().__init__(sigma)
249
+
250
+ def exponential_map(self, base, tangent):
251
+ """
252
+ Args:
253
+ base: base point (rotation vector) on the manifold
254
+ tangent: point in tangent space at identity
255
+ Returns:
256
+ rotation vector on the manifold
257
+ """
258
+ # return so3.exp_not_from_identity(tangent, base_point=base)
259
+ return so3.compose_rotations(base, so3.exp(tangent))
260
+
261
+ def logarithm_map(self, base, r):
262
+ """
263
+ Args:
264
+ base: base point (rotation vector) on the manifold
265
+ r: rotation vector on the manifold
266
+ Return:
267
+ point in tangent space at identity
268
+ """
269
+ # return so3.log_not_from_identity(r, base_point=base)
270
+ return so3.log(so3.compose_rotations(-base, r))
271
+
272
+ def sample_zt(self, z0, z1, t, batch_mask):
273
+ """
274
+ Expressed in terms of exponential and logarithm maps.
275
+ Corresponds to SLERP interpolation: R(t) = R1 exp( t * log(R1^T R2) )
276
+ (see https://lucaballan.altervista.org/pdfs/IK.pdf, slide 16)
277
+ """
278
+
279
+ # apply logarithm map
280
+ zt_tangent = t[batch_mask] * self.logarithm_map(z0, z1)
281
+
282
+ # apply exponential map
283
+ return self.exponential_map(z0, zt_tangent)
284
+
285
+ def get_z1_given_zt_and_pred(self, zt, pred, z0, t, batch_mask):
286
+ """ Make a best guess on the final state z1 given the current state and
287
+ the network prediction. """
288
+
289
+ # estimate z1_tangent based on zt and pred only
290
+ z1_tangent = (1 - t)[batch_mask] * pred
291
+
292
+ # exponential map
293
+ return self.exponential_map(zt, z1_tangent)
294
+
295
+ def sample_zt_given_zs(self, zs, pred, s, t, batch_mask):
296
+ """ Perform update, typically using an explicit Euler step. """
297
+
298
+ # # parallel transport vector field to lie algebra so3 (at identity)
299
+ # # (FoldFlow paper, Algorithm 3, line 8)
300
+ # # TODO: is this correct? is it necessary?
301
+ # pred = so3.compose(so3.inverse(zs), pred)
302
+
303
+ step_size = t - s
304
+ zt_tangent = step_size[batch_mask] * pred
305
+
306
+ # exponential map
307
+ return self.exponential_map(zs, zt_tangent)
308
+
309
+ def sample_z0(self, batch_mask):
310
+ """ Prior. """
311
+ return so3.random_uniform(n_samples=len(batch_mask), device=batch_mask.device)
312
+
313
+ @staticmethod
314
+ def d_R_squared_SO3(rot_vec_1, rot_vec_2):
315
+ """
316
+ Squared Riemannian metric on SO(3).
317
+ Defined as d(R1, R2) = sqrt(0.5) ||log(R1^T R2)||_F
318
+ where R1, R2 are rotation matrices.
319
+
320
+ The following is equivalent if the difference between the rotations is
321
+ expressed as a rotation vector \omega_diff:
322
+ d(r1, r2) = ||\omega_diff||_2
323
+ -----
324
+ With the definition of the Frobenius matrix norm ||A||_F^2 = trace(A^H A):
325
+ d^2(R1, R2) = 1/2 ||log(R1^T R2)||_F^2
326
+ = 1/2 || hat(R_d) ||_F^2
327
+ = 1/2 tr( hat(R_d)^T hat(R_d) )
328
+ = 1/2 * 2 * ||\omega||_2^2
329
+ """
330
+
331
+ # rot_mat_1 = so3.matrix_from_rotation_vector(rot_vec_1)
332
+ # rot_mat_2 = so3.matrix_from_rotation_vector(rot_vec_2)
333
+ # rot_mat_diff = rot_mat_1.transpose(-2, -1) @ rot_mat_2
334
+ # return torch.norm(so3.log(rot_mat_diff, as_skew=True), p='fro', dim=(-2, -1))
335
+
336
+ diff_rot = so3.compose_rotations(-rot_vec_1, rot_vec_2)
337
+ return diff_rot.square().sum(dim=-1)
338
+
339
+ def compute_loss(self, pred, z0, z1, zt, t, batch_mask, reduce='mean', eps=5e-2):
340
+ """ Compute loss per sample. """
341
+ assert reduce in {'mean', 'sum', 'none'}
342
+
343
+ zt_dot = self.logarithm_map(zt, z1) / torch.clamp(1 - t, min=eps)[batch_mask]
344
+
345
+ # TODO: do I need this?
346
+ # pred_at_id = self.logarithm_map(zt, pred) / torch.clamp(1 - t, min=eps)[batch_mask]
347
+
348
+ loss = torch.sum((pred - zt_dot)**2, dim=-1) # TODO: is this the right loss in SO3?
349
+ # loss = self.d_R_squared_SO3(zt_dot, pred)
350
+
351
+ if reduce == 'mean':
352
+ loss = scatter_mean(loss, batch_mask, dim=0)
353
+ elif reduce == 'sum':
354
+ loss = scatter_add(loss, batch_mask, dim=0)
355
+
356
+ return loss
357
+
358
+
359
+ #################
360
+ # Predicting z1 #
361
+ #################
362
+
363
+ class CoordICFMPredictFinal(CoordICFM):
364
+ def __init__(self, sigma):
365
+ self.dim = 3
366
+ super().__init__(sigma)
367
+
368
+ def sample_zt_given_zs(self, zs, z1_minus_zs_pred, s, t, batch_mask):
369
+ """ Perform an explicit Euler step. """
370
+
371
+ # step_size = t - s
372
+ # zt = zs + step_size[batch_mask] * z1_minus_zs_pred / (1.0 - s)[batch_mask]
373
+
374
+ # for numerical stability
375
+ step_size = (t - s) / (1.0 - s)
376
+ assert torch.all(step_size <= 1.0)
377
+ # step_size = torch.clamp(step_size, max=1.0)
378
+ zt = zs + step_size[batch_mask] * z1_minus_zs_pred
379
+ return zt
380
+
381
+ def compute_loss(self, z1_minus_zt_pred, z0, z1, t, batch_mask, reduce='mean'):
382
+ """ Compute loss per sample. """
383
+ assert reduce in {'mean', 'sum', 'none'}
384
+ t = torch.clamp(t, max=0.9)
385
+ zt = self.sample_zt(z0, z1, t, batch_mask)
386
+ loss = torch.sum((z1_minus_zt_pred + zt - z1) ** 2, dim=-1) / torch.square(1 - t)[batch_mask].squeeze()
387
+
388
+ if reduce == 'mean':
389
+ loss = scatter_mean(loss / self.dim, batch_mask, dim=0)
390
+ elif reduce == 'sum':
391
+ loss = scatter_add(loss, batch_mask, dim=0)
392
+
393
+ return loss
394
+
395
+ def get_z1_given_zt_and_pred(self, zt, z1_minus_zt_pred, z0, t, batch_mask):
396
+ return z1_minus_zt_pred + zt
397
+
398
+
399
+ class TorusICFMPredictFinal(TorusICFM):
400
+ """
401
+ Following:
402
+ Chen, Ricky TQ, and Yaron Lipman.
403
+ "Riemannian flow matching on general geometries."
404
+ arXiv preprint arXiv:2302.03660 (2023).
405
+ """
406
+ def __init__(self, sigma, dim):
407
+ super().__init__(sigma, dim)
408
+
409
+ def get_z1_given_zt_and_pred(self, zt, z1_tangent_pred, z0, t, batch_mask):
410
+ """ Make a best guess on the final state z1 given the current state and
411
+ the network prediction. """
412
+
413
+ # exponential map
414
+ return self.exponential_map(zt, z1_tangent_pred)
415
+
416
+ def sample_zt_given_zs(self, zs, z1_tangent_pred, s, t, batch_mask):
417
+ """ Perform update, typically using an explicit Euler step. """
418
+
419
+ # step_size = t - s
420
+ # zt_tangent = step_size[batch_mask] * z1_tangent_pred / (1.0 - s)[batch_mask]
421
+
422
+ # for numerical stability
423
+ step_size = (t - s) / (1.0 - s)
424
+ assert torch.all(step_size <= 1.0)
425
+ # step_size = torch.clamp(step_size, max=1.0)
426
+ zt_tangent = step_size[batch_mask] * z1_tangent_pred
427
+
428
+ # exponential map
429
+ return self.exponential_map(zs, zt_tangent)
430
+
431
+ def compute_loss(self, z1_tangent_pred, z0, z1, t, batch_mask, reduce='mean'):
432
+ """ Compute loss per sample. """
433
+ assert reduce in {'mean', 'sum', 'none'}
434
+ zt = self.sample_zt(z0, z1, t, batch_mask)
435
+ t = torch.clamp(t, max=0.9)
436
+
437
+ mask = ~torch.isnan(z1)
438
+ z1 = torch.nan_to_num(z1, nan=0.0)
439
+ loss = mask * (z1_tangent_pred - self.logarithm_map(zt, z1)) ** 2
440
+ loss = torch.sum(loss, dim=-1) / torch.square(1 - t)[batch_mask].squeeze()
441
+
442
+ if reduce == 'mean':
443
+ denom = mask.sum(dim=-1) + 1e-6
444
+ loss = scatter_mean(loss / denom, batch_mask, dim=0)
445
+ elif reduce == 'sum':
446
+ loss = scatter_add(loss, batch_mask, dim=0)
447
+
448
+ return loss
src/model/gvp.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geometric Vector Perceptron implementation taken from:
3
+ https://github.com/drorlab/gvp-pytorch/blob/main/gvp/__init__.py
4
+ """
5
+ import copy
6
+ import warnings
7
+
8
+ import torch, functools
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+ from torch_geometric.nn import MessagePassing
12
+ from torch_scatter import scatter_add, scatter_mean
13
+
14
+
15
+ def tuple_sum(*args):
16
+ '''
17
+ Sums any number of tuples (s, V) elementwise.
18
+ '''
19
+ return tuple(map(sum, zip(*args)))
20
+
21
+
22
+ def tuple_cat(*args, dim=-1):
23
+ '''
24
+ Concatenates any number of tuples (s, V) elementwise.
25
+
26
+ :param dim: dimension along which to concatenate when viewed
27
+ as the `dim` index for the scalar-channel tensors.
28
+ This means that `dim=-1` will be applied as
29
+ `dim=-2` for the vector-channel tensors.
30
+ '''
31
+ dim %= len(args[0][0].shape)
32
+ s_args, v_args = list(zip(*args))
33
+ return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim)
34
+
35
+
36
+ def tuple_index(x, idx):
37
+ '''
38
+ Indexes into a tuple (s, V) along the first dimension.
39
+
40
+ :param idx: any object which can be used to index into a `torch.Tensor`
41
+ '''
42
+ return x[0][idx], x[1][idx]
43
+
44
+
45
+ def randn(n, dims, device="cpu"):
46
+ '''
47
+ Returns random tuples (s, V) drawn elementwise from a normal distribution.
48
+
49
+ :param n: number of data points
50
+ :param dims: tuple of dimensions (n_scalar, n_vector)
51
+
52
+ :return: (s, V) with s.shape = (n, n_scalar) and
53
+ V.shape = (n, n_vector, 3)
54
+ '''
55
+ return torch.randn(n, dims[0], device=device), \
56
+ torch.randn(n, dims[1], 3, device=device)
57
+
58
+
59
+ def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True):
60
+ '''
61
+ L2 norm of tensor clamped above a minimum value `eps`.
62
+
63
+ :param sqrt: if `False`, returns the square of the L2 norm
64
+ '''
65
+ out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps)
66
+ return torch.sqrt(out) if sqrt else out
67
+
68
+
69
+ def _split(x, nv):
70
+ '''
71
+ Splits a merged representation of (s, V) back into a tuple.
72
+ Should be used only with `_merge(s, V)` and only if the tuple
73
+ representation cannot be used.
74
+
75
+ :param x: the `torch.Tensor` returned from `_merge`
76
+ :param nv: the number of vector channels in the input to `_merge`
77
+ '''
78
+ v = torch.reshape(x[..., -3 * nv:], x.shape[:-1] + (nv, 3))
79
+ s = x[..., :-3 * nv]
80
+ return s, v
81
+
82
+
83
+ def _merge(s, v):
84
+ '''
85
+ Merges a tuple (s, V) into a single `torch.Tensor`, where the
86
+ vector channels are flattened and appended to the scalar channels.
87
+ Should be used only if the tuple representation cannot be used.
88
+ Use `_split(x, nv)` to reverse.
89
+ '''
90
+ v = torch.reshape(v, v.shape[:-2] + (3 * v.shape[-2],))
91
+ return torch.cat([s, v], -1)
92
+
93
+
94
+ class GVP(nn.Module):
95
+ '''
96
+ Geometric Vector Perceptron. See manuscript and README.md
97
+ for more details.
98
+
99
+ :param in_dims: tuple (n_scalar, n_vector)
100
+ :param out_dims: tuple (n_scalar, n_vector)
101
+ :param h_dim: intermediate number of vector channels, optional
102
+ :param activations: tuple of functions (scalar_act, vector_act)
103
+ :param vector_gate: whether to use vector gating.
104
+ (vector_act will be used as sigma^+ in vector gating if `True`)
105
+ '''
106
+
107
+ def __init__(self, in_dims, out_dims, h_dim=None,
108
+ activations=(F.relu, torch.sigmoid), vector_gate=False):
109
+ super(GVP, self).__init__()
110
+ self.si, self.vi = in_dims
111
+ self.so, self.vo = out_dims
112
+ self.vector_gate = vector_gate
113
+ if self.vi:
114
+ self.h_dim = h_dim or max(self.vi, self.vo)
115
+ self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
116
+ self.ws = nn.Linear(self.h_dim + self.si, self.so)
117
+ if self.vo:
118
+ self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
119
+ if self.vector_gate: self.wsv = nn.Linear(self.so, self.vo)
120
+ else:
121
+ self.ws = nn.Linear(self.si, self.so)
122
+
123
+ self.scalar_act, self.vector_act = activations
124
+ self.dummy_param = nn.Parameter(torch.empty(0))
125
+
126
+ def forward(self, x):
127
+ '''
128
+ :param x: tuple (s, V) of `torch.Tensor`,
129
+ or (if vectors_in is 0), a single `torch.Tensor`
130
+ :return: tuple (s, V) of `torch.Tensor`,
131
+ or (if vectors_out is 0), a single `torch.Tensor`
132
+ '''
133
+ if self.vi:
134
+ s, v = x
135
+ v = torch.transpose(v, -1, -2)
136
+ vh = self.wh(v)
137
+ vn = _norm_no_nan(vh, axis=-2)
138
+ s = self.ws(torch.cat([s, vn], -1))
139
+ if self.vo:
140
+ v = self.wv(vh)
141
+ v = torch.transpose(v, -1, -2)
142
+ if self.vector_gate:
143
+ if self.vector_act:
144
+ gate = self.wsv(self.vector_act(s))
145
+ else:
146
+ gate = self.wsv(s)
147
+ v = v * torch.sigmoid(gate).unsqueeze(-1)
148
+ elif self.vector_act:
149
+ v = v * self.vector_act(
150
+ _norm_no_nan(v, axis=-1, keepdims=True))
151
+ else:
152
+ s = self.ws(x)
153
+ if self.vo:
154
+ v = torch.zeros(s.shape[0], self.vo, 3,
155
+ device=self.dummy_param.device)
156
+ if self.scalar_act:
157
+ s = self.scalar_act(s)
158
+
159
+ return (s, v) if self.vo else s
160
+
161
+
162
+ class _VDropout(nn.Module):
163
+ '''
164
+ Vector channel dropout where the elements of each
165
+ vector channel are dropped together.
166
+ '''
167
+
168
+ def __init__(self, drop_rate):
169
+ super(_VDropout, self).__init__()
170
+ self.drop_rate = drop_rate
171
+ self.dummy_param = nn.Parameter(torch.empty(0))
172
+
173
+ def forward(self, x):
174
+ '''
175
+ :param x: `torch.Tensor` corresponding to vector channels
176
+ '''
177
+ device = self.dummy_param.device
178
+ if not self.training:
179
+ return x
180
+ mask = torch.bernoulli(
181
+ (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device)
182
+ ).unsqueeze(-1)
183
+ x = mask * x / (1 - self.drop_rate)
184
+ return x
185
+
186
+
187
+ class Dropout(nn.Module):
188
+ '''
189
+ Combined dropout for tuples (s, V).
190
+ Takes tuples (s, V) as input and as output.
191
+ '''
192
+
193
+ def __init__(self, drop_rate):
194
+ super(Dropout, self).__init__()
195
+ self.sdropout = nn.Dropout(drop_rate)
196
+ self.vdropout = _VDropout(drop_rate)
197
+
198
+ def forward(self, x):
199
+ '''
200
+ :param x: tuple (s, V) of `torch.Tensor`,
201
+ or single `torch.Tensor`
202
+ (will be assumed to be scalar channels)
203
+ '''
204
+ if type(x) is torch.Tensor:
205
+ return self.sdropout(x)
206
+ s, v = x
207
+ return self.sdropout(s), self.vdropout(v)
208
+
209
+
210
+ class LayerNorm(nn.Module):
211
+ '''
212
+ Combined LayerNorm for tuples (s, V).
213
+ Takes tuples (s, V) as input and as output.
214
+ '''
215
+
216
+ def __init__(self, dims, learnable_vector_weight=False):
217
+ super(LayerNorm, self).__init__()
218
+ self.s, self.v = dims
219
+ self.scalar_norm = nn.LayerNorm(self.s)
220
+ self.vector_norm = VectorLayerNorm(self.v, learnable_vector_weight) \
221
+ if self.v > 0 else None
222
+
223
+ def forward(self, x):
224
+ '''
225
+ :param x: tuple (s, V) of `torch.Tensor`,
226
+ or single `torch.Tensor`
227
+ (will be assumed to be scalar channels)
228
+ '''
229
+ if not self.v:
230
+ return self.scalar_norm(x)
231
+ s, v = x
232
+ # vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False)
233
+ # vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True))
234
+ # return self.scalar_norm(s), v / vn
235
+ return self.scalar_norm(s), self.vector_norm(v)
236
+
237
+
238
+ class VectorLayerNorm(nn.Module):
239
+ """
240
+ Equivariant normalization of vector-valued features inspired by:
241
+ Liao, Yi-Lun, and Tess Smidt.
242
+ "Equiformer: Equivariant graph attention transformer for 3d atomistic graphs."
243
+ arXiv preprint arXiv:2206.11990 (2022).
244
+ Section 4.1, "Layer Normalization"
245
+ """
246
+ def __init__(self, n_channels, learnable_weight=True):
247
+ super(VectorLayerNorm, self).__init__()
248
+ self.gamma = nn.Parameter(torch.ones(1, n_channels, 1)) \
249
+ if learnable_weight else None # (1, c, 1)
250
+
251
+ def forward(self, x):
252
+ """
253
+ Computes LN(x) = ( x / RMS( L2-norm(x) ) ) * gamma
254
+ :param x: input tensor (n, c, 3)
255
+ :return: layer normalized vector feature
256
+ """
257
+ norm2 = _norm_no_nan(x, axis=-1, keepdims=True, sqrt=False) # (n, c, 1)
258
+ rms = torch.sqrt(torch.mean(norm2, dim=-2, keepdim=True)) # (n, 1, 1)
259
+ x = x / rms # (n, c, 3)
260
+ if self.gamma is not None:
261
+ x = x * self.gamma
262
+ return x
263
+
264
+
265
+ class GVPConv(MessagePassing):
266
+ '''
267
+ Graph convolution / message passing with Geometric Vector Perceptrons.
268
+ Takes in a graph with node and edge embeddings,
269
+ and returns new node embeddings.
270
+
271
+ This does NOT do residual updates and pointwise feedforward layers
272
+ ---see `GVPConvLayer`.
273
+
274
+ :param in_dims: input node embedding dimensions (n_scalar, n_vector)
275
+ :param out_dims: output node embedding dimensions (n_scalar, n_vector)
276
+ :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
277
+ :param n_layers: number of GVPs in the message function
278
+ :param module_list: preconstructed message function, overrides n_layers
279
+ :param aggr: should be "add" if some incoming edges are masked, as in
280
+ a masked autoregressive decoder architecture, otherwise "mean"
281
+ :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
282
+ :param vector_gate: whether to use vector gating.
283
+ (vector_act will be used as sigma^+ in vector gating if `True`)
284
+ :param update_edge_attr: whether to compute an updated edge representation
285
+ '''
286
+
287
+ def __init__(self, in_dims, out_dims, edge_dims,
288
+ n_layers=3, module_list=None, aggr="mean",
289
+ activations=(F.relu, torch.sigmoid), vector_gate=False,
290
+ update_edge_attr=False):
291
+ super(GVPConv, self).__init__(aggr=aggr)
292
+ self.si, self.vi = in_dims
293
+ self.so, self.vo = out_dims
294
+ self.se, self.ve = edge_dims
295
+ self.update_edge_attr = update_edge_attr
296
+
297
+ GVP_ = functools.partial(GVP,
298
+ activations=activations,
299
+ vector_gate=vector_gate)
300
+
301
+ module_list = module_list or []
302
+ if not module_list:
303
+ if n_layers == 1:
304
+ module_list.append(
305
+ GVP_((2 * self.si + self.se, 2 * self.vi + self.ve),
306
+ (self.so, self.vo), activations=(None, None)))
307
+ else:
308
+ module_list.append(
309
+ GVP_((2 * self.si + self.se, 2 * self.vi + self.ve),
310
+ out_dims)
311
+ )
312
+ for i in range(n_layers - 2):
313
+ module_list.append(GVP_(out_dims, out_dims))
314
+ module_list.append(GVP_(out_dims, out_dims,
315
+ activations=(None, None)))
316
+ self.message_func = nn.Sequential(*module_list)
317
+
318
+ self.edge_func = copy.deepcopy(self.message_func) \
319
+ if self.update_edge_attr else None
320
+
321
+ def forward(self, x, edge_index, edge_attr):
322
+ '''
323
+ :param x: tuple (s, V) of `torch.Tensor`
324
+ :param edge_index: array of shape [2, n_edges]
325
+ :param edge_attr: tuple (s, V) of `torch.Tensor`
326
+ '''
327
+ x_s, x_v = x
328
+ message = self.propagate(edge_index,
329
+ s=x_s,
330
+ v=x_v.reshape(x_v.shape[0], 3 * x_v.shape[1]),
331
+ edge_attr=edge_attr)
332
+
333
+ if self.update_edge_attr:
334
+ s_i, s_j = x_s[edge_index[0]], x_s[edge_index[1]]
335
+ x_v = x_v.reshape(x_v.shape[0], 3 * x_v.shape[1])
336
+ v_i, v_j = x_v[edge_index[0]], x_v[edge_index[1]]
337
+
338
+ edge_out = self.edge_attr(s_i, v_i, s_j, v_j, edge_attr)
339
+ return _split(message, self.vo), edge_out
340
+ else:
341
+ return _split(message, self.vo)
342
+
343
+ def message(self, s_i, v_i, s_j, v_j, edge_attr):
344
+ v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3)
345
+ v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3)
346
+ message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
347
+ message = self.message_func(message)
348
+ return _merge(*message)
349
+
350
+ def edge_attr(self, s_i, v_i, s_j, v_j, edge_attr):
351
+ v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3)
352
+ v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3)
353
+ message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
354
+ return self.edge_func(message)
355
+
356
+
357
+ class GVPConvLayer(nn.Module):
358
+ '''
359
+ Full graph convolution / message passing layer with
360
+ Geometric Vector Perceptrons. Residually updates node embeddings with
361
+ aggregated incoming messages, applies a pointwise feedforward
362
+ network to node embeddings, and returns updated node embeddings.
363
+
364
+ To only compute the aggregated messages, see `GVPConv`.
365
+
366
+ :param node_dims: node embedding dimensions (n_scalar, n_vector)
367
+ :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
368
+ :param n_message: number of GVPs to use in message function
369
+ :param n_feedforward: number of GVPs to use in feedforward function
370
+ :param drop_rate: drop probability in all dropout layers
371
+ :param autoregressive: if `True`, this `GVPConvLayer` will be used
372
+ with a different set of input node embeddings for messages
373
+ where src >= dst
374
+ :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
375
+ :param vector_gate: whether to use vector gating.
376
+ (vector_act will be used as sigma^+ in vector gating if `True`)
377
+ :param update_edge_attr: whether to compute an updated edge representation
378
+ :param ln_vector_weight: whether to include a learnable weight in the vector
379
+ layer norm
380
+ '''
381
+
382
+ def __init__(self, node_dims, edge_dims,
383
+ n_message=3, n_feedforward=2, drop_rate=.1,
384
+ autoregressive=False,
385
+ activations=(F.relu, torch.sigmoid), vector_gate=False,
386
+ update_edge_attr=False, ln_vector_weight=False):
387
+
388
+ super(GVPConvLayer, self).__init__()
389
+ assert not (update_edge_attr and autoregressive), "Not implemented"
390
+ self.update_edge_attr = update_edge_attr
391
+ self.conv = GVPConv(node_dims, node_dims, edge_dims, n_message,
392
+ aggr="add" if autoregressive else "mean",
393
+ activations=activations, vector_gate=vector_gate,
394
+ update_edge_attr=update_edge_attr)
395
+ GVP_ = functools.partial(GVP,
396
+ activations=activations,
397
+ vector_gate=vector_gate)
398
+ self.norm = nn.ModuleList([LayerNorm(node_dims, ln_vector_weight)
399
+ for _ in range(2)])
400
+ self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])
401
+
402
+ def get_feedforward(n_dims):
403
+ ff_func = []
404
+ if n_feedforward == 1:
405
+ ff_func.append(GVP_(n_dims, n_dims, activations=(None, None)))
406
+ else:
407
+ hid_dims = 4 * n_dims[0], 2 * n_dims[1]
408
+ ff_func.append(GVP_(n_dims, hid_dims))
409
+ for i in range(n_feedforward - 2):
410
+ ff_func.append(GVP_(hid_dims, hid_dims))
411
+ ff_func.append(GVP_(hid_dims, n_dims, activations=(None, None)))
412
+ return nn.Sequential(*ff_func)
413
+
414
+ self.ff_func = get_feedforward(node_dims)
415
+
416
+ if self.update_edge_attr:
417
+ self.edge_norm = nn.ModuleList([LayerNorm(edge_dims, ln_vector_weight)
418
+ for _ in range(2)])
419
+ self.edge_dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])
420
+ self.edge_ff = get_feedforward(edge_dims)
421
+
422
+ def forward(self, x, edge_index, edge_attr,
423
+ autoregressive_x=None, node_mask=None):
424
+ '''
425
+ :param x: tuple (s, V) of `torch.Tensor`
426
+ :param edge_index: array of shape [2, n_edges]
427
+ :param edge_attr: tuple (s, V) of `torch.Tensor`
428
+ :param autoregressive_x: tuple (s, V) of `torch.Tensor`.
429
+ If not `None`, will be used as src node embeddings
430
+ for forming messages where src >= dst. The corrent node
431
+ embeddings `x` will still be the base of the update and the
432
+ pointwise feedforward.
433
+ :param node_mask: array of type `bool` to index into the first
434
+ dim of node embeddings (s, V). If not `None`, only
435
+ these nodes will be updated.
436
+ '''
437
+
438
+ if autoregressive_x is not None:
439
+ src, dst = edge_index
440
+ mask = src < dst
441
+ edge_index_forward = edge_index[:, mask]
442
+ edge_index_backward = edge_index[:, ~mask]
443
+ edge_attr_forward = tuple_index(edge_attr, mask)
444
+ edge_attr_backward = tuple_index(edge_attr, ~mask)
445
+
446
+ dh = tuple_sum(
447
+ self.conv(x, edge_index_forward, edge_attr_forward),
448
+ self.conv(autoregressive_x, edge_index_backward,
449
+ edge_attr_backward)
450
+ )
451
+
452
+ count = scatter_add(torch.ones_like(dst), dst,
453
+ dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(
454
+ -1)
455
+
456
+ dh = dh[0] / count, dh[1] / count.unsqueeze(-1)
457
+
458
+ else:
459
+ dh = self.conv(x, edge_index, edge_attr)
460
+
461
+ if self.update_edge_attr:
462
+ dh, de = dh
463
+ edge_attr = self.edge_norm[0](tuple_sum(edge_attr, self.dropout[0](de)))
464
+ de = self.edge_ff(edge_attr)
465
+ edge_attr = self.edge_norm[1](tuple_sum(edge_attr, self.dropout[1](de)))
466
+
467
+ if node_mask is not None:
468
+ x_ = x
469
+ x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask)
470
+
471
+ x = self.norm[0](tuple_sum(x, self.dropout[0](dh)))
472
+
473
+ dh = self.ff_func(x)
474
+ x = self.norm[1](tuple_sum(x, self.dropout[1](dh)))
475
+
476
+ if node_mask is not None:
477
+ x_[0][node_mask], x_[1][node_mask] = x[0], x[1]
478
+ x = x_
479
+ return (x, edge_attr) if self.update_edge_attr else x
480
+
481
+
482
+ ################################################################################
483
+ def _normalize(tensor, dim=-1, eps=1e-8):
484
+ '''
485
+ Normalizes a `torch.Tensor` along dimension `dim` without `nan`s.
486
+ '''
487
+ return torch.nan_to_num(
488
+ torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True) + eps))
489
+
490
+
491
+ def _rbf(D, D_min=0., D_max=20., D_count=16, device='cpu'):
492
+ '''
493
+ From https://github.com/jingraham/neurips19-graph-protein-design
494
+
495
+ Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1.
496
+ That is, if `D` has shape [...dims], then the returned tensor will have
497
+ shape [...dims, D_count].
498
+ '''
499
+ D_mu = torch.linspace(D_min, D_max, D_count, device=device)
500
+ D_mu = D_mu.view([1, -1])
501
+ D_sigma = (D_max - D_min) / D_count
502
+ D_expand = torch.unsqueeze(D, -1)
503
+
504
+ RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2)
505
+ return RBF
506
+
507
+
508
+ class GVPModel(torch.nn.Module):
509
+ """
510
+ GVP-GNN model
511
+ inspired by: https://github.com/drorlab/gvp-pytorch/blob/main/gvp/models.py
512
+ and: https://github.com/drorlab/gvp-pytorch/blob/82af6b22eaf8311c15733117b0071408d24ed877/gvp/atom3d.py#L115
513
+
514
+ :param node_in_dim: node dimension in input graph, scalars or tuple (scalars, vectors)
515
+ :param node_h_dim: node dimensions to use in GVP-GNN layers, tuple (s, V)
516
+ :param node_out_nf: node dimensions in output graph, tuple (s, V)
517
+ :param edge_in_nf: edge dimension in input graph (scalars)
518
+ :param edge_h_dim: edge dimensions to embed to before use in GVP-GNN layers,
519
+ tuple (s, V)
520
+ :param edge_out_nf: edge dimensions in output graph, tuple (s, V)
521
+ :param num_layers: number of GVP-GNN layers
522
+ :param drop_rate: rate to use in all dropout layers
523
+ :param vector_gate: use vector gates in all GVPs
524
+ :param reflection_equiv: bool, use reflection-sensitive feature based on the
525
+ cross product if False
526
+ :param d_max:
527
+ :param num_rbf:
528
+ :param update_edge_attr: bool, update edge attributes at each layer in a
529
+ learnable way
530
+ """
531
+ def __init__(self, node_in_dim, node_h_dim, node_out_nf,
532
+ edge_in_nf, edge_h_dim, edge_out_nf,
533
+ num_layers=3, drop_rate=0.1, vector_gate=False,
534
+ reflection_equiv=True, d_max=20.0, num_rbf=16,
535
+ update_edge_attr=False):
536
+
537
+ super(GVPModel, self).__init__()
538
+
539
+ self.reflection_equiv = reflection_equiv
540
+ self.update_edge_attr = update_edge_attr
541
+ self.d_max = d_max
542
+ self.num_rbf = num_rbf
543
+
544
+ # node_in_dim = (node_in_dim, 1)
545
+ if not isinstance(node_in_dim, tuple):
546
+ node_in_dim = (node_in_dim, 0)
547
+
548
+ edge_in_dim = (edge_in_nf + 2 * node_in_dim[0] + self.num_rbf, 1)
549
+ if not self.reflection_equiv:
550
+ edge_in_dim = (edge_in_dim[0], edge_in_dim[1] + 1)
551
+
552
+ # self.W_v = nn.Sequential(
553
+ # GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=True),
554
+ # LayerNorm(node_h_dim)
555
+ # )
556
+ self.W_v = nn.Sequential(
557
+ LayerNorm(node_in_dim, learnable_vector_weight=True),
558
+ GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=vector_gate),
559
+ )
560
+ # self.W_e = nn.Sequential(
561
+ # GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=True),
562
+ # LayerNorm(edge_h_dim)
563
+ # )
564
+ self.W_e = nn.Sequential(
565
+ LayerNorm(edge_in_dim, learnable_vector_weight=True),
566
+ GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=vector_gate),
567
+ )
568
+
569
+ self.layers = nn.ModuleList(
570
+ GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate,
571
+ update_edge_attr=self.update_edge_attr,
572
+ activations=(F.relu, None), vector_gate=vector_gate,
573
+ ln_vector_weight=True)
574
+ # activations=(F.relu, torch.sigmoid))
575
+ # GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate,
576
+ # update_edge_attr=self.update_edge_attr,
577
+ # activations=(nn.SiLU(), nn.SiLU()))
578
+ for _ in range(num_layers))
579
+
580
+ # self.W_v_out = GVP(node_h_dim, (node_out_nf, 1),
581
+ # activations=(None, None), vector_gate=True)
582
+ self.W_v_out = nn.Sequential(
583
+ LayerNorm(node_h_dim, learnable_vector_weight=True),
584
+ GVP(node_h_dim, (node_out_nf, 1), activations=(None, None), vector_gate=vector_gate),
585
+ )
586
+ # self.W_e_out = GVP(edge_h_dim, (edge_out_nf, 0),
587
+ # activations=(None, None), vector_gate=True) \
588
+ # if self.update_edge_attr else None
589
+ self.W_e_out = nn.Sequential(
590
+ LayerNorm(edge_h_dim, learnable_vector_weight=True),
591
+ GVP(edge_h_dim, (edge_out_nf, 0), activations=(None, None), vector_gate=vector_gate)
592
+ ) if self.update_edge_attr else None
593
+
594
+ def edge_features(self, h, x, edge_index, batch_mask=None, edge_attr=None):
595
+ """
596
+ :param h:
597
+ :param x:
598
+ :param edge_index:
599
+ :param batch_mask:
600
+ :param edge_attr:
601
+ :return: scalar and vector-valued edge features
602
+ """
603
+ row, col = edge_index
604
+ coord_diff = x[row] - x[col]
605
+ dist = coord_diff.norm(dim=-1)
606
+ rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf,
607
+ device=x.device)
608
+
609
+ edge_s = torch.cat([h[row], h[col], rbf], dim=1)
610
+ edge_v = _normalize(coord_diff).unsqueeze(-2)
611
+
612
+ if edge_attr is not None:
613
+ edge_s = torch.cat([edge_s, edge_attr], dim=1)
614
+
615
+ if not self.reflection_equiv:
616
+ mean = scatter_mean(x, batch_mask, dim=0,
617
+ dim_size=batch_mask.max() + 1)
618
+ row, col = edge_index
619
+ cross = torch.cross(x[row] - mean[batch_mask[row]],
620
+ x[col] - mean[batch_mask[col]], dim=1)
621
+ cross = _normalize(cross).unsqueeze(-2)
622
+
623
+ edge_v = torch.cat([edge_v, cross], dim=-2)
624
+
625
+ return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v)
626
+
627
+ def forward(self, h, x, edge_index, v=None, batch_mask=None, edge_attr=None):
628
+
629
+ # h_v = (h, x.unsqueeze(-2))
630
+ h_v = h if v is None else (h, v)
631
+ h_e = self.edge_features(h, x, edge_index, batch_mask, edge_attr)
632
+
633
+ h_v = self.W_v(h_v)
634
+ h_e = self.W_e(h_e)
635
+
636
+ for layer in self.layers:
637
+ h_v = layer(h_v, edge_index, edge_attr=h_e)
638
+ if self.update_edge_attr:
639
+ h_v, h_e = h_v
640
+
641
+ # h, x = self.W_v_out(h_v)
642
+ # x = x.squeeze(-2)
643
+ h, vel = self.W_v_out(h_v)
644
+ # x = x + vel.squeeze(-2)
645
+
646
+ if self.update_edge_attr:
647
+ edge_attr = self.W_e_out(h_e)
648
+
649
+ # return h, x, edge_attr
650
+ return h, vel.squeeze(-2), edge_attr
src/model/gvp_transformer.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import functools
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ from torch_scatter import scatter_mean, scatter_std, scatter_min, scatter_max, scatter_softmax
7
+
8
+
9
+ # ## debug
10
+ # import sys
11
+ # from pathlib import Path
12
+ #
13
+ # basedir = Path(__file__).resolve().parent.parent.parent
14
+ # sys.path.append(str(basedir))
15
+ # ###
16
+
17
+ from src.model.gvp import GVP, _norm_no_nan, tuple_sum, Dropout, LayerNorm, \
18
+ tuple_cat, tuple_index, _rbf, _normalize
19
+
20
+
21
+ def tuple_mul(tup, val):
22
+ if isinstance(val, torch.Tensor):
23
+ return (tup[0] * val, tup[1] * val.unsqueeze(-1))
24
+ return (tup[0] * val, tup[1] * val)
25
+
26
+
27
+ class GVPBlock(nn.Module):
28
+ def __init__(self, in_dims, out_dims, n_layers=1,
29
+ activations=(F.relu, torch.sigmoid), vector_gate=False,
30
+ dropout=0.0, skip=False, layernorm=False):
31
+ super(GVPBlock, self).__init__()
32
+ self.si, self.vi = in_dims
33
+ self.so, self.vo = out_dims
34
+ assert not skip or (self.si == self.so and self.vi == self.vo)
35
+ self.skip = skip
36
+
37
+ GVP_ = functools.partial(GVP, activations=activations, vector_gate=vector_gate)
38
+
39
+ module_list = []
40
+ if n_layers == 1:
41
+ module_list.append(GVP_(in_dims, out_dims, activations=(None, None)))
42
+ else:
43
+ module_list.append(GVP_(in_dims, out_dims))
44
+ for i in range(n_layers - 2):
45
+ module_list.append(GVP_(out_dims, out_dims))
46
+ module_list.append(GVP_(out_dims, out_dims, activations=(None, None)))
47
+
48
+ self.layers = nn.Sequential(*module_list)
49
+
50
+ self.norm = LayerNorm(out_dims, learnable_vector_weight=True) if layernorm else None
51
+ self.dropout = Dropout(dropout) if dropout > 0 else None
52
+
53
+ def forward(self, x):
54
+ """
55
+ :param x: tuple (s, V) of `torch.Tensor`
56
+ :return: tuple (s, V) of `torch.Tensor`
57
+ """
58
+
59
+ dx = self.layers(x)
60
+
61
+ if self.dropout is not None:
62
+ dx = self.dropout(dx)
63
+
64
+ if self.skip:
65
+ x = tuple_sum(x, dx)
66
+ else:
67
+ x = dx
68
+
69
+ if self.norm is not None:
70
+ x = self.norm(x)
71
+
72
+ return x
73
+
74
+
75
+ class GeometricPNA(nn.Module):
76
+ def __init__(self, d_in, d_out):
77
+ """ Map features to global features """
78
+ super().__init__()
79
+ si, vi = d_in
80
+ so, vo = d_out
81
+ self.gvp = GVPBlock((4 * si + 3 * vi, vi), d_out)
82
+
83
+ def forward(self, x, batch_mask, batch_size=None):
84
+ """ x: tuple (s, V) """
85
+ s, v = x
86
+
87
+ sm = scatter_mean(s, batch_mask, dim=0, dim_size=batch_size)
88
+ smi = scatter_min(s, batch_mask, dim=0, dim_size=batch_size)[0]
89
+ sma = scatter_max(s, batch_mask, dim=0, dim_size=batch_size)[0]
90
+ sstd = scatter_std(s, batch_mask, dim=0, dim_size=batch_size)
91
+
92
+ vnorm = _norm_no_nan(v)
93
+ vm = scatter_mean(v, batch_mask, dim=0, dim_size=batch_size)
94
+ vmi = scatter_min(vnorm, batch_mask, dim=0, dim_size=batch_size)[0]
95
+ vma = scatter_max(vnorm, batch_mask, dim=0, dim_size=batch_size)[0]
96
+ vstd = scatter_std(vnorm, batch_mask, dim=0, dim_size=batch_size)
97
+
98
+ z = torch.hstack((sm, smi, sma, sstd, vmi, vma, vstd))
99
+ out = self.gvp((z, vm))
100
+ return out
101
+
102
+
103
+ class TupleLinear(nn.Module):
104
+ def __init__(self, in_dims, out_dims, bias=True):
105
+ super().__init__()
106
+ self.si, self.vi = in_dims
107
+ self.so, self.vo = out_dims
108
+ assert self.si and self.so
109
+ self.ws = nn.Linear(self.si, self.so, bias=bias)
110
+ self.wv = nn.Linear(self.vi, self.vo, bias=bias) if self.vi and self.vo else None
111
+
112
+ def forward(self, x):
113
+ if self.vi:
114
+ s, v = x
115
+
116
+ s = self.ws(s)
117
+
118
+ if self.vo:
119
+ v = v.transpose(-1, -2)
120
+ v = self.wv(v)
121
+ v = v.transpose(-1, -2)
122
+
123
+ else:
124
+ s = self.ws(x)
125
+
126
+ if self.vo:
127
+ v = torch.zeros(s.size(0), self.vo, 3, device=s.device)
128
+
129
+ return (s, v) if self.vo else s
130
+
131
+
132
+ class GVPTransformerLayer(nn.Module):
133
+ """
134
+ Full graph transformer layer with Geometric Vector Perceptrons.
135
+ Inspired by
136
+ - GVP: Jing, Bowen, et al. "Learning from protein structure with geometric vector perceptrons." arXiv preprint arXiv:2009.01411 (2020).
137
+ - Transformer architecture: Vignac, Clement, et al. "Digress: Discrete denoising diffusion for graph generation." arXiv preprint arXiv:2209.14734 (2022).
138
+ - Invariant point attention: Jumper, John, et al. "Highly accurate protein structure prediction with AlphaFold." Nature 596.7873 (2021): 583-589.
139
+
140
+ :param node_dims: node embedding dimensions (n_scalar, n_vector)
141
+ :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
142
+ :param global_dims: global feature dimension (n_scalar, n_vector)
143
+ :param dk: key dimension, (n_scalar, n_vector)
144
+ :param dv: node value dimension, (n_scalar, n_vector)
145
+ :param de: edge value dimension, (n_scalar, n_vector)
146
+ :param db: dimension of edge contribution to attention, int
147
+ :param attn_heads: number of attention heads, int
148
+ :param n_feedforward: number of GVPs to use in feedforward function
149
+ :param drop_rate: drop probability in all dropout layers
150
+ :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
151
+ :param vector_gate: whether to use vector gating.
152
+ (vector_act will be used as sigma^+ in vector gating if `True`)
153
+ :param attention: can be used to turn off the attention mechanism
154
+ """
155
+
156
+ def __init__(self, node_dims, edge_dims, global_dims, dk, dv, de, db,
157
+ attn_heads, n_feedforward=1, drop_rate=0.0,
158
+ activations=(F.relu, torch.sigmoid), vector_gate=False,
159
+ attention=True):
160
+
161
+ super(GVPTransformerLayer, self).__init__()
162
+
163
+ self.attention = attention
164
+
165
+ dq = dk
166
+ self.dq = dq
167
+ self.dk = dk
168
+ self.dv = dv
169
+ self.de = de
170
+ self.db = db
171
+
172
+ self.h = attn_heads
173
+
174
+ self.q = TupleLinear(node_dims, tuple_mul(dq, self.h), bias=False) if self.attention else None
175
+ self.k = TupleLinear(node_dims, tuple_mul(dk, self.h), bias=False) if self.attention else None
176
+ self.vx = TupleLinear(node_dims, tuple_mul(dv, self.h), bias=False)
177
+
178
+ self.ve = TupleLinear(edge_dims, tuple_mul(de, self.h), bias=False)
179
+ self.b = TupleLinear(edge_dims, (db * self.h, 0), bias=False) if self.attention else None
180
+
181
+ m_dim = tuple_sum(tuple_mul(dv, self.h), tuple_mul(de, self.h))
182
+ self.msg = GVPBlock(m_dim, m_dim, n_feedforward,
183
+ activations=activations, vector_gate=vector_gate)
184
+
185
+ m_dim = tuple_sum(m_dim, global_dims)
186
+ self.x_out = GVPBlock(m_dim, node_dims, n_feedforward,
187
+ activations=activations, vector_gate=vector_gate)
188
+ self.x_norm = LayerNorm(node_dims, learnable_vector_weight=True)
189
+ self.x_dropout = Dropout(drop_rate)
190
+
191
+ e_dim = tuple_sum(tuple_mul(node_dims, 2), edge_dims, global_dims)
192
+ if self.attention:
193
+ e_dim = (e_dim[0] + 3 * attn_heads, e_dim[1])
194
+ self.e_out = GVPBlock(e_dim, edge_dims, n_feedforward,
195
+ activations=activations, vector_gate=vector_gate)
196
+ self.e_norm = LayerNorm(edge_dims, learnable_vector_weight=True)
197
+ self.e_dropout = Dropout(drop_rate)
198
+
199
+ self.pna_x = GeometricPNA(node_dims, node_dims)
200
+ self.pna_e = GeometricPNA(edge_dims, edge_dims)
201
+ self.y = GVP(global_dims, global_dims, activations=(None, None), vector_gate=vector_gate)
202
+ _dim = tuple_sum(node_dims, edge_dims, global_dims)
203
+ self.y_out = GVPBlock(_dim, global_dims, n_feedforward,
204
+ activations=activations, vector_gate=vector_gate)
205
+ self.y_norm = LayerNorm(global_dims, learnable_vector_weight=True)
206
+ self.y_dropout = Dropout(drop_rate)
207
+
208
+ def forward(self, x, edge_index, batch_mask, edge_attr, global_attr=None,
209
+ node_mask=None):
210
+ """
211
+ :param x: tuple (s, V) of `torch.Tensor`
212
+ :param edge_index: array of shape [2, n_edges]
213
+ :param batch_mask: array indicating different graphs
214
+ :param edge_attr: tuple (s, V) of `torch.Tensor`
215
+ :param global_attr: tuple (s, V) of `torch.Tensor`
216
+ :param node_mask: array of type `bool` to index into the first
217
+ dim of node embeddings (s, V). If not `None`, only
218
+ these nodes will be updated.
219
+ """
220
+
221
+ row, col = edge_index
222
+ n = len(x[0])
223
+ batch_size = len(torch.unique(batch_mask))
224
+
225
+ # Compute attention
226
+ if self.attention:
227
+ Q = self.q(x)
228
+ K = self.k(x)
229
+ b = self.b(edge_attr)
230
+
231
+ qs, qv = Q # (n, dq * h), (n, dq * h, 3)
232
+ ks, kv = K # (n, dq * h), (n, dq * h, 3)
233
+ attn_s = (qs[row] * ks[col]).reshape(len(row), self.h, self.dq[0]).sum(dim=-1) # (m, h)
234
+ # NOTE: attn_v is the Frobenius inner product between vector-valued queries and keys of size [dq, 3]
235
+ # (generalizes the dot-product between queries and keys similar to Pocket2Mol)
236
+ # TODO: double-check if this is correctly implemented!
237
+ attn_v = (qv[row] * kv[col]).reshape(len(row), self.h, self.dq[1], 3).sum(dim=(-2, -1)) # (m, h)
238
+ attn_e = b.reshape(b.size(0), self.h, self.db).sum(dim=-1) # (m, h)
239
+
240
+ attn = attn_s / math.sqrt(3 * self.dk[0]) + \
241
+ attn_v / math.sqrt(9 * self.dk[1]) + \
242
+ attn_e / math.sqrt(3 * self.db)
243
+ attn = scatter_softmax(attn, row, dim=0) # (m, h)
244
+ attn = attn.unsqueeze(-1) # (m, h, 1)
245
+
246
+ # Compute new features
247
+ Vx = self.vx(x)
248
+ Ve = self.ve(edge_attr)
249
+
250
+ mx = (Vx[0].reshape(Vx[0].size(0), self.h, self.dv[0]), # (n, h, dv)
251
+ Vx[1].reshape(Vx[1].size(0), self.h, self.dv[1], 3)) # (n, h, dv, 3)
252
+ me = (Ve[0].reshape(Ve[0].size(0), self.h, self.de[0]),
253
+ Ve[1].reshape(Ve[1].size(0), self.h, self.de[1], 3))
254
+
255
+ mx = tuple_index(mx, col)
256
+ if self.attention:
257
+ mx = tuple_mul(mx, attn)
258
+ me = tuple_mul(me, attn)
259
+
260
+ _m = tuple_cat(mx, me)
261
+ _m = (_m[0].flatten(1), _m[1].flatten(1, 2))
262
+ m = self.msg(_m) # (m, h * dv), (m, h * dv, 3)
263
+ m = (scatter_mean(m[0], row, dim=0, dim_size=n), # (n, h * dv)
264
+ scatter_mean(m[1], row, dim=0, dim_size=n)) # (n, h * dv, 3)
265
+ if global_attr is not None:
266
+ m = tuple_cat(m, tuple_index(global_attr, batch_mask))
267
+ X_out = self.x_norm(tuple_sum(x, self.x_dropout(self.x_out(m))))
268
+
269
+ _e = tuple_cat(tuple_index(x, row), tuple_index(x, col), edge_attr)
270
+ if self.attention:
271
+ _e = (torch.cat([_e[0], attn_s, attn_v, attn_e], dim=-1), _e[1])
272
+ if global_attr is not None:
273
+ _e = tuple_cat(_e, tuple_index(global_attr, batch_mask[row]))
274
+ E_out = self.e_norm(tuple_sum(edge_attr, self.e_dropout(self.e_out(_e))))
275
+
276
+ _y = tuple_cat(self.pna_x(x, batch_mask, batch_size),
277
+ self.pna_e(edge_attr, batch_mask[row], batch_size))
278
+ if global_attr is not None:
279
+ _y = tuple_cat(_y, self.y(global_attr))
280
+ y_out = self.y_norm(tuple_sum(global_attr, self.y_dropout(self.y_out(_y))))
281
+ else:
282
+ y_out = self.y_norm(self.y_dropout(self.y_out(_y)))
283
+
284
+ if node_mask is not None:
285
+ X_out[0][~node_mask], X_out[1][~node_mask] = tuple_index(x, ~node_mask)
286
+
287
+ return X_out, E_out, y_out
288
+
289
+
290
+ class GVPTransformerModel(torch.nn.Module):
291
+ """
292
+ GVP-Transformer model
293
+
294
+ :param node_in_dim: node dimension in input graph, scalars or tuple (scalars, vectors)
295
+ :param node_h_dim: node dimensions to use in GVP-GNN layers, tuple (s, V)
296
+ :param node_out_nf: node dimensions in output graph, tuple (s, V)
297
+ :param edge_in_nf: edge dimension in input graph (scalars)
298
+ :param edge_h_dim: edge dimensions to embed to before use in GVP-GNN layers,
299
+ tuple (s, V)
300
+ :param edge_out_nf: edge dimensions in output graph, tuple (s, V)
301
+ :param num_layers: number of GVP-GNN layers
302
+ :param drop_rate: rate to use in all dropout layers
303
+ :param reflection_equiv: bool, use reflection-sensitive feature based on the
304
+ cross product if False
305
+ :param d_max:
306
+ :param num_rbf:
307
+ :param vector_gate: use vector gates in all GVPs
308
+ :param attention: can be used to turn off the attention mechanism
309
+ """
310
+ def __init__(self, node_in_dim, node_h_dim, node_out_nf, edge_in_nf,
311
+ edge_h_dim, edge_out_nf, num_layers, dk, dv, de, db, dy,
312
+ attn_heads, n_feedforward, drop_rate, reflection_equiv=True,
313
+ d_max=20.0, num_rbf=16, vector_gate=False, attention=True):
314
+
315
+ super(GVPTransformerModel, self).__init__()
316
+
317
+ self.reflection_equiv = reflection_equiv
318
+ self.d_max = d_max
319
+ self.num_rbf = num_rbf
320
+
321
+ # node_in_dim = (node_in_dim, 1)
322
+ if not isinstance(node_in_dim, tuple):
323
+ node_in_dim = (node_in_dim, 0)
324
+
325
+ edge_in_dim = (edge_in_nf + 2 * node_in_dim[0] + self.num_rbf, 1)
326
+ if not self.reflection_equiv:
327
+ edge_in_dim = (edge_in_dim[0], edge_in_dim[1] + 1)
328
+
329
+ self.W_v = GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=vector_gate)
330
+ self.W_e = GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=vector_gate)
331
+ # self.W_v = nn.Sequential(
332
+ # LayerNorm(node_in_dim, learnable_vector_weight=True),
333
+ # GVP(node_in_dim, node_h_dim, activations=(None, None)),
334
+ # )
335
+ # self.W_e = nn.Sequential(
336
+ # LayerNorm(edge_in_dim, learnable_vector_weight=True),
337
+ # GVP(edge_in_dim, edge_h_dim, activations=(None, None)),
338
+ # )
339
+
340
+ self.dy = dy
341
+ self.layers = nn.ModuleList(
342
+ GVPTransformerLayer(node_h_dim, edge_h_dim, dy, dk, dv, de, db,
343
+ attn_heads, n_feedforward=n_feedforward,
344
+ drop_rate=drop_rate, vector_gate=vector_gate,
345
+ activations=(F.relu, None), attention=attention)
346
+ for _ in range(num_layers))
347
+
348
+ self.W_v_out = GVP(node_h_dim, (node_out_nf, 1), activations=(None, None), vector_gate=vector_gate)
349
+ self.W_e_out = GVP(edge_h_dim, (edge_out_nf, 0), activations=(None, None), vector_gate=vector_gate)
350
+ # self.W_v_out = nn.Sequential(
351
+ # LayerNorm(node_h_dim, learnable_vector_weight=True),
352
+ # GVP(node_h_dim, (node_out_nf, 1), activations=(None, None)),
353
+ # )
354
+ # self.W_e_out = nn.Sequential(
355
+ # LayerNorm(edge_h_dim, learnable_vector_weight=True),
356
+ # GVP(edge_h_dim, (edge_out_nf, 0), activations=(None, None))
357
+ # )
358
+
359
+ def edge_features(self, h, x, edge_index, batch_mask=None, edge_attr=None):
360
+ """
361
+ :param h:
362
+ :param x:
363
+ :param edge_index:
364
+ :param batch_mask:
365
+ :param edge_attr:
366
+ :return: scalar and vector-valued edge features
367
+ """
368
+ row, col = edge_index
369
+ coord_diff = x[row] - x[col]
370
+ dist = coord_diff.norm(dim=-1)
371
+ rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf,
372
+ device=x.device)
373
+
374
+ edge_s = torch.cat([h[row], h[col], rbf], dim=1)
375
+ edge_v = _normalize(coord_diff).unsqueeze(-2)
376
+
377
+ if edge_attr is not None:
378
+ edge_s = torch.cat([edge_s, edge_attr], dim=1)
379
+
380
+ if not self.reflection_equiv:
381
+ mean = scatter_mean(x, batch_mask, dim=0,
382
+ dim_size=batch_mask.max() + 1)
383
+ row, col = edge_index
384
+ cross = torch.cross(x[row] - mean[batch_mask[row]],
385
+ x[col] - mean[batch_mask[col]], dim=1)
386
+ cross = _normalize(cross).unsqueeze(-2)
387
+
388
+ edge_v = torch.cat([edge_v, cross], dim=-2)
389
+
390
+ return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v)
391
+
392
+ def forward(self, h, x, edge_index, v=None, batch_mask=None, edge_attr=None):
393
+
394
+ bs = len(batch_mask.unique())
395
+
396
+ # h_v = (h, x.unsqueeze(-2))
397
+ h_v = h if v is None else (h, v)
398
+ h_e = self.edge_features(h, x, edge_index, batch_mask, edge_attr)
399
+
400
+ h_v = self.W_v(h_v)
401
+ h_e = self.W_e(h_e)
402
+ h_y = (torch.zeros(bs, self.dy[0], device=h.device),
403
+ torch.zeros(bs, self.dy[1], 3, device=h.device))
404
+
405
+ for layer in self.layers:
406
+ h_v, h_e, h_y = layer(h_v, edge_index, batch_mask, h_e, h_y)
407
+
408
+ # h, x = self.W_v_out(h_v)
409
+ # x = x.squeeze(-2)
410
+ h, vel = self.W_v_out(h_v)
411
+ # x = x + vel.squeeze(-2)
412
+
413
+ edge_attr = self.W_e_out(h_e)
414
+
415
+ # return h, x, edge_attr
416
+ return h, vel.squeeze(-2), edge_attr
417
+
418
+
419
+ if __name__ == "__main__":
420
+ from src.model.gvp import randn
421
+ from scipy.spatial.transform import Rotation
422
+
423
+ def test_equivariance(model, nodes, edges, glob_feat):
424
+ random = torch.as_tensor(Rotation.random().as_matrix(),
425
+ dtype=torch.float32, device=device)
426
+
427
+ with torch.no_grad():
428
+ X_out, E_out, y_out = model(nodes, edges, glob_feat)
429
+ n_v_rot, e_v_rot, y_v_rot = nodes[1] @ random, edges[1] @ random, glob_feat[1] @ random
430
+ X_out_v_rot = X_out[1] @ random
431
+ E_out_v_rot = E_out[1] @ random
432
+ y_out_v_rot = y_out[1] @ random
433
+ X_out_prime, E_out_prime, y_out_prime = model((nodes[0], n_v_rot), (edges[0], e_v_rot), (glob_feat[0], y_v_rot))
434
+
435
+ assert torch.allclose(X_out[0], X_out_prime[0], atol=1e-5, rtol=1e-4)
436
+ assert torch.allclose(X_out_v_rot, X_out_prime[1], atol=1e-5, rtol=1e-4)
437
+ assert torch.allclose(E_out[0], E_out_prime[0], atol=1e-5, rtol=1e-4)
438
+ assert torch.allclose(E_out_v_rot, E_out_prime[1], atol=1e-5, rtol=1e-4)
439
+ assert torch.allclose(y_out[0], y_out_prime[0], atol=1e-5, rtol=1e-4)
440
+ assert torch.allclose(y_out_v_rot, y_out_prime[1], atol=1e-5, rtol=1e-4)
441
+ print("SUCCESS")
442
+
443
+
444
+ n_nodes = 300
445
+ n_edges = 10000
446
+ batch_size = 6
447
+
448
+ node_dim = (16, 8)
449
+ edge_dim = (8, 4)
450
+ global_dim = (4, 2)
451
+ dk = (6, 3)
452
+ dv = (7, 4)
453
+ de = (5, 2)
454
+ db = 10
455
+ attn_heads = 9
456
+
457
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
458
+
459
+
460
+ nodes = randn(n_nodes, node_dim, device=device)
461
+ edges = randn(n_edges, edge_dim, device=device)
462
+ glob_feat = randn(batch_size, global_dim, device=device)
463
+ edge_index = torch.randint(0, n_nodes, (2, n_edges), device=device)
464
+ batch_idx = torch.randint(0, batch_size, (n_nodes,), device=device)
465
+
466
+ model = GVPTransformerLayer(node_dim, edge_dim, global_dim, dk, dv, de, db,
467
+ attn_heads, n_feedforward = 2,
468
+ drop_rate = 0.1).to(device).eval()
469
+
470
+ model_fn = lambda h_V, h_E, h_y: model(h_V, edge_index, batch_idx, h_E, h_y)
471
+ test_equivariance(model_fn, nodes, edges, glob_feat)
src/model/lightning.py ADDED
@@ -0,0 +1,1426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import tempfile
3
+ from typing import Optional, Union
4
+ from time import time
5
+ from pathlib import Path
6
+ from functools import partial
7
+ from itertools import accumulate
8
+ from argparse import Namespace
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ from rdkit import Chem
13
+ import torch
14
+ from torch.utils.data import DataLoader, SubsetRandomSampler
15
+ from torch.distributions.categorical import Categorical
16
+ import pytorch_lightning as pl
17
+ from torch_scatter import scatter_mean
18
+
19
+ import src.utils as utils
20
+ from src.constants import atom_encoder, atom_decoder, aa_encoder, aa_decoder, \
21
+ bond_encoder, bond_decoder, residue_encoder, residue_bond_encoder, \
22
+ residue_decoder, residue_bond_decoder, aa_atom_index, aa_atom_mask
23
+ from src.data.dataset import ProcessedLigandPocketDataset, ClusteredDataset, get_wds
24
+ from src.data import data_utils
25
+ from src.data.data_utils import AppendVirtualNodesInCoM, center_data, Residues, TensorDict, randomize_tensors
26
+ from src.model.flows import CoordICFM, TorusICFM, CoordICFMPredictFinal, TorusICFMPredictFinal, SO3ICFM
27
+ from src.model.markov_bridge import UniformPriorMarkovBridge, MarginalPriorMarkovBridge
28
+ from src.model.dynamics import Dynamics
29
+ from src.model.dynamics_hetero import DynamicsHetero
30
+ from src.model.diffusion_utils import DistributionNodes
31
+ from src.model.loss_utils import TimestepWeights, clash_loss
32
+ from src.analysis.visualization_utils import pocket_to_rdkit, mols_to_pdbfile
33
+ from src.analysis.metrics import MoleculeValidity, CategoricalDistribution, MolecularProperties
34
+ from src.data.molecule_builder import build_molecule
35
+ from src.data.postprocessing import process_all
36
+ from src.sbdd_metrics.metrics import FullEvaluator
37
+ from src.sbdd_metrics.evaluation import VALIDITY_METRIC_NAME, aggregated_metrics, collection_metrics
38
+ from tqdm import tqdm
39
+
40
+ # derive additional constants
41
+ aa_atom_mask_tensor = torch.tensor([aa_atom_mask[aa] for aa in aa_decoder])
42
+ aa_atom_decoder = {aa: {v: k for k, v in aa_atom_index[aa].items()} for aa in aa_decoder}
43
+ aa_atom_type_tensor = torch.tensor([[atom_encoder.get(aa_atom_decoder[aa].get(i, '-')[0], -42)
44
+ for i in range(14)] for aa in aa_decoder])
45
+
46
+
47
+ def set_default(namespace, key, default_val):
48
+ val = vars(namespace).get(key, default_val)
49
+ setattr(namespace, key, val)
50
+
51
+
52
+ class DrugFlow(pl.LightningModule):
53
+ def __init__(
54
+ self,
55
+ pocket_representation: str,
56
+ train_params: Namespace,
57
+ loss_params: Namespace,
58
+ eval_params: Namespace,
59
+ predictor_params: Namespace,
60
+ simulation_params: Namespace,
61
+ virtual_nodes: Union[list, None],
62
+ flexible: bool,
63
+ flexible_bb: bool = False,
64
+ debug: bool = False,
65
+ overfit: bool = False,
66
+ ):
67
+ super(DrugFlow, self).__init__()
68
+ self.save_hyperparameters()
69
+
70
+ # Set default parameters
71
+ set_default(train_params, "sharded_dataset", False)
72
+ set_default(train_params, "sample_from_clusters", False)
73
+ set_default(train_params, "lr_step_size", None)
74
+ set_default(train_params, "lr_gamma", None)
75
+ set_default(train_params, "gnina", None)
76
+ set_default(loss_params, "lambda_x", 1.0)
77
+ set_default(loss_params, "lambda_clash", None)
78
+ set_default(loss_params, "reduce", "mean")
79
+ set_default(loss_params, "regularize_uncertainty", None)
80
+ set_default(eval_params, "n_loss_per_sample", 1)
81
+ set_default(eval_params, "n_sampling_steps", simulation_params.n_steps)
82
+ set_default(predictor_params, "transform_sc_pred", False)
83
+ set_default(predictor_params, "add_chi_as_feature", False)
84
+ set_default(predictor_params, "augment_residue_sc", False)
85
+ set_default(predictor_params, "augment_ligand_sc", False)
86
+ set_default(predictor_params, "add_all_atom_diff", False)
87
+ set_default(predictor_params, "angle_act_fn", None)
88
+ set_default(simulation_params, "predict_confidence", False)
89
+ set_default(simulation_params, "predict_final", False)
90
+ set_default(simulation_params, "scheduler_chi", None)
91
+
92
+ # Check for invalid configurations
93
+ assert pocket_representation in {'side_chain_bead', 'CA+'}
94
+ self.pocket_representation = pocket_representation
95
+
96
+ assert flexible or not predictor_params.augment_residue_sc
97
+ self.augment_residue_sc = predictor_params.augment_residue_sc \
98
+ if 'augment_residue_sc' in predictor_params else False
99
+ self.augment_ligand_sc = predictor_params.augment_ligand_sc \
100
+ if 'augment_ligand_sc' in predictor_params else False
101
+
102
+ assert not (flexible_bb and predictor_params.normal_modes), \
103
+ "Normal mode eigenvectors are only meaningful for fixed backbones"
104
+ assert (not flexible_bb) or flexible, \
105
+ "Currently atom vectors aren't updated if flexible=False"
106
+
107
+ assert not (simulation_params.predict_confidence and
108
+ (not predictor_params.heterogeneous_graph or simulation_params.predict_final))
109
+
110
+ # Set parameters
111
+ self.train_dataset = None
112
+ self.val_dataset = None
113
+ self.test_dataset = None
114
+ self.virtual_nodes = virtual_nodes
115
+ self.flexible = flexible
116
+ self.flexible_bb = flexible_bb
117
+ self.debug = debug
118
+ self.overfit = overfit
119
+ self.predict_confidence = simulation_params.predict_confidence
120
+
121
+ if self.virtual_nodes:
122
+ self.add_virtual_min = virtual_nodes[0]
123
+ self.add_virtual_max = virtual_nodes[1]
124
+
125
+ # Training parameters
126
+ self.datadir = train_params.datadir
127
+ self.receptor_dir = train_params.datadir
128
+ self.batch_size = train_params.batch_size
129
+ self.lr = train_params.lr
130
+ self.lr_step_size = train_params.lr_step_size
131
+ self.lr_gamma = train_params.lr_gamma
132
+ self.num_workers = train_params.num_workers
133
+ self.sample_from_clusters = train_params.sample_from_clusters
134
+ self.sharded_dataset = train_params.sharded_dataset
135
+ self.clip_grad = train_params.clip_grad
136
+ if self.clip_grad:
137
+ self.gradnorm_queue = utils.Queue()
138
+ # Add large value that will be flushed.
139
+ self.gradnorm_queue.add(3000)
140
+
141
+ # Evaluation parameters
142
+ self.outdir = eval_params.outdir
143
+ self.eval_batch_size = eval_params.eval_batch_size
144
+ self.eval_epochs = eval_params.eval_epochs
145
+ # assert eval_params.visualize_sample_epoch % self.eval_epochs == 0
146
+ self.visualize_sample_epoch = eval_params.visualize_sample_epoch
147
+ self.visualize_chain_epoch = eval_params.visualize_chain_epoch
148
+ self.sample_with_ground_truth_size = eval_params.sample_with_ground_truth_size
149
+ self.n_loss_per_sample = eval_params.n_loss_per_sample
150
+ self.n_eval_samples = eval_params.n_eval_samples
151
+ self.n_visualize_samples = eval_params.n_visualize_samples
152
+ self.keep_frames = eval_params.keep_frames
153
+ self.gnina = train_params.gnina
154
+
155
+ # Feature encoders/decoders
156
+ self.atom_encoder = atom_encoder
157
+ self.atom_decoder = atom_decoder
158
+ self.bond_encoder = bond_encoder
159
+ self.bond_decoder = bond_decoder
160
+ self.aa_encoder = aa_encoder
161
+ self.aa_decoder = aa_decoder
162
+ self.residue_encoder = residue_encoder
163
+ self.residue_decoder = residue_decoder
164
+ self.residue_bond_encoder = residue_bond_encoder
165
+ self.residue_bond_decoder = residue_bond_decoder
166
+
167
+ self.atom_nf = len(self.atom_decoder)
168
+ self.residue_nf = len(self.aa_decoder)
169
+ if self.pocket_representation == 'side_chain_bead':
170
+ self.residue_nf += len(self.residue_encoder)
171
+ if self.pocket_representation == 'CA+':
172
+ self.aa_atom_index = aa_atom_index
173
+ self.n_atom_aa = max([x for aa in aa_atom_index.values() for x in aa.values()]) + 1
174
+ self.residue_nf = (self.residue_nf, self.n_atom_aa) # (s, V)
175
+ self.bond_nf = len(self.bond_decoder)
176
+ self.pocket_bond_nf = len(self.residue_bond_decoder)
177
+ self.x_dim = 3
178
+
179
+ # Set up the neural network
180
+ self.dynamics = self.init_model(predictor_params)
181
+
182
+ # Initialize objects for each variable type
183
+ if simulation_params.predict_final:
184
+ self.module_x = CoordICFMPredictFinal(None)
185
+ self.module_chi = TorusICFMPredictFinal(None, 5) if self.flexible else None
186
+ if self.flexible_bb:
187
+ raise NotImplementedError()
188
+ else:
189
+ self.module_x = CoordICFM(None)
190
+ # self.module_chi = AngleICFM(None, 5) if self.flexible else None
191
+ scheduler_args = None if simulation_params.scheduler_chi is None else vars(simulation_params.scheduler_chi)
192
+ self.module_chi = TorusICFM(None, 5, scheduler_args) if self.flexible else None
193
+ self.module_trans = CoordICFM(None) if self.flexible_bb else None
194
+ self.module_rot = SO3ICFM(None) if self.flexible_bb else None
195
+
196
+ if simulation_params.prior_h == 'uniform':
197
+ self.module_h = UniformPriorMarkovBridge(self.atom_nf, loss_type=loss_params.discrete_loss)
198
+ elif simulation_params.prior_h == 'marginal':
199
+ self.register_buffer('prior_h', self.get_categorical_prop('atom')) # add to module
200
+ self.module_h = MarginalPriorMarkovBridge(self.atom_nf, self.prior_h, loss_type=loss_params.discrete_loss)
201
+
202
+ if simulation_params.prior_e == 'uniform':
203
+ self.module_e = UniformPriorMarkovBridge(self.bond_nf, loss_type=loss_params.discrete_loss)
204
+ elif simulation_params.prior_e == 'marginal':
205
+ self.register_buffer('prior_e', self.get_categorical_prop('bond')) # add to module
206
+ self.module_e = MarginalPriorMarkovBridge(self.bond_nf, self.prior_e, loss_type=loss_params.discrete_loss)
207
+
208
+
209
+ # Loss parameters
210
+ self.loss_reduce = loss_params.reduce
211
+ self.lambda_x = loss_params.lambda_x
212
+ self.lambda_h = loss_params.lambda_h
213
+ self.lambda_e = loss_params.lambda_e
214
+ self.lambda_chi = loss_params.lambda_chi if self.flexible else None
215
+ self.lambda_trans = loss_params.lambda_trans if self.flexible_bb else None
216
+ self.lambda_rot = loss_params.lambda_rot if self.flexible_bb else None
217
+ self.lambda_clash = loss_params.lambda_clash
218
+ self.regularize_uncertainty = loss_params.regularize_uncertainty
219
+
220
+ if loss_params.timestep_weights is not None:
221
+ weight_type = loss_params.timestep_weights.split('_')[0]
222
+ kwargs = loss_params.timestep_weights.split('_')[1:]
223
+ kwargs = {x.split('=')[0]: float(x.split('=')[1]) for x in kwargs}
224
+ self.timestep_weights = TimestepWeights(weight_type, **kwargs)
225
+ else:
226
+ self.timestep_weights = None
227
+
228
+
229
+ # Sampling
230
+ self.T_sampling = eval_params.n_sampling_steps
231
+ self.train_step_size = 1 / simulation_params.n_steps
232
+ self.size_distribution = None # initialized only if needed
233
+
234
+
235
+ # Metrics, initialized only if needed
236
+ self.train_smiles = None
237
+ self.ligand_metrics = None
238
+ self.molecule_properties = None
239
+ self.evaluator = None
240
+ self.ligand_atom_type_distribution = None
241
+ self.ligand_bond_type_distribution = None
242
+
243
+ # containers for metric aggregation
244
+ self.training_step_outputs = []
245
+ self.validation_step_outputs = []
246
+
247
+ def on_load_checkpoint(self, checkpoint):
248
+ """
249
+ This hook is only used for backward compatibility with checkpoints that
250
+ did not save prior_h and prior_e in state_dict in the past
251
+ """
252
+ if hasattr(self, "prior_h") and "prior_h" not in checkpoint["state_dict"]:
253
+ checkpoint["state_dict"]["prior_h"] = self.get_categorical_prop('atom')
254
+ if hasattr(self, "prior_e") and "prior_e" not in checkpoint["state_dict"]:
255
+ checkpoint["state_dict"]["prior_e"] = self.get_categorical_prop('bond')
256
+ if "prior_e" in checkpoint["state_dict"] and not hasattr(self, "prior_e"):
257
+ # NOTE: a very exotic case that happened to one model. Potentially can be removed in the future
258
+ self.register_buffer("prior_e", self.get_categorical_prop('bond'))
259
+
260
+ def init_model(self, predictor_params):
261
+
262
+ model_type = predictor_params.backbone
263
+
264
+ if 'heterogeneous_graph' in predictor_params and predictor_params.heterogeneous_graph:
265
+ return DynamicsHetero(
266
+ atom_nf=self.atom_nf,
267
+ residue_nf=self.residue_nf,
268
+ bond_dict=self.bond_encoder,
269
+ pocket_bond_dict=self.residue_bond_encoder,
270
+ model=model_type,
271
+ num_rbf_time=predictor_params.__dict__.get('num_rbf_time'),
272
+ model_params=getattr(predictor_params, model_type + '_params'),
273
+ edge_cutoff_ligand=predictor_params.edge_cutoff_ligand,
274
+ edge_cutoff_pocket=predictor_params.edge_cutoff_pocket,
275
+ edge_cutoff_interaction=predictor_params.edge_cutoff_interaction,
276
+ predict_angles=self.flexible,
277
+ predict_frames=self.flexible_bb,
278
+ add_cycle_counts=predictor_params.cycle_counts,
279
+ add_spectral_feat=predictor_params.spectral_feat,
280
+ add_nma_feat=predictor_params.normal_modes,
281
+ reflection_equiv=predictor_params.reflection_equivariant,
282
+ d_max=predictor_params.d_max,
283
+ num_rbf_dist=predictor_params.num_rbf,
284
+ self_conditioning=predictor_params.self_conditioning,
285
+ augment_residue_sc=self.augment_residue_sc,
286
+ augment_ligand_sc=self.augment_ligand_sc,
287
+ add_chi_as_feature=predictor_params.add_chi_as_feature,
288
+ angle_act_fn=predictor_params.angle_act_fn,
289
+ add_all_atom_diff=predictor_params.add_all_atom_diff,
290
+ predict_confidence=self.predict_confidence,
291
+ )
292
+
293
+ else:
294
+ if predictor_params.__dict__.get('num_rbf_time') is not None:
295
+ raise NotImplementedError("RBF time embedding not yet implemented")
296
+
297
+ return Dynamics(
298
+ atom_nf=self.atom_nf,
299
+ residue_nf=self.residue_nf,
300
+ joint_nf=predictor_params.joint_nf,
301
+ bond_dict=self.bond_encoder,
302
+ pocket_bond_dict=self.residue_bond_encoder,
303
+ edge_nf=predictor_params.edge_nf,
304
+ hidden_nf=predictor_params.hidden_nf,
305
+ model=model_type,
306
+ model_params=getattr(predictor_params, model_type + '_params'),
307
+ edge_cutoff_ligand=predictor_params.edge_cutoff_ligand,
308
+ edge_cutoff_pocket=predictor_params.edge_cutoff_pocket,
309
+ edge_cutoff_interaction=predictor_params.edge_cutoff_interaction,
310
+ predict_angles=self.flexible,
311
+ predict_frames=self.flexible_bb,
312
+ add_cycle_counts=predictor_params.cycle_counts,
313
+ add_spectral_feat=predictor_params.spectral_feat,
314
+ add_nma_feat=predictor_params.normal_modes,
315
+ self_conditioning=predictor_params.self_conditioning,
316
+ augment_residue_sc=self.augment_residue_sc,
317
+ augment_ligand_sc=self.augment_ligand_sc,
318
+ add_chi_as_feature=predictor_params.add_chi_as_feature,
319
+ angle_act_fn=predictor_params.angle_act_fn,
320
+ )
321
+
322
+ def _load_histogram(self, type):
323
+ """
324
+ Load empirical categorical distributions of atom or bond types from disk.
325
+ Returns None if the required file is not found.
326
+ """
327
+ assert type in {"atom", "bond"}
328
+ filename = 'ligand_type_histogram.npy' if type == 'atom' else 'ligand_bond_type_histogram.npy'
329
+ encoder = self.atom_encoder if type == 'atom' else self.bond_encoder
330
+ hist_file = Path(self.datadir, filename)
331
+ if not hist_file.exists():
332
+ return None
333
+ hist = np.load(hist_file, allow_pickle=True).item()
334
+ return CategoricalDistribution(hist, encoder)
335
+
336
+ def get_categorical_prop(self, type):
337
+ hist = self._load_histogram(type)
338
+ encoder = self.atom_encoder if type == 'atom' else self.bond_encoder
339
+ # Note: default value ensures that code will crash if prior is not
340
+ # read from disk or loaded from checkpoint later on
341
+ return torch.zeros(len(encoder)) * float("nan") if hist is None else torch.tensor(hist.p)
342
+
343
+ def configure_optimizers(self):
344
+ optimizers = [
345
+ torch.optim.AdamW(self.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12),
346
+ ]
347
+
348
+ if self.lr_step_size is None or self.lr_gamma is None:
349
+ lr_schedulers = []
350
+ else:
351
+ lr_schedulers = [
352
+ torch.optim.lr_scheduler.StepLR(optimizers[0], step_size=self.lr_step_size, gamma=self.lr_gamma),
353
+ ]
354
+ return optimizers, lr_schedulers
355
+
356
+ def setup(self, stage: Optional[str] = None):
357
+
358
+ self.setup_sampling()
359
+
360
+ if stage == 'fit':
361
+ self.train_dataset = self.get_dataset(stage='train')
362
+ self.val_dataset = self.get_dataset(stage='val')
363
+ self.setup_metrics()
364
+ elif stage == 'val':
365
+ self.val_dataset = self.get_dataset(stage='val')
366
+ self.setup_metrics()
367
+ elif stage == 'test':
368
+ self.test_dataset = self.get_dataset(stage='test')
369
+ self.setup_metrics()
370
+ elif stage == 'generation':
371
+ pass
372
+ else:
373
+ raise NotImplementedError
374
+
375
+ def get_dataset(self, stage, pocket_transform=None):
376
+
377
+ # when sampling we don't append virtual nodes as we might need access to the ground truth size
378
+ if self.virtual_nodes and stage == "train":
379
+ ligand_transform = AppendVirtualNodesInCoM(
380
+ atom_encoder, bond_encoder, add_min=self.add_virtual_min, add_max=self.add_virtual_max)
381
+ else:
382
+ ligand_transform = None
383
+
384
+ # we want to know if something goes wrong on the validation or test set
385
+ catch_errors = stage == "train"
386
+
387
+ if self.sharded_dataset:
388
+ return get_wds(
389
+ data_path=self.datadir,
390
+ stage='val' if self.debug else stage,
391
+ ligand_transform=ligand_transform,
392
+ pocket_transform=pocket_transform,
393
+ )
394
+
395
+ if self.sample_from_clusters and stage == "train": # val/test should be deterministic
396
+ return ClusteredDataset(
397
+ pt_path=Path(self.datadir, 'val.pt' if self.debug else f'{stage}.pt'),
398
+ ligand_transform=ligand_transform,
399
+ pocket_transform=pocket_transform,
400
+ catch_errors=catch_errors
401
+ )
402
+
403
+ return ProcessedLigandPocketDataset(
404
+ pt_path=Path(self.datadir, 'val.pt' if self.debug else f'{stage}.pt'),
405
+ ligand_transform=ligand_transform,
406
+ pocket_transform=pocket_transform,
407
+ catch_errors=catch_errors
408
+ )
409
+
410
+ def setup_sampling(self):
411
+ # distribution of nodes
412
+ histogram_file = Path(self.datadir, 'size_distribution.npy') # TODO: store this in model checkpoint so that we can sample without this file
413
+ size_histogram = np.load(histogram_file).tolist()
414
+ self.size_distribution = DistributionNodes(size_histogram)
415
+
416
+ def setup_metrics(self):
417
+ # For metrics
418
+ smiles_file = Path(self.datadir, 'train_smiles.npy')
419
+ self.train_smiles = None if not smiles_file.exists() else np.load(smiles_file)
420
+
421
+ self.ligand_metrics = MoleculeValidity()
422
+ self.molecule_properties = MolecularProperties()
423
+ self.evaluator = FullEvaluator(gnina=self.gnina, exclude_evaluators=['geometry', 'ring_count'])
424
+ self.ligand_atom_type_distribution = self._load_histogram('atom')
425
+ self.ligand_bond_type_distribution = self._load_histogram('bond')
426
+
427
+ def train_dataloader(self):
428
+ shuffle = None if self.overfit else False if self.sharded_dataset else True
429
+ return DataLoader(self.train_dataset, self.batch_size, shuffle=shuffle,
430
+ sampler=SubsetRandomSampler([0]) if self.overfit else None,
431
+ num_workers=self.num_workers,
432
+ collate_fn=self.train_dataset.collate_fn,
433
+ # collate_fn=partial(self.train_dataset.collate_fn, ligand_transform=batch_transform),
434
+ pin_memory=True)
435
+
436
+ def val_dataloader(self):
437
+ if self.overfit:
438
+ return self.train_dataloader()
439
+
440
+ return DataLoader(self.val_dataset, self.eval_batch_size,
441
+ shuffle=False, num_workers=self.num_workers,
442
+ collate_fn=self.val_dataset.collate_fn,
443
+ pin_memory=True)
444
+
445
+ def test_dataloader(self):
446
+ return DataLoader(self.test_dataset, self.eval_batch_size, shuffle=False,
447
+ num_workers=self.num_workers,
448
+ collate_fn=self.test_dataset.collate_fn,
449
+ pin_memory=True)
450
+
451
+ def log_metrics(self, metrics_dict, split, batch_size=None, **kwargs):
452
+ for m, value in metrics_dict.items():
453
+ self.log(f'{m}/{split}', value, batch_size=batch_size, **kwargs)
454
+
455
+ def aggregate_metrics(self, step_outputs, prefix):
456
+ if 'timestep' in step_outputs[0]:
457
+ timesteps = torch.cat([x['timestep'] for x in step_outputs]).squeeze()
458
+
459
+ if 'loss_per_sample' in step_outputs[0]:
460
+ losses = torch.cat([x['loss_per_sample'] for x in step_outputs])
461
+ pearson_corr = torch.corrcoef(torch.stack([timesteps, losses], dim=0))[0, 1]
462
+ self.log(f'corr_loss_timestep/{prefix}', pearson_corr, prog_bar=False)
463
+
464
+ if 'eps_hat_norm' in step_outputs[0]:
465
+ eps_norm = torch.cat([x['eps_hat_norm'] for x in step_outputs])
466
+ pearson_corr = torch.corrcoef(torch.stack([timesteps, eps_norm], dim=0))[0, 1]
467
+ self.log(f'corr_eps_timestep/{prefix}', pearson_corr, prog_bar=False)
468
+
469
+ def on_train_epoch_end(self):
470
+ self.aggregate_metrics(self.training_step_outputs, 'train')
471
+ self.training_step_outputs.clear()
472
+
473
+ # TODO: doesn't work in multi-GPU mode
474
+ # def on_before_batch_transfer(self, batch, dataloader_idx):
475
+ # """
476
+ # Performs operations on data before it is transferred to the GPU.
477
+ # Hence, supports multiple dataloaders for speedup.
478
+ # """
479
+ # batch['pocket'] = Residues(**batch['pocket'])
480
+ # return batch
481
+
482
+ # # TODO: try if this is compatible with DDP
483
+ # def on_after_batch_transfer(self, batch, dataloader_idx):
484
+ # """
485
+ # Performs operations on data after it is transferred to the GPU.
486
+ # """
487
+ # batch['pocket'] = Residues(**batch['pocket'])
488
+ # batch['ligand'] = TensorDict(**batch['ligand'])
489
+ # return batch
490
+
491
+ def get_sc_transform_fn(self, zt_chi, zt_x, t, z0_chi, ligand_mask, pocket):
492
+ sc_transform = {}
493
+
494
+ if self.augment_residue_sc:
495
+ def pred_all_atom(pred_chi, pred_trans=None, pred_rot=None):
496
+ temp_pocket = pocket.deepcopy()
497
+
498
+ if pred_trans is not None and pred_rot is not None:
499
+ zt_trans = pocket['x']
500
+ zt_rot = pocket['axis_angle']
501
+ z1_trans_pred = self.module_trans.get_z1_given_zt_and_pred(
502
+ zt_trans, pred_trans, None, t, pocket['mask'])
503
+ z1_rot_pred = self.module_rot.get_z1_given_zt_and_pred(
504
+ zt_rot, pred_rot, None, t, pocket['mask'])
505
+ temp_pocket.set_frame(z1_trans_pred, z1_rot_pred)
506
+
507
+ z1_chi_pred = self.module_chi.get_z1_given_zt_and_pred(
508
+ zt_chi[..., :5], pred_chi, z0_chi, t, pocket['mask'])
509
+ temp_pocket.set_chi(z1_chi_pred)
510
+
511
+ all_coord = temp_pocket['v'] + temp_pocket['x'].unsqueeze(1)
512
+ return all_coord - pocket['x'].unsqueeze(1)
513
+
514
+ sc_transform['residues'] = pred_all_atom
515
+
516
+ if self.augment_ligand_sc:
517
+ # sc_transform['atoms'] = partial(self.module_x.get_z1_given_zt_and_pred, zt=zs_x, z0=None, t=t, batch_mask=lig_mask)
518
+ sc_transform['atoms'] = lambda pred: (self.module_x.get_z1_given_zt_and_pred(
519
+ zt_x, pred.squeeze(1), None, t, ligand_mask) - zt_x).unsqueeze(1)
520
+
521
+ return sc_transform
522
+
523
+ def compute_loss(self, ligand, pocket, return_info=False):
524
+ """
525
+ Samples time steps and computes network predictions
526
+ """
527
+ # TODO: move somewhere else (like collate_fn)
528
+ pocket = Residues(**pocket)
529
+
530
+ # Center sample
531
+ ligand, pocket = center_data(ligand, pocket)
532
+ if pocket['x'].numel() > 0:
533
+ pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0)
534
+ else:
535
+ pocket_com = scatter_mean(ligand['x'], ligand['mask'], dim=0)
536
+
537
+ # # Normalize pocket coordinates
538
+ # pocket['x'] = self.module_x.normalize(pocket['x'])
539
+
540
+ # Sample a timestep t for each example in batch
541
+ t = torch.rand(ligand['size'].size(0), device=ligand['x'].device).unsqueeze(-1)
542
+
543
+ # Noise
544
+ z0_x = self.module_x.sample_z0(pocket_com, ligand['mask'])
545
+ z0_h = self.module_h.sample_z0(ligand['mask'])
546
+ z0_e = self.module_e.sample_z0(ligand['bond_mask'])
547
+ zt_x = self.module_x.sample_zt(z0_x, ligand['x'], t, ligand['mask'])
548
+ zt_h = self.module_h.sample_zt(z0_h, ligand['one_hot'], t, ligand['mask'])
549
+ zt_e = self.module_e.sample_zt(z0_e, ligand['bond_one_hot'], t, ligand['bond_mask'])
550
+
551
+ if self.flexible_bb:
552
+ z0_trans = self.module_trans.sample_z0(pocket_com, pocket['mask'])
553
+ z1_trans = pocket['x'].detach().clone()
554
+ zt_trans = self.module_trans.sample_zt(z0_trans, z1_trans, t, pocket['mask'])
555
+
556
+ z0_rot = self.module_rot.sample_z0(pocket['mask'])
557
+ z1_rot = pocket['axis_angle'].detach().clone()
558
+ zt_rot = self.module_rot.sample_zt(z0_rot, z1_rot, t, pocket['mask'])
559
+
560
+ # update pocket
561
+ pocket.set_frame(zt_trans, zt_rot)
562
+
563
+ z0_chi, zt_chi = None, None
564
+ if self.flexible:
565
+ # residues = [data_utils.residue_from_internal_coord(ic) for ic in pocket['residues']]
566
+ # residues = pocket['residues']
567
+ # z1_chi = torch.stack([data_utils.get_torsion_angles(r, device=self.device) for r in residues], dim=0)
568
+ z1_chi = pocket['chi'][:, :5].detach().clone()
569
+
570
+ z0_chi = self.module_chi.sample_z0(pocket['mask'])
571
+ zt_chi = self.module_chi.sample_zt(z0_chi, z1_chi, t, pocket['mask'])
572
+
573
+ # internal to external coordinates
574
+ pocket.set_chi(zt_chi)
575
+ if pocket['x'].numel() == 0:
576
+ pocket.set_empty_v()
577
+
578
+ # Predict denoising
579
+ sc_transform = self.get_sc_transform_fn(zt_chi, zt_x, t, z0_chi, ligand['mask'], pocket)
580
+ # sc_transform = None
581
+ pred_ligand, pred_residues = self.dynamics(
582
+ zt_x, zt_h, ligand['mask'], pocket, t,
583
+ bonds_ligand=(ligand['bonds'], zt_e), sc_transform=sc_transform
584
+ )
585
+
586
+ # Compute L2 loss
587
+ if self.predict_confidence:
588
+ loss_x = self.module_x.compute_loss(pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce='none')
589
+
590
+ # compute confidence regularization
591
+ k = self.module_x.dim # pred.size(-1)
592
+ sigma = pred_ligand['uncertainty_vel']
593
+ loss_x = loss_x / (2 * sigma ** 2) + k * torch.log(sigma)
594
+
595
+ if self.regularize_uncertainty is not None:
596
+ loss_x = loss_x + self.regularize_uncertainty * (pred_ligand['uncertainty_vel'] - 1) ** 2
597
+
598
+ loss_x = self.module_x.reduce_loss(loss_x, ligand['mask'], reduce=self.loss_reduce)
599
+ else:
600
+ loss_x = self.module_x.compute_loss(pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce=self.loss_reduce)
601
+
602
+ # Loss for categorical variables
603
+ t_next = torch.clamp(t + self.train_step_size, max=1.0)
604
+ loss_h = self.module_h.compute_loss(pred_ligand['logits_h'], zt_h, ligand['one_hot'], ligand['mask'], t, t_next, reduce=self.loss_reduce)
605
+ loss_e = self.module_e.compute_loss(pred_ligand['logits_e'], zt_e, ligand['bond_one_hot'], ligand['bond_mask'], t, t_next, reduce=self.loss_reduce)
606
+
607
+ loss = self.lambda_x * loss_x + self.lambda_h * loss_h + self.lambda_e * loss_e
608
+ if self.flexible:
609
+ loss_chi = self.module_chi.compute_loss(pred_residues['chi'], z0_chi, z1_chi, zt_chi, t, pocket['mask'], reduce=self.loss_reduce)
610
+ loss = loss + self.lambda_chi * loss_chi
611
+
612
+ if self.flexible_bb:
613
+ loss_trans = self.module_trans.compute_loss(pred_residues['trans'], z0_trans, z1_trans, t, pocket['mask'], reduce=self.loss_reduce)
614
+ loss_rot = self.module_rot.compute_loss(pred_residues['rot'], z0_rot, z1_rot, zt_rot, t, pocket['mask'], reduce=self.loss_reduce)
615
+ loss = loss + self.lambda_trans * loss_trans + self.lambda_rot * loss_rot
616
+
617
+ if self.lambda_clash is not None and self.lambda_clash > 0:
618
+
619
+ if self.flexible_bb:
620
+ pred_z1_trans = self.module_trans.get_z1_given_zt_and_pred(zt_trans, pred_residues['trans'], z0_trans, t, pocket['mask'])
621
+ pred_z1_rot = self.module_rot.get_z1_given_zt_and_pred(zt_rot, pred_residues['rot'], z0_rot, t, pocket['mask'])
622
+ pocket.set_frame(pred_z1_trans, pred_z1_rot)
623
+
624
+ if self.flexible:
625
+ # internal to external coordinates
626
+ pred_z1_chi = self.module_chi.get_z1_given_zt_and_pred(zt_chi, pred_residues['chi'], z0_chi, t, pocket['mask'])
627
+ pocket.set_chi(pred_z1_chi)
628
+
629
+ pocket_coord = pocket['x'].unsqueeze(1) + pocket['v']
630
+ pocket_types = aa_atom_type_tensor[pocket['one_hot'].argmax(dim=-1)]
631
+ pocket_mask = pocket['mask'].unsqueeze(-1).repeat((1, pocket['v'].size(1)))
632
+
633
+ # Extract only existing atoms
634
+ atom_mask = aa_atom_mask_tensor[pocket['one_hot'].argmax(dim=-1)]
635
+ pocket_coord = pocket_coord[atom_mask]
636
+ pocket_types = pocket_types[atom_mask]
637
+ pocket_mask = pocket_mask[atom_mask]
638
+
639
+ # pred_z1_x = pred_x + z0_x
640
+ pred_z1_x = self.module_x.get_z1_given_zt_and_pred(zt_x, pred_ligand['vel'], z0_x, t, ligand['mask'])
641
+ pred_z1_h = pred_ligand['logits_h'].argmax(dim=-1)
642
+ loss_clash = clash_loss(pred_z1_x, pred_z1_h, ligand['mask'],
643
+ pocket_coord, pocket_types, pocket_mask)
644
+ loss = loss + self.lambda_clash * loss_clash
645
+
646
+ if self.timestep_weights is not None:
647
+ w_t = self.timestep_weights(t).squeeze()
648
+ loss = w_t * loss
649
+
650
+ loss = loss.mean(0)
651
+
652
+ info = {
653
+ 'loss_x': loss_x.mean().item(),
654
+ 'loss_h': loss_h.mean().item(),
655
+ 'loss_e': loss_e.mean().item(),
656
+ }
657
+ if self.flexible:
658
+ info['loss_chi'] = loss_chi.mean().item()
659
+ if self.flexible_bb:
660
+ info['loss_trans'] = loss_trans.mean().item()
661
+ info['loss_rot'] = loss_rot.mean().item()
662
+ if self.lambda_clash is not None:
663
+ info['loss_clash'] = loss_clash.mean().item()
664
+ if self.predict_confidence:
665
+ sigma_x_mol = scatter_mean(pred_ligand['uncertainty_vel'], ligand['mask'], dim=0)
666
+ info['pearson_sigma_x'] = torch.corrcoef(torch.stack([sigma_x_mol.detach(), t.squeeze()]))[0, 1].item()
667
+ info['mean_sigma_x'] = sigma_x_mol.mean().item()
668
+ entropy_h = Categorical(logits=pred_ligand['logits_h']).entropy()
669
+ entropy_h_mol = scatter_mean(entropy_h, ligand['mask'], dim=0)
670
+ info['pearson_entropy_h'] = torch.corrcoef(torch.stack([entropy_h_mol.detach(), t.squeeze()]))[0, 1].item()
671
+ info['mean_entropy_h'] = entropy_h_mol.mean().item()
672
+ entropy_e = Categorical(logits=pred_ligand['logits_e']).entropy()
673
+ entropy_e_mol = scatter_mean(entropy_e, ligand['bond_mask'], dim=0)
674
+ info['pearson_entropy_e'] = torch.corrcoef(torch.stack([entropy_e_mol.detach(), t.squeeze()]))[0, 1].item()
675
+ info['mean_entropy_e'] = entropy_e_mol.mean().item()
676
+
677
+ return (loss, info) if return_info else loss
678
+
679
+ def training_step(self, data, *args):
680
+ ligand, pocket = data['ligand'], data['pocket']
681
+ try:
682
+ loss, info = self.compute_loss(ligand, pocket, return_info=True)
683
+ except RuntimeError as e:
684
+ # this is not supported for multi-GPU
685
+ if self.trainer.num_devices < 2 and 'out of memory' in str(e):
686
+ print('WARNING: ran out of memory, skipping to the next batch')
687
+ return None
688
+ else:
689
+ raise e
690
+
691
+ log_dict = {k: v for k, v in info.items() if isinstance(v, float)
692
+ or torch.numel(v) <= 1}
693
+ # if self.learn_nu:
694
+ # log_dict['nu_x'] = self.noise_schedules['x'].nu.item()
695
+ # log_dict['nu_h'] = self.noise_schedules['h'].nu.item()
696
+ # log_dict['nu_e'] = self.noise_schedules['e'].nu.item()
697
+
698
+ self.log_metrics({'loss': loss, **log_dict}, 'train',
699
+ batch_size=len(ligand['size']))
700
+
701
+ out = {'loss': loss, **info}
702
+ self.training_step_outputs.append(out)
703
+ return out
704
+
705
+ def validation_step(self, data, *args):
706
+
707
+ # Compute the loss N times and average to get a better estimate
708
+ loss_list, info_list = [], []
709
+ self.dynamics.train() # TODO: this is currently necessary to make self-conditioning work
710
+ for _ in range(self.n_loss_per_sample):
711
+ loss, info = self.compute_loss(data['ligand'].copy(),
712
+ data['pocket'].copy(),
713
+ return_info=True)
714
+ loss_list.append(loss.item())
715
+ info_list.append(info)
716
+ self.dynamics.eval()
717
+ if len(loss_list) >= 1:
718
+ loss = np.mean(loss_list)
719
+ info = {k: np.mean([x[k] for x in info_list]) for k in info_list[0]}
720
+ self.log_metrics({'loss': loss, **info}, 'val', batch_size=len(data['ligand']['size']))
721
+
722
+ # Sample
723
+ rdmols, rdpockets, _ = self.sample(
724
+ data=data,
725
+ n_samples=self.n_eval_samples,
726
+ num_nodes="ground_truth" if self.sample_with_ground_truth_size else None,
727
+ )
728
+
729
+ out = {
730
+ 'ligands': rdmols,
731
+ 'pockets': rdpockets,
732
+ 'receptor_files': [Path(self.receptor_dir, 'val', x) for x in data['pocket']['name']]
733
+ }
734
+ self.validation_step_outputs.append(out)
735
+ return out
736
+
737
+ # def test_step(self, data, *args):
738
+ # self._shared_eval(data, 'test', *args)
739
+
740
+ def on_validation_epoch_end(self):
741
+
742
+ outdir = Path(self.outdir, f'epoch_{self.current_epoch}')
743
+
744
+ rdmols = [m for x in self.validation_step_outputs for m in x['ligands']]
745
+ rdpockets = [p for x in self.validation_step_outputs for p in x['pockets']]
746
+ receptors = [r for x in self.validation_step_outputs for r in x['receptor_files']]
747
+ self.validation_step_outputs.clear()
748
+
749
+ ligand_atom_types = [atom_encoder[a.GetSymbol()] for m in rdmols for a in m.GetAtoms()]
750
+ ligand_bond_types = []
751
+ for m in rdmols:
752
+ bonds = m.GetBonds()
753
+ no_bonds = m.GetNumAtoms() * (m.GetNumAtoms() - 1) // 2 - m.GetNumBonds()
754
+ ligand_bond_types += [bond_encoder['NOBOND']] * no_bonds
755
+ for b in bonds:
756
+ ligand_bond_types.append(bond_encoder[b.GetBondType().name])
757
+
758
+ tic = time()
759
+ results = self.analyze_sample(
760
+ rdmols, ligand_atom_types, ligand_bond_types, receptors=(rdpockets if len(rdpockets) != 0 else None)
761
+ )
762
+ self.log_metrics(results, 'val')
763
+ print(f'Evaluation took {time() - tic:.2f} seconds')
764
+
765
+ if (self.current_epoch + 1) % self.visualize_sample_epoch == 0:
766
+ tic = time()
767
+
768
+ outdir.mkdir(exist_ok=True, parents=True)
769
+
770
+ # center for better visualization
771
+ rdmols = rdmols[:self.n_visualize_samples]
772
+ rdpockets = rdpockets[:self.n_visualize_samples]
773
+ for m, p in zip(rdmols, rdpockets):
774
+ center = m.GetConformer().GetPositions().mean(axis=0)
775
+ for i in range(m.GetNumAtoms()):
776
+ x, y, z = m.GetConformer().GetPositions()[i] - center
777
+ m.GetConformer().SetAtomPosition(i, (x, y, z))
778
+ for i in range(p.GetNumAtoms()):
779
+ x, y, z = p.GetConformer().GetPositions()[i] - center
780
+ p.GetConformer().SetAtomPosition(i, (x, y, z))
781
+
782
+ # save molecule
783
+ utils.write_sdf_file(Path(outdir, 'molecules.sdf'), rdmols)
784
+
785
+ # save pocket
786
+ utils.write_sdf_file(Path(outdir, 'pockets.sdf'), rdpockets)
787
+
788
+ print(f'Sample visualization took {time() - tic:.2f} seconds')
789
+
790
+ if (self.current_epoch + 1) % self.visualize_chain_epoch == 0:
791
+ tic = time()
792
+ outdir.mkdir(exist_ok=True, parents=True)
793
+
794
+ if self.sharded_dataset:
795
+ index = torch.randint(len(self.val_dataset), size=(1,)).item()
796
+ for i, x in enumerate(self.val_dataset):
797
+ if i == index:
798
+ break
799
+ batch = self.val_dataset.collate_fn([x])
800
+ else:
801
+ batch = self.val_dataset.collate_fn([self.val_dataset[torch.randint(len(self.val_dataset), size=(1,))]])
802
+ batch['pocket'] = Residues(**batch['pocket']).to(self.device)
803
+ pocket_copy = batch['pocket'].copy()
804
+
805
+ if len(batch['pocket']['x']) > 0:
806
+ ligand_chain, pocket_chain, info = self.sample_chain(batch['pocket'], self.keep_frames)
807
+ else:
808
+ num_nodes, _ = self.size_distribution.sample()
809
+ ligand_chain, pocket_chain, info = self.sample_chain(batch['pocket'], self.keep_frames, num_nodes=num_nodes)
810
+
811
+ # utils.write_sdf_file(Path(outdir, 'chain_pocket.sdf'), pocket_chain)
812
+ # utils.write_chain(Path(outdir, 'chain_pocket.xyz'), pocket_chain)
813
+ if self.flexible or self.flexible_bb:
814
+ # insert ground truth at the beginning so that it's used by PyMOL to determine the connectivity
815
+ ground_truth_pocket = pocket_to_rdkit(
816
+ pocket_copy, self.pocket_representation,
817
+ self.atom_encoder, self.atom_decoder,
818
+ self.aa_decoder, self.residue_decoder,
819
+ self.aa_atom_index
820
+ )[0]
821
+ ground_truth_ligand = build_molecule(
822
+ batch['ligand']['x'], batch['ligand']['one_hot'].argmax(1),
823
+ bonds=batch['ligand']['bonds'],
824
+ bond_types=batch['ligand']['bond_one_hot'].argmax(1),
825
+ atom_decoder=self.atom_decoder,
826
+ bond_decoder=self.bond_decoder
827
+ )
828
+ pocket_chain.insert(0, ground_truth_pocket)
829
+ ligand_chain.insert(0, ground_truth_ligand)
830
+ # pocket_chain.insert(0, pocket_chain[-1])
831
+ # ligand_chain.insert(0, ligand_chain[-1])
832
+
833
+ # save molecules
834
+ utils.write_sdf_file(Path(outdir, 'chain_ligand.sdf'), ligand_chain)
835
+
836
+ # save pocket
837
+ mols_to_pdbfile(pocket_chain, Path(outdir, 'chain_pocket.pdb'))
838
+
839
+ self.log_metrics(info, 'val')
840
+ print(f'Chain visualization took {time() - tic:.2f} seconds')
841
+
842
+
843
+ # NOTE: temporary fix of this Lightning bug:
844
+ # https://github.com/Lightning-AI/pytorch-lightning/discussions/18110
845
+ # Without it resume training has a strange behavior and fails
846
+ @property
847
+ def total_batch_idx(self) -> int:
848
+ """Returns the current batch index (across epochs)"""
849
+ # use `ready` instead of `completed` in case this is accessed after `completed` has been increased
850
+ # but before the next `ready` increase
851
+ return max(0, self.batch_progress.total.ready - 1)
852
+
853
+ @property
854
+ def batch_idx(self) -> int:
855
+ """Returns the current batch index (within this epoch)"""
856
+ # use `ready` instead of `completed` in case this is accessed after `completed` has been increased
857
+ # but before the next `ready` increase
858
+ return max(0, self.batch_progress.current.ready - 1)
859
+
860
+ # def analyze_sample(self, rdmols, atom_types, bond_types, aa_types=None, receptors=None):
861
+ # out = {}
862
+
863
+ # # Distribution of node types
864
+ # kl_div_atom = self.ligand_atom_type_distribution.kl_divergence(atom_types) \
865
+ # if self.ligand_atom_type_distribution is not None else -1
866
+ # out['kl_div_atom_types'] = kl_div_atom
867
+
868
+ # # Distribution of edge types
869
+ # kl_div_bond = self.ligand_bond_type_distribution.kl_divergence(bond_types) \
870
+ # if self.ligand_bond_type_distribution is not None else -1
871
+ # out['kl_div_bond_types'] = kl_div_bond
872
+
873
+ # if aa_types is not None:
874
+ # kl_div_aa = self.pocket_type_distribution.kl_divergence(aa_types) \
875
+ # if self.pocket_type_distribution is not None else -1
876
+ # out['kl_div_residue_types'] = kl_div_aa
877
+
878
+ # # Post-process sample
879
+ # processed_mols = [process_all(m) for m in rdmols]
880
+
881
+ # # Other basic metrics
882
+ # results = self.ligand_metrics(rdmols)
883
+ # out['n_samples'] = results['n_total']
884
+ # out['Validity'] = results['validity']
885
+ # out['Connectivity'] = results['connectivity']
886
+ # out['valid_and_connected'] = results['valid_and_connected']
887
+
888
+ # # connected_mols = [get_largest_fragment(m) for m in rdmols]
889
+ # connected_mols = [process_all(m, largest_frag=True, adjust_aromatic_Ns=False, relax_iter=0) for m in rdmols]
890
+ # connected_mols = [m for m in connected_mols if m is not None]
891
+ # out.update(self.molecule_properties(connected_mols))
892
+
893
+ # # Repeat after post-processing
894
+ # results = self.ligand_metrics(processed_mols)
895
+ # out['validity_processed'] = results['validity']
896
+ # out['connectivity_processed'] = results['connectivity']
897
+ # out['valid_and_connected_processed'] = results['valid_and_connected']
898
+
899
+ # processed_mols = [m for m in processed_mols if m is not None]
900
+ # for k, v in self.molecule_properties(processed_mols).items():
901
+ # out[f"{k}_processed"] = v
902
+
903
+ # # Simple docking score
904
+ # if receptors is not None and self.gnina is not None:
905
+ # assert len(receptors) == len(rdmols)
906
+ # docking_results = compute_gnina_scores(rdmols, receptors, gnina=self.gnina)
907
+ # out.update(docking_results)
908
+
909
+ # # Clash score
910
+ # if receptors is not None:
911
+ # assert len(receptors) == len(rdmols)
912
+ # clashes = {
913
+ # 'ligands': [legacy_clash_score(m) for m in rdmols],
914
+ # 'pockets': [legacy_clash_score(p) for p in receptors],
915
+ # 'between': [legacy_clash_score(m, p) for m, p in zip(rdmols, receptors)],
916
+ # 'v2_ligands': [clash_score(m) for m in rdmols],
917
+ # 'v2_pockets': [clash_score(p) for p in receptors],
918
+ # 'v2_between': [clash_score(m, p) for m, p in zip(rdmols, receptors)]
919
+ # }
920
+ # for k, v in clashes.items():
921
+ # out[f'mean_clash_score_{k}'] = np.mean(v)
922
+ # out[f'frac_no_clashes_{k}'] = np.mean(np.array(v) <= 0.0)
923
+
924
+ # return out
925
+
926
+ def analyze_sample(self, rdmols, atom_types, bond_types, aa_types=None, receptors=None):
927
+ out = {}
928
+
929
+ # Distribution of node types
930
+ kl_div_atom = self.ligand_atom_type_distribution.kl_divergence(atom_types) \
931
+ if self.ligand_atom_type_distribution is not None else -1
932
+ out['kl_div_atom_types'] = kl_div_atom
933
+
934
+ # Distribution of edge types
935
+ kl_div_bond = self.ligand_bond_type_distribution.kl_divergence(bond_types) \
936
+ if self.ligand_bond_type_distribution is not None else -1
937
+ out['kl_div_bond_types'] = kl_div_bond
938
+
939
+ if aa_types is not None:
940
+ kl_div_aa = self.pocket_type_distribution.kl_divergence(aa_types) \
941
+ if self.pocket_type_distribution is not None else -1
942
+ out['kl_div_residue_types'] = kl_div_aa
943
+
944
+ # Evaluation
945
+ results = []
946
+ if receptors is not None:
947
+ with tempfile.TemporaryDirectory() as tmpdir:
948
+ for mol, receptor in zip(tqdm(rdmols, desc='FullEvaluator'), receptors):
949
+ receptor_path = Path(tmpdir, 'receptor.pdb')
950
+ Chem.MolToPDBFile(receptor, str(receptor_path))
951
+ results.append(self.evaluator(mol, receptor_path))
952
+ else:
953
+ for mol in tqdm(rdmols, desc='FullEvaluator'):
954
+ self.evaluator = FullEvaluator(pb_conf='mol')
955
+ results.append(self.evaluator(mol))
956
+
957
+ results = pd.DataFrame(results)
958
+ agg_results = aggregated_metrics(results, self.evaluator.dtypes, VALIDITY_METRIC_NAME).fillna(0)
959
+ agg_results['metric'] = agg_results['metric'].str.replace('.', '/')
960
+
961
+ col_results = collection_metrics(results, self.train_smiles, VALIDITY_METRIC_NAME, exclude_evaluators='fcd')
962
+ col_results['metric'] = 'collection/' + col_results['metric']
963
+
964
+ all_results = pd.concat([agg_results, col_results])
965
+ out.update(**dict(all_results[['metric', 'value']].values))
966
+
967
+ return out
968
+
969
+ def sample_zt_given_zs(self, zs_ligand, zs_pocket, s, t, delta_eps_x=None, uncertainty=None):
970
+
971
+ sc_transform = self.get_sc_transform_fn(zs_pocket.get('chi'), zs_ligand['x'], s, None, zs_ligand['mask'], zs_pocket)
972
+ pred_ligand, pred_residues = self.dynamics(
973
+ zs_ligand['x'], zs_ligand['h'], zs_ligand['mask'], zs_pocket, s, bonds_ligand=(zs_ligand['bonds'], zs_ligand['e']),
974
+ sc_transform=sc_transform
975
+ )
976
+
977
+ if delta_eps_x is not None:
978
+ pred_ligand['vel'] = pred_ligand['vel'] + delta_eps_x
979
+
980
+ zt_ligand = zs_ligand.copy()
981
+ zt_ligand['x'] = self.module_x.sample_zt_given_zs(zs_ligand['x'], pred_ligand['vel'], s, t, zs_ligand['mask'])
982
+
983
+ zt_ligand['h'] = self.module_h.sample_zt_given_zs(zs_ligand['h'], pred_ligand['logits_h'], s, t, zs_ligand['mask'])
984
+ zt_ligand['e'] = self.module_e.sample_zt_given_zs(zs_ligand['e'], pred_ligand['logits_e'], s, t, zs_ligand['edge_mask'])
985
+
986
+ zt_pocket = zs_pocket.copy()
987
+ if self.flexible_bb:
988
+ zt_trans_pocket = self.module_trans.sample_zt_given_zs(zs_pocket['x'], pred_residues['trans'], s, t, zs_pocket['mask'])
989
+ zt_rot_pocket = self.module_rot.sample_zt_given_zs(zs_pocket['axis_angle'], pred_residues['rot'], s, t, zs_pocket['mask'])
990
+
991
+ # update pocket in-place
992
+ zt_pocket.set_frame(zt_trans_pocket, zt_rot_pocket)
993
+
994
+ if self.flexible:
995
+ zt_chi_pocket = self.module_chi.sample_zt_given_zs(zs_pocket['chi'][..., :5], pred_residues['chi'], s, t, zs_pocket['mask'])
996
+
997
+ # update pocket in-place
998
+ zt_pocket.set_chi(zt_chi_pocket)
999
+
1000
+ if self.predict_confidence:
1001
+ assert uncertainty is not None
1002
+ dt = (t - s).view(-1)[zt_ligand['mask']]
1003
+ uncertainty['sigma_x_squared'] += (dt * pred_ligand['uncertainty_vel']**2)
1004
+ uncertainty['entropy_h'] += (dt * Categorical(logits=pred_ligand['logits_h']).entropy())
1005
+
1006
+ return zt_ligand, zt_pocket
1007
+
1008
+ def simulate(self, ligand, pocket, timesteps, t_start, t_end=1.0,
1009
+ return_frames=1, guide_log_prob=None):
1010
+ """
1011
+ Take a version of the ligand and pocket (at any time step t_start) and
1012
+ simulate the generative process from t_start to t_end.
1013
+ """
1014
+
1015
+ assert 0 < return_frames <= timesteps
1016
+ assert timesteps % return_frames == 0
1017
+ assert 0.0 <= t_start < 1.0
1018
+ assert 0 < t_end <= 1.0
1019
+ assert t_start < t_end
1020
+
1021
+ device = ligand['x'].device
1022
+ n_samples = len(pocket['size'])
1023
+ delta_t = (t_end - t_start) / timesteps
1024
+
1025
+ # Initialize output tensors
1026
+ out_ligand = {
1027
+ 'x': torch.zeros((return_frames, len(ligand['mask']), self.x_dim), device=device),
1028
+ 'h': torch.zeros((return_frames, len(ligand['mask']), self.atom_nf), device=device),
1029
+ 'e': torch.zeros((return_frames, len(ligand['edge_mask']), self.bond_nf), device=device)
1030
+ }
1031
+ if self.predict_confidence:
1032
+ out_ligand['sigma_x'] = torch.zeros((return_frames, len(ligand['mask'])), device=device)
1033
+ out_ligand['entropy_h'] = torch.zeros((return_frames, len(ligand['mask'])), device=device)
1034
+ out_pocket = {
1035
+ 'x': torch.zeros((return_frames, len(pocket['mask']), 3), device=device), # CA-coord
1036
+ 'v': torch.zeros((return_frames, len(pocket['mask']), self.n_atom_aa, 3), device=device) # difference vectors to all other atoms
1037
+ }
1038
+
1039
+ cumulative_uncertainty = {
1040
+ 'sigma_x_squared': torch.zeros(len(ligand['mask']), device=device),
1041
+ 'entropy_h': torch.zeros(len(ligand['mask']), device=device)
1042
+ } if self.predict_confidence else None
1043
+
1044
+ for i, t in enumerate(torch.linspace(t_start, t_end - delta_t, timesteps)):
1045
+ t_array = torch.full((n_samples, 1), fill_value=t, device=device)
1046
+
1047
+ if guide_log_prob is not None:
1048
+ raise NotImplementedError('Not yet implemented for flow matching model')
1049
+ alpha_t = self.diffusion_x.schedule.alpha(self.gamma_x(t_array))
1050
+
1051
+ with torch.enable_grad():
1052
+ zt_x_ligand.requires_grad = True
1053
+ g = guide_log_prob(t_array, x=ligand['x'], h=ligand['h'], batch_mask=ligand['mask'],
1054
+ bonds=ligand['bonds'], bond_types=ligand['e'])
1055
+
1056
+ # Compute gradient w.r.t. coordinates
1057
+ grad_x_lig = torch.autograd.grad(g.sum(), inputs=ligand['x'])[0]
1058
+
1059
+ # clip gradients
1060
+ g_max = 1.0
1061
+ clip_mask = (grad_x_lig.norm(dim=-1) > g_max)
1062
+ grad_x_lig[clip_mask] = \
1063
+ grad_x_lig[clip_mask] / grad_x_lig[clip_mask].norm(
1064
+ dim=-1, keepdim=True) * g_max
1065
+
1066
+ delta_eps_lig = -1 * (1 - alpha_t[lig_mask]).sqrt() * grad_x_lig
1067
+ else:
1068
+ delta_eps_lig = None
1069
+
1070
+ ligand, pocket = self.sample_zt_given_zs(
1071
+ ligand, pocket, t_array, t_array + delta_t, delta_eps_lig, cumulative_uncertainty)
1072
+
1073
+ # save frame
1074
+ if (i + 1) % (timesteps // return_frames) == 0:
1075
+ idx = (i + 1) // (timesteps // return_frames)
1076
+ idx = idx - 1
1077
+
1078
+ out_ligand['x'][idx] = ligand['x'].detach()
1079
+ out_ligand['h'][idx] = ligand['h'].detach()
1080
+ out_ligand['e'][idx] = ligand['e'].detach()
1081
+ if pocket['x'].numel() > 0:
1082
+ out_pocket['x'][idx] = pocket['x'].detach()
1083
+ out_pocket['v'][idx] = pocket['v'][:, :self.n_atom_aa, :].detach()
1084
+ if self.predict_confidence:
1085
+ out_ligand['sigma_x'][idx] = cumulative_uncertainty['sigma_x_squared'].sqrt().detach()
1086
+ out_ligand['entropy_h'][idx] = cumulative_uncertainty['entropy_h'].detach()
1087
+
1088
+ # remove frame dimension if only the final molecule is returned
1089
+ out_ligand = {k: v.squeeze(0) for k, v in out_ligand.items()}
1090
+ out_pocket = {k: v.squeeze(0) for k, v in out_pocket.items()}
1091
+
1092
+ return out_ligand, out_pocket
1093
+
1094
+ def init_ligand(self, num_nodes_lig, pocket):
1095
+ device = pocket['x'].device
1096
+
1097
+ n_samples = len(pocket['size'])
1098
+ lig_mask = utils.num_nodes_to_batch_mask(n_samples, num_nodes_lig, device)
1099
+
1100
+ # only consider upper triangular matrix for symmetry
1101
+ lig_bonds = torch.stack(torch.where(torch.triu(
1102
+ lig_mask[:, None] == lig_mask[None, :], diagonal=1)), dim=0)
1103
+ lig_edge_mask = lig_mask[lig_bonds[0]]
1104
+
1105
+ # Sample from Normal distribution in the pocket center
1106
+ pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0)
1107
+ z0_x = self.module_x.sample_z0(pocket_com, lig_mask)
1108
+ z0_h = self.module_h.sample_z0(lig_mask)
1109
+ z0_e = self.module_e.sample_z0(lig_edge_mask)
1110
+
1111
+ return TensorDict(**{
1112
+ 'x': z0_x, 'h': z0_h, 'e': z0_e, 'mask': lig_mask,
1113
+ 'bonds': lig_bonds, 'edge_mask': lig_edge_mask
1114
+ })
1115
+
1116
+ def init_pocket(self, pocket):
1117
+
1118
+ if self.flexible_bb:
1119
+ pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0)
1120
+ z0_trans = self.module_trans.sample_z0(pocket_com, pocket['mask'])
1121
+ z0_rot = self.module_rot.sample_z0(pocket['mask'])
1122
+
1123
+ # update pocket in-place
1124
+ pocket.set_frame(z0_trans, z0_rot)
1125
+
1126
+ if self.flexible:
1127
+ z0_chi = self.module_chi.sample_z0(pocket['mask'])
1128
+
1129
+ # # DEBUG ##
1130
+ # z0_chi = torch.stack([data_utils.get_torsion_angles(r, device=self.device) for r in pocket['residues']], dim=0)
1131
+ # ####
1132
+
1133
+ # internal to external coordinates
1134
+ pocket.set_chi(z0_chi)
1135
+
1136
+ if pocket['x'].numel() == 0:
1137
+ pocket.set_empty_v()
1138
+
1139
+ return pocket
1140
+
1141
+ def parse_num_nodes_spec(self, batch, spec=None, size_model=None):
1142
+
1143
+ if spec == "2d_histogram" or spec is None: # default option
1144
+ assert "pocket" in batch
1145
+ num_nodes = self.size_distribution.sample_conditional(
1146
+ n1=None, n2=batch['pocket']['size'])
1147
+
1148
+ # make sure there is at least one potential bond
1149
+ num_nodes[num_nodes < 2] = 2
1150
+
1151
+ elif isinstance(spec, (int, torch.Tensor)):
1152
+ num_nodes = spec
1153
+
1154
+ elif spec == "ground_truth":
1155
+ assert "ligand" in batch
1156
+ num_nodes = batch['ligand']['size']
1157
+
1158
+ elif spec == "nn_prediction":
1159
+ assert size_model is not None
1160
+ assert "pocket" in batch
1161
+ predictions = size_model.forward(batch['pocket'])
1162
+ predictions = torch.softmax(predictions, dim=-1)
1163
+ predictions[:, :5] = 0.0
1164
+ probabilities = predictions / predictions.sum(dim=1, keepdims=True)
1165
+ num_nodes = torch.distributions.Categorical(probabilities).sample()
1166
+
1167
+ elif isinstance(spec, str) and spec.startswith("uniform"):
1168
+ # expected format: uniform_low_high
1169
+ assert "pocket" in batch
1170
+ left, right = map(int, spec.split("_")[1:])
1171
+ shape = batch['pocket']['size'].shape
1172
+ num_nodes = torch.randint(left, right + 1, shape, dtype=torch.long)
1173
+
1174
+ else:
1175
+ raise NotImplementedError(f"Invalid size specification {spec}")
1176
+
1177
+ if self.virtual_nodes:
1178
+ num_nodes += self.add_virtual_max
1179
+
1180
+ return num_nodes
1181
+
1182
+ @torch.no_grad()
1183
+ def sample(self, data, n_samples, num_nodes=None, timesteps=None,
1184
+ guide_log_prob=None, size_model=None, **kwargs):
1185
+
1186
+ # TODO: move somewhere else (like collate_fn)
1187
+ data['pocket'] = Residues(**data['pocket'])
1188
+
1189
+ timesteps = self.T_sampling if timesteps is None else timesteps
1190
+
1191
+ if len(data['pocket']['x']) > 0:
1192
+ pocket = data_utils.repeat_items(data['pocket'], n_samples)
1193
+ else:
1194
+ pocket = Residues(**{key: value for key, value in data['pocket'].items()})
1195
+ pocket['name'] = pocket['name'] * n_samples
1196
+ pocket['size'] = pocket['size'].repeat(n_samples)
1197
+ pocket['n_bonds'] = pocket['n_bonds'].repeat(n_samples)
1198
+
1199
+ _ligand = data_utils.repeat_items(data['ligand'], n_samples)
1200
+ # _ligand = randomize_tensors(_ligand, exclude_keys=['size', 'name']) # avoid data leakage
1201
+
1202
+ batch = {"ligand": _ligand, "pocket": pocket}
1203
+ num_nodes = self.parse_num_nodes_spec(batch, spec=num_nodes, size_model=size_model)
1204
+
1205
+ # Sample from prior
1206
+ if pocket['x'].numel() > 0:
1207
+ ligand = self.init_ligand(num_nodes, pocket)
1208
+ else:
1209
+ ligand = self.init_ligand(num_nodes, _ligand)
1210
+ pocket = self.init_pocket(pocket)
1211
+
1212
+ # return prior samples
1213
+ if timesteps == 0:
1214
+ # Convert into rdmols
1215
+ rdmols = [build_molecule(coords=m['x'],
1216
+ atom_types=m['h'].argmax(1),
1217
+ bonds=m['bonds'],
1218
+ bond_types=m['e'].argmax(1),
1219
+ atom_decoder=self.atom_decoder, bond_decoder=self.bond_decoder)
1220
+ for m in data_utils.split_entity(ligand.detach().cpu(), edge_types={"e", "edge_mask"}, edge_mask=ligand["edge_mask"])]
1221
+
1222
+ rdpockets = pocket_to_rdkit(pocket, self.pocket_representation,
1223
+ self.atom_encoder, self.atom_decoder,
1224
+ self.aa_decoder, self.residue_decoder,
1225
+ self.aa_atom_index)
1226
+
1227
+ return rdmols, rdpockets, _ligand['name']
1228
+
1229
+ out_tensors_ligand, out_tensors_pocket = self.simulate(
1230
+ ligand, pocket, timesteps, 0.0, 1.0,
1231
+ guide_log_prob=guide_log_prob
1232
+ )
1233
+
1234
+ # Build mol objects
1235
+ x = out_tensors_ligand['x'].detach().cpu()
1236
+ ligand_type = out_tensors_ligand['h'].argmax(1).detach().cpu()
1237
+ edge_type = out_tensors_ligand['e'].argmax(1).detach().cpu()
1238
+ lig_mask = ligand['mask'].detach().cpu()
1239
+ lig_bonds = ligand['bonds'].detach().cpu()
1240
+ lig_edge_mask = ligand['edge_mask'].detach().cpu()
1241
+ sizes = torch.unique(ligand['mask'], return_counts=True)[1].tolist()
1242
+ offsets = list(accumulate(sizes[:-1], initial=0))
1243
+ mol_kwargs = {
1244
+ 'coords': utils.batch_to_list(x, lig_mask),
1245
+ 'atom_types': utils.batch_to_list(ligand_type, lig_mask),
1246
+ 'bonds': utils.batch_to_list_for_indices(lig_bonds, lig_edge_mask, offsets),
1247
+ 'bond_types': utils.batch_to_list(edge_type, lig_edge_mask)
1248
+ }
1249
+ if self.predict_confidence:
1250
+ sigma_x = out_tensors_ligand['sigma_x'].detach().cpu()
1251
+ entropy_h = out_tensors_ligand['entropy_h'].detach().cpu()
1252
+ mol_kwargs['atom_props'] = [
1253
+ {'sigma_x': x[0], 'entropy_h': x[1]}
1254
+ for x in zip(utils.batch_to_list(sigma_x, lig_mask),
1255
+ utils.batch_to_list(entropy_h, lig_mask))
1256
+ ]
1257
+ mol_kwargs = [{k: v[i] for k, v in mol_kwargs.items()}
1258
+ for i in range(len(mol_kwargs['coords']))]
1259
+
1260
+ # Convert into rdmols
1261
+ rdmols = [build_molecule(
1262
+ **m, atom_decoder=self.atom_decoder, bond_decoder=self.bond_decoder)
1263
+ for m in mol_kwargs
1264
+ ]
1265
+
1266
+ out_pocket = pocket.copy()
1267
+ out_pocket['x'] = out_tensors_pocket['x']
1268
+ out_pocket['v'] = out_tensors_pocket['v']
1269
+ rdpockets = pocket_to_rdkit(out_pocket, self.pocket_representation,
1270
+ self.atom_encoder, self.atom_decoder,
1271
+ self.aa_decoder, self.residue_decoder,
1272
+ self.aa_atom_index)
1273
+
1274
+ return rdmols, rdpockets, _ligand['name']
1275
+
1276
+ @torch.no_grad()
1277
+ def sample_chain(self, pocket, keep_frames, num_nodes=None, timesteps=None,
1278
+ guide_log_prob=None, **kwargs):
1279
+
1280
+ # TODO: move somewhere else (like collate_fn)
1281
+ pocket = Residues(**pocket)
1282
+
1283
+ info = {}
1284
+
1285
+ timesteps = self.T_sampling if timesteps is None else timesteps
1286
+
1287
+ # n_samples = 1
1288
+ # TODO: get batch_size differently
1289
+ assert len(pocket['mask'].unique()) <= 1, "sample_chain only supports a single sample"
1290
+
1291
+ # # Pocket's initial center of mass
1292
+ # pocket_com_before = scatter_mean(pocket['x'], pocket['mask'], dim=0)
1293
+
1294
+ num_nodes = self.parse_num_nodes_spec(batch={"pocket": pocket}, spec=num_nodes)
1295
+
1296
+ # Sample from prior
1297
+ if pocket['x'].numel() > 0:
1298
+ ligand = self.init_ligand(num_nodes, pocket)
1299
+ else:
1300
+ dummy_pocket = Residues.empty(pocket['x'].device)
1301
+ ligand = self.init_ligand(num_nodes, dummy_pocket)
1302
+
1303
+ pocket = self.init_pocket(pocket)
1304
+
1305
+ out_tensors_ligand, out_tensors_pocket = self.simulate(
1306
+ ligand, pocket, timesteps, 0.0, 1.0, guide_log_prob=guide_log_prob, return_frames=keep_frames)
1307
+
1308
+ # chain_lig = utils.reverse_tensor(chain_lig)
1309
+ # chain_pocket = utils.reverse_tensor(chain_pocket)
1310
+ # chain_bond = utils.reverse_tensor(chain_bond)
1311
+
1312
+ info['traj_displacement_lig'] = torch.norm(out_tensors_ligand['x'][-1] - out_tensors_ligand['x'][0], dim=-1).mean()
1313
+ info['traj_rms_lig'] = out_tensors_ligand['x'].std(dim=0).mean()
1314
+
1315
+ # # Repeat last frame to see final sample better.
1316
+ # chain_lig = torch.cat([chain_lig, chain_lig[-1:].repeat(10, 1, 1)], dim=0)
1317
+ # chain_pocket = torch.cat([chain_pocket, chain_pocket[-1:].repeat(10, 1, 1)], dim=0)
1318
+ # chain_bond = torch.cat([chain_bond, chain_bond[-1:].repeat(10, 1, 1)], dim=0)
1319
+
1320
+ # Flatten
1321
+ assert keep_frames == out_tensors_ligand['x'].size(0) == out_tensors_pocket['x'].size(0)
1322
+ n_atoms = out_tensors_ligand['x'].size(1)
1323
+ n_bonds = out_tensors_ligand['e'].size(1)
1324
+ n_residues = out_tensors_pocket['x'].size(1)
1325
+ device = out_tensors_ligand['x'].device
1326
+
1327
+ def flatten_tensor(chain):
1328
+ if len(chain.size()) == 3: # l=0 values
1329
+ return chain.view(-1, chain.size(-1))
1330
+ elif len(chain.size()) == 4: # vectors
1331
+ return chain.view(-1, chain.size(-2), chain.size(-1))
1332
+ else:
1333
+ warnings.warn(f"Could not flatten frame dimension of tensor with shape {list(chain.size())}")
1334
+ return chain
1335
+
1336
+ out_tensors_ligand_flat = {k: flatten_tensor(chain) for k, chain in out_tensors_ligand.items()}
1337
+ out_tensors_pocket_flat = {k: flatten_tensor(chain) for k, chain in out_tensors_pocket.items()}
1338
+ # ligand_flat = chain_lig.view(-1, chain_lig.size(-1))
1339
+ # ligand_mask_flat = torch.arange(chain_lig.size(0)).repeat_interleave(chain_lig.size(1)).to(chain_lig.device)
1340
+ ligand_mask_flat = torch.arange(keep_frames).repeat_interleave(n_atoms).to(device)
1341
+
1342
+ # # pocket_flat = chain_pocket.view(-1, chain_pocket.size(-1))
1343
+ # # pocket_v_flat = pocket['v'].repeat(100, 1, 1)
1344
+ # pocket_flat = chain_pocket.view(-1, chain_pocket.size(-2), chain_pocket.size(-1))
1345
+ # pocket_mask_flat = torch.arange(chain_pocket.size(0)).repeat_interleave(chain_pocket.size(1)).to(chain_pocket.device)
1346
+ pocket_mask_flat = torch.arange(keep_frames).repeat_interleave(n_residues).to(device)
1347
+
1348
+ # bond_flat = chain_bond.view(-1, chain_bond.size(-1))
1349
+ # bond_mask_flat = torch.arange(chain_bond.size(0)).repeat_interleave(chain_bond.size(1)).to(chain_bond.device)
1350
+ bond_mask_flat = torch.arange(keep_frames).repeat_interleave(n_bonds).to(device)
1351
+ edges_flat = ligand['bonds'].repeat(1, keep_frames)
1352
+
1353
+ # # Move generated molecule back to the original pocket position
1354
+ # pocket_com_after = scatter_mean(pocket_flat[:, 0, :], pocket_mask_flat, dim=0)
1355
+ # ligand_flat[:, :self.x_dim] += (pocket_com_before - pocket_com_after)[ligand_mask_flat]
1356
+ #
1357
+ # # Move pocket back as well (for visualization purposes)
1358
+ # pocket_flat[:, 0, :] += (pocket_com_before - pocket_com_after)[pocket_mask_flat]
1359
+
1360
+ # Build ligands
1361
+ x = out_tensors_ligand_flat['x'].detach().cpu()
1362
+ ligand_type = out_tensors_ligand_flat['h'].argmax(1).detach().cpu()
1363
+ ligand_mask_flat = ligand_mask_flat.detach().cpu()
1364
+ bond_mask_flat = bond_mask_flat.detach().cpu()
1365
+ edges_flat = edges_flat.detach().cpu()
1366
+ edge_type = out_tensors_ligand_flat['e'].argmax(1).detach().cpu()
1367
+ offsets = torch.zeros(keep_frames, dtype=int) # edges_flat is already zero-based
1368
+ molecules = list(
1369
+ zip(utils.batch_to_list(x, ligand_mask_flat),
1370
+ utils.batch_to_list(ligand_type, ligand_mask_flat),
1371
+ utils.batch_to_list_for_indices(edges_flat, bond_mask_flat, offsets),
1372
+ utils.batch_to_list(edge_type, bond_mask_flat)
1373
+ )
1374
+ )
1375
+
1376
+ # Convert into rdmols
1377
+ ligand_chain = [build_molecule(
1378
+ *graph, atom_decoder=self.atom_decoder,
1379
+ bond_decoder=self.bond_decoder) for graph in molecules
1380
+ ]
1381
+
1382
+ # Build pockets
1383
+ # as long as the pocket does not change during sampling, we can ust
1384
+ # write it once
1385
+ out_pocket = {
1386
+ 'x': out_tensors_pocket_flat['x'],
1387
+ 'one_hot': pocket['one_hot'].repeat(keep_frames, 1),
1388
+ 'mask': pocket_mask_flat,
1389
+ 'v': out_tensors_pocket_flat['v'],
1390
+ 'atom_mask': pocket['atom_mask'].repeat(keep_frames, 1),
1391
+ } if self.flexible else pocket
1392
+ pocket_chain = pocket_to_rdkit(out_pocket, self.pocket_representation,
1393
+ self.atom_encoder, self.atom_decoder,
1394
+ self.aa_decoder, self.residue_decoder,
1395
+ self.aa_atom_index)
1396
+
1397
+ return ligand_chain, pocket_chain, info
1398
+
1399
+ # def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
1400
+ # def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm):
1401
+ def configure_gradient_clipping(self, optimizer, *args, **kwargs):
1402
+
1403
+ if not self.clip_grad:
1404
+ return
1405
+
1406
+ # Allow gradient norm to be 150% + 2 * stdev of the recent history.
1407
+ max_grad_norm = 1.5 * self.gradnorm_queue.mean() + \
1408
+ 2 * self.gradnorm_queue.std()
1409
+
1410
+ # hard upper limit
1411
+ max_grad_norm = min(max_grad_norm, 10.0)
1412
+
1413
+ # Get current grad_norm
1414
+ params = [p for g in optimizer.param_groups for p in g['params']]
1415
+ grad_norm = utils.get_grad_norm(params)
1416
+
1417
+ # Lightning will handle the gradient clipping
1418
+ self.clip_gradients(optimizer, gradient_clip_val=max_grad_norm,
1419
+ gradient_clip_algorithm='norm')
1420
+
1421
+ if float(grad_norm) > max_grad_norm:
1422
+ print(f'Clipped gradient with value {grad_norm:.1f} '
1423
+ f'while allowed {max_grad_norm:.1f}')
1424
+ grad_norm = max_grad_norm
1425
+
1426
+ self.gradnorm_queue.add(float(grad_norm))
src/model/loss_utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_scatter import scatter_add, scatter_mean
3
+
4
+ from src.constants import atom_decoder, vdw_radii
5
+ _vdw_radii = {**vdw_radii}
6
+ _vdw_radii['NH'] = vdw_radii['N']
7
+ _vdw_radii['N+'] = vdw_radii['N']
8
+ _vdw_radii['O-'] = vdw_radii['O']
9
+ _vdw_radii['NOATOM'] = 0
10
+ vdw_radii_array = torch.tensor([_vdw_radii[a] for a in atom_decoder])
11
+
12
+
13
+ def clash_loss(ligand_coord, ligand_types, ligand_mask, pocket_coord,
14
+ pocket_types, pocket_mask):
15
+ """
16
+ Computes a clash loss that penalizes interatomic distances smaller than the
17
+ sum of van der Waals radii between atoms.
18
+ """
19
+
20
+ ligand_radii = vdw_radii_array[ligand_types].to(ligand_coord.device)
21
+ pocket_radii = vdw_radii_array[pocket_types].to(pocket_coord.device)
22
+
23
+ dist = torch.sqrt(torch.sum((ligand_coord[:, None, :] - pocket_coord[None, :, :]) ** 2, dim=-1))
24
+ # dist[ligand_mask[:, None] != pocket_mask[None, :]] = float('inf')
25
+
26
+ # compute linearly decreasing penalty
27
+ # penalty = max(1 - 1/sum_vdw * d, 0)
28
+ sum_vdw = ligand_radii[:, None] + pocket_radii[None, :]
29
+ loss = torch.clamp(1 - dist / sum_vdw, min=0.0) # (n_ligand, n_pocket)
30
+
31
+ loss = scatter_add(loss, pocket_mask, dim=1)
32
+ loss = scatter_mean(loss, ligand_mask, dim=0)
33
+ loss = loss.diag()
34
+
35
+ # # DEBUG (non-differentiable version)
36
+ # dist = torch.sqrt(torch.sum((ligand_coord[:, None, :] - pocket_coord[None, :, :]) ** 2, dim=-1))
37
+ # dist[ligand_mask[:, None] != pocket_mask[None, :]] = float('inf')
38
+ # _loss = torch.clamp(1 - dist / sum_vdw, min=0.0) # (n_ligand, n_pocket)
39
+ # _loss = _loss.sum(dim=-1)
40
+ # _loss = scatter_mean(_loss, ligand_mask, dim=0)
41
+ # assert torch.allclose(loss, _loss)
42
+
43
+ return loss
44
+
45
+
46
+ class TimestepSampler:
47
+ def __init__(self, type='uniform', lowest_t=1, highest_t=500):
48
+ assert type in {'uniform', 'sigmoid'}
49
+ self.type = type
50
+ self.lowest_t = lowest_t
51
+ self.highest_t = highest_t
52
+
53
+ def __call__(self, n, device=None):
54
+ if self.type == 'uniform':
55
+ t_int = torch.randint(self.lowest_t, self.highest_t + 1,
56
+ size=(n, 1), device=device)
57
+
58
+ elif self.type == 'sigmoid':
59
+ weight_fun = lambda t: 1.45 * torch.sigmoid(-t * 10 / self.highest_t + 5) + 0.05
60
+
61
+ possible_ts = torch.arange(self.lowest_t, self.highest_t + 1, device=device)
62
+ weights = weight_fun(possible_ts)
63
+ weights = weights / weights.sum()
64
+ t_int = possible_ts[torch.multinomial(weights, n, replacement=True)].unsqueeze(-1)
65
+
66
+ return t_int.float()
67
+
68
+
69
+ class TimestepWeights:
70
+ def __init__(self, weight_type, a, b):
71
+ if weight_type != 'sigmoid':
72
+ raise NotImplementedError("Only sigmoidal loss weighting is available.")
73
+ # self.weight_fn = lambda t: a * torch.sigmoid((-t + 0.5) * b) + (1 - a / 2)
74
+ self.weight_fn = lambda t: a * torch.sigmoid((t - 0.5) * b) + (1 - a / 2)
75
+
76
+ def __call__(self, t_array):
77
+ # normalized t \in [0, 1]
78
+ # return self.weight_fn(1 - t_array)
79
+ return self.weight_fn(t_array)
src/model/markov_bridge.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch_scatter import scatter_mean, scatter_add
5
+
6
+ from src.utils import bvm
7
+
8
+
9
+ class LinearSchedule:
10
+ """
11
+ We use the scheduling parameter \beta to linearly remove noise, i.e.
12
+ \bar{\beta}_t = 1 - h (h: step size) with
13
+ \bar{Q}_t = \bar{\beta}_t I + (1 - \bar{\beta}_t) 1_vec z1^T
14
+
15
+ From this, it follows that for each step transition matrix, we have
16
+ \beta_t = \bar{\beta}_t / \bar{\beta}_{t-h} = \frac{1-t}{1-t+h}
17
+ """
18
+ def __init__(self):
19
+ super().__init__()
20
+
21
+ def beta_bar(self, t):
22
+ return 1 - t
23
+
24
+ def beta(self, t, step_size):
25
+ return (1 - t) / (1 - t + step_size)
26
+
27
+
28
+ class UniformPriorMarkovBridge:
29
+ """
30
+ Markov bridge model in which z0 is drawn from a uniform prior.
31
+ Transitions are defined as:
32
+ Q_t = \beta_t I + (1 - \beta_t) 1_vec z1^T
33
+ where z1 is a one-hot representation of the final state.
34
+ We follow the notation from [1] and multiply transition matrices from the
35
+ right to one-hot state vectors.
36
+
37
+ We use the scheduling parameter \beta to linearly remove noise, i.e.
38
+ \bar{\beta}_t = 1 - h (h: step size) with
39
+ \bar{Q}_t = \bar{\beta}_t I + (1 - \bar{\beta}_t) 1_vec z1^T
40
+
41
+ From this, it follows that for each step transition matrix, we have
42
+ \beta_t = \bar{\beta}_t / \bar{\beta}_{t-h} = \frac{1-t}{1-t+h}
43
+
44
+ [1] Austin, Jacob, et al.
45
+ "Structured denoising diffusion models in discrete state-spaces."
46
+ Advances in Neural Information Processing Systems 34 (2021): 17981-17993.
47
+ """
48
+ def __init__(self, dim, loss_type='CE', step_size=None):
49
+ assert loss_type in ['VLB', 'CE']
50
+ self.dim = dim
51
+ self.step_size = step_size # required for VLB
52
+ self.schedule = LinearSchedule()
53
+ self.loss_type = loss_type
54
+ super(UniformPriorMarkovBridge, self).__init__()
55
+
56
+ @staticmethod
57
+ def sample_categorical(p):
58
+ """
59
+ Sample from categorical distribution defined by probabilities 'p'
60
+ :param p: (n, dim)
61
+ :return: one-hot encoded samples (n, dim)
62
+ """
63
+ sampled = torch.multinomial(p, 1).squeeze(-1)
64
+ return F.one_hot(sampled, num_classes=p.size(1)).float()
65
+
66
+ def p_z0(self, batch_mask):
67
+ return torch.ones((len(batch_mask), self.dim), device=batch_mask.device) / self.dim
68
+
69
+ def sample_z0(self, batch_mask):
70
+ """ Prior. """
71
+ z0 = self.sample_categorical(self.p_z0(batch_mask))
72
+ return z0
73
+
74
+ def p_zt(self, z0, z1, t, batch_mask):
75
+ Qt_bar = self.get_Qt_bar(t, z1, batch_mask)
76
+ return bvm(z0, Qt_bar)
77
+
78
+ def sample_zt(self, z0, z1, t, batch_mask):
79
+ zt = self.sample_categorical(self.p_zt(z0, z1, t, batch_mask))
80
+ return zt
81
+
82
+ def p_zt_given_zs_and_z1(self, zs, z1, s, t, batch_mask):
83
+ # 'z1' are one-hot "probabilities" for each class
84
+ Qt = self.get_Qt(t, s, z1, batch_mask)
85
+ # from pdb import set_trace; set_trace()
86
+ q_zs_given_zt = bvm(zs, Qt)
87
+ return q_zs_given_zt
88
+
89
+ def p_zt_given_zs(self, zs, p_z1_hat, s, t, batch_mask):
90
+ """
91
+ Note that x can also represent a categorical distribution to compute
92
+ transitions more efficiently at sampling time:
93
+ p(z_t|z_s) = \sum_{\hat{z}_1} p(z_t | z_s, \hat{z}_1) * p(\hat{z}_1 | z_s)
94
+ = \sum_i z_s (\beta_t I + (1 - \beta_t) 1_vec z1_i^T) * \hat{p}_i
95
+ = \beta_t z_s I + (1 - \beta_t) z_s 1_vec \hat{p}^t
96
+ """
97
+ return self.p_zt_given_zs_and_z1(zs, p_z1_hat, s, t, batch_mask)
98
+
99
+ def sample_zt_given_zs(self, zs, z1_logits, s, t, batch_mask):
100
+ p_z1 = z1_logits.softmax(dim=-1)
101
+ zt = self.sample_categorical(self.p_zt_given_zs(zs, p_z1, s, t, batch_mask))
102
+ return zt
103
+
104
+ def compute_loss(self, pred_logits, zs, z1, batch_mask, s, t, reduce='mean'):
105
+ """ Compute loss per sample. """
106
+ assert reduce in {'mean', 'sum', 'none'}
107
+
108
+ if self.loss_type == 'CE':
109
+ loss = F.cross_entropy(pred_logits, z1, reduction='none')
110
+
111
+ else: # VLB
112
+ true_p_zs = self.p_zt_given_zs_and_z1(zs, z1, s, t, batch_mask)
113
+ pred_p_zs = self.p_zt_given_zs(zs, pred_logits.softmax(dim=-1), s, t, batch_mask)
114
+ loss = F.kl_div(pred_p_zs.log(), true_p_zs, reduction='none').sum(dim=-1)
115
+
116
+ if reduce == 'mean':
117
+ loss = scatter_mean(loss, batch_mask, dim=0)
118
+ elif reduce == 'sum':
119
+ loss = scatter_add(loss, batch_mask, dim=0)
120
+
121
+ return loss
122
+
123
+ def get_Qt(self, t, s, z1, batch_mask):
124
+ """ Returns one-step transition matrix from step s to step t. """
125
+
126
+ beta_t_given_s = self.schedule.beta(t, t - s)
127
+ beta_t_given_s = beta_t_given_s.unsqueeze(-1)[batch_mask]
128
+
129
+ # Q_t = beta_t * I + (1 - beta_t) * ones (dot) z1^T
130
+ Qt = beta_t_given_s * torch.eye(self.dim, device=t.device).unsqueeze(0) + \
131
+ (1 - beta_t_given_s) * z1.unsqueeze(1)
132
+ # (1 - beta_t_given_s) * (torch.ones(self.dim, 1, device=t.device) @ z1)
133
+
134
+ # assert (Qt.sum(-1) == 1).all()
135
+
136
+ return Qt
137
+
138
+ def get_Qt_bar(self, t, z1, batch_mask):
139
+ """ Returns transition matrix from step 0 to step t. """
140
+
141
+ beta_bar_t = self.schedule.beta_bar(t)
142
+ beta_bar_t = beta_bar_t.unsqueeze(-1)[batch_mask]
143
+
144
+ # Q_t_bar = beta_bar * I + (1 - beta_bar) * ones (dot) z1^T
145
+ Qt_bar = beta_bar_t * torch.eye(self.dim, device=t.device).unsqueeze(0) + \
146
+ (1 - beta_bar_t) * z1.unsqueeze(1)
147
+ # (1 - beta_bar_t) * (torch.ones(self.dim, 1, device=t.device) @ z1)
148
+
149
+ # assert (Qt_bar.sum(-1) == 1).all()
150
+
151
+ return Qt_bar
152
+
153
+
154
+ class MarginalPriorMarkovBridge(UniformPriorMarkovBridge):
155
+ def __init__(self, dim, prior_p, loss_type='CE', step_size=None):
156
+ self.prior_p = prior_p
157
+ print('Marginal Prior MB')
158
+ super(MarginalPriorMarkovBridge, self).__init__(dim, loss_type, step_size)
159
+
160
+ def p_z0(self, batch_mask):
161
+ device = batch_mask.device
162
+ p = torch.ones((len(batch_mask), self.dim), device=device) * self.prior_p.view(1, -1).to(device)
163
+ return p
src/sample_and_evaluate.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ import yaml
4
+ import torch
5
+ import numpy as np
6
+ import pickle
7
+ from argparse import Namespace
8
+
9
+ from pathlib import Path
10
+
11
+ basedir = Path(__file__).resolve().parent.parent
12
+ sys.path.append(str(basedir))
13
+
14
+ from src import utils
15
+ from src.utils import dict_to_namespace, namespace_to_dict
16
+ from src.analysis.visualization_utils import mols_to_pdbfile, mol_as_pdb
17
+ from src.data.data_utils import TensorDict, Residues
18
+ from src.data.postprocessing import process_all
19
+ from src.model.lightning import DrugFlow
20
+ from src.sbdd_metrics.evaluation import compute_all_metrics_drugflow
21
+
22
+ from tqdm import tqdm
23
+ from pdb import set_trace
24
+
25
+
26
+ def combine(base_args, override_args):
27
+ assert not isinstance(base_args, dict)
28
+ assert not isinstance(override_args, dict)
29
+
30
+ arg_dict = base_args.__dict__
31
+ for key, value in override_args.__dict__.items():
32
+ if key not in arg_dict or arg_dict[key] is None: # parameter not provided previously
33
+ print(f"Add parameter {key}: {value}")
34
+ arg_dict[key] = value
35
+ elif isinstance(value, Namespace):
36
+ arg_dict[key] = combine(arg_dict[key], value)
37
+ else:
38
+ print(f"Replace parameter {key}: {arg_dict[key]} -> {value}")
39
+ arg_dict[key] = value
40
+ return base_args
41
+
42
+
43
+ def path_to_str(input_dict):
44
+ for key, value in input_dict.items():
45
+ if isinstance(value, dict):
46
+ input_dict[key] = path_to_str(value)
47
+ else:
48
+ input_dict[key] = str(value) if isinstance(value, Path) else value
49
+ return input_dict
50
+
51
+
52
+ def sample(cfg, model_params, samples_dir, job_id=0, n_jobs=1):
53
+ print('Sampling...')
54
+ model = DrugFlow.load_from_checkpoint(cfg.checkpoint, map_location=cfg.device, strict=False,
55
+ **model_params)
56
+ model.setup(stage='fit' if cfg.set == 'train' else cfg.set)
57
+ model.eval().to(cfg.device)
58
+
59
+ dataloader = getattr(model, f'{cfg.set}_dataloader')()
60
+ print(f'Real batch size is {dataloader.batch_size * cfg.n_samples}')
61
+
62
+ name2count = {}
63
+ for i, data in enumerate(tqdm(dataloader)):
64
+ if i % n_jobs != job_id:
65
+ print(f'Skipping batch {i}')
66
+ continue
67
+
68
+ new_data = {
69
+ 'ligand': TensorDict(**data['ligand']).to(cfg.device),
70
+ 'pocket': Residues(**data['pocket']).to(cfg.device),
71
+ }
72
+ try:
73
+ rdmols, rdpockets, names = model.sample(
74
+ data=new_data,
75
+ n_samples=cfg.n_samples,
76
+ num_nodes=("ground_truth" if cfg.sample_with_ground_truth_size else None)
77
+ )
78
+ except Exception as e:
79
+ if cfg.set == 'train':
80
+ names = data['ligand']['name']
81
+ print(f'Failed to sample for {names}: {e}')
82
+ continue
83
+ else:
84
+ raise e
85
+
86
+ for mol, pocket, name in zip(rdmols, rdpockets, names):
87
+ name = name.replace('.sdf', '')
88
+ idx = name2count.setdefault(name, 0)
89
+ output_dir = Path(samples_dir, name)
90
+ output_dir.mkdir(parents=True, exist_ok=True)
91
+ if cfg.postprocess:
92
+ mol = process_all(mol, largest_frag=True, adjust_aromatic_Ns=True, relax_iter=0)
93
+
94
+ for prop in mol.GetAtoms()[0].GetPropsAsDict().keys():
95
+ # compute avg uncertainty
96
+ mol.SetDoubleProp(prop, np.mean([a.GetDoubleProp(prop) for a in mol.GetAtoms()]))
97
+
98
+ # visualise local differences
99
+ out_pdb_path = Path(output_dir, f'{idx}_ligand_{prop}.pdb')
100
+ mol_as_pdb(mol, out_pdb_path, bfactor=prop)
101
+
102
+ out_sdf_path = Path(output_dir, f'{idx}_ligand.sdf')
103
+ out_pdb_path = Path(output_dir, f'{idx}_pocket.pdb')
104
+ utils.write_sdf_file(out_sdf_path, [mol])
105
+ mols_to_pdbfile([pocket], out_pdb_path)
106
+
107
+ name2count[name] += 1
108
+
109
+
110
+ def evaluate(cfg, model_params, samples_dir):
111
+ print('Evaluation...')
112
+ data, table_detailed, table_aggregated = compute_all_metrics_drugflow(
113
+ in_dir=samples_dir,
114
+ gnina_path=model_params['train_params'].gnina,
115
+ reduce_path=cfg.reduce,
116
+ reference_smiles_path=Path(model_params['train_params'].datadir, 'train_smiles.npy'),
117
+ n_samples=cfg.n_samples,
118
+ exclude_evaluators=[] if cfg.exclude_evaluators is None else cfg.exclude_evaluators,
119
+ )
120
+ with open(Path(samples_dir, 'metrics_data.pkl'), 'wb') as f:
121
+ pickle.dump(data, f)
122
+ table_detailed.to_csv(Path(samples_dir, 'metrics_detailed.csv'), index=False)
123
+ table_aggregated.to_csv(Path(samples_dir, 'metrics_aggregated.csv'), index=False)
124
+
125
+
126
+ if __name__ == "__main__":
127
+ p = argparse.ArgumentParser()
128
+ p.add_argument('--config', type=str)
129
+ p.add_argument('--job_id', type=int, default=0, help='Job ID')
130
+ p.add_argument('--n_jobs', type=int, default=1, help='Number of jobs')
131
+ args = p.parse_args()
132
+
133
+ with open(args.config, 'r') as f:
134
+ cfg = yaml.safe_load(f)
135
+ cfg = dict_to_namespace(cfg)
136
+
137
+ utils.set_deterministic(seed=cfg.seed)
138
+ utils.disable_rdkit_logging()
139
+
140
+ model_params = torch.load(cfg.checkpoint, map_location=cfg.device)['hyper_parameters']
141
+ if 'model_args' in cfg:
142
+ ckpt_args = dict_to_namespace(model_params)
143
+ model_params = combine(ckpt_args, cfg.model_args).__dict__
144
+
145
+ ckpt_path = Path(cfg.checkpoint)
146
+ ckpt_name = ckpt_path.parts[-1].split('.')[0]
147
+ n_steps = model_params['simulation_params'].n_steps
148
+ samples_dir = Path(cfg.sample_outdir, cfg.set, f'{ckpt_name}_T={n_steps}') or \
149
+ Path(ckpt_path.parent.parent, 'samples', cfg.set, f'{ckpt_name}_T={n_steps}')
150
+ assert cfg.set in {'val', 'test', 'train'}
151
+ samples_dir.mkdir(parents=True, exist_ok=True)
152
+
153
+ # save configs
154
+ with open(Path(samples_dir, 'model_params.yaml'), 'w') as f:
155
+ yaml.dump(path_to_str(namespace_to_dict(model_params)), f)
156
+ with open(Path(samples_dir, 'sampling_params.yaml'), 'w') as f:
157
+ yaml.dump(path_to_str(namespace_to_dict(cfg)), f)
158
+
159
+ if cfg.sample:
160
+ sample(cfg, model_params, samples_dir, job_id=args.job_id, n_jobs=args.n_jobs)
161
+
162
+ if cfg.evaluate:
163
+ assert args.job_id == 0 and args.n_jobs == 1, 'Evaluation is not parallelised on GPU machines'
164
+ evaluate(cfg, model_params, samples_dir)
src/sbdd_metrics/evaluation.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+
5
+ from pathlib import Path
6
+ from typing import Collection, List, Dict, Type
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ from tqdm import tqdm
11
+
12
+ from .metrics import FullEvaluator, FullCollectionEvaluator
13
+
14
+ AUXILIARY_COLUMNS = ['sample', 'sdf_file', 'pdb_file', 'subdir']
15
+ VALIDITY_METRIC_NAME = 'medchem.valid'
16
+
17
+
18
+ def get_data_type(key: str, data_types: Dict[str, Type], default=float) -> Type:
19
+ found_data_type_key = None
20
+ found_data_type_value = None
21
+ for data_type_key, data_type_value in data_types.items():
22
+ if re.match(data_type_key, key) is not None:
23
+ if found_data_type_key is not None:
24
+ raise ValueError(f'Multiple data type keys match [{key}]: {found_data_type_key}, {data_type_key}')
25
+
26
+ found_data_type_value = data_type_value
27
+ found_data_type_key = data_type_key
28
+
29
+ if found_data_type_key is None:
30
+ if default is None:
31
+ raise KeyError(key)
32
+ else:
33
+ found_data_type_value = default
34
+
35
+ return found_data_type_value
36
+
37
+
38
+ def convert_data_to_table(data: List[Dict], data_types: Dict[str, Type]) -> pd.DataFrame:
39
+ """
40
+ Converts data from `evaluate_drugflow` to a detailed table
41
+ """
42
+ table = []
43
+ for entry in data:
44
+ table_entry = {}
45
+ for key, value in entry.items():
46
+ if key in AUXILIARY_COLUMNS:
47
+ table_entry[key] = value
48
+ continue
49
+ if get_data_type(key, data_types) != list:
50
+ table_entry[key] = value
51
+ table.append(table_entry)
52
+
53
+ return pd.DataFrame(table)
54
+
55
+ def aggregated_metrics(table: pd.DataFrame, data_types: Dict[str, Type], validity_metric_name: str = None):
56
+ """
57
+ Args:
58
+ table (pd.DataFrame): table with metrics computed for each sample
59
+ data_types (Dict[str, Type]): dictionary with data types for each column
60
+ validity_metric_name (str): name of the column that has validity metric
61
+
62
+ Returns:
63
+ agg_table (pd.DataFrame): table with columns ['metric', 'value', 'std']
64
+ """
65
+ aggregated_results = []
66
+
67
+ # If validity column name is provided:
68
+ # 1. compute validity on the entire data
69
+ # 2. drop all invalid molecules to compute the rest
70
+ if validity_metric_name is not None:
71
+ aggregated_results.append({
72
+ 'metric': validity_metric_name,
73
+ 'value': table[validity_metric_name].fillna(False).astype(float).mean(),
74
+ 'std': None,
75
+ })
76
+ table = table[table[validity_metric_name]]
77
+
78
+ # Compute aggregated metrics + standard deviations where applicable
79
+ for column in table.columns:
80
+ if column in AUXILIARY_COLUMNS + [validity_metric_name] or get_data_type(column, data_types) == str:
81
+ continue
82
+ with pd.option_context("future.no_silent_downcasting", True):
83
+ if get_data_type(column, data_types) == bool:
84
+ values = table[column].fillna(0).values.astype(float).mean()
85
+ std = None
86
+ else:
87
+ values = table[column].dropna().values.astype(float).mean()
88
+ std = table[column].dropna().values.astype(float).std()
89
+
90
+ aggregated_results.append({
91
+ 'metric': column,
92
+ 'value': values,
93
+ 'std': std,
94
+ })
95
+
96
+ agg_table = pd.DataFrame(aggregated_results)
97
+ return agg_table
98
+
99
+
100
+ def collection_metrics(
101
+ table: pd.DataFrame,
102
+ reference_smiles: Collection[str],
103
+ validity_metric_name: str = None,
104
+ exclude_evaluators: Collection[str] = [],
105
+ ):
106
+ """
107
+ Args:
108
+ table (pd.DataFrame): table with metrics computed for each sample
109
+ reference_smiles (Collection[str]): list of reference SMILES (e.g. training set)
110
+ validity_metric_name (str): name of the column that has validity metric
111
+ exclude_evaluators (Collection[str]): Evaluator IDs to exclude
112
+
113
+ Returns:
114
+ col_table (pd.DataFrame): table with columns ['metric', 'value']
115
+ """
116
+
117
+ # If validity column name is provided drop all invalid molecules
118
+ if validity_metric_name is not None:
119
+ table = table[table[validity_metric_name]]
120
+
121
+ evaluator = FullCollectionEvaluator(reference_smiles, exclude_evaluators=exclude_evaluators)
122
+ smiles = table['representation.smiles'].values
123
+ if len(smiles) == 0:
124
+ print('No valid input molecules')
125
+ return pd.DataFrame(columns=['metric', 'value'])
126
+
127
+ collection_metrics = evaluator(smiles)
128
+ results = [
129
+ {'metric': key, 'value': value}
130
+ for key, value in collection_metrics.items()
131
+ ]
132
+
133
+ col_table = pd.DataFrame(results)
134
+ return col_table
135
+
136
+
137
+ def evaluate_drugflow_subdir(
138
+ in_dir: Path,
139
+ evaluator: FullEvaluator,
140
+ desc: str = None,
141
+ n_samples: int = None,
142
+ ) -> List[Dict]:
143
+ """
144
+ Computes per-molecule metrics for a single directory of samples for one target
145
+ """
146
+ results = []
147
+ valid_files = [
148
+ int(fname.split('_')[0])
149
+ for fname in os.listdir(in_dir)
150
+ if fname.endswith('_ligand.sdf') and not fname.startswith('.')
151
+ ]
152
+ if len(valid_files) == 0:
153
+ return pd.DataFrame()
154
+
155
+ upper_bound = max(valid_files) + 1
156
+ if n_samples is not None:
157
+ upper_bound = min(upper_bound, n_samples)
158
+
159
+ for i in tqdm(range(upper_bound), desc=desc, file=sys.stdout):
160
+ in_mol = Path(in_dir, f'{i}_ligand.sdf')
161
+ in_prot = Path(in_dir, f'{i}_pocket.pdb')
162
+ res = evaluator(in_mol, in_prot)
163
+
164
+ res['sample'] = i
165
+ res['sdf_file'] = str(in_mol)
166
+ res['pdb_file'] = str(in_prot)
167
+ results.append(res)
168
+
169
+ return results
170
+
171
+
172
+ def evaluate_drugflow(
173
+ in_dir: Path,
174
+ evaluator: FullEvaluator,
175
+ n_samples: int = None,
176
+ job_id: int = 0,
177
+ n_jobs: int = 1,
178
+ ) -> List[Dict]:
179
+ """
180
+ 1. Computes per-molecule metrics for all single directories of samples
181
+ 2. Aggregates these metrics
182
+ 3. Computes additional collection metrics (if `reference_smiles_path` is provided)
183
+ """
184
+ data = []
185
+ total_number_of_subdirs = len([path for path in in_dir.glob("[!.]*") if os.path.isdir(path)])
186
+ i = 0
187
+ for subdir in in_dir.glob("[!.]*"):
188
+ if not os.path.isdir(subdir):
189
+ continue
190
+
191
+ i += 1
192
+ if (i - 1) % n_jobs != job_id:
193
+ continue
194
+
195
+ curr_data = evaluate_drugflow_subdir(
196
+ in_dir=subdir,
197
+ evaluator=evaluator,
198
+ desc=f'[{i}/{total_number_of_subdirs}] {str(subdir.name)}',
199
+ n_samples=n_samples,
200
+ )
201
+ for entry in curr_data:
202
+ entry['subdir'] = str(subdir)
203
+ data.append(entry)
204
+
205
+ return data
206
+
207
+
208
+ def compute_all_metrics_drugflow(
209
+ in_dir: Path,
210
+ gnina_path: Path,
211
+ reduce_path: Path = None,
212
+ reference_smiles_path: Path = None,
213
+ n_samples: int = None,
214
+ validity_metric_name: str = VALIDITY_METRIC_NAME,
215
+ exclude_evaluators: Collection[str] = [],
216
+ job_id: int = 0,
217
+ n_jobs: int = 1,
218
+ ):
219
+ evaluator = FullEvaluator(gnina=gnina_path, reduce=reduce_path, exclude_evaluators=exclude_evaluators)
220
+ data = evaluate_drugflow(in_dir=in_dir, evaluator=evaluator, n_samples=n_samples, job_id=job_id, n_jobs=n_jobs)
221
+ table_detailed = convert_data_to_table(data, evaluator.dtypes)
222
+ table_aggregated = aggregated_metrics(
223
+ table_detailed,
224
+ data_types=evaluator.dtypes,
225
+ validity_metric_name=validity_metric_name
226
+ )
227
+
228
+ # Add collection metrics (uniqueness, novelty, FCD, etc.) if reference smiles are provided
229
+ if reference_smiles_path is not None:
230
+ reference_smiles = np.load(reference_smiles_path)
231
+ col_metrics = collection_metrics(
232
+ table=table_detailed,
233
+ reference_smiles=reference_smiles,
234
+ validity_metric_name=validity_metric_name,
235
+ exclude_evaluators=exclude_evaluators
236
+ )
237
+ table_aggregated = pd.concat([table_aggregated, col_metrics])
238
+
239
+ return data, table_detailed, table_aggregated
src/sbdd_metrics/fpscores.pkl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10dcef9340c873e7b987924461b0af5365eb8dd96be607203debe8ddf80c1e73
3
+ size 3848394
src/sbdd_metrics/interactions.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import prody
2
+ import prolif as plf
3
+ import pandas as pd
4
+ import subprocess
5
+
6
+ from io import StringIO
7
+ from prolif.fingerprint import Fingerprint
8
+ from prolif.plotting.complex3d import Complex3D
9
+ from prolif.residue import ResidueId
10
+ from prolif.ifp import IFP
11
+ from rdkit import Chem
12
+ from tqdm import tqdm
13
+
14
+
15
+ prody.confProDy(verbosity='none')
16
+
17
+
18
+ INTERACTION_LIST = [
19
+ 'Anionic', 'Cationic', # Salt Bridges ~400 kJ/mol
20
+ 'HBAcceptor', 'HBDonor', # Hydrogen bonds ~10 kJ/mol
21
+ 'XBAcceptor', 'XBDonor', # Halogen bonds ~5-30 kJ/mol
22
+ 'CationPi', 'PiCation', # 5-10 kJ/mol
23
+ 'PiStacking', # ~2-10 kJ/mol
24
+ 'Hydrophobic', # 1-10 kJ/mol
25
+ ]
26
+
27
+ INTERACTION_ALIASES = {
28
+ 'Anionic': 'SaltBridge',
29
+ 'Cationic': 'SaltBridge',
30
+ 'HBAcceptor': 'HBAcceptor',
31
+ 'HBDonor': 'HBDonor',
32
+ 'XBAcceptor': 'HalogenBond',
33
+ 'XBDonor': 'HalogenBond',
34
+ 'CationPi': 'CationPi',
35
+ 'PiCation': 'PiCation',
36
+ 'PiStacking': 'PiStacking',
37
+ 'Hydrophobic': 'Hydrophobic',
38
+ }
39
+
40
+ INTERACTION_COLORS = {
41
+ 'SaltBridge': '#eba823',
42
+ 'HBDonor': '#3d5dfc',
43
+ 'HBAcceptor': '#3d5dfc',
44
+ 'HalogenBond': '#53f514',
45
+ 'CationPi': '#ff0000',
46
+ 'PiCation': '#ff0000',
47
+ 'PiStacking': '#e359d8',
48
+ 'Hydrophobic': '#c9c5c5',
49
+ }
50
+
51
+ INTERACTION_IMPORTANCE = ['SaltBridge', 'HydrogenBond', 'HBAcceptor', 'HBDonor', 'CationPi', 'PiCation', 'PiStacking', 'Hydrophobic']
52
+
53
+ REDUCE_EXEC = './reduce'
54
+
55
+ def remove_residue_by_atomic_number(structure, resnum, chain_id, icode):
56
+ exclude_selection = f'not (chain {chain_id} and resnum {resnum} and icode {icode})'
57
+ structure = structure.select(exclude_selection)
58
+ return structure
59
+
60
+
61
+ def read_protein(protein_path, verbose=False, reduce_exec=REDUCE_EXEC):
62
+ structure = prody.parsePDB(protein_path).select('protein')
63
+ hydrogens = structure.select('hydrogen')
64
+ if hydrogens is None or len(hydrogens) < len(set(structure.getResnums())):
65
+ if verbose:
66
+ print('Target structure is not protonated. Adding hydrogens...')
67
+
68
+ reduce_cmd = f'{str(reduce_exec)} {protein_path}'
69
+ reduce_result = subprocess.run(reduce_cmd, shell=True, capture_output=True, text=True)
70
+ if reduce_result.returncode != 0:
71
+ raise RuntimeError('Error during reduce execution:', reduce_result.stderr)
72
+
73
+ pdb_content = reduce_result.stdout
74
+ stream = StringIO()
75
+ stream.write(pdb_content)
76
+ stream.seek(0)
77
+ structure = prody.parsePDBStream(stream).select('protein')
78
+
79
+ # Select only one (largest) altloc
80
+ altlocs = set(structure.getAltlocs())
81
+ try:
82
+ best_altloc = max(altlocs, key=lambda a: structure.select(f'altloc "{a}"').numAtoms())
83
+ structure = structure.select(f'altloc "{best_altloc}"')
84
+ except TypeError:
85
+ # Strange thing that happens only once in the beginning sometimes...
86
+ best_altloc = max(altlocs, key=lambda a: structure.select(f'altloc "{a}"').numAtoms())
87
+ structure = structure.select(f'altloc "{best_altloc}"')
88
+
89
+ return prepare_protein(structure, to_exclude=[], verbose=verbose)
90
+
91
+
92
+ def prepare_protein(input_structure, to_exclude=[], verbose=False):
93
+ structure = input_structure.copy()
94
+
95
+ # Remove residues with bad atoms
96
+ if verbose and len(to_exclude) > 0:
97
+ print(f'Removing {len(to_exclude)} residues...')
98
+ for resnum, chain_id, icode in to_exclude:
99
+ exclude_selection = f'not (chain {chain_id} and resnum {resnum})'
100
+ structure = structure.select(exclude_selection)
101
+
102
+ # Write new PDB content to the stream
103
+ stream = StringIO()
104
+ prody.writePDBStream(stream, structure)
105
+ stream.seek(0)
106
+
107
+ # Sanitize
108
+ rdprot = Chem.MolFromPDBBlock(stream.read(), sanitize=False, removeHs=False)
109
+ try:
110
+ Chem.SanitizeMol(rdprot)
111
+ plfprot = plf.Molecule(rdprot)
112
+ return plfprot
113
+
114
+ except Chem.AtomValenceException as e:
115
+ atom_num = int(e.args[0].replace('Explicit valence for atom # ', '').split()[0])
116
+ info = rdprot.GetAtomWithIdx(atom_num).GetPDBResidueInfo()
117
+ resnum = info.GetResidueNumber()
118
+ chain_id = info.GetChainId()
119
+ icode = f'"{info.GetInsertionCode()}"'
120
+
121
+ to_exclude_next = to_exclude + [(resnum, chain_id, icode)]
122
+ if verbose:
123
+ print(f'[{len(to_exclude_next)}] Removing broken residue with atom={atom_num}, resnum={resnum}, chain_id={chain_id}, icode={icode}')
124
+ return prepare_protein(input_structure, to_exclude=to_exclude_next)
125
+
126
+
127
+ def prepare_ligand(mol):
128
+ Chem.SanitizeMol(mol)
129
+ mol = Chem.AddHs(mol, addCoords=True)
130
+ ligand_plf = plf.Molecule.from_rdkit(mol)
131
+ return ligand_plf
132
+
133
+
134
+ def sdf_reader(sdf_path, proress_bar=False):
135
+ supp = Chem.SDMolSupplier(sdf_path, removeHs=True, sanitize=False)
136
+ for mol in tqdm(supp) if progress_bar else supp:
137
+ yield prepare_ligand(mol)
138
+
139
+
140
+ def profile_detailed(
141
+ ligand_plf, protein_plf, interaction_list=INTERACTION_LIST, ligand_name='ligand', protein_name='protein'
142
+ ):
143
+
144
+ fp = Fingerprint(interactions=interaction_list)
145
+ fp.run_from_iterable(lig_iterable=[ligand_plf], prot_mol=protein_plf, progress=False)
146
+
147
+ profile = []
148
+
149
+ for ligand_residue in ligand_plf.residues:
150
+ for protein_residue in protein_plf.residues:
151
+ metadata = fp.metadata(ligand_plf[ligand_residue], protein_plf[protein_residue])
152
+ for int_name, int_metadata in metadata.items():
153
+ for int_instance in int_metadata:
154
+ profile.append({
155
+ 'ligand': ligand_name,
156
+ 'protein': protein_name,
157
+ 'ligand_residue': str(ligand_residue),
158
+ 'protein_residue': str(protein_residue),
159
+ 'interaction': int_name,
160
+ 'alias': INTERACTION_ALIASES[int_name],
161
+ 'ligand_atoms': ','.join(map(str, int_instance['indices']['ligand'])),
162
+ 'protein_atoms': ','.join(map(str, int_instance['indices']['protein'])),
163
+ 'ligand_orig_atoms': ','.join(map(str, int_instance['parent_indices']['ligand'])),
164
+ 'protein_orig_atoms': ','.join(map(str, int_instance['parent_indices']['protein'])),
165
+ 'distance': int_instance['distance'],
166
+ 'plane_angle': int_instance.get('plane_angle', None),
167
+ 'normal_to_centroid_angle': int_instance.get('normal_to_centroid_angle', None),
168
+ 'intersect_distance': int_instance.get('intersect_distance', None),
169
+ 'intersect_radius': int_instance.get('intersect_radius', None),
170
+ 'pi_ring': int_instance.get('pi_ring', None),
171
+ })
172
+
173
+ return pd.DataFrame(profile)
174
+
175
+
176
+ def map_orig_atoms_to_new(atoms, mol):
177
+ orig2new = dict()
178
+ for atom in mol.GetAtoms():
179
+ orig2new[atom.GetUnsignedProp("mapindex")] = atom.GetIdx()
180
+
181
+ atoms = list(map(int, atoms.split(',')))
182
+ new_atoms = ','.join(map(str, [orig2new[atom] for atom in atoms]))
183
+ return new_atoms
184
+
185
+
186
+ def visualize(profile, ligand_plf, protein_plf):
187
+ metadata = dict()
188
+
189
+ for _, row in profile.iterrows():
190
+ if 'ligand_atoms' not in row:
191
+ row['ligand_atoms'] = map_orig_atoms_to_new(row['ligand_orig_atoms'], ligand_plf)
192
+ if 'protein_atoms' not in row:
193
+ row['protein_atoms'] = map_orig_atoms_to_new(row['protein_orig_atoms'], protein_plf[row['residue']])
194
+
195
+ namenum, chain = row['residue'].split('.')
196
+ name = namenum[:3]
197
+ num = int(namenum[3:])
198
+ protres = ResidueId(name=name, number=num, chain=chain)
199
+ key = (ResidueId(name='UNL', number=1, chain=None), protres)
200
+
201
+ metadata.setdefault(key, dict())
202
+ interaction = {
203
+ 'indices': {
204
+ 'ligand': tuple(map(int, row['ligand_atoms'].split(','))),
205
+ 'protein': tuple(map(int, row['protein_atoms'].split(','))),
206
+ },
207
+ 'parent_indices': {
208
+ 'ligand': tuple(map(int, row['ligand_atoms'].split(','))),
209
+ 'protein': tuple(map(int, row['protein_atoms'].split(','))),
210
+ },
211
+ 'distance': row['distance'],
212
+ }
213
+ # if row['plane_angle'] is not None:
214
+ # interaction['plane_angle'] = row['plane_angle']
215
+ # if row['normal_to_centroid_angle'] is not None:
216
+ # interaction['normal_to_centroid_angle'] = row['normal_to_centroid_angle']
217
+ # if row['intersect_distance'] is not None:
218
+ # interaction['intersect_distance'] = row['intersect_distance']
219
+ # if row['intersect_radius'] is not None:
220
+ # interaction['intersect_radius'] = row['intersect_radius']
221
+ # if row['pi_ring'] is not None:
222
+ # interaction['pi_ring'] = row['pi_ring']
223
+
224
+ metadata[key].setdefault(row['alias'], list()).append(interaction)
225
+
226
+ ifp = IFP(metadata)
227
+ fp = Fingerprint(interactions=INTERACTION_LIST, vicinity_cutoff=8.0)
228
+ fp.ifp = {0: ifp}
229
+ Complex3D.COLORS.update(INTERACTION_COLORS)
230
+ v = fp.plot_3d(ligand_mol=ligand_plf, protein_mol=protein_plf, frame=0)
231
+ return v
src/sbdd_metrics/metrics.py ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ import subprocess
3
+ import tempfile
4
+ from abc import abstractmethod
5
+ from collections import defaultdict
6
+ from pathlib import Path
7
+ from typing import Union, Dict, Collection, Set, Optional
8
+ import signal
9
+ import numpy as np
10
+ import pandas as pd
11
+ from unittest.mock import patch
12
+ from scipy.spatial.distance import jensenshannon
13
+ from fcd import get_fcd
14
+ from posebusters import PoseBusters
15
+ from posebusters.modules.distance_geometry import _get_bond_atom_indices, _get_angle_atom_indices
16
+ from rdkit import Chem, RDLogger
17
+ from rdkit.Chem import Descriptors, Crippen, Lipinski, QED, KekulizeException, AtomKekulizeException
18
+ from rdkit.Chem.rdForceFieldHelpers import UFFGetMoleculeForceField
19
+ from scipy.spatial.distance import jensenshannon
20
+ from tqdm import tqdm
21
+ from useful_rdkit_utils import REOS, RingSystemLookup, get_min_ring_frequency, RingSystemFinder
22
+
23
+ from .interactions import INTERACTION_LIST, prepare_ligand, read_protein, profile_detailed
24
+ from .sascorer import calculateScore
25
+
26
+ def timeout_handler(signum, frame):
27
+ raise TimeoutError('Timeout')
28
+
29
+ BOND_SYMBOLS = {
30
+ Chem.rdchem.BondType.SINGLE: '-',
31
+ Chem.rdchem.BondType.DOUBLE: '=',
32
+ Chem.rdchem.BondType.TRIPLE: '#',
33
+ Chem.rdchem.BondType.AROMATIC: ':',
34
+ }
35
+
36
+
37
+ def is_nan(value):
38
+ return value is None or pd.isna(value) or np.isnan(value)
39
+
40
+
41
+ def safe_run(func, timeout, **kwargs):
42
+ def _run(f, q, **kwargs):
43
+ r = f(**kwargs)
44
+ q.put(r)
45
+
46
+ queue = multiprocessing.Queue()
47
+ process = multiprocessing.Process(target=_run, kwargs={'f': func, 'q': queue, **kwargs})
48
+ process.start()
49
+ process.join(timeout)
50
+ if process.is_alive():
51
+ print(f"Function {func} didn't finish in {timeout} seconds. Terminating it.")
52
+ process.terminate()
53
+ process.join()
54
+ return None
55
+ elif not queue.empty():
56
+ return queue.get()
57
+ return None
58
+
59
+
60
+ class AbstractEvaluator:
61
+ ID = None
62
+ def __call__(self, molecule: Union[str, Path, Chem.Mol], protein: Union[str, Path] = None,
63
+ timeout=350):
64
+ """
65
+ Args:
66
+ molecule (Union[str, Path, Chem.Mol]): input molecule
67
+ protein (str): target protein
68
+
69
+ Returns:
70
+ metrics (dict): dictionary of metrics
71
+ """
72
+ RDLogger.DisableLog('rdApp.*')
73
+ self.check_format(molecule, protein)
74
+
75
+ # timeout handler
76
+ signal.signal(signal.SIGALRM, timeout_handler)
77
+ try:
78
+ signal.alarm(timeout)
79
+ results = self.evaluate(molecule, protein)
80
+ except TimeoutError:
81
+ print(f'Error when evaluating [{self.ID}]: Timeout after {timeout} seconds')
82
+ signal.alarm(0)
83
+ return {}
84
+ except Exception as e:
85
+ print(f'Error when evaluating [{self.ID}]: {e}')
86
+ signal.alarm(0)
87
+ return {}
88
+ finally:
89
+ signal.alarm(0)
90
+ return self.add_id(results)
91
+
92
+ def add_id(self, results):
93
+ if self.ID is not None:
94
+ return {f'{self.ID}.{key}': value for key, value in results.items()}
95
+ else:
96
+ return results
97
+
98
+ @abstractmethod
99
+ def evaluate(self, molecule: Union[str, Path, Chem.Mol], protein: Union[str, Path]) -> Dict[str, Union[int, float, str]]:
100
+ raise NotImplementedError
101
+
102
+ @staticmethod
103
+ def check_format(molecule, protein):
104
+ assert isinstance(molecule, (str, Path, Chem.Mol)), 'Supported molecule types: str, Path, Chem.Mol'
105
+ assert protein is None or isinstance(protein, (str, Path)), 'Supported protein types: str'
106
+ if isinstance(molecule, (str, Path)):
107
+ supp = Chem.SDMolSupplier(str(molecule), sanitize=False)
108
+ assert len(supp) == 1, 'Only one molecule per file is supported'
109
+
110
+ @staticmethod
111
+ def load_molecule(molecule):
112
+ if isinstance(molecule, (str, Path)):
113
+ return Chem.SDMolSupplier(str(molecule), sanitize=False)[0]
114
+ return Chem.Mol(molecule) # create copy to avoid overriding properties of the input molecule
115
+
116
+ @staticmethod
117
+ def save_molecule(molecule, sdf_path):
118
+ if isinstance(molecule, (str, Path)):
119
+ return molecule
120
+
121
+ with Chem.SDWriter(str(sdf_path)) as w:
122
+ try:
123
+ w.write(molecule)
124
+ except (RuntimeError, ValueError) as e:
125
+ if isinstance(e, (KekulizeException, AtomKekulizeException)):
126
+ w.SetKekulize(False)
127
+ w.write(molecule)
128
+ w.SetKekulize(True)
129
+ else:
130
+ w.write(Chem.Mol())
131
+ print('[AbstractEvaluator] Error when saving the molecule')
132
+
133
+ return sdf_path
134
+
135
+ @property
136
+ def dtypes(self):
137
+ return self.add_id(self._dtypes)
138
+
139
+ @property
140
+ @abstractmethod
141
+ def _dtypes(self):
142
+ raise NotImplementedError
143
+
144
+
145
+ class RepresentationEvaluator(AbstractEvaluator):
146
+ ID = 'representation'
147
+
148
+ def evaluate(self, molecule, protein=None):
149
+ molecule = self.load_molecule(molecule)
150
+ try:
151
+ smiles = Chem.MolToSmiles(molecule)
152
+ except:
153
+ smiles = None
154
+
155
+ return {'smiles': smiles}
156
+
157
+ @property
158
+ def _dtypes(self):
159
+ return {'smiles': str}
160
+
161
+
162
+ class MolPropertyEvaluator(AbstractEvaluator):
163
+ ID = 'mol_props'
164
+
165
+ def evaluate(self, molecule, protein=None):
166
+ molecule = self.load_molecule(molecule)
167
+ return {k: v for k, v in molecule.GetPropsAsDict().items() if isinstance(v, float)}
168
+
169
+ @property
170
+ def _dtypes(self):
171
+ return {'*': float}
172
+
173
+
174
+ class PoseBustersEvaluator(AbstractEvaluator):
175
+ ID = 'posebusters'
176
+ def __init__(self, pb_conf: str = 'dock'):
177
+ self.posebusters = PoseBusters(config=pb_conf)
178
+
179
+ @patch('rdkit.RDLogger.EnableLog', lambda x: None)
180
+ @patch('rdkit.RDLogger.DisableLog', lambda x: None)
181
+ def evaluate(self, molecule, protein=None):
182
+ result = safe_run(self.posebusters.bust, timeout=20, mol_pred=molecule, mol_cond=protein)
183
+ if result is None:
184
+ return dict()
185
+
186
+ with pd.option_context("future.no_silent_downcasting", True):
187
+ result = dict(result.fillna(False).iloc[0])
188
+ result['all'] = all([bool(value) if not is_nan(value) else False for value in result.values()])
189
+ return result
190
+
191
+ @property
192
+ def _dtypes(self):
193
+ return {'*': bool}
194
+
195
+
196
+ class GeometryEvaluator(AbstractEvaluator):
197
+ ID = 'geometry'
198
+
199
+ def evaluate(self, molecule, protein=None):
200
+ mol = self.load_molecule(molecule)
201
+ data = self.get_distances_and_angles(mol)
202
+ return data
203
+
204
+ @staticmethod
205
+ def angle_repr(mol, triplet):
206
+ i = mol.GetAtomWithIdx(triplet[0]).GetSymbol()
207
+ j = mol.GetAtomWithIdx(triplet[1]).GetSymbol()
208
+ k = mol.GetAtomWithIdx(triplet[2]).GetSymbol()
209
+ ij = BOND_SYMBOLS[mol.GetBondBetweenAtoms(triplet[0], triplet[1]).GetBondType()]
210
+ jk = BOND_SYMBOLS[mol.GetBondBetweenAtoms(triplet[1], triplet[2]).GetBondType()]
211
+
212
+ # Unified (sorted) representation
213
+ if i < k:
214
+ return f'{i}{ij}{j}{jk}{k}'
215
+ elif i > j:
216
+ return f'{k}{jk}{j}{ij}{i}'
217
+ elif ij <= jk:
218
+ return f'{i}{ij}{j}{jk}{k}'
219
+ else:
220
+ return f'{k}{jk}{j}{ij}{i}'
221
+
222
+ @staticmethod
223
+ def bond_repr(mol, pair):
224
+ i = mol.GetAtomWithIdx(pair[0]).GetSymbol()
225
+ j = mol.GetAtomWithIdx(pair[1]).GetSymbol()
226
+ ij = BOND_SYMBOLS[mol.GetBondBetweenAtoms(pair[0], pair[1]).GetBondType()]
227
+ # Unified (sorted) representation
228
+ return f'{i}{ij}{j}' if i <= j else f'{j}{ij}{i}'
229
+
230
+ @staticmethod
231
+ def get_bond_distances(mol, bonds):
232
+ i, j = np.array(bonds).T
233
+ x = mol.GetConformer().GetPositions()
234
+ xi = x[i]
235
+ xj = x[j]
236
+ bond_distances = np.linalg.norm(xi - xj, axis=1)
237
+ return bond_distances
238
+
239
+ @staticmethod
240
+ def get_angle_values(mol, triplets):
241
+ i, j, k = np.array(triplets).T
242
+ x = mol.GetConformer().GetPositions()
243
+ xi = x[i]
244
+ xj = x[j]
245
+ xk = x[k]
246
+ vji = xi - xj
247
+ vjk = xk - xj
248
+ angles = np.arccos((vji * vjk).sum(axis=1) / (np.linalg.norm(vji, axis=1) * np.linalg.norm(vjk, axis=1)))
249
+ return np.degrees(angles)
250
+
251
+ @staticmethod
252
+ def get_distances_and_angles(mol):
253
+ data = defaultdict(list)
254
+ bonds = _get_bond_atom_indices(mol)
255
+ distances = GeometryEvaluator.get_bond_distances(mol, bonds)
256
+ for b, d in zip(bonds, distances):
257
+ data[GeometryEvaluator.bond_repr(mol, b)].append(d)
258
+
259
+ triplets = _get_angle_atom_indices(bonds)
260
+ angles = GeometryEvaluator.get_angle_values(mol, triplets)
261
+ for t, a in zip(triplets, angles):
262
+ data[GeometryEvaluator.angle_repr(mol, t)].append(a)
263
+
264
+ return data
265
+
266
+ @property
267
+ def _dtypes(self):
268
+ return {'*': list}
269
+
270
+
271
+ class EnergyEvaluator(AbstractEvaluator):
272
+ ID = 'energy'
273
+
274
+ def evaluate(self, molecule, protein=None):
275
+ molecule = self.load_molecule(molecule)
276
+ try:
277
+ energy = self.get_energy(molecule)
278
+ except:
279
+ energy = None
280
+ return {'energy': energy}
281
+
282
+ @staticmethod
283
+ def get_energy(mol, conf_id=-1):
284
+ mol = Chem.AddHs(mol, addCoords=True)
285
+ uff = UFFGetMoleculeForceField(mol, confId=conf_id)
286
+ e_uff = uff.CalcEnergy()
287
+ return e_uff
288
+
289
+ @property
290
+ def _dtypes(self):
291
+ return {'energy': float}
292
+
293
+
294
+ class InteractionsEvaluator(AbstractEvaluator):
295
+ ID = 'interactions'
296
+
297
+ def __init__(self, reduce='./reduce'):
298
+ self.reduce = reduce
299
+
300
+ @property
301
+ def default_profile(self):
302
+ return {i: 0 for i in INTERACTION_LIST}
303
+
304
+ def evaluate(self, molecule, protein=None):
305
+ molecule = self.load_molecule(molecule)
306
+ profile = self.default_profile
307
+ try:
308
+ ligand_plf = prepare_ligand(molecule)
309
+ protein_plf = read_protein(str(protein), reduce_exec=self.reduce)
310
+ interactions = profile_detailed(ligand_plf, protein_plf)
311
+ if not interactions.empty:
312
+ profile.update(dict(interactions.interaction.value_counts()))
313
+ except Exception:
314
+ pass
315
+ return profile
316
+
317
+ @property
318
+ def _dtypes(self):
319
+ return {'*': int}
320
+
321
+
322
+ class GninaEvalulator(AbstractEvaluator):
323
+ ID = 'gnina'
324
+ def __init__(self, gnina):
325
+ self.gnina = gnina
326
+
327
+ def evaluate(self, molecule, protein=None):
328
+ with tempfile.TemporaryDirectory() as tmpdir:
329
+ molecule = self.save_molecule(molecule, sdf_path=Path(tmpdir, 'molecule.sdf'))
330
+ gnina_cmd = f'{self.gnina} -r {str(protein)} -l {str(molecule)} --minimize --seed 42 --no_gpu'
331
+ gnina_result = subprocess.run(gnina_cmd, shell=True, capture_output=True, text=True)
332
+ n_atoms = self.load_molecule(molecule).GetNumAtoms()
333
+
334
+ gnina_scores = self.read_gnina_results(gnina_result)
335
+
336
+ # Additionally computing ligand efficiency
337
+ gnina_scores['vina_efficiency'] = gnina_scores['vina_score'] / n_atoms if n_atoms > 0 else None
338
+ gnina_scores['gnina_efficiency'] = gnina_scores['gnina_score'] / n_atoms if n_atoms > 0 else None
339
+ return gnina_scores
340
+
341
+ @staticmethod
342
+ def read_gnina_results(gnina_result):
343
+ res = {
344
+ 'vina_score': None,
345
+ 'gnina_score': None,
346
+ 'minimisation_rmsd': None,
347
+ 'cnn_score': None,
348
+ }
349
+ if gnina_result.returncode != 0:
350
+ print(gnina_result.stderr)
351
+ return res
352
+
353
+ for line in gnina_result.stdout.split('\n'):
354
+ if line.startswith('Affinity'):
355
+ res['vina_score'] = float(line.split(' ')[1].strip())
356
+ if line.startswith('CNNaffinity'):
357
+ res['gnina_score'] = float(line.split(' ')[1].strip())
358
+ if line.startswith('CNNscore'):
359
+ res['cnn_score'] = float(line.split(' ')[1].strip())
360
+ if line.startswith('RMSD'):
361
+ res['minimisation_rmsd'] = float(line.split(' ')[1].strip())
362
+
363
+ return res
364
+
365
+ @property
366
+ def _dtypes(self):
367
+ return {'*': float}
368
+
369
+
370
+ class MedChemEvaluator(AbstractEvaluator):
371
+ ID = 'medchem'
372
+ def __init__(self, connectivity_threshold=1.0):
373
+ self.connectivity_threshold = connectivity_threshold
374
+
375
+ def evaluate(self, molecule, protein=None):
376
+ molecule = self.load_molecule(molecule)
377
+ valid = self.is_valid(molecule)
378
+
379
+ if valid:
380
+ Chem.SanitizeMol(molecule)
381
+
382
+ connected = None if not valid else self.is_connected(molecule)
383
+ qed = None if not valid else self.calculate_qed(molecule)
384
+ sa = None if not valid else self.calculate_sa(molecule)
385
+ logp = None if not valid else self.calculate_logp(molecule)
386
+ lipinski = None if not valid else self.calculate_lipinski(molecule)
387
+ n_rotatable_bonds = None if not valid else self.calculate_rotatable_bonds(molecule)
388
+ size = self.calculate_molecule_size(molecule)
389
+
390
+ return {
391
+ 'valid': valid,
392
+ 'connected': connected,
393
+ 'qed': qed,
394
+ 'sa': sa,
395
+ 'logp': logp,
396
+ 'lipinski': lipinski,
397
+ 'size': size,
398
+ 'n_rotatable_bonds': n_rotatable_bonds,
399
+ }
400
+
401
+ @staticmethod
402
+ def is_valid(rdmol):
403
+ if rdmol.GetNumAtoms() < 1:
404
+ return False
405
+
406
+ _mol = Chem.Mol(rdmol)
407
+ try:
408
+ Chem.SanitizeMol(_mol)
409
+ except ValueError:
410
+ return False
411
+
412
+ return True
413
+
414
+ def is_connected(self, rdmol):
415
+ if rdmol.GetNumAtoms() < 1:
416
+ return False
417
+
418
+ try:
419
+ mol_frags = Chem.rdmolops.GetMolFrags(rdmol, asMols=True)
420
+ largest_frag = max(mol_frags, default=rdmol, key=lambda m: m.GetNumAtoms())
421
+ return largest_frag.GetNumAtoms() / rdmol.GetNumAtoms() >= self.connectivity_threshold
422
+ except:
423
+ return False
424
+
425
+ @staticmethod
426
+ def calculate_qed(rdmol):
427
+ try:
428
+ return QED.qed(rdmol)
429
+ except:
430
+ return None
431
+
432
+ @staticmethod
433
+ def calculate_sa(rdmol):
434
+ try:
435
+ sa = calculateScore(rdmol)
436
+ return sa
437
+ except:
438
+ return None
439
+
440
+ @staticmethod
441
+ def calculate_logp(rdmol):
442
+ try:
443
+ return Crippen.MolLogP(rdmol)
444
+ except:
445
+ return None
446
+
447
+ @staticmethod
448
+ def calculate_lipinski(rdmol):
449
+ try:
450
+ rule_1 = Descriptors.ExactMolWt(rdmol) < 500
451
+ rule_2 = Lipinski.NumHDonors(rdmol) <= 5
452
+ rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10
453
+ rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5)
454
+ rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10
455
+ return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])
456
+ except:
457
+ return None
458
+
459
+ @staticmethod
460
+ def calculate_molecule_size(rdmol):
461
+ try:
462
+ return rdmol.GetNumAtoms()
463
+ except:
464
+ return None
465
+
466
+ @staticmethod
467
+ def calculate_rotatable_bonds(rdmol):
468
+ try:
469
+ return Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol)
470
+ except:
471
+ return None
472
+
473
+ @property
474
+ def _dtypes(self):
475
+ return {
476
+ 'valid': bool,
477
+ 'connected': bool,
478
+ 'qed': float,
479
+ 'sa': float,
480
+ 'logp': float,
481
+ 'lipinski': int,
482
+ 'size': int,
483
+ 'n_rotatable_bonds': int,
484
+ }
485
+
486
+
487
+ class ClashEvaluator(AbstractEvaluator):
488
+ ID = 'clashes'
489
+ def __init__(self, margin=0.75, ignore={'H'}):
490
+ self.margin = margin
491
+ self.ignore = ignore
492
+
493
+ def evaluate(self, molecule=None, protein=None):
494
+ result = {
495
+ 'passed_clash_score_ligands': None,
496
+ 'passed_clash_score_pockets': None,
497
+ 'passed_clash_score_between': None,
498
+ }
499
+ if molecule is not None:
500
+ molecule = self.load_molecule(molecule)
501
+ clash_score = self.clash_score(molecule)
502
+ result['clash_score_ligands'] = clash_score
503
+ result['passed_clash_score_ligands'] = (clash_score == 0)
504
+
505
+ if protein is not None:
506
+ protein = Chem.MolFromPDBFile(str(protein), sanitize=False)
507
+ clash_score = self.clash_score(protein)
508
+ result['clash_score_pockets'] = clash_score
509
+ result['passed_clash_score_pockets'] = (clash_score == 0)
510
+
511
+ if molecule is not None and protein is not None:
512
+ clash_score = self.clash_score(molecule, protein)
513
+ result['clash_score_between'] = clash_score
514
+ result['passed_clash_score_between'] = (clash_score == 0)
515
+
516
+ return result
517
+
518
+ def clash_score(self, rdmol1, rdmol2=None):
519
+ """
520
+ Computes a clash score as the number of atoms that have at least one
521
+ clash divided by the number of atoms in the molecule.
522
+
523
+ INTERMOLECULAR CLASH SCORE
524
+ If rdmol2 is provided, the score is the percentage of atoms in rdmol1
525
+ that have at least one clash with rdmol2.
526
+ We define a clash if two atoms are closer than "margin times the sum of
527
+ their van der Waals radii".
528
+
529
+ INTRAMOLECULAR CLASH SCORE
530
+ If rdmol2 is not provided, the score is the percentage of atoms in rdmol1
531
+ that have at least one clash with other atoms in rdmol1.
532
+ In this case, a clash is defined by margin times the atoms' smallest
533
+ covalent radii (among single, double and triple bond radii). This is done
534
+ so that this function is applicable even if no connectivity information is
535
+ available.
536
+ """
537
+
538
+ intramolecular = rdmol2 is None
539
+ if intramolecular:
540
+ rdmol2 = rdmol1
541
+
542
+ coord1, radii1 = self.coord_and_radii(rdmol1, intramolecular=intramolecular)
543
+ coord2, radii2 = self.coord_and_radii(rdmol2, intramolecular=intramolecular)
544
+
545
+ dist = np.sqrt(np.sum((coord1[:, None, :] - coord2[None, :, :]) ** 2, axis=-1))
546
+ if intramolecular:
547
+ np.fill_diagonal(dist, np.inf)
548
+
549
+ clashes = dist < self.margin * (radii1[:, None] + radii2[None, :])
550
+ clashes = np.any(clashes, axis=1)
551
+ return np.mean(clashes)
552
+
553
+ def coord_and_radii(self, rdmol, intramolecular):
554
+ _periodic_table = Chem.GetPeriodicTable()
555
+ _get_radius = _periodic_table.GetRcovalent if intramolecular else _periodic_table.GetRvdw
556
+
557
+ coord = rdmol.GetConformer().GetPositions()
558
+ radii = np.array([_get_radius(a.GetSymbol()) for a in rdmol.GetAtoms()])
559
+
560
+ mask = np.array([a.GetSymbol() not in self.ignore for a in rdmol.GetAtoms()])
561
+ coord = coord[mask]
562
+ radii = radii[mask]
563
+
564
+ assert coord.shape[0] == radii.shape[0]
565
+ return coord, radii
566
+
567
+ @property
568
+ def _dtypes(self):
569
+ return {
570
+ 'clash_score_ligands': float,
571
+ 'clash_score_pockets': float,
572
+ 'clash_score_between': float,
573
+ 'passed_clash_score_ligands': bool,
574
+ 'passed_clash_score_pockets': bool,
575
+ 'passed_clash_score_between': bool,
576
+ }
577
+
578
+
579
+ class RingCountEvaluator(AbstractEvaluator):
580
+ ID = 'ring_count'
581
+
582
+ def evaluate(self, molecule, protein=None):
583
+ _mol = self.load_molecule(molecule)
584
+
585
+ # compute ring info if not yet available
586
+ try:
587
+ _mol.UpdatePropertyCache()
588
+ except ValueError:
589
+ return {}
590
+ Chem.GetSymmSSSR(_mol)
591
+
592
+ rings = _mol.GetRingInfo().AtomRings()
593
+ ring_sizes = [len(r) for r in rings]
594
+
595
+ ring_counts = defaultdict(int)
596
+ for k in ring_sizes:
597
+ ring_counts[f"num_{k}_rings"] += 1
598
+
599
+ return ring_counts
600
+
601
+ @property
602
+ def _dtypes(self):
603
+ return {'*': int}
604
+
605
+
606
+ class ChemblRingEvaluator(AbstractEvaluator):
607
+ ID = 'chembl_ring_systems'
608
+
609
+ def __init__(self):
610
+ self.ring_system_lookup = RingSystemLookup.default() # ChEMBL
611
+
612
+ def evaluate(self, molecule, protein=None):
613
+
614
+ results = {
615
+ 'min_ring_smi': None,
616
+ 'min_ring_freq_gt0_': None,
617
+ 'min_ring_freq_gt10_': None,
618
+ 'min_ring_freq_gt100_': None,
619
+ }
620
+
621
+ molecule = self.load_molecule(molecule)
622
+
623
+ try:
624
+ Chem.SanitizeMol(molecule)
625
+ freq_list = self.ring_system_lookup.process_mol(molecule)
626
+ freq_list = self.ring_system_lookup.process_mol(molecule)
627
+ except ValueError:
628
+ return results
629
+
630
+ min_ring, min_freq = get_min_ring_frequency(freq_list)
631
+
632
+ return {
633
+ 'min_ring_smi': min_ring,
634
+ 'min_ring_freq_gt0_': min_freq > 0,
635
+ 'min_ring_freq_gt10_': min_freq > 10,
636
+ 'min_ring_freq_gt100_': min_freq > 100,
637
+ }
638
+
639
+ @property
640
+ def _dtypes(self):
641
+ return {
642
+ 'min_ring_smi': str,
643
+ 'min_ring_freq_gt0_': bool,
644
+ 'min_ring_freq_gt10_': bool,
645
+ 'min_ring_freq_gt100_': bool,
646
+ }
647
+
648
+
649
+ class REOSEvaluator(AbstractEvaluator):
650
+ # Based on https://practicalcheminformatics.blogspot.com/2024/05/generative-molecular-design-isnt-as.html
651
+ ID = 'reos'
652
+
653
+ def __init__(self):
654
+ self.reos = REOS()
655
+
656
+ def evaluate(self, molecule, protein=None):
657
+
658
+ molecule = self.load_molecule(molecule)
659
+ try:
660
+ Chem.SanitizeMol(molecule)
661
+ except ValueError:
662
+ return {rule_set: False for rule_set in self.reos.get_available_rule_sets()}
663
+
664
+ results = {}
665
+ for rule_set in self.reos.get_available_rule_sets():
666
+ self.reos.set_active_rule_sets([rule_set])
667
+ if rule_set == 'PW':
668
+ self.reos.drop_rule('furans')
669
+
670
+ reos_res = self.reos.process_mol(molecule)
671
+ results[rule_set] = reos_res[0] == 'ok'
672
+
673
+ results['all'] = all([bool(value) if not is_nan(value) else False for value in results.values()])
674
+ return results
675
+
676
+ @property
677
+ def _dtypes(self):
678
+ return {'*': bool}
679
+
680
+
681
+ class FullEvaluator(AbstractEvaluator):
682
+ def __init__(
683
+ self,
684
+ pb_conf: str = 'dock',
685
+ gnina: Optional[Union[Path, str]] = None,
686
+ reduce: Optional[Union[Path, str]] = None,
687
+ connectivity_threshold: float = 1.0,
688
+ margin: float = 0.75,
689
+ ignore: Set[str] = {'H'},
690
+ exclude_evaluators: Collection[str] = [],
691
+ ):
692
+ all_evaluators = [
693
+ RepresentationEvaluator(),
694
+ MolPropertyEvaluator(),
695
+ PoseBustersEvaluator(pb_conf=pb_conf),
696
+ MedChemEvaluator(connectivity_threshold=connectivity_threshold),
697
+ ClashEvaluator(margin=margin, ignore=ignore),
698
+ GeometryEvaluator(),
699
+ RingCountEvaluator(),
700
+ EnergyEvaluator(),
701
+ ChemblRingEvaluator(),
702
+ REOSEvaluator()
703
+ ]
704
+ if gnina is not None:
705
+ all_evaluators.append(GninaEvalulator(gnina=gnina))
706
+ else:
707
+ print(f'Evaluator [{GninaEvalulator.ID}] is not included')
708
+ if reduce is not None:
709
+ all_evaluators.append(InteractionsEvaluator(reduce=reduce))
710
+ else:
711
+ print(f'Evaluator [{InteractionsEvaluator.ID}] is not included')
712
+
713
+ self.evaluators = []
714
+ for e in all_evaluators:
715
+ if e.ID in exclude_evaluators:
716
+ print(f'Excluded Evaluator [{e.ID}]')
717
+ else:
718
+ self.evaluators.append(e)
719
+
720
+ print('Will use the following evaluators:')
721
+ for e in self.evaluators:
722
+ print(f'- [{e.ID}]')
723
+
724
+
725
+ def evaluate(self, molecule, protein):
726
+ results = {}
727
+ for evaluator in self.evaluators:
728
+ results.update(evaluator(molecule, protein))
729
+ return results
730
+
731
+ @property
732
+ def _dtypes(self):
733
+ all_dtypes = {}
734
+ for evaluator in self.evaluators:
735
+ all_dtypes.update(evaluator.dtypes)
736
+ return all_dtypes
737
+
738
+
739
+ ########################################################################################
740
+ ################################# Collection Metrics ###################################
741
+ ########################################################################################
742
+
743
+
744
+ class AbstractCollectionEvaluator:
745
+ ID = None
746
+ def __call__(self, smiles: Collection[str], timeout=300):
747
+ """
748
+ Args:
749
+ smiles (Collection[smiles]): input list of SMILES
750
+
751
+ Returns:
752
+ metrics (dict): dictionary of metrics
753
+ """
754
+ if self.ID is not None:
755
+ print(f'Running CollectionEvaluator [{self.ID}]')
756
+
757
+ RDLogger.DisableLog('rdApp.*')
758
+ self.check_format(smiles)
759
+ # timeout handler
760
+ signal.signal(signal.SIGALRM, timeout_handler)
761
+ try:
762
+ signal.alarm(timeout)
763
+ results = self.evaluate(smiles)
764
+ except TimeoutError:
765
+ print(f'Error when evaluating [{self.ID}]: Timeout after {timeout} seconds')
766
+ signal.alarm(0)
767
+ return {}
768
+ except Exception as e:
769
+ print(f'Error when evaluating [{self.ID}]: {e}')
770
+ signal.alarm(0)
771
+ return {}
772
+ finally:
773
+ print(f'Finished CollectionEvaluator [{self.ID}]')
774
+ signal.alarm(0)
775
+ return results
776
+
777
+ @staticmethod
778
+ def check_format(smiles):
779
+ assert len(smiles) > 0, 'List of input SMILES cannot be empty'
780
+ assert isinstance(smiles, Collection), 'Only list of SMILES supported'
781
+ assert isinstance(smiles[0], str), 'Only list of SMILES supported'
782
+
783
+
784
+ class UniquenessEvaluator(AbstractCollectionEvaluator):
785
+ ID = 'uniqueness'
786
+ def evaluate(self, smiles: Collection[str]):
787
+ uniqueness = len(set(smiles)) / len(smiles)
788
+ return {'uniqueness': uniqueness}
789
+
790
+
791
+ class NoveltyEvaluator(AbstractCollectionEvaluator):
792
+ ID = 'novelty'
793
+ def __init__(self, reference_smiles: Collection[str]):
794
+ self.reference_smiles = set(list(reference_smiles))
795
+ assert len(self.reference_smiles) > 0, 'List of refernce SMILES cannot be empty'
796
+
797
+ def evaluate(self, smiles: Collection[str]):
798
+ smiles = set(smiles)
799
+ novel = [smi for smi in smiles if smi not in self.reference_smiles]
800
+ novelty = len(novel) / len(smiles)
801
+ return {'novelty': novelty}
802
+
803
+ def canonical_smiles(smiles):
804
+ for smi in smiles:
805
+ try:
806
+ mol = Chem.MolFromSmiles(smi)
807
+ if mol is not None:
808
+ yield Chem.MolToSmiles(mol)
809
+ except:
810
+ yield None
811
+
812
+ class FCDEvaluator(AbstractCollectionEvaluator):
813
+ ID = 'fcd'
814
+ def __init__(self, reference_smiles: Collection[str]):
815
+ self.reference_smiles = list(reference_smiles)
816
+ assert len(self.reference_smiles) > 0, 'List of refernce SMILES cannot be empty'
817
+
818
+ def evaluate(self, smiles: Collection[str]):
819
+ if len(smiles) > len(self.reference_smiles):
820
+ print('Number of reference molecules should be greater than number of input molecules')
821
+ return {'fcd': None}
822
+
823
+ np.random.seed(42)
824
+ reference_smiles = np.random.choice(self.reference_smiles, len(smiles), replace=False).tolist()
825
+ reference_smiles_canonical = [w for w in canonical_smiles(reference_smiles) if w is not None]
826
+ smiles_canonical = [w for w in canonical_smiles(smiles) if w is not None]
827
+ fcd = get_fcd(reference_smiles_canonical, smiles_canonical)
828
+ return {'fcd': fcd}
829
+
830
+
831
+ class RingDistributionEvaluator(AbstractCollectionEvaluator):
832
+ ID = 'ring_system_distribution'
833
+
834
+ def __init__(self, reference_smiles: Collection[str], jsd_on_k_most_freq: Collection[int] = ()):
835
+ self.ring_system_finder = RingSystemFinder()
836
+ self.ref_ring_dict = self.compute_ring_dict(reference_smiles)
837
+ self.jsd_on_k_most_freq = jsd_on_k_most_freq
838
+
839
+ def compute_ring_dict(self, molecules):
840
+
841
+ ring_system_dict = defaultdict(int)
842
+
843
+ for mol in tqdm(molecules, desc="Computing ring systems"):
844
+
845
+ if isinstance(mol, str):
846
+ mol = Chem.MolFromSmiles(mol)
847
+
848
+ try:
849
+ ring_system_list = self.ring_system_finder.find_ring_systems(mol, as_mols=True)
850
+ except ValueError:
851
+ print(f"WARNING[{type(self).__name__}]: error while computing ring systems; skipping molecule.")
852
+ continue
853
+
854
+ for ring in ring_system_list:
855
+ inchi_key = Chem.MolToInchiKey(ring)
856
+ ring_system_dict[inchi_key] += 1
857
+
858
+ return ring_system_dict
859
+
860
+ def precision(self, query_ring_dict):
861
+ query_ring_systems = set(query_ring_dict.keys())
862
+ ref_ring_systems = set(self.ref_ring_dict.keys())
863
+ intersection = ref_ring_systems & query_ring_systems
864
+ return len(intersection) / len(query_ring_systems) if len(query_ring_systems) > 0 else 0
865
+
866
+ def recall(self, query_ring_dict):
867
+ query_ring_systems = set(query_ring_dict.keys())
868
+ ref_ring_systems = set(self.ref_ring_dict.keys())
869
+ intersection = ref_ring_systems & query_ring_systems
870
+ return len(intersection) / len(ref_ring_systems) if len(ref_ring_systems) > 0 else 0
871
+
872
+ def jsd(self, query_ring_dict, k_most_freq=None):
873
+
874
+ if k_most_freq is None:
875
+ # example on the union of all ring systems
876
+ sample_space = set(self.ref_ring_dict.keys()) | set(query_ring_dict.keys())
877
+ else:
878
+ # evaluate only on the k most common rings from the reference set
879
+ sorted_rings = [k for k, v in sorted(self.ref_ring_dict.items(), key=lambda item: item[1], reverse=True)]
880
+ sample_space = sorted_rings[:k_most_freq]
881
+
882
+ p = np.zeros(len(sample_space))
883
+ q = np.zeros(len(sample_space))
884
+
885
+ for i, inchi_key in enumerate(sample_space):
886
+ p[i] = self.ref_ring_dict.get(inchi_key, 0)
887
+ q[i] = query_ring_dict.get(inchi_key, 0)
888
+
889
+ # normalize
890
+ p = p / np.sum(p)
891
+ q = q / np.sum(q)
892
+
893
+ return jensenshannon(p, q)
894
+
895
+ def evaluate(self, smiles: Collection[str]):
896
+
897
+ query_ring_dict = self.compute_ring_dict(smiles)
898
+
899
+ out = {
900
+ "precision": self.precision(query_ring_dict),
901
+ "recall": self.recall(query_ring_dict),
902
+ "jsd": self.jsd(query_ring_dict),
903
+ }
904
+
905
+ out.update(
906
+ {f"jsd_{k}_most_freq": self.jsd(query_ring_dict, k_most_freq=k) for k in self.jsd_on_k_most_freq}
907
+ )
908
+
909
+ return out
910
+
911
+
912
+ class FullCollectionEvaluator(AbstractCollectionEvaluator):
913
+ def __init__(self, reference_smiles: Collection[str], exclude_evaluators: Collection[str] = []):
914
+ self.evaluators = [
915
+ UniquenessEvaluator(),
916
+ NoveltyEvaluator(reference_smiles=reference_smiles),
917
+ FCDEvaluator(reference_smiles=reference_smiles),
918
+ RingDistributionEvaluator(reference_smiles, jsd_on_k_most_freq=[10, 100, 1000, 10000]),
919
+ ]
920
+ for e in self.evaluators:
921
+ if e.ID in exclude_evaluators:
922
+ print(f'Excluding CollectionEvaluator [{e.ID}]')
923
+ self.evaluators.remove(e)
924
+
925
+ def evaluate(self, smiles):
926
+ results = {}
927
+ for evaluator in self.evaluators:
928
+ results.update(evaluator(smiles))
929
+ return results
src/sbdd_metrics/sascorer.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # calculation of synthetic accessibility score as described in:
3
+ #
4
+ # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
5
+ # Peter Ertl and Ansgar Schuffenhauer
6
+ # Journal of Cheminformatics 1:8 (2009)
7
+ # http://www.jcheminf.com/content/1/1/8
8
+ #
9
+ # several small modifications to the original paper are included
10
+ # particularly slightly different formula for marocyclic penalty
11
+ # and taking into account also molecule symmetry (fingerprint density)
12
+ #
13
+ # for a set of 10k diverse molecules the agreement between the original method
14
+ # as implemented in PipelinePilot and this implementation is r2 = 0.97
15
+ #
16
+ # peter ertl & greg landrum, september 2013
17
+ #
18
+
19
+
20
+ from rdkit import Chem
21
+ from rdkit.Chem import rdMolDescriptors
22
+ import pickle
23
+
24
+ import math
25
+ from collections import defaultdict
26
+
27
+ import os.path as op
28
+
29
+ _fscores = None
30
+
31
+
32
+ def readFragmentScores(name='fpscores'):
33
+ import gzip
34
+ global _fscores
35
+ # generate the full path filename:
36
+ if name == "fpscores":
37
+ name = op.join(op.dirname(__file__), name)
38
+ data = pickle.load(gzip.open('%s.pkl.gz' % name))
39
+ outDict = {}
40
+ for i in data:
41
+ for j in range(1, len(i)):
42
+ outDict[i[j]] = float(i[0])
43
+ _fscores = outDict
44
+
45
+
46
+ def numBridgeheadsAndSpiro(mol, ri=None):
47
+ nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
48
+ nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
49
+ return nBridgehead, nSpiro
50
+
51
+
52
+ def calculateScore(m):
53
+ if _fscores is None:
54
+ readFragmentScores()
55
+
56
+ # fragment score
57
+ fp = rdMolDescriptors.GetMorganFingerprint(m,
58
+ 2) # <- 2 is the *radius* of the circular fingerprint
59
+ fps = fp.GetNonzeroElements()
60
+ score1 = 0.
61
+ nf = 0
62
+ for bitId, v in fps.items():
63
+ nf += v
64
+ sfp = bitId
65
+ score1 += _fscores.get(sfp, -4) * v
66
+ score1 /= nf
67
+
68
+ # features score
69
+ nAtoms = m.GetNumAtoms()
70
+ nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
71
+ ri = m.GetRingInfo()
72
+ nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
73
+ nMacrocycles = 0
74
+ for x in ri.AtomRings():
75
+ if len(x) > 8:
76
+ nMacrocycles += 1
77
+
78
+ sizePenalty = nAtoms**1.005 - nAtoms
79
+ stereoPenalty = math.log10(nChiralCenters + 1)
80
+ spiroPenalty = math.log10(nSpiro + 1)
81
+ bridgePenalty = math.log10(nBridgeheads + 1)
82
+ macrocyclePenalty = 0.
83
+ # ---------------------------------------
84
+ # This differs from the paper, which defines:
85
+ # macrocyclePenalty = math.log10(nMacrocycles+1)
86
+ # This form generates better results when 2 or more macrocycles are present
87
+ if nMacrocycles > 0:
88
+ macrocyclePenalty = math.log10(2)
89
+
90
+ score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
91
+
92
+ # correction for the fingerprint density
93
+ # not in the original publication, added in version 1.1
94
+ # to make highly symmetrical molecules easier to synthetise
95
+ score3 = 0.
96
+ if nAtoms > len(fps):
97
+ score3 = math.log(float(nAtoms) / len(fps)) * .5
98
+
99
+ sascore = score1 + score2 + score3
100
+
101
+ # need to transform "raw" value into scale between 1 and 10
102
+ min = -4.0
103
+ max = 2.5
104
+ sascore = 11. - (sascore - min + 1) / (max - min) * 9.
105
+ # smooth the 10-end
106
+ if sascore > 8.:
107
+ sascore = 8. + math.log(sascore + 1. - 9.)
108
+ if sascore > 10.:
109
+ sascore = 10.0
110
+ elif sascore < 1.:
111
+ sascore = 1.0
112
+
113
+ return sascore
114
+
115
+
116
+ def processMols(mols):
117
+ print('smiles\tName\tsa_score')
118
+ for i, m in enumerate(mols):
119
+ if m is None:
120
+ continue
121
+
122
+ s = calculateScore(m)
123
+
124
+ smiles = Chem.MolToSmiles(m)
125
+ print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
126
+
127
+
128
+ if __name__ == '__main__':
129
+ import sys
130
+ import time
131
+
132
+ t1 = time.time()
133
+ readFragmentScores("fpscores")
134
+ t2 = time.time()
135
+
136
+ suppl = Chem.SmilesMolSupplier(sys.argv[1])
137
+ t3 = time.time()
138
+ processMols(suppl)
139
+ t4 = time.time()
140
+
141
+ print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
142
+ file=sys.stderr)
143
+
144
+ #
145
+ # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
146
+ # All rights reserved.
147
+ #
148
+ # Redistribution and use in source and binary forms, with or without
149
+ # modification, are permitted provided that the following conditions are
150
+ # met:
151
+ #
152
+ # * Redistributions of source code must retain the above copyright
153
+ # notice, this list of conditions and the following disclaimer.
154
+ # * Redistributions in binary form must reproduce the above
155
+ # copyright notice, this list of conditions and the following
156
+ # disclaimer in the documentation and/or other materials provided
157
+ # with the distribution.
158
+ # * Neither the name of Novartis Institutes for BioMedical Research Inc.
159
+ # nor the names of its contributors may be used to endorse or promote
160
+ # products derived from this software without specific prior written permission.
161
+ #
162
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
163
+ # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
164
+ # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
165
+ # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
166
+ # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
167
+ # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
168
+ # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
169
+ # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
170
+ # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
171
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
172
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
173
+ #