diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..c34ac4784aabe2ca5773a2a90abb30738d696d80 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +docs/drugflow.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d41e0ee28e326db91de7b893d357febe12a66efb --- /dev/null +++ b/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2025 Arne Schneuing, Ilia Igashov, Adrian Dobbelstein + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/configs/sampling/sample_and_maybe_evaluate.yml b/configs/sampling/sample_and_maybe_evaluate.yml new file mode 100644 index 0000000000000000000000000000000000000000..d709a2f0b501e9e1acf426f9c03ec19c40548a99 --- /dev/null +++ b/configs/sampling/sample_and_maybe_evaluate.yml @@ -0,0 +1,25 @@ +checkpoint: +set: test +sample_outdir: ./samples +n_samples: 100 +sample_with_ground_truth_size: False +device: cuda +seed: 42 +sample: True +postprocess: False +evaluate: False +reduce: reduce + +# Override training config parameters if necessary +model_args: + + virtual_nodes: [0, 5] + + train_params: + datadir: ./processed_crossdocked + gnina: gnina + + eval_params: + n_sampling_steps: 500 + eval_batch_size: 1 + \ No newline at end of file diff --git a/configs/sampling/sample_train_split.yml b/configs/sampling/sample_train_split.yml new file mode 100644 index 0000000000000000000000000000000000000000..52b9626db839e683889e3ca14c370514e4253253 --- /dev/null +++ b/configs/sampling/sample_train_split.yml @@ -0,0 +1,25 @@ +checkpoint: +set: train +sample_outdir: ./samples +n_samples: 50 +sample_with_ground_truth_size: False +device: cuda +seed: 42 +sample: True +postprocess: False +evaluate: False +reduce: reduce + +# Override training config parameters if necessary +model_args: + + virtual_nodes: [0, 10] + + train_params: + datadir: ./processed_crossdocked + gnina: gnina + batch_size: 2 + + eval_params: + n_sampling_steps: 100 + \ No newline at end of file diff --git a/configs/training/drugflow.yml b/configs/training/drugflow.yml new file mode 100644 index 0000000000000000000000000000000000000000..f49916b84c716a62533851b81e8422e787b2ba66 --- /dev/null +++ b/configs/training/drugflow.yml @@ -0,0 +1,82 @@ +run_name: drugflow # iclr_drugflow_T5000 +pocket_representation: CA+ +virtual_nodes: [0, 10] +flexible: False +flexible_bb: False + +train_params: + logdir: ./runs # symlink to any location you like + datadir: ./processed_crossdocked # symlink to the dataset location + enable_progress_bar: True + num_sanity_val_steps: 0 + batch_size: 64 + accumulate_grad_batches: 2 + lr: 5.0e-4 + n_epochs: 1000 + num_workers: 0 + gpus: 1 + clip_grad: True + gnina: gnina + sample_from_clusters: False + sharded_dataset: False + +wandb_params: + mode: online # disabled, offline, online + entity: + group: crossdocked + +loss_params: + discrete_loss: VLB # VLB or CE + lambda_x: 1.0 + lambda_h: 50.0 + lambda_e: 50.0 + lambda_chi: null + lambda_trans: null + lambda_rot: null + lambda_clash: null + timestep_weights: null + +simulation_params: + n_steps: 5000 + prior_h: marginal # uniform, marginal + prior_e: uniform # uniform, marginal + predict_final: False + predict_confidence: False + +eval_params: + eval_epochs: 100 + n_eval_samples: 4 + n_sampling_steps: 500 + eval_batch_size: 16 + visualize_sample_epoch: 1 + n_visualize_samples: 100 + visualize_chain_epoch: 1 + keep_frames: 100 + sample_with_ground_truth_size: True + +predictor_params: + heterogeneous_graph: True + backbone: gvp + num_rbf_time: 16 + edge_cutoff_ligand: null + edge_cutoff_pocket: 10.0 + edge_cutoff_interaction: 10.0 + cycle_counts: True + spectral_feat: False + reflection_equivariant: False + num_rbf: 16 + d_max: 15.0 + self_conditioning: True + augment_residue_sc: False + augment_ligand_sc: False + normal_modes: False + add_chi_as_feature: False + angle_act_fn: null + add_all_atom_diff: False + + gvp_params: + n_layers: 5 + node_h_dim: [ 128, 32 ] # (s, V) + edge_h_dim: [ 128, 32 ] + dropout: 0.0 + vector_gate: True \ No newline at end of file diff --git a/configs/training/drugflow_no_virtual_nodes.yml b/configs/training/drugflow_no_virtual_nodes.yml new file mode 100644 index 0000000000000000000000000000000000000000..776bea42bb96ddd721a5ce5c57071249daa889f1 --- /dev/null +++ b/configs/training/drugflow_no_virtual_nodes.yml @@ -0,0 +1,82 @@ +run_name: drugflow_no_virtual_nodes # iclr_drugflow_T5000_no_virtual_nodes +pocket_representation: CA+ +virtual_nodes: null +flexible: False +flexible_bb: False + +train_params: + logdir: ./runs # symlink to any location you like + datadir: ./processed_crossdocked # symlink to the dataset location + enable_progress_bar: True + num_sanity_val_steps: 0 + batch_size: 64 + accumulate_grad_batches: 2 + lr: 5.0e-4 + n_epochs: 1000 + num_workers: 0 + gpus: 1 + clip_grad: True + gnina: gnina + sample_from_clusters: False + sharded_dataset: False + +wandb_params: + mode: online # disabled, offline, online + entity: lpdi + group: crossdocked + +loss_params: + discrete_loss: VLB # VLB or CE + lambda_x: 1.0 + lambda_h: 50.0 + lambda_e: 50.0 + lambda_chi: null + lambda_trans: null + lambda_rot: null + lambda_clash: null + timestep_weights: null + +simulation_params: + n_steps: 5000 + prior_h: marginal # uniform, marginal + prior_e: uniform # uniform, marginal + predict_final: False + predict_confidence: False + +eval_params: + eval_epochs: 100 + n_eval_samples: 4 + n_sampling_steps: 500 + eval_batch_size: 16 + visualize_sample_epoch: 1 + n_visualize_samples: 100 + visualize_chain_epoch: 1 + keep_frames: 100 + sample_with_ground_truth_size: True + +predictor_params: + heterogeneous_graph: True + backbone: gvp + num_rbf_time: 16 + edge_cutoff_ligand: null + edge_cutoff_pocket: 10.0 + edge_cutoff_interaction: 10.0 + cycle_counts: True + spectral_feat: False + reflection_equivariant: False + num_rbf: 16 + d_max: 15.0 + self_conditioning: True + augment_residue_sc: False + augment_ligand_sc: False + normal_modes: False + add_chi_as_feature: False + angle_act_fn: null + add_all_atom_diff: False + + gvp_params: + n_layers: 5 + node_h_dim: [ 128, 32 ] # (s, V) + edge_h_dim: [ 128, 32 ] + dropout: 0.0 + vector_gate: True diff --git a/configs/training/drugflow_ood.yml b/configs/training/drugflow_ood.yml new file mode 100644 index 0000000000000000000000000000000000000000..aac36391eb6c04586075f46d29e6d1c1a8ce7378 --- /dev/null +++ b/configs/training/drugflow_ood.yml @@ -0,0 +1,83 @@ +run_name: drugflow_ood # iclr_drugflow_T5000_confidence_ru10 +pocket_representation: CA+ +virtual_nodes: [0, 10] +flexible: False +flexible_bb: False + +train_params: + logdir: ./runs # symlink to any location you like + datadir: ./processed_crossdocked # symlink to the dataset location + enable_progress_bar: True + num_sanity_val_steps: 0 + batch_size: 64 + accumulate_grad_batches: 2 + lr: 5.0e-4 + n_epochs: 1000 + num_workers: 0 + gpus: 1 + clip_grad: True + gnina: gnina + sample_from_clusters: False + sharded_dataset: False + +wandb_params: + mode: online # disabled, offline, online + entity: lpdi + group: crossdocked + +loss_params: + discrete_loss: VLB # VLB or CE + lambda_x: 1.0 + lambda_h: 50.0 + lambda_e: 50.0 + lambda_chi: null + lambda_trans: null + lambda_rot: null + lambda_clash: null + timestep_weights: null + regularize_uncertainty: 10.0 + +simulation_params: + n_steps: 5000 + prior_h: marginal # uniform, marginal + prior_e: uniform # uniform, marginal + predict_final: False + predict_confidence: True + +eval_params: + eval_epochs: 100 + n_eval_samples: 4 + n_sampling_steps: 500 + eval_batch_size: 16 + visualize_sample_epoch: 1 + n_visualize_samples: 100 + visualize_chain_epoch: 1 + keep_frames: 100 + sample_with_ground_truth_size: True + +predictor_params: + heterogeneous_graph: True + backbone: gvp + num_rbf_time: 16 + edge_cutoff_ligand: null + edge_cutoff_pocket: 10.0 + edge_cutoff_interaction: 10.0 + cycle_counts: True + spectral_feat: False + reflection_equivariant: False + num_rbf: 16 + d_max: 15.0 + self_conditioning: True + augment_residue_sc: False + augment_ligand_sc: False + normal_modes: False + add_chi_as_feature: False + angle_act_fn: null + add_all_atom_diff: False + + gvp_params: + n_layers: 5 + node_h_dim: [ 128, 32 ] # (s, V) + edge_h_dim: [ 128, 32 ] + dropout: 0.0 + vector_gate: True \ No newline at end of file diff --git a/configs/training/flexflow.yml b/configs/training/flexflow.yml new file mode 100644 index 0000000000000000000000000000000000000000..d2192cb63b8602b48b4a40ba9482c9a6e34fd3c1 --- /dev/null +++ b/configs/training/flexflow.yml @@ -0,0 +1,90 @@ +run_name: flexflow +pocket_representation: CA+ +virtual_nodes: [0, 10] +flexible: True +flexible_bb: False + +train_params: + logdir: ./runs # symlink to any location you like + datadir: ./processed_crossdocked # symlink to the dataset location + enable_progress_bar: False + num_sanity_val_steps: 0 + batch_size: 64 + accumulate_grad_batches: 2 + lr: 5.0e-4 + lr_step_size: null + lr_gamma: null + n_epochs: 700 + num_workers: 4 + gpus: 1 + clip_grad: True + gnina: gnina # add Gnina location to path + sample_from_clusters: False + sharded_dataset: False + +wandb_params: + mode: online # disabled, offline, online + entity: + group: crossdocked + +loss_params: + discrete_loss: VLB # VLB or CE + reduce: sum # 'mean' or 'sum' + lambda_x: 0.015 + lambda_h: 2.5 + lambda_e: 0.25 + lambda_chi: 0.002 + lambda_trans: null + lambda_rot: null + lambda_clash: null + regularize_uncertainty: null + timestep_weights: null + +simulation_params: + n_steps: 5000 + prior_h: marginal # uniform, marginal + prior_e: uniform # uniform, marginal + predict_final: False + predict_confidence: False + scheduler_chi: + type: polynomial + k: 3 # constant for exponential scheduler kappa(t)=(1-t)^k + +eval_params: + eval_epochs: 100 + n_loss_per_sample: 100 + n_eval_samples: 4 + n_sampling_steps: 500 + eval_batch_size: 16 + visualize_sample_epoch: 1 + n_visualize_samples: 100 + visualize_chain_epoch: 1 + keep_frames: 100 + sample_with_ground_truth_size: True + +predictor_params: + heterogeneous_graph: True + backbone: gvp + num_rbf_time: 16 + edge_cutoff_ligand: null + edge_cutoff_pocket: 10.0 + edge_cutoff_interaction: 10.0 + cycle_counts: True + spectral_feat: False + reflection_equivariant: False + num_rbf: 16 + d_max: 15.0 + self_conditioning: True + augment_residue_sc: False + augment_ligand_sc: False + normal_modes: False + add_chi_as_feature: False + angle_act_fn: null + add_all_atom_diff: True + + gvp_params: + n_layers: 5 + node_h_dim: [ 128, 32 ] # (s, V) + edge_h_dim: [ 128, 32 ] + dropout: 0.0 + vector_gate: True diff --git a/configs/training/preference_alignment.yml b/configs/training/preference_alignment.yml new file mode 100644 index 0000000000000000000000000000000000000000..1591e6b23e7682a3244d95cf96fef53de814059c --- /dev/null +++ b/configs/training/preference_alignment.yml @@ -0,0 +1,93 @@ +run_name: drugflow_preference_alignment + +checkpoint: ./reference.ckpt # TODO: specify reference checkpoint +dpo_mode: single_dpo_comp_v3 + +pocket_representation: CA+ +virtual_nodes: [0, 10] +flexible: False +flexible_bb: False + +train_params: + logdir: ./runs # symlink to any location you like + datadir: ./processed_crossdocked # symlink to the dataset location + enable_progress_bar: True + num_sanity_val_steps: 0 + batch_size: 64 + accumulate_grad_batches: 2 + lr: 5.0e-5 + n_epochs: 500 + num_workers: 0 + gpus: 1 + clip_grad: True + gnina: gnina # path to gnina binary + sample_from_clusters: False + sharded_dataset: False + +wandb_params: + mode: online # disabled, offline, online + entity: + group: crossdocked + +loss_params: + discrete_loss: VLB # VLB or CE + lambda_x: 1.0 + lambda_h: 500 + dpo_lambda_h: 2500 + lambda_e: 500 + dpo_lambda_e: 2500 + lambda_chi: 0.5 # only effective if flexible=True + lambda_trans: 1.0 # only effective if flexible_bb=True + lambda_rot: 0.1 # only effective if flexible_bb=True + lambda_clash: null + timestep_weights: null # sigmoid_a=1_b=10 # null, sigmoid_a=?_b=? + dpo_beta: 100.0 + dpo_beta_schedule: 't' + dpo_lambda_w: 1.0 + dpo_lambda_l: 0.2 + clamp_dpo: False + +simulation_params: + n_steps: 5000 + prior_h: marginal # uniform, marginal + prior_e: uniform # uniform, marginal + predict_final: False + predict_confidence: False + +eval_params: + eval_epochs: 4 + n_eval_samples: 1 + n_sampling_steps: 500 + eval_batch_size: 16 + visualize_sample_epoch: 1 + n_visualize_samples: 10 + visualize_chain_epoch: 1 + keep_frames: 100 + sample_with_ground_truth_size: True + +predictor_params: + heterogeneous_graph: True + backbone: gvp + num_rbf_time: 16 + edge_cutoff_ligand: null + edge_cutoff_pocket: 10.0 + edge_cutoff_interaction: 10.0 + cycle_counts: True + spectral_feat: False + reflection_equivariant: False + num_rbf: 16 + d_max: 15.0 + self_conditioning: True + augment_residue_sc: False + augment_ligand_sc: False + normal_modes: False + add_chi_as_feature: False + angle_act_fn: null + add_all_atom_diff: False + + gvp_params: + n_layers: 5 + node_h_dim: [ 128, 32 ] # (s, V) + edge_h_dim: [ 128, 32 ] + dropout: 0.0 + vector_gate: True diff --git a/docs/drugflow.jpg b/docs/drugflow.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ff128bb39f640bed29baa3929941bb54fd541c02 --- /dev/null +++ b/docs/drugflow.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c16f816eafa8e13658526b74f06f8ee5fb3258f51172c711ec1d1d539b48a4ef +size 762431 diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..02473f0a9f035c8e451bdd764e9a6031539441e7 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,30 @@ +name: sbdd + +channels: + - pytorch + - conda-forge + - anaconda + - pyg + - nvidia + +dependencies: + - python=3.11.8 + - pytorch=2.2.1=*cuda12.1* + - pytorch-cuda=12.1 + - pytorch-lightning=2.2.1 + - rdkit=2023.09.6 + - openbabel=3.1.1 + - biopython=1.83 + - scipy=1.12.0 + - pyg=2.5.1 + - pytorch-scatter=2.1.2 + - ProDy=2.4.0 + - wandb=0.16.3 + - pandas=2.2.2 + - pip=24.0 + - pip: + - posebusters==0.3.1 + - useful_rdkit_utils==0.65 + - fcd==1.2.2 + - webdataset==0.2.86 + - prolif==2.0.3 diff --git a/examples/kras.pdb b/examples/kras.pdb new file mode 100644 index 0000000000000000000000000000000000000000..0be256e448e3381da6b12077a36e2f43f6883325 --- /dev/null +++ b/examples/kras.pdb @@ -0,0 +1,1346 @@ +CRYST1 86.847 40.300 55.887 90.00 90.00 90.00 P 21 21 2 1 +ATOM 1 N GLY A 0 -9.024 8.471 19.532 1.00 16.21 N +ATOM 2 CA GLY A 0 -8.153 7.369 19.898 1.00 22.53 C +ATOM 3 C GLY A 0 -6.736 7.880 20.035 1.00 25.73 C +ATOM 4 O GLY A 0 -6.448 9.021 19.687 1.00 25.69 O +ATOM 5 N MET A 1 -5.841 7.048 20.547 1.00 17.40 N +ATOM 6 CA MET A 1 -4.477 7.486 20.800 1.00 18.67 C +ATOM 7 C MET A 1 -3.656 7.163 19.557 1.00 15.66 C +ATOM 8 O MET A 1 -3.681 6.028 19.072 1.00 15.26 O +ATOM 9 CB MET A 1 -3.940 6.828 22.075 1.00 17.78 C +ATOM 10 CG MET A 1 -2.481 7.106 22.432 1.00 17.55 C +ATOM 11 SD MET A 1 -1.822 6.027 23.717 1.00 16.83 S +ATOM 12 CE MET A 1 -2.924 6.405 25.093 1.00 17.61 C +ATOM 13 N THR A 2 -2.999 8.176 18.998 1.00 12.92 N +ATOM 14 CA THR A 2 -2.213 7.975 17.783 1.00 13.28 C +ATOM 15 C THR A 2 -0.940 7.186 18.083 1.00 17.65 C +ATOM 16 O THR A 2 -0.237 7.455 19.063 1.00 12.72 O +ATOM 17 CB THR A 2 -1.843 9.320 17.153 1.00 16.05 C +ATOM 18 CG2 THR A 2 -1.075 9.115 15.862 1.00 15.07 C +ATOM 19 OG1 THR A 2 -3.032 10.078 16.890 1.00 21.26 O +ATOM 20 N GLU A 3 -0.639 6.216 17.219 1.00 14.26 N +ATOM 21 CA GLU A 3 0.611 5.471 17.280 1.00 12.42 C +ATOM 22 C GLU A 3 1.540 5.956 16.167 1.00 11.31 C +ATOM 23 O GLU A 3 1.106 6.184 15.033 1.00 10.50 O +ATOM 24 CB GLU A 3 0.370 3.961 17.162 1.00 11.29 C +ATOM 25 CG GLU A 3 1.646 3.111 17.301 1.00 9.09 C +ATOM 26 CD GLU A 3 1.337 1.616 17.395 1.00 13.38 C +ATOM 27 OE1 GLU A 3 0.158 1.240 17.226 1.00 13.60 O +ATOM 28 OE2 GLU A 3 2.257 0.819 17.673 1.00 14.65 O1- +ATOM 29 N TYR A 4 2.818 6.134 16.506 1.00 12.77 N +ATOM 30 CA TYR A 4 3.841 6.593 15.571 1.00 12.84 C +ATOM 31 C TYR A 4 4.901 5.505 15.443 1.00 10.58 C +ATOM 32 O TYR A 4 5.497 5.100 16.443 1.00 9.16 O +ATOM 33 CB TYR A 4 4.485 7.898 16.059 1.00 14.15 C +ATOM 34 CG TYR A 4 3.532 9.067 16.106 1.00 16.47 C +ATOM 35 CD1 TYR A 4 3.235 9.795 14.965 1.00 19.05 C +ATOM 36 CD2 TYR A 4 2.914 9.428 17.298 1.00 18.43 C +ATOM 37 CE1 TYR A 4 2.355 10.868 15.015 1.00 21.16 C +ATOM 38 CE2 TYR A 4 2.042 10.488 17.350 1.00 16.04 C +ATOM 39 CZ TYR A 4 1.772 11.202 16.209 1.00 20.16 C +ATOM 40 OH TYR A 4 0.903 12.266 16.271 1.00 24.07 O +ATOM 41 N LYS A 5 5.109 5.017 14.223 1.00 11.58 N +ATOM 42 CA LYS A 5 6.116 3.996 13.939 1.00 12.13 C +ATOM 43 C LYS A 5 7.434 4.681 13.610 1.00 10.31 C +ATOM 44 O LYS A 5 7.579 5.285 12.542 1.00 10.17 O +ATOM 45 CB LYS A 5 5.642 3.100 12.795 1.00 13.50 C +ATOM 46 CG LYS A 5 6.316 1.728 12.652 1.00 25.83 C +ATOM 47 CD LYS A 5 6.842 1.057 13.940 1.00 20.73 C +ATOM 48 CE LYS A 5 7.851 -0.030 13.526 1.00 26.18 C +ATOM 49 NZ LYS A 5 7.325 -0.916 12.395 1.00 22.24 N1+ +ATOM 50 N LEU A 6 8.388 4.591 14.536 1.00 8.84 N +ATOM 51 CA LEU A 6 9.699 5.219 14.419 1.00 6.69 C +ATOM 52 C LEU A 6 10.756 4.144 14.209 1.00 8.74 C +ATOM 53 O LEU A 6 10.762 3.131 14.916 1.00 10.99 O +ATOM 54 CB LEU A 6 10.040 6.023 15.679 1.00 7.27 C +ATOM 55 CG LEU A 6 8.977 6.987 16.225 1.00 11.53 C +ATOM 56 CD1 LEU A 6 9.520 7.821 17.382 1.00 10.27 C +ATOM 57 CD2 LEU A 6 8.447 7.888 15.121 1.00 10.42 C +ATOM 58 N VAL A 7 11.678 4.377 13.275 1.00 7.92 N +ATOM 59 CA VAL A 7 12.716 3.398 12.951 1.00 6.15 C +ATOM 60 C VAL A 7 14.074 4.031 13.199 1.00 7.80 C +ATOM 61 O VAL A 7 14.352 5.125 12.696 1.00 6.25 O +ATOM 62 CB VAL A 7 12.607 2.922 11.492 1.00 5.84 C +ATOM 63 CG1 VAL A 7 13.662 1.851 11.182 1.00 9.90 C +ATOM 64 CG2 VAL A 7 11.214 2.405 11.200 1.00 7.52 C +ATOM 65 N VAL A 8 14.923 3.341 13.955 1.00 4.86 N +ATOM 66 CA VAL A 8 16.262 3.819 14.292 1.00 6.13 C +ATOM 67 C VAL A 8 17.263 3.074 13.423 1.00 4.98 C +ATOM 68 O VAL A 8 17.382 1.843 13.531 1.00 5.58 O +ATOM 69 CB VAL A 8 16.580 3.603 15.779 1.00 4.64 C +ATOM 70 CG1 VAL A 8 17.960 4.234 16.119 1.00 4.33 C +ATOM 71 CG2 VAL A 8 15.458 4.179 16.669 1.00 5.83 C +ATOM 72 N VAL A 9 18.005 3.812 12.585 1.00 5.74 N +ATOM 73 CA VAL A 9 18.950 3.215 11.649 1.00 6.52 C +ATOM 74 C VAL A 9 20.315 3.872 11.802 1.00 6.40 C +ATOM 75 O VAL A 9 20.457 4.963 12.366 1.00 7.37 O +ATOM 76 CB VAL A 9 18.477 3.313 10.173 1.00 5.91 C +ATOM 77 CG1 VAL A 9 17.064 2.767 10.041 1.00 6.87 C +ATOM 78 CG2 VAL A 9 18.563 4.777 9.645 1.00 4.36 C +ATOM 79 N GLY A 10 21.328 3.183 11.271 1.00 6.37 N +ATOM 80 CA GLY A 10 22.689 3.677 11.294 1.00 7.05 C +ATOM 81 C GLY A 10 23.660 2.529 11.499 1.00 7.65 C +ATOM 82 O GLY A 10 23.234 1.402 11.786 1.00 8.84 O +ATOM 83 N ALA A 11 24.958 2.811 11.363 1.00 6.94 N +ATOM 84 CA ALA A 11 25.989 1.777 11.429 1.00 8.38 C +ATOM 85 C ALA A 11 25.973 1.039 12.765 1.00 8.81 C +ATOM 86 O ALA A 11 25.614 1.592 13.816 1.00 7.06 O +ATOM 87 CB ALA A 11 27.364 2.402 11.206 1.00 10.80 C +ATOM 88 N GLY A 12 26.363 -0.236 12.714 1.00 12.10 N +ATOM 89 CA GLY A 12 26.527 -0.995 13.937 1.00 9.18 C +ATOM 90 C GLY A 12 27.429 -0.255 14.901 1.00 8.76 C +ATOM 91 O GLY A 12 28.448 0.309 14.481 1.00 9.88 O +ATOM 92 N GLY A 13 27.048 -0.210 16.182 1.00 8.52 N +ATOM 93 CA GLY A 13 27.874 0.369 17.230 1.00 7.41 C +ATOM 94 C GLY A 13 27.705 1.857 17.481 1.00 10.08 C +ATOM 95 O GLY A 13 28.399 2.404 18.343 1.00 8.92 O +ATOM 96 N VAL A 14 26.820 2.537 16.753 1.00 6.88 N +ATOM 97 CA VAL A 14 26.648 3.969 16.973 1.00 7.80 C +ATOM 98 C VAL A 14 25.796 4.260 18.207 1.00 9.95 C +ATOM 99 O VAL A 14 25.796 5.397 18.699 1.00 7.84 O +ATOM 100 CB VAL A 14 26.015 4.639 15.738 1.00 8.21 C +ATOM 101 CG1 VAL A 14 26.894 4.433 14.476 1.00 7.49 C +ATOM 102 CG2 VAL A 14 24.605 4.108 15.506 1.00 6.23 C +ATOM 103 N GLY A 15 25.068 3.277 18.718 1.00 5.43 N +ATOM 104 CA GLY A 15 24.223 3.474 19.871 1.00 5.97 C +ATOM 105 C GLY A 15 22.734 3.451 19.602 1.00 8.04 C +ATOM 106 O GLY A 15 21.985 4.030 20.388 1.00 6.49 O +ATOM 107 N LYS A 16 22.285 2.810 18.517 1.00 5.68 N +ATOM 108 CA LYS A 16 20.853 2.705 18.254 1.00 5.84 C +ATOM 109 C LYS A 16 20.136 2.030 19.411 1.00 6.71 C +ATOM 110 O LYS A 16 19.091 2.506 19.873 1.00 5.76 O +ATOM 111 CB LYS A 16 20.608 1.925 16.957 1.00 7.44 C +ATOM 112 CG LYS A 16 21.262 2.533 15.722 1.00 3.61 C +ATOM 113 CD LYS A 16 20.934 1.736 14.440 1.00 5.01 C +ATOM 114 CE LYS A 16 21.574 0.360 14.452 1.00 5.84 C +ATOM 115 NZ LYS A 16 23.046 0.367 14.638 1.00 8.17 N1+ +ATOM 116 N SER A 17 20.681 0.912 19.896 1.00 8.97 N +ATOM 117 CA SER A 17 20.007 0.202 20.979 1.00 5.88 C +ATOM 118 C SER A 17 20.059 1.003 22.266 1.00 5.58 C +ATOM 119 O SER A 17 19.047 1.132 22.967 1.00 6.43 O +ATOM 120 CB SER A 17 20.627 -1.183 21.157 1.00 5.85 C +ATOM 121 OG SER A 17 20.361 -1.962 19.997 1.00 4.69 O +ATOM 122 N ALA A 18 21.213 1.592 22.575 1.00 4.67 N +ATOM 123 CA ALA A 18 21.318 2.349 23.815 1.00 5.12 C +ATOM 124 C ALA A 18 20.350 3.529 23.822 1.00 7.45 C +ATOM 125 O ALA A 18 19.716 3.809 24.844 1.00 6.08 O +ATOM 126 CB ALA A 18 22.758 2.808 24.023 1.00 4.60 C +ATOM 127 N LEU A 19 20.203 4.225 22.690 1.00 4.80 N +ATOM 128 CA LEU A 19 19.212 5.310 22.639 1.00 5.71 C +ATOM 129 C LEU A 19 17.795 4.778 22.803 1.00 6.94 C +ATOM 130 O LEU A 19 16.987 5.352 23.549 1.00 5.56 O +ATOM 131 CB LEU A 19 19.324 6.084 21.330 1.00 4.70 C +ATOM 132 CG LEU A 19 20.540 7.001 21.176 1.00 4.62 C +ATOM 133 CD1 LEU A 19 20.582 7.462 19.732 1.00 5.41 C +ATOM 134 CD2 LEU A 19 20.508 8.188 22.158 1.00 5.60 C +ATOM 135 N THR A 20 17.475 3.677 22.118 1.00 4.75 N +ATOM 136 CA THR A 20 16.121 3.122 22.200 1.00 8.86 C +ATOM 137 C THR A 20 15.793 2.692 23.627 1.00 8.97 C +ATOM 138 O THR A 20 14.714 3.009 24.145 1.00 6.87 O +ATOM 139 CB THR A 20 15.958 1.945 21.223 1.00 5.94 C +ATOM 140 CG2 THR A 20 14.580 1.325 21.346 1.00 7.54 C +ATOM 141 OG1 THR A 20 16.136 2.384 19.861 1.00 6.69 O +ATOM 142 N ILE A 21 16.729 2.004 24.290 1.00 7.80 N +ATOM 143 CA ILE A 21 16.520 1.557 25.670 1.00 8.44 C +ATOM 144 C ILE A 21 16.484 2.736 26.656 1.00 8.18 C +ATOM 145 O ILE A 21 15.774 2.685 27.669 1.00 6.58 O +ATOM 146 CB ILE A 21 17.597 0.517 26.052 1.00 7.97 C +ATOM 147 CG1 ILE A 21 17.456 -0.770 25.231 1.00 9.50 C +ATOM 148 CG2 ILE A 21 17.501 0.142 27.512 1.00 9.04 C +ATOM 149 CD1 ILE A 21 16.030 -1.166 24.915 1.00 18.81 C +ATOM 150 N GLN A 22 17.238 3.806 26.405 1.00 7.38 N +ATOM 151 CA GLN A 22 17.088 4.987 27.256 1.00 9.24 C +ATOM 152 C GLN A 22 15.661 5.507 27.194 1.00 6.92 C +ATOM 153 O GLN A 22 15.054 5.821 28.223 1.00 8.70 O +ATOM 154 CB GLN A 22 18.062 6.086 26.848 1.00 7.59 C +ATOM 155 CG GLN A 22 19.328 6.113 27.653 1.00 11.53 C +ATOM 156 CD GLN A 22 19.183 6.600 29.126 1.00 12.00 C +ATOM 157 NE2 GLN A 22 20.321 6.842 29.729 1.00 11.88 N +ATOM 158 OE1 GLN A 22 18.080 6.768 29.695 1.00 11.93 O +ATOM 159 N LEU A 23 15.100 5.600 25.991 1.00 6.49 N +ATOM 160 CA LEU A 23 13.710 6.030 25.880 1.00 5.84 C +ATOM 161 C LEU A 23 12.761 5.031 26.551 1.00 9.77 C +ATOM 162 O LEU A 23 11.830 5.429 27.265 1.00 7.78 O +ATOM 163 CB LEU A 23 13.339 6.232 24.406 1.00 7.58 C +ATOM 164 CG LEU A 23 11.955 6.884 24.242 1.00 8.31 C +ATOM 165 CD1 LEU A 23 12.020 8.384 24.388 1.00 12.37 C +ATOM 166 CD2 LEU A 23 11.281 6.573 22.984 1.00 14.17 C +ATOM 167 N ILE A 24 12.985 3.731 26.354 1.00 7.30 N +ATOM 168 CA ILE A 24 12.030 2.737 26.840 1.00 6.01 C +ATOM 169 C ILE A 24 12.170 2.534 28.341 1.00 12.00 C +ATOM 170 O ILE A 24 11.176 2.563 29.078 1.00 12.67 O +ATOM 171 CB ILE A 24 12.197 1.412 26.061 1.00 10.63 C +ATOM 172 CG1 ILE A 24 11.837 1.617 24.580 1.00 8.50 C +ATOM 173 CG2 ILE A 24 11.360 0.270 26.703 1.00 9.45 C +ATOM 174 CD1 ILE A 24 10.423 2.015 24.336 1.00 9.23 C +ATOM 175 N GLN A 25 13.403 2.335 28.819 1.00 6.05 N +ATOM 176 CA GLN A 25 13.637 1.986 30.216 1.00 8.13 C +ATOM 177 C GLN A 25 13.986 3.175 31.098 1.00 9.66 C +ATOM 178 O GLN A 25 13.939 3.046 32.328 1.00 10.68 O +ATOM 179 CB GLN A 25 14.790 0.978 30.335 1.00 11.88 C +ATOM 180 CG GLN A 25 14.526 -0.430 29.800 1.00 15.30 C +ATOM 181 CD GLN A 25 15.638 -1.382 30.185 1.00 15.85 C +ATOM 182 NE2 GLN A 25 15.697 -2.532 29.521 1.00 14.63 N +ATOM 183 OE1 GLN A 25 16.449 -1.077 31.067 1.00 13.34 O +ATOM 184 N ASN A 26 14.344 4.314 30.513 1.00 7.71 N +ATOM 185 CA ASN A 26 14.940 5.423 31.264 1.00 9.91 C +ATOM 186 C ASN A 26 16.218 4.970 31.970 1.00 11.50 C +ATOM 187 O ASN A 26 16.468 5.327 33.119 1.00 12.67 O +ATOM 188 CB ASN A 26 13.950 6.031 32.262 1.00 8.62 C +ATOM 189 CG ASN A 26 14.241 7.499 32.557 1.00 17.51 C +ATOM 190 ND2 ASN A 26 14.192 7.865 33.835 1.00 18.42 N +ATOM 191 OD1 ASN A 26 14.492 8.297 31.640 1.00 12.48 O +ATOM 192 N HIS A 27 17.028 4.160 31.278 1.00 9.27 N +ATOM 193 CA HIS A 27 18.233 3.578 31.859 1.00 10.75 C +ATOM 194 C HIS A 27 19.264 3.356 30.753 1.00 8.66 C +ATOM 195 O HIS A 27 18.915 2.921 29.645 1.00 8.70 O +ATOM 196 CB HIS A 27 17.886 2.268 32.597 1.00 11.90 C +ATOM 197 CG HIS A 27 19.051 1.574 33.242 1.00 10.84 C +ATOM 198 CD2 HIS A 27 19.682 0.415 32.928 1.00 17.11 C +ATOM 199 ND1 HIS A 27 19.625 2.014 34.417 1.00 18.93 N +ATOM 200 CE1 HIS A 27 20.591 1.183 34.773 1.00 16.75 C +ATOM 201 NE2 HIS A 27 20.644 0.202 33.889 1.00 17.00 N +ATOM 202 N PHE A 28 20.530 3.669 31.053 1.00 10.26 N +ATOM 203 CA PHE A 28 21.613 3.469 30.093 1.00 9.35 C +ATOM 204 C PHE A 28 22.065 2.015 30.127 1.00 9.15 C +ATOM 205 O PHE A 28 22.446 1.495 31.184 1.00 10.93 O +ATOM 206 CB PHE A 28 22.801 4.387 30.378 1.00 11.03 C +ATOM 207 CG PHE A 28 23.996 4.126 29.484 1.00 8.50 C +ATOM 208 CD1 PHE A 28 23.833 4.044 28.111 1.00 9.07 C +ATOM 209 CD2 PHE A 28 25.271 3.970 30.013 1.00 9.50 C +ATOM 210 CE1 PHE A 28 24.922 3.816 27.274 1.00 7.04 C +ATOM 211 CE2 PHE A 28 26.364 3.735 29.182 1.00 10.43 C +ATOM 212 CZ PHE A 28 26.188 3.661 27.814 1.00 9.62 C +ATOM 213 N VAL A 29 22.039 1.368 28.973 1.00 9.85 N +ATOM 214 CA VAL A 29 22.510 -0.001 28.829 1.00 10.04 C +ATOM 215 C VAL A 29 23.726 0.053 27.912 1.00 8.89 C +ATOM 216 O VAL A 29 23.600 0.346 26.713 1.00 10.94 O +ATOM 217 CB VAL A 29 21.399 -0.915 28.292 1.00 10.78 C +ATOM 218 CG1 VAL A 29 21.936 -2.239 27.851 1.00 17.26 C +ATOM 219 CG2 VAL A 29 20.338 -1.114 29.372 1.00 13.17 C +ATOM 220 N ASP A 30 24.916 -0.165 28.486 1.00 12.22 N +ATOM 221 CA ASP A 30 26.159 -0.104 27.721 1.00 11.11 C +ATOM 222 C ASP A 30 26.428 -1.375 26.934 1.00 11.02 C +ATOM 223 O ASP A 30 27.363 -1.403 26.120 1.00 11.94 O +ATOM 224 CB ASP A 30 27.345 0.219 28.649 1.00 10.50 C +ATOM 225 CG ASP A 30 27.693 -0.910 29.653 1.00 13.55 C +ATOM 226 OD1 ASP A 30 27.144 -2.032 29.616 1.00 12.75 O +ATOM 227 OD2 ASP A 30 28.532 -0.627 30.530 1.00 18.80 O1- +ATOM 228 N GLU A 31 25.613 -2.405 27.137 1.00 9.03 N +ATOM 229 CA GLU A 31 25.803 -3.700 26.505 1.00 10.13 C +ATOM 230 C GLU A 31 24.429 -4.218 26.105 1.00 11.20 C +ATOM 231 O GLU A 31 23.605 -4.541 26.969 1.00 8.66 O +ATOM 232 CB GLU A 31 26.519 -4.671 27.453 1.00 8.97 C +ATOM 233 CG GLU A 31 26.710 -6.055 26.848 1.00 8.96 C +ATOM 234 CD GLU A 31 27.485 -7.000 27.725 1.00 10.49 C +ATOM 235 OE1 GLU A 31 27.680 -6.712 28.916 1.00 8.31 O +ATOM 236 OE2 GLU A 31 27.919 -8.042 27.206 1.00 12.71 O1- +ATOM 237 N TYR A 32 24.168 -4.278 24.803 1.00 7.30 N +ATOM 238 CA TYR A 32 22.937 -4.866 24.285 1.00 9.40 C +ATOM 239 C TYR A 32 23.314 -5.643 23.036 1.00 7.30 C +ATOM 240 O TYR A 32 23.912 -5.072 22.120 1.00 8.91 O +ATOM 241 CB TYR A 32 21.897 -3.782 23.977 1.00 8.50 C +ATOM 242 CG TYR A 32 20.453 -4.258 23.856 1.00 10.02 C +ATOM 243 CD1 TYR A 32 19.704 -4.555 24.991 1.00 11.70 C +ATOM 244 CD2 TYR A 32 19.839 -4.386 22.611 1.00 9.10 C +ATOM 245 CE1 TYR A 32 18.398 -4.967 24.906 1.00 13.65 C +ATOM 246 CE2 TYR A 32 18.504 -4.808 22.511 1.00 10.42 C +ATOM 247 CZ TYR A 32 17.797 -5.096 23.674 1.00 15.11 C +ATOM 248 OH TYR A 32 16.480 -5.514 23.622 1.00 19.66 O +ATOM 249 N ASP A 33 23.021 -6.954 23.029 1.00 7.35 N +ATOM 250 CA ASP A 33 23.380 -7.869 21.942 1.00 9.05 C +ATOM 251 C ASP A 33 23.287 -7.136 20.607 1.00 7.12 C +ATOM 252 O ASP A 33 22.193 -6.741 20.193 1.00 7.00 O +ATOM 253 CB ASP A 33 22.451 -9.088 21.997 1.00 7.93 C +ATOM 254 CG ASP A 33 22.763 -10.156 20.944 1.00 12.57 C +ATOM 255 OD1 ASP A 33 23.346 -9.857 19.872 1.00 10.81 O +ATOM 256 OD2 ASP A 33 22.369 -11.322 21.198 1.00 11.67 O1- +ATOM 257 N APRO A 34 24.410 -6.886 19.931 0.64 9.70 N +ATOM 258 N BPRO A 34 24.407 -6.946 19.902 0.36 9.69 N +ATOM 259 CA APRO A 34 24.360 -6.086 18.691 0.64 10.14 C +ATOM 260 CA BPRO A 34 24.380 -6.122 18.682 0.36 10.14 C +ATOM 261 C APRO A 34 23.548 -6.716 17.562 0.64 9.10 C +ATOM 262 C BPRO A 34 23.551 -6.720 17.563 0.36 9.13 C +ATOM 263 O APRO A 34 23.239 -6.015 16.589 0.64 9.36 O +ATOM 264 O BPRO A 34 23.247 -6.009 16.596 0.36 9.38 O +ATOM 265 CB APRO A 34 25.839 -5.950 18.306 0.64 10.13 C +ATOM 266 CB BPRO A 34 25.856 -6.051 18.282 0.36 10.14 C +ATOM 267 CG APRO A 34 26.561 -6.070 19.613 0.64 11.32 C +ATOM 268 CG BPRO A 34 26.412 -7.333 18.773 0.36 12.04 C +ATOM 269 CD APRO A 34 25.795 -7.112 20.380 0.64 10.46 C +ATOM 270 CD BPRO A 34 25.713 -7.600 20.090 0.36 11.11 C +ATOM 271 N THR A 35 23.185 -7.995 17.656 1.00 7.99 N +ATOM 272 CA THR A 35 22.408 -8.647 16.609 1.00 12.76 C +ATOM 273 C THR A 35 20.910 -8.533 16.809 1.00 8.31 C +ATOM 274 O THR A 35 20.158 -8.893 15.895 1.00 9.45 O +ATOM 275 CB THR A 35 22.768 -10.143 16.484 1.00 11.97 C +ATOM 276 CG2 THR A 35 24.276 -10.320 16.337 1.00 12.47 C +ATOM 277 OG1 THR A 35 22.279 -10.881 17.622 1.00 13.05 O +ATOM 278 N ILE A 36 20.445 -8.098 17.980 1.00 9.52 N +ATOM 279 CA ILE A 36 19.013 -8.103 18.260 1.00 9.30 C +ATOM 280 C ILE A 36 18.357 -6.919 17.562 1.00 9.73 C +ATOM 281 O ILE A 36 18.622 -5.757 17.892 1.00 8.67 O +ATOM 282 CB ILE A 36 18.719 -8.063 19.770 1.00 10.51 C +ATOM 283 CG1 ILE A 36 19.203 -9.325 20.477 1.00 9.60 C +ATOM 284 CG2 ILE A 36 17.226 -7.849 20.000 1.00 11.09 C +ATOM 285 CD1 ILE A 36 18.977 -9.285 21.985 1.00 11.22 C +ATOM 286 N GLU A 37 17.457 -7.214 16.637 1.00 10.42 N +ATOM 287 CA GLU A 37 16.590 -6.210 16.042 1.00 12.13 C +ATOM 288 C GLU A 37 15.197 -6.458 16.596 1.00 11.36 C +ATOM 289 O GLU A 37 14.687 -7.580 16.514 1.00 9.48 O +ATOM 290 CB GLU A 37 16.597 -6.310 14.519 1.00 9.32 C +ATOM 291 CG GLU A 37 15.475 -5.549 13.861 1.00 9.19 C +ATOM 292 CD GLU A 37 15.667 -5.467 12.369 1.00 15.01 C +ATOM 293 OE1 GLU A 37 16.599 -4.774 11.913 1.00 11.69 O +ATOM 294 OE2 GLU A 37 14.887 -6.122 11.651 1.00 17.25 O1- +ATOM 295 N ASP A 38 14.593 -5.441 17.203 1.00 8.99 N +ATOM 296 CA ASP A 38 13.272 -5.659 17.779 1.00 10.35 C +ATOM 297 C ASP A 38 12.560 -4.321 17.880 1.00 11.79 C +ATOM 298 O ASP A 38 13.144 -3.262 17.630 1.00 8.04 O +ATOM 299 CB ASP A 38 13.370 -6.350 19.142 1.00 14.46 C +ATOM 300 CG ASP A 38 12.164 -7.224 19.441 1.00 17.59 C +ATOM 301 OD1 ASP A 38 11.193 -7.185 18.659 1.00 17.55 O +ATOM 302 OD2 ASP A 38 12.199 -7.949 20.455 1.00 24.28 O1- +ATOM 303 N SER A 39 11.288 -4.381 18.241 1.00 8.64 N +ATOM 304 CA SER A 39 10.480 -3.178 18.362 1.00 11.67 C +ATOM 305 C SER A 39 9.935 -3.064 19.779 1.00 12.26 C +ATOM 306 O SER A 39 9.668 -4.069 20.450 1.00 13.35 O +ATOM 307 CB SER A 39 9.343 -3.162 17.347 1.00 12.34 C +ATOM 308 OG SER A 39 8.497 -4.279 17.504 1.00 21.55 O +ATOM 309 N TYR A 40 9.793 -1.820 20.237 1.00 8.09 N +ATOM 310 CA TYR A 40 9.482 -1.522 21.631 1.00 8.93 C +ATOM 311 C TYR A 40 8.475 -0.390 21.646 1.00 8.52 C +ATOM 312 O TYR A 40 8.669 0.596 20.937 1.00 10.19 O +ATOM 313 CB TYR A 40 10.748 -1.098 22.401 1.00 8.62 C +ATOM 314 CG TYR A 40 11.903 -2.046 22.197 1.00 11.34 C +ATOM 315 CD1 TYR A 40 12.760 -1.901 21.110 1.00 9.24 C +ATOM 316 CD2 TYR A 40 12.141 -3.089 23.085 1.00 12.57 C +ATOM 317 CE1 TYR A 40 13.814 -2.784 20.906 1.00 11.01 C +ATOM 318 CE2 TYR A 40 13.204 -3.962 22.894 1.00 12.70 C +ATOM 319 CZ TYR A 40 14.033 -3.802 21.800 1.00 12.92 C +ATOM 320 OH TYR A 40 15.097 -4.654 21.587 1.00 14.41 O +ATOM 321 N ARG A 41 7.419 -0.517 22.452 1.00 10.06 N +ATOM 322 CA ARG A 41 6.375 0.500 22.530 1.00 10.68 C +ATOM 323 C ARG A 41 6.432 1.266 23.852 1.00 10.24 C +ATOM 324 O ARG A 41 6.804 0.719 24.897 1.00 13.02 O +ATOM 325 CB ARG A 41 4.975 -0.116 22.357 1.00 12.42 C +ATOM 326 CG ARG A 41 4.706 -0.645 20.943 1.00 12.41 C +ATOM 327 CD ARG A 41 3.210 -0.764 20.628 1.00 15.47 C +ATOM 328 NE ARG A 41 3.002 -1.326 19.296 1.00 9.33 N +ATOM 329 CZ ARG A 41 3.108 -2.617 19.013 1.00 11.72 C +ATOM 330 NH1 ARG A 41 3.321 -3.516 19.962 1.00 13.48 N1+ +ATOM 331 NH2 ARG A 41 2.992 -3.014 17.751 1.00 13.00 N +ATOM 332 N LYS A 42 6.038 2.538 23.785 1.00 11.94 N +ATOM 333 CA LYS A 42 6.004 3.428 24.939 1.00 13.99 C +ATOM 334 C LYS A 42 4.946 4.498 24.755 1.00 9.52 C +ATOM 335 O LYS A 42 4.911 5.170 23.716 1.00 11.68 O +ATOM 336 CB LYS A 42 7.345 4.131 25.170 1.00 17.79 C +ATOM 337 CG LYS A 42 7.294 5.035 26.399 1.00 17.60 C +ATOM 338 CD LYS A 42 8.645 5.181 27.066 1.00 23.59 C +ATOM 339 CE LYS A 42 8.499 5.000 28.573 1.00 25.01 C +ATOM 340 NZ LYS A 42 8.500 3.523 28.888 1.00 25.62 N1+ +ATOM 341 N GLN A 43 4.114 4.682 25.778 1.00 14.11 N +ATOM 342 CA GLN A 43 3.179 5.797 25.800 1.00 13.96 C +ATOM 343 C GLN A 43 3.901 7.019 26.340 1.00 12.30 C +ATOM 344 O GLN A 43 4.534 6.951 27.397 1.00 19.48 O +ATOM 345 CB GLN A 43 1.968 5.477 26.679 1.00 15.07 C +ATOM 346 CG GLN A 43 0.951 4.569 26.029 1.00 23.81 C +ATOM 347 CD GLN A 43 -0.265 4.357 26.909 1.00 24.91 C +ATOM 348 NE2 GLN A 43 -0.740 3.113 26.983 1.00 26.39 N +ATOM 349 OE1 GLN A 43 -0.766 5.301 27.527 1.00 22.84 O +ATOM 350 N VAL A 44 3.829 8.127 25.609 1.00 11.85 N +ATOM 351 CA VAL A 44 4.522 9.348 25.996 1.00 20.14 C +ATOM 352 C VAL A 44 3.623 10.544 25.740 1.00 17.59 C +ATOM 353 O VAL A 44 2.770 10.538 24.844 1.00 20.58 O +ATOM 354 CB VAL A 44 5.859 9.545 25.240 1.00 21.62 C +ATOM 355 CG1 VAL A 44 6.816 8.419 25.531 1.00 18.06 C +ATOM 356 CG2 VAL A 44 5.623 9.718 23.743 1.00 17.68 C +ATOM 357 N VAL A 45 3.846 11.591 26.520 1.00 21.60 N +ATOM 358 CA VAL A 45 3.175 12.866 26.326 1.00 24.33 C +ATOM 359 C VAL A 45 4.126 13.763 25.545 1.00 19.76 C +ATOM 360 O VAL A 45 5.210 14.111 26.026 1.00 25.21 O +ATOM 361 CB VAL A 45 2.756 13.487 27.666 1.00 24.84 C +ATOM 362 CG1 VAL A 45 1.650 14.471 27.461 1.00 32.56 C +ATOM 363 CG2 VAL A 45 2.291 12.386 28.609 1.00 22.02 C +ATOM 364 N ILE A 46 3.732 14.095 24.324 1.00 20.79 N +ATOM 365 CA ILE A 46 4.463 15.013 23.464 1.00 25.02 C +ATOM 366 C ILE A 46 3.533 16.181 23.204 1.00 29.11 C +ATOM 367 O ILE A 46 2.489 16.023 22.553 1.00 29.90 O +ATOM 368 CB ILE A 46 4.914 14.365 22.149 1.00 19.26 C +ATOM 369 CG1 ILE A 46 5.935 13.259 22.425 1.00 20.94 C +ATOM 370 CG2 ILE A 46 5.502 15.409 21.221 1.00 24.47 C +ATOM 371 CD1 ILE A 46 6.123 12.292 21.268 1.00 14.32 C +ATOM 372 N ASP A 47 3.894 17.348 23.733 1.00 30.89 N +ATOM 373 CA ASP A 47 3.094 18.554 23.563 1.00 32.65 C +ATOM 374 C ASP A 47 1.695 18.343 24.146 1.00 26.04 C +ATOM 375 O ASP A 47 0.683 18.692 23.538 1.00 28.75 O +ATOM 376 CB ASP A 47 3.043 18.954 22.087 1.00 29.41 C +ATOM 377 CG ASP A 47 4.406 19.397 21.560 1.00 32.42 C +ATOM 378 OD1 ASP A 47 5.274 19.757 22.384 1.00 40.24 O +ATOM 379 OD2 ASP A 47 4.629 19.361 20.332 1.00 31.12 O1- +ATOM 380 N GLY A 48 1.645 17.738 25.337 1.00 35.98 N +ATOM 381 CA GLY A 48 0.399 17.451 26.021 1.00 39.84 C +ATOM 382 C GLY A 48 -0.394 16.276 25.479 1.00 41.71 C +ATOM 383 O GLY A 48 -1.325 15.813 26.153 1.00 41.88 O +ATOM 384 N GLU A 49 -0.060 15.776 24.292 1.00 37.36 N +ATOM 385 CA GLU A 49 -0.784 14.672 23.672 1.00 38.47 C +ATOM 386 C GLU A 49 -0.187 13.348 24.123 1.00 32.48 C +ATOM 387 O GLU A 49 1.020 13.123 23.983 1.00 30.44 O +ATOM 388 CB GLU A 49 -0.717 14.770 22.147 1.00 33.82 C +ATOM 389 CG GLU A 49 -1.198 16.094 21.568 1.00 42.98 C +ATOM 390 CD GLU A 49 -2.694 16.114 21.309 1.00 53.90 C +ATOM 391 OE1 GLU A 49 -3.155 15.351 20.427 1.00 56.07 O +ATOM 392 OE2 GLU A 49 -3.407 16.893 21.981 1.00 52.20 O1- +ATOM 393 N THR A 50 -1.025 12.470 24.656 1.00 24.36 N +ATOM 394 CA THR A 50 -0.586 11.109 24.907 1.00 23.63 C +ATOM 395 C THR A 50 -0.606 10.364 23.584 1.00 27.57 C +ATOM 396 O THR A 50 -1.653 10.268 22.934 1.00 24.43 O +ATOM 397 CB THR A 50 -1.473 10.418 25.936 1.00 30.22 C +ATOM 398 CG2 THR A 50 -0.887 9.065 26.300 1.00 24.99 C +ATOM 399 OG1 THR A 50 -1.543 11.227 27.115 1.00 39.12 O +ATOM 400 N CYS A 51 0.554 9.881 23.159 1.00 22.18 N +ATOM 401 CA CYS A 51 0.635 9.072 21.958 1.00 21.60 C +ATOM 402 C CYS A 51 1.482 7.844 22.242 1.00 15.85 C +ATOM 403 O CYS A 51 2.176 7.753 23.258 1.00 16.33 O +ATOM 404 CB CYS A 51 1.188 9.861 20.768 1.00 22.20 C +ATOM 405 SG CYS A 51 2.706 10.764 21.118 1.00 26.74 S +ATOM 406 N LEU A 52 1.375 6.879 21.339 1.00 12.60 N +ATOM 407 CA LEU A 52 1.990 5.573 21.473 1.00 9.53 C +ATOM 408 C LEU A 52 3.148 5.521 20.495 1.00 8.39 C +ATOM 409 O LEU A 52 2.940 5.626 19.286 1.00 11.15 O +ATOM 410 CB LEU A 52 0.970 4.466 21.180 1.00 11.35 C +ATOM 411 CG LEU A 52 1.442 3.020 21.333 1.00 13.29 C +ATOM 412 CD1 LEU A 52 1.950 2.772 22.743 1.00 10.86 C +ATOM 413 CD2 LEU A 52 0.314 2.034 20.957 1.00 14.40 C +ATOM 414 N LEU A 53 4.366 5.416 21.012 1.00 8.96 N +ATOM 415 CA LEU A 53 5.521 5.252 20.144 1.00 8.65 C +ATOM 416 C LEU A 53 5.797 3.768 19.946 1.00 10.07 C +ATOM 417 O LEU A 53 5.740 2.983 20.895 1.00 11.09 O +ATOM 418 CB LEU A 53 6.757 5.944 20.725 1.00 12.84 C +ATOM 419 CG LEU A 53 6.624 7.415 21.135 1.00 11.26 C +ATOM 420 CD1 LEU A 53 8.029 7.997 21.429 1.00 10.00 C +ATOM 421 CD2 LEU A 53 5.863 8.243 20.086 1.00 10.91 C +ATOM 422 N ASP A 54 6.113 3.395 18.711 1.00 8.58 N +ATOM 423 CA ASP A 54 6.496 2.033 18.364 1.00 7.14 C +ATOM 424 C ASP A 54 7.866 2.150 17.709 1.00 9.81 C +ATOM 425 O ASP A 54 7.963 2.570 16.556 1.00 9.93 O +ATOM 426 CB ASP A 54 5.469 1.397 17.425 1.00 11.17 C +ATOM 427 CG ASP A 54 5.769 -0.057 17.119 1.00 15.05 C +ATOM 428 OD1 ASP A 54 6.667 -0.640 17.768 1.00 14.33 O +ATOM 429 OD2 ASP A 54 5.099 -0.623 16.224 1.00 15.73 O1- +ATOM 430 N ILE A 55 8.927 1.832 18.440 1.00 8.28 N +ATOM 431 CA ILE A 55 10.292 2.089 17.982 1.00 7.94 C +ATOM 432 C ILE A 55 10.916 0.796 17.492 1.00 9.09 C +ATOM 433 O ILE A 55 11.072 -0.158 18.265 1.00 7.36 O +ATOM 434 CB ILE A 55 11.142 2.691 19.103 1.00 8.69 C +ATOM 435 CG1 ILE A 55 10.513 3.989 19.567 1.00 13.07 C +ATOM 436 CG2 ILE A 55 12.575 2.928 18.642 1.00 6.87 C +ATOM 437 CD1 ILE A 55 11.000 4.284 20.857 1.00 18.24 C +ATOM 438 N LEU A 56 11.305 0.766 16.215 1.00 7.44 N +ATOM 439 CA LEU A 56 12.026 -0.374 15.660 1.00 8.91 C +ATOM 440 C LEU A 56 13.515 -0.062 15.745 1.00 9.27 C +ATOM 441 O LEU A 56 14.010 0.880 15.107 1.00 6.69 O +ATOM 442 CB LEU A 56 11.595 -0.671 14.225 1.00 7.90 C +ATOM 443 CG LEU A 56 12.488 -1.670 13.458 1.00 7.51 C +ATOM 444 CD1 LEU A 56 12.498 -3.045 14.135 1.00 10.21 C +ATOM 445 CD2 LEU A 56 12.039 -1.804 11.996 1.00 10.98 C +ATOM 446 N ASP A 57 14.210 -0.829 16.565 1.00 6.07 N +ATOM 447 CA ASP A 57 15.653 -0.721 16.738 1.00 5.68 C +ATOM 448 C ASP A 57 16.282 -1.700 15.748 1.00 8.83 C +ATOM 449 O ASP A 57 16.261 -2.919 15.969 1.00 8.43 O +ATOM 450 CB ASP A 57 16.004 -1.044 18.188 1.00 7.10 C +ATOM 451 CG ASP A 57 17.500 -1.116 18.449 1.00 5.99 C +ATOM 452 OD1 ASP A 57 18.274 -0.466 17.719 1.00 8.86 O +ATOM 453 OD2 ASP A 57 17.904 -1.822 19.416 1.00 5.70 O1- +ATOM 454 N THR A 58 16.820 -1.183 14.639 1.00 7.22 N +ATOM 455 CA THR A 58 17.320 -2.058 13.585 1.00 8.38 C +ATOM 456 C THR A 58 18.715 -2.578 13.918 1.00 9.27 C +ATOM 457 O THR A 58 19.441 -2.017 14.737 1.00 7.35 O +ATOM 458 CB THR A 58 17.328 -1.353 12.227 1.00 6.62 C +ATOM 459 CG2 THR A 58 15.924 -0.804 11.882 1.00 8.02 C +ATOM 460 OG1 THR A 58 18.316 -0.305 12.213 1.00 8.30 O +ATOM 461 N ALA A 59 19.072 -3.685 13.272 1.00 8.84 N +ATOM 462 CA ALA A 59 20.337 -4.364 13.519 1.00 10.10 C +ATOM 463 C ALA A 59 20.616 -5.301 12.361 1.00 14.74 C +ATOM 464 O ALA A 59 19.721 -5.628 11.575 1.00 14.77 O +ATOM 465 CB ALA A 59 20.309 -5.155 14.822 1.00 7.11 C +ATOM 466 N GLY A 60 21.864 -5.736 12.268 1.00 10.95 N +ATOM 467 CA GLY A 60 22.208 -6.787 11.326 1.00 17.05 C +ATOM 468 C GLY A 60 22.771 -6.269 10.019 1.00 17.06 C +ATOM 469 O GLY A 60 23.035 -5.080 9.826 1.00 13.67 O +ATOM 470 N GLN A 61 22.953 -7.204 9.095 1.00 16.30 N +ATOM 471 CA GLN A 61 23.602 -6.880 7.833 1.00 17.71 C +ATOM 472 C GLN A 61 22.605 -6.315 6.822 1.00 17.69 C +ATOM 473 O GLN A 61 21.381 -6.437 6.971 1.00 17.37 O +ATOM 474 CB GLN A 61 24.303 -8.114 7.259 1.00 23.34 C +ATOM 475 CG GLN A 61 23.382 -9.298 6.994 1.00 27.05 C +ATOM 476 CD GLN A 61 24.148 -10.557 6.568 1.00 48.43 C +ATOM 477 NE2 GLN A 61 23.695 -11.192 5.488 1.00 42.78 N +ATOM 478 OE1 GLN A 61 25.121 -10.961 7.217 1.00 47.48 O +ATOM 479 N GLU A 62 23.159 -5.686 5.781 1.00 11.07 N +ATOM 480 CA GLU A 62 22.366 -5.242 4.641 1.00 13.41 C +ATOM 481 C GLU A 62 21.525 -6.388 4.098 1.00 15.73 C +ATOM 482 O GLU A 62 22.012 -7.509 3.944 1.00 15.39 O +ATOM 483 CB GLU A 62 23.278 -4.749 3.519 1.00 18.00 C +ATOM 484 CG GLU A 62 23.774 -3.357 3.667 1.00 21.50 C +ATOM 485 CD GLU A 62 24.854 -3.041 2.656 1.00 21.93 C +ATOM 486 OE1 GLU A 62 25.930 -3.679 2.711 1.00 32.49 O +ATOM 487 OE2 GLU A 62 24.627 -2.145 1.818 1.00 20.78 O1- +ATOM 488 N GLU A 63 20.257 -6.104 3.812 1.00 15.82 N +ATOM 489 CA GLU A 63 19.428 -7.013 3.021 1.00 15.70 C +ATOM 490 C GLU A 63 18.455 -6.176 2.207 1.00 18.43 C +ATOM 491 O GLU A 63 17.627 -5.456 2.780 1.00 16.83 O +ATOM 492 CB GLU A 63 18.653 -8.007 3.888 1.00 17.01 C +ATOM 493 CG GLU A 63 17.825 -8.959 3.030 1.00 21.96 C +ATOM 494 CD GLU A 63 17.064 -9.980 3.843 1.00 32.94 C +ATOM 495 OE1 GLU A 63 17.706 -10.698 4.637 1.00 40.57 O +ATOM 496 OE2 GLU A 63 15.828 -10.091 3.662 1.00 32.75 O1- +ATOM 497 N TYR A 64 18.540 -6.271 0.879 1.00 12.96 N +ATOM 498 CA TYR A 64 17.643 -5.496 0.025 1.00 9.98 C +ATOM 499 C TYR A 64 16.624 -6.409 -0.656 1.00 11.78 C +ATOM 500 O TYR A 64 16.473 -6.400 -1.877 1.00 16.53 O +ATOM 501 CB TYR A 64 18.461 -4.684 -0.986 1.00 11.98 C +ATOM 502 CG TYR A 64 19.472 -3.788 -0.287 1.00 11.84 C +ATOM 503 CD1 TYR A 64 19.057 -2.696 0.474 1.00 13.28 C +ATOM 504 CD2 TYR A 64 20.838 -4.076 -0.333 1.00 13.91 C +ATOM 505 CE1 TYR A 64 19.981 -1.893 1.147 1.00 8.25 C +ATOM 506 CE2 TYR A 64 21.763 -3.287 0.336 1.00 12.78 C +ATOM 507 CZ TYR A 64 21.330 -2.200 1.070 1.00 11.19 C +ATOM 508 OH TYR A 64 22.259 -1.426 1.719 1.00 14.39 O +ATOM 509 N SER A 65 15.929 -7.209 0.144 1.00 17.74 N +ATOM 510 CA SER A 65 14.835 -8.055 -0.306 1.00 18.55 C +ATOM 511 C SER A 65 13.521 -7.280 -0.309 1.00 17.66 C +ATOM 512 O SER A 65 13.381 -6.238 0.340 1.00 14.67 O +ATOM 513 CB SER A 65 14.706 -9.268 0.608 1.00 21.23 C +ATOM 514 OG SER A 65 14.363 -8.850 1.922 1.00 18.62 O +ATOM 515 N ALA A 66 12.544 -7.810 -1.056 1.00 13.92 N +ATOM 516 CA ALA A 66 11.199 -7.239 -1.021 1.00 14.38 C +ATOM 517 C ALA A 66 10.642 -7.253 0.395 1.00 14.58 C +ATOM 518 O ALA A 66 9.985 -6.297 0.825 1.00 19.99 O +ATOM 519 CB ALA A 66 10.271 -8.015 -1.955 1.00 16.02 C +ATOM 520 N MET A 67 10.897 -8.337 1.129 1.00 13.20 N +ATOM 521 CA MET A 67 10.402 -8.474 2.498 1.00 20.68 C +ATOM 522 C MET A 67 10.954 -7.389 3.414 1.00 18.55 C +ATOM 523 O MET A 67 10.212 -6.785 4.197 1.00 19.08 O +ATOM 524 CB MET A 67 10.784 -9.850 3.043 1.00 14.90 C +ATOM 525 CG MET A 67 10.487 -10.032 4.513 1.00 35.29 C +ATOM 526 SD MET A 67 11.173 -11.586 5.111 1.00 61.90 S +ATOM 527 CE MET A 67 12.874 -11.096 5.440 1.00 38.88 C +ATOM 528 N ARG A 68 12.267 -7.170 3.378 1.00 15.94 N +ATOM 529 CA ARG A 68 12.855 -6.182 4.276 1.00 15.63 C +ATOM 530 C ARG A 68 12.429 -4.773 3.891 1.00 15.56 C +ATOM 531 O ARG A 68 12.195 -3.931 4.769 1.00 15.12 O +ATOM 532 CB ARG A 68 14.380 -6.299 4.288 1.00 16.74 C +ATOM 533 CG ARG A 68 15.033 -5.451 5.398 1.00 18.39 C +ATOM 534 CD ARG A 68 16.291 -6.095 5.938 1.00 23.53 C +ATOM 535 NE ARG A 68 16.944 -5.317 6.990 1.00 24.97 N +ATOM 536 CZ ARG A 68 16.550 -5.243 8.258 1.00 24.44 C +ATOM 537 NH1 ARG A 68 15.467 -5.873 8.692 1.00 17.08 N1+ +ATOM 538 NH2 ARG A 68 17.270 -4.527 9.117 1.00 17.50 N +ATOM 539 N ASP A 69 12.338 -4.505 2.582 1.00 11.29 N +ATOM 540 CA ASP A 69 11.819 -3.230 2.092 1.00 14.17 C +ATOM 541 C ASP A 69 10.410 -2.985 2.620 1.00 16.08 C +ATOM 542 O ASP A 69 10.109 -1.922 3.177 1.00 18.79 O +ATOM 543 CB ASP A 69 11.806 -3.221 0.559 1.00 15.30 C +ATOM 544 CG ASP A 69 13.176 -2.946 -0.059 1.00 13.51 C +ATOM 545 OD1 ASP A 69 14.165 -2.705 0.671 1.00 14.20 O +ATOM 546 OD2 ASP A 69 13.253 -2.948 -1.307 1.00 11.55 O1- +ATOM 547 N GLN A 70 9.525 -3.962 2.430 1.00 18.57 N +ATOM 548 CA GLN A 70 8.173 -3.868 2.969 1.00 21.34 C +ATOM 549 C GLN A 70 8.196 -3.547 4.457 1.00 20.24 C +ATOM 550 O GLN A 70 7.455 -2.677 4.930 1.00 21.09 O +ATOM 551 CB GLN A 70 7.432 -5.182 2.706 1.00 24.62 C +ATOM 552 CG GLN A 70 5.921 -5.073 2.747 1.00 34.56 C +ATOM 553 CD GLN A 70 5.386 -3.942 1.876 1.00 41.51 C +ATOM 554 NE2 GLN A 70 5.338 -4.172 0.566 1.00 45.72 N +ATOM 555 OE1 GLN A 70 5.021 -2.875 2.377 1.00 37.10 O +ATOM 556 N TYR A 71 9.074 -4.212 5.199 1.00 15.73 N +ATOM 557 CA TYR A 71 9.156 -4.005 6.639 1.00 20.47 C +ATOM 558 C TYR A 71 9.590 -2.583 6.976 1.00 19.46 C +ATOM 559 O TYR A 71 8.975 -1.912 7.822 1.00 18.40 O +ATOM 560 CB TYR A 71 10.121 -5.023 7.237 1.00 22.94 C +ATOM 561 CG TYR A 71 10.122 -5.031 8.735 1.00 25.33 C +ATOM 562 CD1 TYR A 71 8.923 -5.000 9.440 1.00 24.63 C +ATOM 563 CD2 TYR A 71 11.307 -5.116 9.450 1.00 22.49 C +ATOM 564 CE1 TYR A 71 8.902 -5.028 10.814 1.00 26.94 C +ATOM 565 CE2 TYR A 71 11.297 -5.153 10.842 1.00 24.23 C +ATOM 566 CZ TYR A 71 10.086 -5.106 11.511 1.00 21.00 C +ATOM 567 OH TYR A 71 10.050 -5.137 12.886 1.00 35.87 O +ATOM 568 N MET A 72 10.659 -2.103 6.343 1.00 12.79 N +ATOM 569 CA MET A 72 11.074 -0.737 6.647 1.00 15.82 C +ATOM 570 C MET A 72 10.079 0.293 6.150 1.00 15.00 C +ATOM 571 O MET A 72 9.996 1.384 6.728 1.00 15.55 O +ATOM 572 CB MET A 72 12.475 -0.454 6.096 1.00 13.44 C +ATOM 573 CG MET A 72 13.445 -1.510 6.529 1.00 12.70 C +ATOM 574 SD MET A 72 13.875 -1.004 8.219 1.00 15.58 S +ATOM 575 CE MET A 72 14.603 -2.492 8.859 1.00 17.29 C +ATOM 576 N ARG A 73 9.298 -0.053 5.129 1.00 14.74 N +ATOM 577 CA ARG A 73 8.253 0.818 4.621 1.00 13.42 C +ATOM 578 C ARG A 73 7.177 1.109 5.669 1.00 15.26 C +ATOM 579 O ARG A 73 6.474 2.118 5.552 1.00 17.30 O +ATOM 580 CB ARG A 73 7.627 0.159 3.391 1.00 20.60 C +ATOM 581 CG ARG A 73 6.825 1.055 2.475 1.00 33.17 C +ATOM 582 CD ARG A 73 6.582 0.350 1.125 1.00 24.67 C +ATOM 583 NE ARG A 73 7.763 0.278 0.273 1.00 23.63 N +ATOM 584 CZ ARG A 73 8.243 -0.835 -0.272 1.00 29.37 C +ATOM 585 NH1 ARG A 73 7.704 -2.022 -0.021 1.00 25.11 N1+ +ATOM 586 NH2 ARG A 73 9.287 -0.756 -1.096 1.00 26.95 N +ATOM 587 N THR A 74 7.018 0.245 6.677 1.00 15.54 N +ATOM 588 CA THR A 74 6.055 0.531 7.738 1.00 15.22 C +ATOM 589 C THR A 74 6.482 1.721 8.579 1.00 16.23 C +ATOM 590 O THR A 74 5.649 2.296 9.288 1.00 15.74 O +ATOM 591 CB THR A 74 5.855 -0.670 8.670 1.00 15.89 C +ATOM 592 CG2 THR A 74 5.404 -1.908 7.891 1.00 22.44 C +ATOM 593 OG1 THR A 74 7.071 -0.955 9.379 1.00 18.88 O +ATOM 594 N GLY A 75 7.755 2.102 8.523 1.00 13.07 N +ATOM 595 CA GLY A 75 8.222 3.208 9.345 1.00 12.01 C +ATOM 596 C GLY A 75 7.659 4.523 8.843 1.00 11.69 C +ATOM 597 O GLY A 75 7.668 4.795 7.635 1.00 11.47 O +ATOM 598 N GLU A 76 7.148 5.342 9.769 1.00 8.45 N +ATOM 599 CA GLU A 76 6.646 6.667 9.431 1.00 10.26 C +ATOM 600 C GLU A 76 7.718 7.736 9.526 1.00 13.45 C +ATOM 601 O GLU A 76 7.647 8.743 8.810 1.00 12.63 O +ATOM 602 CB GLU A 76 5.485 7.036 10.350 1.00 12.12 C +ATOM 603 CG GLU A 76 4.290 6.117 10.205 1.00 10.63 C +ATOM 604 CD GLU A 76 3.200 6.456 11.202 1.00 17.07 C +ATOM 605 OE1 GLU A 76 3.372 6.125 12.385 1.00 12.88 O +ATOM 606 OE2 GLU A 76 2.190 7.085 10.810 1.00 24.77 O1- +ATOM 607 N GLY A 77 8.698 7.544 10.402 1.00 9.91 N +ATOM 608 CA GLY A 77 9.822 8.455 10.510 1.00 10.31 C +ATOM 609 C GLY A 77 11.050 7.679 10.928 1.00 10.88 C +ATOM 610 O GLY A 77 10.950 6.566 11.457 1.00 8.66 O +ATOM 611 N PHE A 78 12.222 8.274 10.677 1.00 8.57 N +ATOM 612 CA PHE A 78 13.493 7.584 10.835 1.00 8.21 C +ATOM 613 C PHE A 78 14.498 8.454 11.573 1.00 5.70 C +ATOM 614 O PHE A 78 14.648 9.642 11.265 1.00 7.98 O +ATOM 615 CB PHE A 78 14.065 7.204 9.467 1.00 7.00 C +ATOM 616 CG PHE A 78 13.228 6.192 8.729 1.00 8.24 C +ATOM 617 CD1 PHE A 78 12.057 6.570 8.092 1.00 8.74 C +ATOM 618 CD2 PHE A 78 13.608 4.862 8.701 1.00 9.14 C +ATOM 619 CE1 PHE A 78 11.268 5.618 7.441 1.00 9.96 C +ATOM 620 CE2 PHE A 78 12.834 3.911 8.055 1.00 9.77 C +ATOM 621 CZ PHE A 78 11.671 4.293 7.410 1.00 9.67 C +ATOM 622 N LEU A 79 15.212 7.855 12.527 1.00 4.41 N +ATOM 623 CA LEU A 79 16.403 8.466 13.106 1.00 7.09 C +ATOM 624 C LEU A 79 17.620 7.903 12.387 1.00 6.20 C +ATOM 625 O LEU A 79 17.837 6.685 12.387 1.00 7.19 O +ATOM 626 CB LEU A 79 16.536 8.180 14.603 1.00 7.22 C +ATOM 627 CG LEU A 79 15.550 8.758 15.606 1.00 10.80 C +ATOM 628 CD1 LEU A 79 16.044 8.428 17.024 1.00 6.13 C +ATOM 629 CD2 LEU A 79 15.339 10.260 15.393 1.00 9.44 C +ATOM 630 N CYS A 80 18.411 8.772 11.779 1.00 6.80 N +ATOM 631 CA CYS A 80 19.664 8.353 11.161 1.00 6.70 C +ATOM 632 C CYS A 80 20.784 8.709 12.129 1.00 6.00 C +ATOM 633 O CYS A 80 21.107 9.891 12.304 1.00 8.06 O +ATOM 634 CB CYS A 80 19.853 9.009 9.800 1.00 9.61 C +ATOM 635 SG CYS A 80 18.591 8.475 8.628 1.00 10.98 S +ATOM 636 N VAL A 81 21.345 7.687 12.773 1.00 4.93 N +ATOM 637 CA VAL A 81 22.258 7.846 13.897 1.00 6.42 C +ATOM 638 C VAL A 81 23.680 7.585 13.430 1.00 7.70 C +ATOM 639 O VAL A 81 23.953 6.573 12.768 1.00 5.76 O +ATOM 640 CB VAL A 81 21.898 6.902 15.059 1.00 5.33 C +ATOM 641 CG1 VAL A 81 22.701 7.258 16.293 1.00 5.37 C +ATOM 642 CG2 VAL A 81 20.409 6.953 15.383 1.00 5.27 C +ATOM 643 N PHE A 82 24.589 8.490 13.787 1.00 6.43 N +ATOM 644 CA PHE A 82 26.024 8.251 13.700 1.00 6.33 C +ATOM 645 C PHE A 82 26.632 8.542 15.071 1.00 6.72 C +ATOM 646 O PHE A 82 25.963 9.053 15.968 1.00 7.41 O +ATOM 647 CB PHE A 82 26.664 9.092 12.591 1.00 7.10 C +ATOM 648 CG PHE A 82 26.702 10.548 12.895 1.00 5.24 C +ATOM 649 CD1 PHE A 82 25.612 11.365 12.616 1.00 7.28 C +ATOM 650 CD2 PHE A 82 27.827 11.105 13.498 1.00 6.31 C +ATOM 651 CE1 PHE A 82 25.653 12.719 12.919 1.00 8.15 C +ATOM 652 CE2 PHE A 82 27.878 12.445 13.813 1.00 8.28 C +ATOM 653 CZ PHE A 82 26.796 13.262 13.522 1.00 6.12 C +ATOM 654 N ALA A 83 27.911 8.220 15.251 1.00 7.06 N +ATOM 655 CA ALA A 83 28.603 8.498 16.501 1.00 6.27 C +ATOM 656 C ALA A 83 29.643 9.579 16.255 1.00 8.37 C +ATOM 657 O ALA A 83 30.361 9.542 15.245 1.00 9.59 O +ATOM 658 CB ALA A 83 29.251 7.229 17.076 1.00 6.50 C +ATOM 659 N ILE A 84 29.696 10.565 17.153 1.00 7.41 N +ATOM 660 CA ILE A 84 30.532 11.737 16.906 1.00 9.74 C +ATOM 661 C ILE A 84 32.018 11.423 16.941 1.00 9.78 C +ATOM 662 O ILE A 84 32.819 12.277 16.546 1.00 11.80 O +ATOM 663 CB ILE A 84 30.243 12.873 17.908 1.00 10.58 C +ATOM 664 CG1 ILE A 84 30.577 12.435 19.341 1.00 9.22 C +ATOM 665 CG2 ILE A 84 28.814 13.386 17.761 1.00 9.51 C +ATOM 666 CD1 ILE A 84 31.140 13.555 20.180 1.00 18.71 C +ATOM 667 N ASN A 85 32.406 10.234 17.415 1.00 7.21 N +ATOM 668 CA ASN A 85 33.809 9.820 17.454 1.00 12.24 C +ATOM 669 C ASN A 85 34.121 8.747 16.418 1.00 13.59 C +ATOM 670 O ASN A 85 35.143 8.057 16.526 1.00 16.01 O +ATOM 671 CB ASN A 85 34.187 9.319 18.852 1.00 10.60 C +ATOM 672 CG ASN A 85 33.499 8.023 19.216 1.00 15.19 C +ATOM 673 ND2 ASN A 85 34.085 7.281 20.165 1.00 13.94 N +ATOM 674 OD1 ASN A 85 32.460 7.675 18.644 1.00 11.03 O +ATOM 675 N ASN A 86 33.258 8.583 15.419 1.00 11.20 N +ATOM 676 CA ASN A 86 33.455 7.573 14.382 1.00 12.37 C +ATOM 677 C ASN A 86 33.117 8.190 13.028 1.00 10.66 C +ATOM 678 O ASN A 86 31.954 8.226 12.617 1.00 9.66 O +ATOM 679 CB ASN A 86 32.612 6.338 14.648 1.00 13.68 C +ATOM 680 CG ASN A 86 32.963 5.202 13.729 1.00 19.40 C +ATOM 681 ND2 ASN A 86 33.765 4.255 14.219 1.00 21.62 N +ATOM 682 OD1 ASN A 86 32.523 5.169 12.588 1.00 14.80 O +ATOM 683 N THR A 87 34.145 8.652 12.324 1.00 11.90 N +ATOM 684 CA THR A 87 33.925 9.291 11.032 1.00 11.11 C +ATOM 685 C THR A 87 33.205 8.368 10.049 1.00 10.77 C +ATOM 686 O THR A 87 32.339 8.819 9.287 1.00 10.46 O +ATOM 687 CB THR A 87 35.267 9.774 10.477 1.00 14.23 C +ATOM 688 CG2 THR A 87 35.094 10.431 9.111 1.00 13.55 C +ATOM 689 OG1 THR A 87 35.810 10.725 11.407 1.00 13.30 O +ATOM 690 N LYS A 88 33.523 7.074 10.050 1.00 12.70 N +ATOM 691 CA LYS A 88 32.915 6.210 9.043 1.00 12.90 C +ATOM 692 C LYS A 88 31.406 6.123 9.246 1.00 12.66 C +ATOM 693 O LYS A 88 30.643 6.106 8.271 1.00 11.55 O +ATOM 694 CB LYS A 88 33.543 4.817 9.063 1.00 15.27 C +ATOM 695 CG LYS A 88 32.894 3.842 8.085 1.00 22.22 C +ATOM 696 CD LYS A 88 33.020 2.377 8.487 1.00 27.89 C +ATOM 697 CE LYS A 88 31.779 1.556 8.084 1.00 33.80 C +ATOM 698 NZ LYS A 88 30.534 1.922 8.821 1.00 26.11 N1+ +ATOM 699 N SER A 89 30.952 6.080 10.505 1.00 11.53 N +ATOM 700 CA SER A 89 29.512 6.031 10.749 1.00 10.11 C +ATOM 701 C SER A 89 28.828 7.291 10.230 1.00 8.80 C +ATOM 702 O SER A 89 27.694 7.224 9.734 1.00 7.88 O +ATOM 703 CB SER A 89 29.236 5.822 12.252 1.00 9.28 C +ATOM 704 OG SER A 89 29.379 7.019 13.021 1.00 8.05 O +ATOM 705 N PHE A 90 29.513 8.440 10.284 1.00 5.94 N +ATOM 706 CA PHE A 90 28.933 9.661 9.720 1.00 9.00 C +ATOM 707 C PHE A 90 28.921 9.598 8.196 1.00 10.90 C +ATOM 708 O PHE A 90 27.931 9.977 7.556 1.00 10.87 O +ATOM 709 CB PHE A 90 29.714 10.886 10.221 1.00 10.86 C +ATOM 710 CG PHE A 90 29.246 12.198 9.654 1.00 12.67 C +ATOM 711 CD1 PHE A 90 28.095 12.795 10.134 1.00 7.81 C +ATOM 712 CD2 PHE A 90 29.991 12.868 8.687 1.00 13.74 C +ATOM 713 CE1 PHE A 90 27.662 14.019 9.635 1.00 11.69 C +ATOM 714 CE2 PHE A 90 29.569 14.093 8.190 1.00 12.05 C +ATOM 715 CZ PHE A 90 28.404 14.671 8.667 1.00 11.50 C +ATOM 716 N GLU A 91 30.008 9.111 7.598 1.00 8.43 N +ATOM 717 CA GLU A 91 30.084 8.991 6.149 1.00 13.33 C +ATOM 718 C GLU A 91 29.126 7.939 5.601 1.00 12.90 C +ATOM 719 O GLU A 91 28.866 7.939 4.396 1.00 12.72 O +ATOM 720 CB GLU A 91 31.534 8.698 5.744 1.00 13.51 C +ATOM 721 CG GLU A 91 32.419 9.943 5.915 1.00 17.04 C +ATOM 722 CD GLU A 91 33.883 9.736 5.560 1.00 25.02 C +ATOM 723 OE1 GLU A 91 34.334 8.575 5.443 1.00 22.95 O +ATOM 724 OE2 GLU A 91 34.592 10.758 5.424 1.00 23.06 O1- +ATOM 725 N ASP A 92 28.560 7.080 6.459 1.00 13.13 N +ATOM 726 CA ASP A 92 27.535 6.123 6.040 1.00 12.50 C +ATOM 727 C ASP A 92 26.138 6.729 5.971 1.00 11.92 C +ATOM 728 O ASP A 92 25.249 6.122 5.358 1.00 11.00 O +ATOM 729 CB ASP A 92 27.466 4.936 7.007 1.00 10.47 C +ATOM 730 CG ASP A 92 28.551 3.893 6.769 1.00 16.00 C +ATOM 731 OD1 ASP A 92 29.251 3.950 5.730 1.00 14.29 O +ATOM 732 OD2 ASP A 92 28.672 2.993 7.637 1.00 16.04 O1- +ATOM 733 N ILE A 93 25.917 7.887 6.603 1.00 9.11 N +ATOM 734 CA ILE A 93 24.563 8.414 6.763 1.00 10.75 C +ATOM 735 C ILE A 93 23.867 8.581 5.415 1.00 9.39 C +ATOM 736 O ILE A 93 22.675 8.279 5.278 1.00 8.84 O +ATOM 737 CB ILE A 93 24.589 9.747 7.540 1.00 6.83 C +ATOM 738 CG1 ILE A 93 24.939 9.528 9.022 1.00 8.64 C +ATOM 739 CG2 ILE A 93 23.252 10.483 7.369 1.00 8.75 C +ATOM 740 CD1 ILE A 93 23.834 8.860 9.836 1.00 6.96 C +ATOM 741 N HIS A 94 24.582 9.072 4.397 1.00 10.88 N +ATOM 742 CA HIS A 94 23.912 9.302 3.117 1.00 10.92 C +ATOM 743 C HIS A 94 23.343 8.008 2.545 1.00 11.55 C +ATOM 744 O HIS A 94 22.280 8.031 1.917 1.00 10.94 O +ATOM 745 CB HIS A 94 24.850 9.969 2.105 1.00 13.11 C +ATOM 746 CG HIS A 94 25.894 9.050 1.557 1.00 14.02 C +ATOM 747 CD2 HIS A 94 27.109 8.693 2.039 1.00 18.09 C +ATOM 748 ND1 HIS A 94 25.743 8.382 0.359 1.00 16.52 N +ATOM 749 CE1 HIS A 94 26.821 7.652 0.129 1.00 18.23 C +ATOM 750 NE2 HIS A 94 27.663 7.817 1.135 1.00 19.12 N +ATOM 751 N HIS A 95 23.997 6.868 2.791 1.00 9.65 N +ATOM 752 CA HIS A 95 23.442 5.595 2.348 1.00 10.50 C +ATOM 753 C HIS A 95 22.097 5.315 3.007 1.00 11.70 C +ATOM 754 O HIS A 95 21.114 4.994 2.329 1.00 10.36 O +ATOM 755 CB HIS A 95 24.416 4.458 2.649 1.00 10.79 C +ATOM 756 CG HIS A 95 23.885 3.107 2.293 1.00 13.74 C +ATOM 757 CD2 HIS A 95 23.201 2.198 3.027 1.00 10.84 C +ATOM 758 ND1 HIS A 95 24.031 2.555 1.035 1.00 12.23 N +ATOM 759 CE1 HIS A 95 23.461 1.366 1.011 1.00 14.63 C +ATOM 760 NE2 HIS A 95 22.953 1.121 2.206 1.00 10.72 N +ATOM 761 N TYR A 96 22.041 5.407 4.338 1.00 9.53 N +ATOM 762 CA TYR A 96 20.791 5.140 5.041 1.00 5.81 C +ATOM 763 C TYR A 96 19.690 6.081 4.571 1.00 8.98 C +ATOM 764 O TYR A 96 18.557 5.647 4.326 1.00 11.44 O +ATOM 765 CB TYR A 96 21.015 5.257 6.551 1.00 7.41 C +ATOM 766 CG TYR A 96 21.880 4.143 7.096 1.00 8.22 C +ATOM 767 CD1 TYR A 96 21.339 2.891 7.375 1.00 8.39 C +ATOM 768 CD2 TYR A 96 23.239 4.328 7.298 1.00 7.16 C +ATOM 769 CE1 TYR A 96 22.114 1.868 7.853 1.00 8.61 C +ATOM 770 CE2 TYR A 96 24.034 3.295 7.790 1.00 7.08 C +ATOM 771 CZ TYR A 96 23.459 2.066 8.062 1.00 8.17 C +ATOM 772 OH TYR A 96 24.211 1.025 8.531 1.00 11.46 O +ATOM 773 N ARG A 97 20.005 7.366 4.410 1.00 9.76 N +ATOM 774 CA ARG A 97 19.009 8.304 3.887 1.00 10.94 C +ATOM 775 C ARG A 97 18.543 7.901 2.495 1.00 13.06 C +ATOM 776 O ARG A 97 17.340 7.912 2.208 1.00 12.96 O +ATOM 777 CB ARG A 97 19.541 9.733 3.859 1.00 13.32 C +ATOM 778 CG ARG A 97 18.555 10.652 3.117 1.00 16.39 C +ATOM 779 CD ARG A 97 18.592 12.075 3.534 1.00 29.29 C +ATOM 780 NE ARG A 97 17.606 12.873 2.804 1.00 26.48 N +ATOM 781 CZ ARG A 97 16.517 13.401 3.346 1.00 31.89 C +ATOM 782 NH1 ARG A 97 16.182 13.158 4.605 1.00 31.23 N1+ +ATOM 783 NH2 ARG A 97 15.746 14.197 2.610 1.00 32.05 N +ATOM 784 N GLU A 98 19.484 7.544 1.615 1.00 10.58 N +ATOM 785 CA GLU A 98 19.104 7.154 0.262 1.00 14.26 C +ATOM 786 C GLU A 98 18.208 5.923 0.257 1.00 10.88 C +ATOM 787 O GLU A 98 17.244 5.862 -0.514 1.00 12.14 O +ATOM 788 CB GLU A 98 20.352 6.958 -0.616 1.00 16.19 C +ATOM 789 CG GLU A 98 21.165 8.264 -0.785 1.00 21.14 C +ATOM 790 CD GLU A 98 22.618 8.073 -1.224 1.00 29.01 C +ATOM 791 OE1 GLU A 98 23.383 9.074 -1.182 1.00 22.04 O +ATOM 792 OE2 GLU A 98 23.005 6.937 -1.588 1.00 34.18 O1- +ATOM 793 N GLN A 99 18.480 4.943 1.125 1.00 10.73 N +ATOM 794 CA GLN A 99 17.606 3.776 1.188 1.00 11.77 C +ATOM 795 C GLN A 99 16.230 4.110 1.764 1.00 12.18 C +ATOM 796 O GLN A 99 15.221 3.536 1.326 1.00 11.51 O +ATOM 797 CB GLN A 99 18.293 2.659 1.984 1.00 10.65 C +ATOM 798 CG GLN A 99 19.565 2.178 1.320 1.00 11.81 C +ATOM 799 CD GLN A 99 19.313 1.719 -0.103 1.00 13.96 C +ATOM 800 NE2 GLN A 99 20.042 2.286 -1.056 1.00 12.67 N +ATOM 801 OE1 GLN A 99 18.457 0.880 -0.339 1.00 13.68 O +ATOM 802 N ILE A 100 16.151 5.043 2.716 1.00 12.33 N +ATOM 803 CA ILE A 100 14.843 5.448 3.234 1.00 11.29 C +ATOM 804 C ILE A 100 14.028 6.120 2.134 1.00 13.35 C +ATOM 805 O ILE A 100 12.861 5.776 1.903 1.00 14.35 O +ATOM 806 CB ILE A 100 15.002 6.362 4.464 1.00 9.89 C +ATOM 807 CG1 ILE A 100 15.515 5.551 5.656 1.00 10.32 C +ATOM 808 CG2 ILE A 100 13.656 7.040 4.809 1.00 11.67 C +ATOM 809 CD1 ILE A 100 16.238 6.393 6.694 1.00 13.69 C +ATOM 810 N LYS A 101 14.646 7.069 1.426 1.00 15.71 N +ATOM 811 CA LYS A 101 13.984 7.731 0.304 1.00 16.12 C +ATOM 812 C LYS A 101 13.495 6.711 -0.720 1.00 16.91 C +ATOM 813 O LYS A 101 12.368 6.809 -1.222 1.00 17.77 O +ATOM 814 CB LYS A 101 14.951 8.714 -0.362 1.00 16.74 C +ATOM 815 CG LYS A 101 15.447 9.930 0.476 1.00 21.94 C +ATOM 816 CD LYS A 101 14.442 11.062 0.720 1.00 35.77 C +ATOM 817 CE LYS A 101 13.500 10.778 1.889 1.00 24.75 C +ATOM 818 NZ LYS A 101 13.185 11.984 2.685 1.00 27.90 N1+ +ATOM 819 N ARG A 102 14.323 5.706 -1.027 1.00 14.73 N +ATOM 820 CA ARG A 102 13.944 4.706 -2.021 1.00 15.57 C +ATOM 821 C ARG A 102 12.740 3.897 -1.566 1.00 14.59 C +ATOM 822 O ARG A 102 11.773 3.721 -2.317 1.00 17.28 O +ATOM 823 CB ARG A 102 15.118 3.768 -2.310 1.00 14.42 C +ATOM 824 CG ARG A 102 14.761 2.642 -3.298 1.00 11.42 C +ATOM 825 CD ARG A 102 15.887 1.617 -3.440 1.00 11.27 C +ATOM 826 NE ARG A 102 16.190 0.920 -2.198 1.00 9.40 N +ATOM 827 CZ ARG A 102 15.459 -0.053 -1.678 1.00 11.14 C +ATOM 828 NH1 ARG A 102 14.364 -0.491 -2.276 1.00 9.26 N1+ +ATOM 829 NH2 ARG A 102 15.845 -0.614 -0.538 1.00 14.32 N +ATOM 830 N VAL A 103 12.792 3.365 -0.344 1.00 14.43 N +ATOM 831 CA VAL A 103 11.751 2.437 0.078 1.00 14.10 C +ATOM 832 C VAL A 103 10.444 3.163 0.390 1.00 17.87 C +ATOM 833 O VAL A 103 9.357 2.591 0.226 1.00 18.41 O +ATOM 834 CB VAL A 103 12.259 1.591 1.260 1.00 14.91 C +ATOM 835 CG1 VAL A 103 12.203 2.380 2.582 1.00 14.56 C +ATOM 836 CG2 VAL A 103 11.460 0.333 1.373 1.00 21.92 C +ATOM 837 N LYS A 104 10.512 4.426 0.794 1.00 14.45 N +ATOM 838 CA LYS A 104 9.315 5.219 1.015 1.00 14.68 C +ATOM 839 C LYS A 104 8.846 5.914 -0.261 1.00 20.64 C +ATOM 840 O LYS A 104 7.746 6.480 -0.276 1.00 20.43 O +ATOM 841 CB LYS A 104 9.587 6.259 2.107 1.00 13.96 C +ATOM 842 CG LYS A 104 9.948 5.658 3.470 1.00 12.08 C +ATOM 843 CD LYS A 104 8.794 4.915 4.138 1.00 16.11 C +ATOM 844 CE LYS A 104 7.682 5.883 4.562 1.00 18.88 C +ATOM 845 NZ LYS A 104 6.445 5.154 4.993 1.00 15.69 N1+ +ATOM 846 N ASP A 105 9.654 5.869 -1.319 1.00 19.27 N +ATOM 847 CA ASP A 105 9.376 6.524 -2.597 1.00 23.09 C +ATOM 848 C ASP A 105 8.931 7.973 -2.389 1.00 29.02 C +ATOM 849 O ASP A 105 7.858 8.397 -2.826 1.00 26.54 O +ATOM 850 CB ASP A 105 8.341 5.727 -3.397 1.00 27.52 C +ATOM 851 CG ASP A 105 8.147 6.265 -4.808 1.00 28.61 C +ATOM 852 OD1 ASP A 105 9.084 6.885 -5.350 1.00 26.40 O +ATOM 853 OD2 ASP A 105 7.053 6.065 -5.373 1.00 36.53 O1- +ATOM 854 N SER A 106 9.782 8.736 -1.703 1.00 21.21 N +ATOM 855 CA SER A 106 9.477 10.128 -1.404 1.00 21.62 C +ATOM 856 C SER A 106 10.758 10.857 -1.039 1.00 27.19 C +ATOM 857 O SER A 106 11.617 10.299 -0.356 1.00 20.39 O +ATOM 858 CB SER A 106 8.468 10.247 -0.256 1.00 28.77 C +ATOM 859 OG SER A 106 8.322 11.596 0.154 1.00 30.42 O +ATOM 860 N GLU A 107 10.872 12.102 -1.499 1.00 22.42 N +ATOM 861 CA GLU A 107 11.907 13.005 -1.020 1.00 29.19 C +ATOM 862 C GLU A 107 11.560 13.610 0.334 1.00 24.26 C +ATOM 863 O GLU A 107 12.419 14.247 0.954 1.00 28.97 O +ATOM 864 CB GLU A 107 12.140 14.117 -2.051 1.00 35.67 C +ATOM 865 CG GLU A 107 13.439 14.912 -1.883 1.00 40.47 C +ATOM 866 CD GLU A 107 14.693 14.059 -2.025 1.00 46.57 C +ATOM 867 OE1 GLU A 107 14.642 13.020 -2.726 1.00 52.45 O +ATOM 868 OE2 GLU A 107 15.734 14.437 -1.440 1.00 48.86 O1- +ATOM 869 N ASP A 108 10.330 13.413 0.805 1.00 24.05 N +ATOM 870 CA ASP A 108 9.811 14.017 2.032 1.00 29.56 C +ATOM 871 C ASP A 108 9.399 12.903 2.993 1.00 27.21 C +ATOM 872 O ASP A 108 8.221 12.559 3.122 1.00 33.19 O +ATOM 873 CB ASP A 108 8.617 14.972 1.723 1.00 32.13 C +ATOM 874 CG ASP A 108 8.373 15.974 2.840 1.00 43.89 C +ATOM 875 OD1 ASP A 108 9.351 16.340 3.527 1.00 51.97 O +ATOM 876 OD2 ASP A 108 7.210 16.391 3.039 1.00 49.41 O1- +ATOM 877 N VAL A 109 10.380 12.328 3.666 1.00 18.97 N +ATOM 878 CA VAL A 109 10.148 11.354 4.726 1.00 16.82 C +ATOM 879 C VAL A 109 10.586 12.000 6.033 1.00 14.24 C +ATOM 880 O VAL A 109 11.707 12.519 6.103 1.00 11.37 O +ATOM 881 CB VAL A 109 10.924 10.044 4.481 1.00 17.29 C +ATOM 882 CG1 VAL A 109 10.759 9.085 5.662 1.00 12.78 C +ATOM 883 CG2 VAL A 109 10.491 9.405 3.153 1.00 16.59 C +ATOM 884 N PRO A 110 9.744 12.022 7.064 1.00 11.39 N +ATOM 885 CA PRO A 110 10.178 12.551 8.366 1.00 10.04 C +ATOM 886 C PRO A 110 11.468 11.884 8.814 1.00 13.84 C +ATOM 887 O PRO A 110 11.572 10.658 8.843 1.00 8.39 O +ATOM 888 CB PRO A 110 9.006 12.211 9.294 1.00 11.76 C +ATOM 889 CG PRO A 110 7.823 12.164 8.378 1.00 17.30 C +ATOM 890 CD PRO A 110 8.328 11.613 7.072 1.00 15.11 C +ATOM 891 N MET A 111 12.468 12.700 9.138 1.00 10.68 N +ATOM 892 CA MET A 111 13.781 12.165 9.467 1.00 11.53 C +ATOM 893 C MET A 111 14.523 13.157 10.352 1.00 10.74 C +ATOM 894 O MET A 111 14.313 14.371 10.258 1.00 11.37 O +ATOM 895 CB MET A 111 14.544 11.836 8.174 1.00 9.39 C +ATOM 896 CG MET A 111 15.926 11.282 8.349 1.00 16.88 C +ATOM 897 SD MET A 111 16.677 11.134 6.719 1.00 23.95 S +ATOM 898 CE MET A 111 15.477 10.168 5.865 1.00 8.98 C +ATOM 899 N VAL A 112 15.343 12.622 11.259 1.00 7.79 N +ATOM 900 CA VAL A 112 16.192 13.424 12.129 1.00 6.95 C +ATOM 901 C VAL A 112 17.592 12.832 12.080 1.00 5.70 C +ATOM 902 O VAL A 112 17.753 11.611 12.193 1.00 7.11 O +ATOM 903 CB VAL A 112 15.669 13.448 13.582 1.00 8.37 C +ATOM 904 CG1 VAL A 112 16.670 14.159 14.487 1.00 9.34 C +ATOM 905 CG2 VAL A 112 14.280 14.109 13.663 1.00 8.57 C +ATOM 906 N LEU A 113 18.594 13.692 11.897 1.00 7.11 N +ATOM 907 CA LEU A 113 19.995 13.295 11.995 1.00 7.03 C +ATOM 908 C LEU A 113 20.414 13.346 13.459 1.00 8.76 C +ATOM 909 O LEU A 113 20.181 14.354 14.143 1.00 6.94 O +ATOM 910 CB LEU A 113 20.877 14.239 11.176 1.00 9.63 C +ATOM 911 CG LEU A 113 22.368 13.920 11.127 1.00 7.16 C +ATOM 912 CD1 LEU A 113 22.543 12.526 10.498 1.00 7.38 C +ATOM 913 CD2 LEU A 113 23.151 14.984 10.354 1.00 9.64 C +ATOM 914 N VAL A 114 21.014 12.262 13.949 1.00 6.45 N +ATOM 915 CA VAL A 114 21.362 12.134 15.364 1.00 5.85 C +ATOM 916 C VAL A 114 22.866 11.903 15.500 1.00 6.17 C +ATOM 917 O VAL A 114 23.386 10.879 15.040 1.00 6.71 O +ATOM 918 CB VAL A 114 20.585 10.988 16.037 1.00 6.02 C +ATOM 919 CG1 VAL A 114 21.081 10.783 17.477 1.00 6.40 C +ATOM 920 CG2 VAL A 114 19.071 11.270 16.013 1.00 6.43 C +ATOM 921 N GLY A 115 23.547 12.812 16.207 1.00 6.91 N +ATOM 922 CA GLY A 115 24.951 12.639 16.552 1.00 6.17 C +ATOM 923 C GLY A 115 25.082 12.112 17.969 1.00 4.99 C +ATOM 924 O GLY A 115 24.970 12.879 18.939 1.00 7.17 O +ATOM 925 N ASN A 116 25.313 10.809 18.106 1.00 5.27 N +ATOM 926 CA ASN A 116 25.327 10.189 19.421 1.00 4.99 C +ATOM 927 C ASN A 116 26.742 10.123 20.016 1.00 6.10 C +ATOM 928 O ASN A 116 27.751 10.381 19.353 1.00 5.85 O +ATOM 929 CB ASN A 116 24.689 8.792 19.341 1.00 7.38 C +ATOM 930 CG ASN A 116 24.401 8.191 20.718 1.00 8.00 C +ATOM 931 ND2 ASN A 116 24.737 6.908 20.878 1.00 7.93 N +ATOM 932 OD1 ASN A 116 23.884 8.862 21.621 1.00 5.88 O +ATOM 933 N LYS A 117 26.793 9.744 21.300 1.00 6.04 N +ATOM 934 CA LYS A 117 28.015 9.665 22.107 1.00 6.58 C +ATOM 935 C LYS A 117 28.606 11.049 22.358 1.00 5.88 C +ATOM 936 O LYS A 117 29.827 11.210 22.452 1.00 10.17 O +ATOM 937 CB LYS A 117 29.052 8.706 21.497 1.00 5.56 C +ATOM 938 CG LYS A 117 28.464 7.313 21.173 1.00 8.01 C +ATOM 939 CD LYS A 117 29.558 6.287 20.831 1.00 7.99 C +ATOM 940 CE LYS A 117 28.938 4.976 20.348 1.00 8.73 C +ATOM 941 NZ LYS A 117 29.952 3.880 20.151 1.00 10.54 N1+ +ATOM 942 N SER A 118 27.725 12.044 22.496 1.00 7.17 N +ATOM 943 CA SER A 118 28.129 13.418 22.762 1.00 7.00 C +ATOM 944 C SER A 118 28.830 13.567 24.108 1.00 10.20 C +ATOM 945 O SER A 118 29.467 14.596 24.338 1.00 10.34 O +ATOM 946 CB SER A 118 26.914 14.349 22.697 1.00 11.39 C +ATOM 947 OG SER A 118 26.316 14.331 21.393 1.00 7.90 O +ATOM 948 N ASP A 119 28.735 12.569 24.991 1.00 7.74 N +ATOM 949 CA ASP A 119 29.437 12.603 26.274 1.00 10.43 C +ATOM 950 C ASP A 119 30.933 12.325 26.150 1.00 10.17 C +ATOM 951 O ASP A 119 31.672 12.579 27.109 1.00 13.18 O +ATOM 952 CB ASP A 119 28.842 11.569 27.225 1.00 10.09 C +ATOM 953 CG ASP A 119 28.877 10.181 26.639 1.00 9.61 C +ATOM 954 OD1 ASP A 119 28.086 9.905 25.713 1.00 10.18 O +ATOM 955 OD2 ASP A 119 29.718 9.365 27.071 1.00 12.15 O1- +ATOM 956 N LEU A 120 31.405 11.766 24.988 1.00 13.16 N +ATOM 957 CA LEU A 120 32.800 11.356 24.852 1.00 12.34 C +ATOM 958 C LEU A 120 33.687 12.498 24.362 1.00 14.31 C +ATOM 959 O LEU A 120 33.257 13.342 23.564 1.00 14.90 O +ATOM 960 CB LEU A 120 32.929 10.181 23.886 1.00 13.59 C +ATOM 961 CG LEU A 120 32.300 8.881 24.365 1.00 12.64 C +ATOM 962 CD1 LEU A 120 32.380 7.806 23.288 1.00 15.75 C +ATOM 963 CD2 LEU A 120 32.961 8.407 25.655 1.00 15.75 C +ATOM 964 N PRO A 121 34.940 12.528 24.817 1.00 15.51 N +ATOM 965 CA PRO A 121 35.870 13.579 24.388 1.00 15.23 C +ATOM 966 C PRO A 121 36.571 13.292 23.070 1.00 14.61 C +ATOM 967 O PRO A 121 37.267 14.175 22.554 1.00 18.02 O +ATOM 968 CB PRO A 121 36.881 13.602 25.552 1.00 20.53 C +ATOM 969 CG PRO A 121 36.978 12.173 25.961 1.00 19.25 C +ATOM 970 CD PRO A 121 35.572 11.587 25.766 1.00 16.51 C +ATOM 971 N SER A 122 36.379 12.109 22.493 1.00 15.66 N +ATOM 972 CA SER A 122 37.129 11.673 21.314 1.00 21.94 C +ATOM 973 C SER A 122 36.499 12.105 19.990 1.00 17.66 C +ATOM 974 O SER A 122 36.509 11.342 19.021 1.00 17.43 O +ATOM 975 CB SER A 122 37.296 10.155 21.380 1.00 26.44 C +ATOM 976 OG SER A 122 36.073 9.512 21.762 1.00 23.67 O +ATOM 977 N ARG A 123 35.979 13.328 19.902 1.00 17.58 N +ATOM 978 CA ARG A 123 35.240 13.760 18.719 1.00 10.69 C +ATOM 979 C ARG A 123 36.106 13.732 17.461 1.00 13.90 C +ATOM 980 O ARG A 123 37.251 14.199 17.461 1.00 17.06 O +ATOM 981 CB ARG A 123 34.701 15.170 18.956 1.00 16.35 C +ATOM 982 CG ARG A 123 33.933 15.767 17.809 1.00 16.14 C +ATOM 983 CD ARG A 123 33.461 17.166 18.171 1.00 15.82 C +ATOM 984 NE ARG A 123 32.344 17.147 19.109 1.00 17.57 N +ATOM 985 CZ ARG A 123 31.090 16.882 18.765 1.00 17.82 C +ATOM 986 NH1 ARG A 123 30.768 16.573 17.520 1.00 12.66 N1+ +ATOM 987 NH2 ARG A 123 30.137 16.924 19.695 1.00 17.13 N +ATOM 988 N THR A 124 35.557 13.169 16.383 1.00 12.73 N +ATOM 989 CA THR A 124 36.168 13.242 15.062 1.00 12.36 C +ATOM 990 C THR A 124 35.270 13.904 14.023 1.00 14.79 C +ATOM 991 O THR A 124 35.761 14.278 12.951 1.00 15.92 O +ATOM 992 CB THR A 124 36.575 11.837 14.574 1.00 12.09 C +ATOM 993 CG2 THR A 124 37.548 11.181 15.537 1.00 14.64 C +ATOM 994 OG1 THR A 124 35.411 11.014 14.434 1.00 16.42 O +ATOM 995 N VAL A 125 33.987 14.099 14.318 1.00 11.53 N +ATOM 996 CA VAL A 125 33.037 14.749 13.416 1.00 14.11 C +ATOM 997 C VAL A 125 32.652 16.086 14.036 1.00 14.53 C +ATOM 998 O VAL A 125 32.019 16.125 15.100 1.00 13.99 O +ATOM 999 CB VAL A 125 31.791 13.882 13.181 1.00 15.95 C +ATOM 1000 CG1 VAL A 125 30.852 14.541 12.153 1.00 11.44 C +ATOM 1001 CG2 VAL A 125 32.183 12.487 12.757 1.00 14.53 C +ATOM 1002 N ASP A 126 33.027 17.181 13.372 1.00 13.97 N +ATOM 1003 CA ASP A 126 32.681 18.515 13.846 1.00 12.18 C +ATOM 1004 C ASP A 126 31.171 18.696 13.869 1.00 10.84 C +ATOM 1005 O ASP A 126 30.474 18.292 12.936 1.00 10.64 O +ATOM 1006 CB ASP A 126 33.285 19.593 12.932 1.00 20.52 C +ATOM 1007 CG ASP A 126 34.804 19.599 12.929 1.00 28.96 C +ATOM 1008 OD1 ASP A 126 35.427 18.991 13.822 1.00 32.81 O +ATOM 1009 OD2 ASP A 126 35.375 20.226 12.013 1.00 27.11 O1- +ATOM 1010 N THR A 127 30.670 19.330 14.931 1.00 10.62 N +ATOM 1011 CA THR A 127 29.248 19.659 15.001 1.00 11.88 C +ATOM 1012 C THR A 127 28.810 20.472 13.790 1.00 13.53 C +ATOM 1013 O THR A 127 27.695 20.301 13.286 1.00 12.06 O +ATOM 1014 CB THR A 127 28.956 20.423 16.293 1.00 15.59 C +ATOM 1015 CG2 THR A 127 27.539 21.019 16.288 1.00 14.41 C +ATOM 1016 OG1 THR A 127 29.130 19.541 17.411 1.00 15.85 O +ATOM 1017 N LYS A 128 29.679 21.357 13.295 1.00 11.91 N +ATOM 1018 CA LYS A 128 29.283 22.169 12.148 1.00 12.62 C +ATOM 1019 C LYS A 128 29.096 21.297 10.910 1.00 14.89 C +ATOM 1020 O LYS A 128 28.153 21.501 10.132 1.00 11.93 O +ATOM 1021 CB LYS A 128 30.316 23.270 11.911 1.00 13.99 C +ATOM 1022 CG LYS A 128 30.282 23.915 10.515 1.00 20.94 C +ATOM 1023 CD LYS A 128 31.598 24.641 10.228 1.00 26.60 C +ATOM 1024 CE LYS A 128 31.603 25.372 8.884 1.00 27.57 C +ATOM 1025 NZ LYS A 128 30.244 25.492 8.277 1.00 34.95 N1+ +ATOM 1026 N GLN A 129 29.964 20.297 10.727 1.00 11.03 N +ATOM 1027 CA GLN A 129 29.798 19.392 9.597 1.00 14.13 C +ATOM 1028 C GLN A 129 28.473 18.645 9.695 1.00 10.60 C +ATOM 1029 O GLN A 129 27.741 18.530 8.703 1.00 9.83 O +ATOM 1030 CB GLN A 129 30.980 18.423 9.511 1.00 13.62 C +ATOM 1031 CG GLN A 129 30.983 17.602 8.230 1.00 12.29 C +ATOM 1032 CD GLN A 129 32.160 16.648 8.119 1.00 23.91 C +ATOM 1033 NE2 GLN A 129 32.621 16.421 6.889 1.00 21.64 N +ATOM 1034 OE1 GLN A 129 32.629 16.100 9.112 1.00 20.55 O +ATOM 1035 N ALA A 130 28.119 18.193 10.899 1.00 9.40 N +ATOM 1036 CA ALA A 130 26.860 17.478 11.097 1.00 10.14 C +ATOM 1037 C ALA A 130 25.650 18.385 10.879 1.00 8.71 C +ATOM 1038 O ALA A 130 24.689 17.998 10.205 1.00 10.45 O +ATOM 1039 CB ALA A 130 26.825 16.877 12.498 1.00 8.86 C +ATOM 1040 N GLN A 131 25.656 19.578 11.480 1.00 6.88 N +ATOM 1041 CA GLN A 131 24.580 20.537 11.240 1.00 7.78 C +ATOM 1042 C GLN A 131 24.438 20.845 9.755 1.00 10.18 C +ATOM 1043 O GLN A 131 23.318 20.921 9.236 1.00 9.23 O +ATOM 1044 CB GLN A 131 24.836 21.837 12.004 1.00 11.28 C +ATOM 1045 CG GLN A 131 24.700 21.741 13.522 1.00 11.17 C +ATOM 1046 CD GLN A 131 25.178 23.012 14.204 1.00 22.22 C +ATOM 1047 NE2 GLN A 131 24.801 23.186 15.465 1.00 18.03 N +ATOM 1048 OE1 GLN A 131 25.909 23.815 13.608 1.00 20.58 O +ATOM 1049 N ASP A 132 25.567 21.032 9.056 1.00 10.08 N +ATOM 1050 CA ASP A 132 25.510 21.337 7.626 1.00 10.18 C +ATOM 1051 C ASP A 132 24.965 20.167 6.810 1.00 11.44 C +ATOM 1052 O ASP A 132 24.382 20.388 5.742 1.00 11.14 O +ATOM 1053 CB ASP A 132 26.890 21.716 7.098 1.00 12.77 C +ATOM 1054 CG ASP A 132 27.315 23.121 7.499 1.00 13.76 C +ATOM 1055 OD1 ASP A 132 26.479 23.896 8.026 1.00 16.89 O +ATOM 1056 OD2 ASP A 132 28.505 23.436 7.281 1.00 19.22 O1- +ATOM 1057 N LEU A 133 25.177 18.922 7.259 1.00 8.72 N +ATOM 1058 CA LEU A 133 24.590 17.786 6.542 1.00 7.83 C +ATOM 1059 C LEU A 133 23.081 17.745 6.736 1.00 10.68 C +ATOM 1060 O LEU A 133 22.326 17.509 5.781 1.00 9.35 O +ATOM 1061 CB LEU A 133 25.217 16.469 7.006 1.00 9.15 C +ATOM 1062 CG LEU A 133 24.654 15.208 6.323 1.00 9.19 C +ATOM 1063 CD1 LEU A 133 24.917 15.215 4.820 1.00 8.65 C +ATOM 1064 CD2 LEU A 133 25.206 13.928 6.940 1.00 7.20 C +ATOM 1065 N ALA A 134 22.620 17.976 7.966 1.00 7.06 N +ATOM 1066 CA ALA A 134 21.191 18.084 8.195 1.00 5.29 C +ATOM 1067 C ALA A 134 20.608 19.215 7.361 1.00 7.58 C +ATOM 1068 O ALA A 134 19.528 19.082 6.780 1.00 8.98 O +ATOM 1069 CB ALA A 134 20.923 18.295 9.683 1.00 10.01 C +ATOM 1070 N ARG A 135 21.334 20.325 7.269 1.00 9.14 N +ATOM 1071 CA ARG A 135 20.891 21.439 6.436 1.00 10.33 C +ATOM 1072 C ARG A 135 20.755 21.016 4.974 1.00 11.60 C +ATOM 1073 O ARG A 135 19.766 21.369 4.311 1.00 13.71 O +ATOM 1074 CB ARG A 135 21.874 22.598 6.585 1.00 12.51 C +ATOM 1075 CG ARG A 135 21.632 23.752 5.633 1.00 26.01 C +ATOM 1076 CD ARG A 135 20.409 24.568 6.023 1.00 29.99 C +ATOM 1077 NE ARG A 135 20.160 25.652 5.078 1.00 31.05 N +ATOM 1078 CZ ARG A 135 18.962 26.151 4.808 1.00 34.62 C +ATOM 1079 NH1 ARG A 135 17.880 25.734 5.445 1.00 40.71 N1+ +ATOM 1080 NH2 ARG A 135 18.846 27.095 3.876 1.00 40.78 N +ATOM 1081 N SER A 136 21.739 20.263 4.452 1.00 12.59 N +ATOM 1082 CA SER A 136 21.677 19.782 3.065 1.00 9.96 C +ATOM 1083 C SER A 136 20.414 18.975 2.808 1.00 13.52 C +ATOM 1084 O SER A 136 19.804 19.081 1.736 1.00 13.45 O +ATOM 1085 CB SER A 136 22.903 18.933 2.731 1.00 12.66 C +ATOM 1086 OG SER A 136 24.097 19.691 2.774 1.00 11.55 O +ATOM 1087 N TYR A 137 20.014 18.153 3.774 1.00 8.75 N +ATOM 1088 CA TYR A 137 18.816 17.335 3.669 1.00 10.62 C +ATOM 1089 C TYR A 137 17.534 18.057 4.073 1.00 12.53 C +ATOM 1090 O TYR A 137 16.451 17.502 3.865 1.00 13.78 O +ATOM 1091 CB TYR A 137 18.951 16.095 4.543 1.00 13.23 C +ATOM 1092 CG TYR A 137 19.993 15.104 4.103 1.00 15.69 C +ATOM 1093 CD1 TYR A 137 20.172 14.812 2.760 1.00 15.55 C +ATOM 1094 CD2 TYR A 137 20.759 14.404 5.039 1.00 17.18 C +ATOM 1095 CE1 TYR A 137 21.112 13.871 2.351 1.00 18.35 C +ATOM 1096 CE2 TYR A 137 21.709 13.460 4.631 1.00 15.47 C +ATOM 1097 CZ TYR A 137 21.867 13.201 3.286 1.00 14.34 C +ATOM 1098 OH TYR A 137 22.801 12.271 2.858 1.00 17.17 O +ATOM 1099 N GLY A 138 17.624 19.244 4.667 1.00 10.87 N +ATOM 1100 CA GLY A 138 16.444 19.925 5.170 1.00 13.61 C +ATOM 1101 C GLY A 138 15.778 19.234 6.344 1.00 12.99 C +ATOM 1102 O GLY A 138 14.548 19.231 6.435 1.00 14.05 O +ATOM 1103 N ILE A 139 16.557 18.641 7.243 1.00 10.10 N +ATOM 1104 CA ILE A 139 16.016 17.916 8.393 1.00 10.40 C +ATOM 1105 C ILE A 139 16.684 18.454 9.651 1.00 10.40 C +ATOM 1106 O ILE A 139 17.726 19.120 9.581 1.00 12.38 O +ATOM 1107 CB ILE A 139 16.230 16.387 8.260 1.00 10.12 C +ATOM 1108 CG1 ILE A 139 17.708 16.053 8.452 1.00 9.52 C +ATOM 1109 CG2 ILE A 139 15.697 15.882 6.914 1.00 10.83 C +ATOM 1110 CD1 ILE A 139 18.107 14.587 8.161 1.00 10.87 C +ATOM 1111 N PRO A 140 16.108 18.187 10.822 1.00 7.85 N +ATOM 1112 CA PRO A 140 16.764 18.591 12.072 1.00 11.07 C +ATOM 1113 C PRO A 140 17.949 17.707 12.441 1.00 11.99 C +ATOM 1114 O PRO A 140 18.055 16.544 12.035 1.00 9.33 O +ATOM 1115 CB PRO A 140 15.639 18.482 13.115 1.00 12.89 C +ATOM 1116 CG PRO A 140 14.357 18.402 12.297 1.00 16.39 C +ATOM 1117 CD PRO A 140 14.764 17.646 11.072 1.00 9.85 C +ATOM 1118 N PHE A 141 18.859 18.301 13.215 1.00 12.38 N +ATOM 1119 CA PHE A 141 20.023 17.630 13.784 1.00 7.37 C +ATOM 1120 C PHE A 141 19.942 17.722 15.301 1.00 10.93 C +ATOM 1121 O PHE A 141 19.745 18.812 15.854 1.00 11.19 O +ATOM 1122 CB PHE A 141 21.330 18.266 13.304 1.00 10.16 C +ATOM 1123 CG PHE A 141 22.552 17.763 14.019 1.00 9.62 C +ATOM 1124 CD1 PHE A 141 22.915 16.415 13.955 1.00 9.53 C +ATOM 1125 CD2 PHE A 141 23.340 18.633 14.775 1.00 9.96 C +ATOM 1126 CE1 PHE A 141 24.052 15.946 14.610 1.00 7.33 C +ATOM 1127 CE2 PHE A 141 24.478 18.177 15.432 1.00 7.76 C +ATOM 1128 CZ PHE A 141 24.839 16.836 15.362 1.00 9.89 C +ATOM 1129 N ILE A 142 20.097 16.588 15.974 1.00 10.73 N +ATOM 1130 CA ILE A 142 20.020 16.526 17.431 1.00 7.80 C +ATOM 1131 C ILE A 142 21.271 15.828 17.943 1.00 6.73 C +ATOM 1132 O ILE A 142 21.639 14.762 17.440 1.00 8.06 O +ATOM 1133 CB ILE A 142 18.765 15.773 17.922 1.00 7.31 C +ATOM 1134 CG1 ILE A 142 17.464 16.432 17.427 1.00 8.16 C +ATOM 1135 CG2 ILE A 142 18.769 15.679 19.451 1.00 7.89 C +ATOM 1136 CD1 ILE A 142 17.226 17.871 17.937 1.00 14.28 C +ATOM 1137 N GLU A 143 21.903 16.415 18.956 1.00 7.30 N +ATOM 1138 CA GLU A 143 23.087 15.851 19.584 1.00 8.14 C +ATOM 1139 C GLU A 143 22.635 15.029 20.782 1.00 8.49 C +ATOM 1140 O GLU A 143 21.928 15.553 21.655 1.00 8.60 O +ATOM 1141 CB GLU A 143 24.031 16.964 20.039 1.00 10.97 C +ATOM 1142 CG GLU A 143 24.616 17.817 18.919 1.00 9.60 C +ATOM 1143 CD GLU A 143 25.472 18.953 19.459 1.00 17.68 C +ATOM 1144 OE1 GLU A 143 26.421 18.663 20.224 1.00 13.39 O +ATOM 1145 OE2 GLU A 143 25.175 20.136 19.139 1.00 13.78 O1- +ATOM 1146 N THR A 144 23.035 13.755 20.836 1.00 6.00 N +ATOM 1147 CA THR A 144 22.582 12.881 21.906 1.00 6.09 C +ATOM 1148 C THR A 144 23.734 12.204 22.637 1.00 7.72 C +ATOM 1149 O THR A 144 24.850 12.069 22.126 1.00 6.67 O +ATOM 1150 CB THR A 144 21.636 11.780 21.408 1.00 6.55 C +ATOM 1151 CG2 THR A 144 20.377 12.360 20.749 1.00 4.53 C +ATOM 1152 OG1 THR A 144 22.344 10.910 20.503 1.00 8.38 O +ATOM 1153 N SER A 145 23.418 11.787 23.860 1.00 7.21 N +ATOM 1154 CA SER A 145 24.170 10.768 24.589 1.00 10.14 C +ATOM 1155 C SER A 145 23.190 9.787 25.211 1.00 8.17 C +ATOM 1156 O SER A 145 22.404 10.163 26.094 1.00 7.83 O +ATOM 1157 CB SER A 145 25.052 11.372 25.686 1.00 7.63 C +ATOM 1158 OG SER A 145 25.605 10.321 26.483 1.00 7.97 O +ATOM 1159 N ALA A 146 23.235 8.531 24.764 1.00 7.38 N +ATOM 1160 CA ALA A 146 22.488 7.494 25.464 1.00 6.97 C +ATOM 1161 C ALA A 146 23.010 7.296 26.890 1.00 6.74 C +ATOM 1162 O ALA A 146 22.267 6.846 27.766 1.00 9.24 O +ATOM 1163 CB ALA A 146 22.560 6.186 24.687 1.00 6.86 C +ATOM 1164 N LYS A 147 24.285 7.605 27.136 1.00 7.42 N +ATOM 1165 CA LYS A 147 24.845 7.445 28.479 1.00 9.01 C +ATOM 1166 C LYS A 147 24.212 8.406 29.481 1.00 10.54 C +ATOM 1167 O LYS A 147 23.814 7.997 30.583 1.00 7.96 O +ATOM 1168 CB LYS A 147 26.355 7.645 28.442 1.00 9.99 C +ATOM 1169 CG LYS A 147 27.012 7.469 29.808 1.00 13.13 C +ATOM 1170 CD LYS A 147 28.520 7.446 29.670 1.00 18.42 C +ATOM 1171 CE LYS A 147 29.191 7.483 31.041 1.00 25.32 C +ATOM 1172 NZ LYS A 147 30.620 7.079 30.951 1.00 32.48 N1+ +ATOM 1173 N THR A 148 24.153 9.699 29.147 1.00 9.28 N +ATOM 1174 CA THR A 148 23.652 10.708 30.078 1.00 8.17 C +ATOM 1175 C THR A 148 22.199 11.092 29.844 1.00 8.42 C +ATOM 1176 O THR A 148 21.627 11.824 30.663 1.00 7.85 O +ATOM 1177 CB THR A 148 24.490 11.992 29.988 1.00 7.02 C +ATOM 1178 CG2 THR A 148 25.984 11.715 30.194 1.00 10.29 C +ATOM 1179 OG1 THR A 148 24.257 12.596 28.703 1.00 7.47 O +ATOM 1180 N ARG A 149 21.598 10.637 28.745 1.00 8.72 N +ATOM 1181 CA ARG A 149 20.238 10.896 28.305 1.00 8.97 C +ATOM 1182 C ARG A 149 20.134 12.238 27.588 1.00 9.19 C +ATOM 1183 O ARG A 149 19.030 12.607 27.172 1.00 8.40 O +ATOM 1184 CB ARG A 149 19.216 10.814 29.453 1.00 11.33 C +ATOM 1185 CG ARG A 149 17.970 10.016 29.127 1.00 16.21 C +ATOM 1186 CD ARG A 149 17.073 9.946 30.351 1.00 14.22 C +ATOM 1187 NE ARG A 149 17.564 8.956 31.297 1.00 16.79 N +ATOM 1188 CZ ARG A 149 17.479 9.058 32.615 1.00 13.53 C +ATOM 1189 NH1 ARG A 149 16.989 10.143 33.200 1.00 18.28 N1+ +ATOM 1190 NH2 ARG A 149 17.889 8.042 33.368 1.00 19.84 N +ATOM 1191 N GLN A 150 21.235 12.979 27.422 1.00 7.38 N +ATOM 1192 CA GLN A 150 21.261 14.177 26.591 1.00 10.00 C +ATOM 1193 C GLN A 150 20.565 13.935 25.258 1.00 6.79 C +ATOM 1194 O GLN A 150 20.899 12.994 24.530 1.00 6.19 O +ATOM 1195 CB GLN A 150 22.716 14.622 26.355 1.00 8.75 C +ATOM 1196 CG GLN A 150 22.864 15.684 25.236 1.00 9.14 C +ATOM 1197 CD GLN A 150 24.308 16.145 24.996 1.00 12.97 C +ATOM 1198 NE2 GLN A 150 24.534 16.800 23.855 1.00 13.33 N +ATOM 1199 OE1 GLN A 150 25.199 15.913 25.812 1.00 9.23 O +ATOM 1200 N GLY A 151 19.550 14.747 24.969 1.00 6.94 N +ATOM 1201 CA GLY A 151 18.933 14.792 23.651 1.00 9.02 C +ATOM 1202 C GLY A 151 17.997 13.657 23.293 1.00 10.44 C +ATOM 1203 O GLY A 151 17.432 13.674 22.183 1.00 8.23 O +ATOM 1204 N VAL A 152 17.804 12.669 24.179 1.00 8.79 N +ATOM 1205 CA VAL A 152 17.070 11.470 23.772 1.00 8.75 C +ATOM 1206 C VAL A 152 15.628 11.813 23.434 1.00 6.82 C +ATOM 1207 O VAL A 152 15.144 11.476 22.345 1.00 7.26 O +ATOM 1208 CB VAL A 152 17.158 10.376 24.848 1.00 7.72 C +ATOM 1209 CG1 VAL A 152 16.264 9.221 24.464 1.00 9.79 C +ATOM 1210 CG2 VAL A 152 18.624 9.920 24.998 1.00 6.19 C +ATOM 1211 N ASP A 153 14.917 12.475 24.360 1.00 8.45 N +ATOM 1212 CA ASP A 153 13.552 12.906 24.066 1.00 10.12 C +ATOM 1213 C ASP A 153 13.527 13.846 22.866 1.00 10.05 C +ATOM 1214 O ASP A 153 12.662 13.727 21.990 1.00 8.47 O +ATOM 1215 CB ASP A 153 12.919 13.600 25.281 1.00 9.99 C +ATOM 1216 CG ASP A 153 12.487 12.635 26.383 1.00 11.02 C +ATOM 1217 OD1 ASP A 153 12.451 11.399 26.187 1.00 13.38 O +ATOM 1218 OD2 ASP A 153 12.160 13.141 27.477 1.00 13.58 O1- +ATOM 1219 N ASP A 154 14.478 14.784 22.805 1.00 7.80 N +ATOM 1220 CA ASP A 154 14.546 15.711 21.676 1.00 7.55 C +ATOM 1221 C ASP A 154 14.561 14.971 20.351 1.00 10.30 C +ATOM 1222 O ASP A 154 13.877 15.362 19.401 1.00 8.35 O +ATOM 1223 CB ASP A 154 15.800 16.589 21.767 1.00 9.60 C +ATOM 1224 CG ASP A 154 15.712 17.636 22.846 1.00 12.61 C +ATOM 1225 OD1 ASP A 154 14.621 17.815 23.424 1.00 12.20 O +ATOM 1226 OD2 ASP A 154 16.756 18.260 23.124 1.00 12.27 O1- +ATOM 1227 N ALA A 155 15.353 13.908 20.259 1.00 6.20 N +ATOM 1228 CA ALA A 155 15.455 13.196 18.991 1.00 9.84 C +ATOM 1229 C ALA A 155 14.139 12.514 18.636 1.00 8.22 C +ATOM 1230 O ALA A 155 13.600 12.710 17.538 1.00 7.76 O +ATOM 1231 CB ALA A 155 16.603 12.191 19.061 1.00 8.65 C +ATOM 1232 N PHE A 156 13.579 11.737 19.571 1.00 8.27 N +ATOM 1233 CA PHE A 156 12.354 11.002 19.270 1.00 7.91 C +ATOM 1234 C PHE A 156 11.157 11.945 19.149 1.00 9.74 C +ATOM 1235 O PHE A 156 10.322 11.779 18.257 1.00 8.66 O +ATOM 1236 CB PHE A 156 12.108 9.928 20.331 1.00 7.56 C +ATOM 1237 CG PHE A 156 13.067 8.757 20.248 1.00 5.35 C +ATOM 1238 CD1 PHE A 156 12.888 7.757 19.290 1.00 8.99 C +ATOM 1239 CD2 PHE A 156 14.145 8.667 21.103 1.00 6.09 C +ATOM 1240 CE1 PHE A 156 13.756 6.661 19.209 1.00 6.24 C +ATOM 1241 CE2 PHE A 156 15.044 7.571 21.018 1.00 5.92 C +ATOM 1242 CZ PHE A 156 14.842 6.578 20.059 1.00 4.02 C +ATOM 1243 N TYR A 157 11.067 12.957 20.014 1.00 6.41 N +ATOM 1244 CA TYR A 157 9.907 13.845 19.954 1.00 8.10 C +ATOM 1245 C TYR A 157 9.927 14.695 18.698 1.00 9.30 C +ATOM 1246 O TYR A 157 8.877 14.947 18.101 1.00 11.75 O +ATOM 1247 CB TYR A 157 9.850 14.738 21.190 1.00 9.39 C +ATOM 1248 CG TYR A 157 9.497 14.046 22.495 1.00 11.75 C +ATOM 1249 CD1 TYR A 157 9.510 12.665 22.618 1.00 9.42 C +ATOM 1250 CD2 TYR A 157 9.206 14.797 23.622 1.00 12.42 C +ATOM 1251 CE1 TYR A 157 9.205 12.048 23.829 1.00 13.02 C +ATOM 1252 CE2 TYR A 157 8.905 14.202 24.824 1.00 11.96 C +ATOM 1253 CZ TYR A 157 8.903 12.835 24.931 1.00 15.12 C +ATOM 1254 OH TYR A 157 8.603 12.267 26.152 1.00 21.02 O +ATOM 1255 N THR A 158 11.102 15.175 18.299 1.00 9.48 N +ATOM 1256 CA THR A 158 11.192 15.930 17.052 1.00 8.87 C +ATOM 1257 C THR A 158 10.695 15.092 15.891 1.00 11.82 C +ATOM 1258 O THR A 158 9.932 15.575 15.051 1.00 10.07 O +ATOM 1259 CB THR A 158 12.634 16.394 16.810 1.00 8.71 C +ATOM 1260 CG2 THR A 158 12.760 17.158 15.475 1.00 8.85 C +ATOM 1261 OG1 THR A 158 13.027 17.275 17.864 1.00 8.61 O +ATOM 1262 N LEU A 159 11.066 13.810 15.862 1.00 10.82 N +ATOM 1263 CA LEU A 159 10.619 12.958 14.769 1.00 9.35 C +ATOM 1264 C LEU A 159 9.098 12.856 14.754 1.00 10.10 C +ATOM 1265 O LEU A 159 8.471 12.939 13.691 1.00 14.53 O +ATOM 1266 CB LEU A 159 11.288 11.578 14.882 1.00 9.18 C +ATOM 1267 CG LEU A 159 11.116 10.617 13.708 1.00 12.35 C +ATOM 1268 CD1 LEU A 159 11.589 11.225 12.385 1.00 9.34 C +ATOM 1269 CD2 LEU A 159 11.881 9.320 14.035 1.00 9.63 C +ATOM 1270 N VAL A 160 8.480 12.723 15.932 1.00 11.11 N +ATOM 1271 CA VAL A 160 7.021 12.693 16.007 1.00 14.09 C +ATOM 1272 C VAL A 160 6.433 13.996 15.480 1.00 18.13 C +ATOM 1273 O VAL A 160 5.441 13.995 14.737 1.00 14.56 O +ATOM 1274 CB VAL A 160 6.559 12.411 17.449 1.00 13.46 C +ATOM 1275 CG1 VAL A 160 5.057 12.633 17.578 1.00 14.02 C +ATOM 1276 CG2 VAL A 160 6.917 10.989 17.853 1.00 13.15 C +ATOM 1277 N ARG A 161 7.036 15.128 15.853 1.00 15.52 N +ATOM 1278 CA ARG A 161 6.564 16.415 15.344 1.00 18.31 C +ATOM 1279 C ARG A 161 6.709 16.501 13.824 1.00 19.78 C +ATOM 1280 O ARG A 161 5.829 17.039 13.143 1.00 22.47 O +ATOM 1281 CB ARG A 161 7.315 17.554 16.037 1.00 18.06 C +ATOM 1282 CG ARG A 161 6.948 17.682 17.517 1.00 17.33 C +ATOM 1283 CD ARG A 161 7.699 18.807 18.245 1.00 20.04 C +ATOM 1284 NE ARG A 161 7.482 18.731 19.687 1.00 24.33 N +ATOM 1285 CZ ARG A 161 8.425 18.499 20.591 1.00 17.72 C +ATOM 1286 NH1 ARG A 161 9.692 18.362 20.249 1.00 20.49 N1+ +ATOM 1287 NH2 ARG A 161 8.085 18.407 21.873 1.00 22.31 N +ATOM 1288 N GLU A 162 7.810 15.972 13.275 1.00 19.00 N +ATOM 1289 CA GLU A 162 7.991 15.948 11.823 1.00 16.05 C +ATOM 1290 C GLU A 162 6.919 15.095 11.148 1.00 21.04 C +ATOM 1291 O GLU A 162 6.453 15.421 10.050 1.00 22.14 O +ATOM 1292 CB GLU A 162 9.384 15.413 11.477 1.00 18.48 C +ATOM 1293 CG GLU A 162 10.591 16.309 11.863 1.00 16.07 C +ATOM 1294 CD GLU A 162 10.711 17.579 11.036 1.00 28.60 C +ATOM 1295 OE1 GLU A 162 11.036 17.470 9.835 1.00 29.83 O +ATOM 1296 OE2 GLU A 162 10.502 18.686 11.579 1.00 26.45 O1- +ATOM 1297 N ILE A 163 6.531 13.990 11.783 1.00 17.67 N +ATOM 1298 CA ILE A 163 5.474 13.137 11.241 1.00 22.80 C +ATOM 1299 C ILE A 163 4.137 13.868 11.249 1.00 26.55 C +ATOM 1300 O ILE A 163 3.397 13.849 10.255 1.00 27.00 O +ATOM 1301 CB ILE A 163 5.396 11.813 12.022 1.00 20.54 C +ATOM 1302 CG1 ILE A 163 6.597 10.924 11.677 1.00 13.43 C +ATOM 1303 CG2 ILE A 163 4.053 11.103 11.766 1.00 18.66 C +ATOM 1304 CD1 ILE A 163 6.867 9.844 12.696 1.00 12.15 C +ATOM 1305 N ARG A 164 3.799 14.504 12.376 1.00 23.97 N +ATOM 1306 CA ARG A 164 2.570 15.292 12.455 1.00 31.12 C +ATOM 1307 C ARG A 164 2.507 16.337 11.345 1.00 31.71 C +ATOM 1308 O ARG A 164 1.481 16.477 10.671 1.00 39.49 O +ATOM 1309 CB ARG A 164 2.452 15.958 13.828 1.00 25.88 C +ATOM 1310 CG ARG A 164 1.961 15.030 14.927 1.00 29.69 C +ATOM 1311 CD ARG A 164 2.113 15.667 16.309 1.00 32.96 C +ATOM 1312 NE ARG A 164 1.644 14.791 17.377 1.00 36.23 N +ATOM 1313 CZ ARG A 164 1.907 14.966 18.666 1.00 36.99 C +ATOM 1314 NH1 ARG A 164 2.665 15.967 19.089 1.00 27.56 N1+ +ATOM 1315 NH2 ARG A 164 1.394 14.114 19.553 1.00 34.65 N +ATOM 1316 N LYS A 165 3.599 17.076 11.136 1.00 29.32 N +ATOM 1317 CA LYS A 165 3.648 18.025 10.026 1.00 28.77 C +ATOM 1318 C LYS A 165 3.367 17.333 8.693 1.00 39.28 C +ATOM 1319 O LYS A 165 2.501 17.767 7.922 1.00 37.35 O +ATOM 1320 CB LYS A 165 5.013 18.715 9.980 1.00 31.82 C +ATOM 1321 CG LYS A 165 5.196 19.826 10.997 1.00 37.13 C +ATOM 1322 CD LYS A 165 6.217 20.844 10.506 1.00 47.87 C +ATOM 1323 CE LYS A 165 7.523 20.185 10.099 1.00 38.42 C +ATOM 1324 NZ LYS A 165 8.532 20.245 11.195 1.00 39.05 N1+ +ATOM 1325 N HIS A 166 4.092 16.244 8.411 1.00 35.20 N +ATOM 1326 CA HIS A 166 4.020 15.608 7.099 1.00 33.17 C +ATOM 1327 C HIS A 166 2.612 15.113 6.789 1.00 33.01 C +ATOM 1328 O HIS A 166 2.208 15.084 5.621 1.00 38.14 O +ATOM 1329 CB HIS A 166 5.024 14.456 7.031 1.00 34.12 C +ATOM 1330 CG HIS A 166 5.099 13.788 5.692 1.00 34.94 C +ATOM 1331 CD2 HIS A 166 5.776 14.119 4.568 1.00 30.80 C +ATOM 1332 ND1 HIS A 166 4.454 12.601 5.414 1.00 34.91 N +ATOM 1333 CE1 HIS A 166 4.710 12.243 4.167 1.00 32.18 C +ATOM 1334 NE2 HIS A 166 5.511 13.147 3.632 1.00 39.83 N +ATOM 1335 N LYS A 167 1.850 14.735 7.809 1.00 35.95 N +ATOM 1336 CA LYS A 167 0.481 14.275 7.605 1.00 37.35 C +ATOM 1337 C LYS A 167 -0.418 15.401 7.100 1.00 46.52 C +ATOM 1338 O LYS A 167 -1.493 15.150 6.547 1.00 51.89 O +ATOM 1339 CB LYS A 167 -0.092 13.698 8.898 1.00 35.94 C +ATOM 1340 CG LYS A 167 0.478 12.352 9.317 1.00 33.03 C +ATOM 1341 CD LYS A 167 -0.153 11.935 10.636 1.00 40.63 C +ATOM 1342 CE LYS A 167 0.246 10.536 11.051 1.00 34.85 C +ATOM 1343 NZ LYS A 167 -0.348 10.202 12.384 1.00 24.17 N1+ +TER +END diff --git a/examples/kras_ref_ligand.sdf b/examples/kras_ref_ligand.sdf new file mode 100644 index 0000000000000000000000000000000000000000..7cdf6de30228481540e08fe84bb159c34f0cdff8 --- /dev/null +++ b/examples/kras_ref_ligand.sdf @@ -0,0 +1,74 @@ +8AZR + PyMOL2.5 3D 0 + + 32 36 0 0 0 0 0 0 0 0999 V2000 + 15.7084 1.6569 4.9428 C 0 0 0 0 0 0 0 0 0 0 0 0 + 16.2939 1.9182 6.3219 C 0 0 0 0 0 0 0 0 0 0 0 0 + 17.7757 1.5677 6.3468 C 0 0 0 0 0 0 0 0 0 0 0 0 + 18.0388 0.0580 6.1328 C 0 0 0 0 0 0 0 0 0 0 0 0 + 16.1458 0.3026 4.4709 C 0 0 0 0 0 0 0 0 0 0 0 0 + 17.1748 -0.4207 4.9854 C 0 0 0 0 0 0 0 0 0 0 0 0 + 17.2894 -1.6945 4.3617 C 0 0 0 0 0 0 0 0 0 0 0 0 + 16.3332 -1.9132 3.3763 C 0 0 0 0 0 0 0 0 0 0 0 0 + 15.2948 -0.5437 3.2188 S 0 0 0 0 0 0 0 0 0 0 0 0 + 17.6856 -0.7371 7.4005 C 0 0 0 0 0 0 0 0 0 0 0 0 + 19.5008 -0.1084 5.7694 C 0 0 0 0 0 0 0 0 0 0 0 0 + 19.9420 0.4778 4.6523 O 0 0 0 0 0 0 0 0 0 0 0 0 + 21.3366 0.1893 4.6052 N 0 0 0 0 0 0 0 0 0 0 0 0 + 21.5306 -0.5212 5.6843 C 0 0 0 0 0 0 0 0 0 0 0 0 + 20.3929 -0.7319 6.4483 N 0 0 0 0 0 0 0 0 0 0 0 0 + 16.1651 -3.0052 2.6033 N 0 0 0 0 0 0 0 0 0 0 0 0 + 22.8349 -1.0932 6.0768 C 0 0 0 0 0 0 0 0 0 0 0 0 + 23.9207 -0.6312 5.4365 N 0 0 0 0 0 0 0 0 0 0 0 0 + 25.1129 -1.1528 5.7755 C 0 0 0 0 0 0 0 0 0 0 0 0 + 25.2639 -2.1387 6.7500 C 0 0 0 0 0 0 0 0 0 0 0 0 + 24.1280 -2.5941 7.3940 C 0 0 0 0 0 0 0 0 0 0 0 0 + 22.8945 -2.0709 7.0591 C 0 0 0 0 0 0 0 0 0 0 0 0 + 18.2816 -2.6789 4.6625 C 0 0 0 0 0 0 0 0 0 0 0 0 + 19.0589 -3.4973 4.8688 N 0 0 0 0 0 0 0 0 0 0 0 0 + 26.1982 -0.6750 5.0820 N 0 0 0 0 0 0 0 0 0 0 0 0 + 26.0358 0.4071 4.0954 C 0 0 0 0 0 0 0 0 0 0 0 0 + 26.8978 0.1468 2.8491 C 0 0 0 0 0 0 0 0 0 0 0 0 + 28.2989 -0.1678 3.2648 N 0 0 0 0 0 0 0 0 0 0 0 0 + 28.3171 -1.4142 4.0851 C 0 0 0 0 0 0 0 0 0 0 0 0 + 27.5312 -1.2091 5.3777 C 0 0 0 0 0 0 0 0 0 0 0 0 + 29.1988 -0.2741 2.0804 C 0 0 0 0 0 0 0 0 0 0 0 0 + 26.3415 1.7618 4.7132 C 0 0 0 0 0 0 0 0 0 0 0 0 + 1 2 1 0 0 0 0 + 2 3 1 0 0 0 0 + 3 4 1 0 0 0 0 + 4 6 1 0 0 0 0 + 4 10 1 0 0 0 0 + 4 11 1 0 0 0 0 + 1 5 1 0 0 0 0 + 5 6 4 0 0 0 0 + 5 9 4 0 0 0 0 + 6 7 4 0 0 0 0 + 7 8 4 0 0 0 0 + 7 23 1 0 0 0 0 + 8 9 4 0 0 0 0 + 8 16 1 0 0 0 0 + 11 12 4 0 0 0 0 + 11 15 4 0 0 0 0 + 12 13 4 0 0 0 0 + 13 14 4 0 0 0 0 + 14 15 4 0 0 0 0 + 14 17 1 0 0 0 0 + 17 18 4 0 0 0 0 + 17 22 4 0 0 0 0 + 18 19 4 0 0 0 0 + 19 25 1 0 0 0 0 + 19 20 4 0 0 0 0 + 20 21 4 0 0 0 0 + 21 22 4 0 0 0 0 + 23 24 3 0 0 0 0 + 25 26 1 0 0 0 0 + 26 27 1 0 0 0 0 + 26 32 1 0 0 0 0 + 27 28 1 0 0 0 0 + 28 29 1 0 0 0 0 + 29 30 1 0 0 0 0 + 25 30 1 0 0 0 0 + 28 31 1 0 0 0 0 +M END +$$$$ diff --git a/scripts/python/evaluate_baselines.py b/scripts/python/evaluate_baselines.py new file mode 100644 index 0000000000000000000000000000000000000000..41029139d2952b109073b0ac7fb27abb44388052 --- /dev/null +++ b/scripts/python/evaluate_baselines.py @@ -0,0 +1,53 @@ +import argparse +import pickle +import sys +from pathlib import Path + +basedir = Path(__file__).resolve().parent.parent.parent +sys.path.append(str(basedir)) + +from src.sbdd_metrics.evaluation import compute_all_metrics_drugflow + +if __name__ == '__main__': + p = argparse.ArgumentParser() + p.add_argument('--in_dir', type=Path, required=True, help='Directory with samples') + p.add_argument('--out_dir', type=str, required=True, help='Output directory') + p.add_argument('--reference_smiles', type=str, default=None, help='Path to the .npy file with reference SMILES (optional)') + p.add_argument('--gnina', type=str, default=None, help='Path to the gnina binary file (optional)') + p.add_argument('--reduce', type=str, default=None, help='Path to the reduce binary file (optional)') + p.add_argument('--n_samples', type=int, default=None, help='Top-N sampels to evaluate (optional)') + p.add_argument('--exclude', type=str, nargs='+', default=[], help='Evaluator IDs to exclude') + p.add_argument('--job_id', type=int, default=0, help='Job ID') + p.add_argument('--n_jobs', type=int, default=1, help='Number of jobs') + args = p.parse_args() + + Path(args.out_dir).mkdir(exist_ok=True, parents=True) + if args.job_id == 0 and args.n_jobs == 1: + out_detailed_table = Path(args.out_dir, 'metrics_detailed.csv') + out_aggregated_table = Path(args.out_dir, 'metrics_aggregated.csv') + out_distributions_file = Path(args.out_dir, 'metrics_data.pkl') + else: + out_detailed_table = Path(args.out_dir, f'metrics_detailed_{args.job_id}.csv') + out_aggregated_table = Path(args.out_dir, f'metrics_aggregated_{args.job_id}.csv') + out_distributions_file = Path(args.out_dir, f'metrics_data_{args.job_id}.pkl') + + if out_detailed_table.exists() and out_aggregated_table.exists() and out_distributions_file.exists(): + print(f'Data already exist. Terminating') + sys.exit(0) + + print(f'Evaluating: {args.in_dir}') + data, detailed, aggregated = compute_all_metrics_drugflow( + in_dir=args.in_dir, + gnina_path=args.gnina, + reduce_path=args.reduce, + reference_smiles_path=args.reference_smiles, + n_samples=args.n_samples, + exclude_evaluators=args.exclude, + job_id=args.job_id, + n_jobs=args.n_jobs, + ) + + detailed.to_csv(out_detailed_table, index=False) + aggregated.to_csv(out_aggregated_table, index=False) + with open(Path(out_distributions_file), 'wb') as f: + pickle.dump(data, f) \ No newline at end of file diff --git a/scripts/python/postprocess_metrics.py b/scripts/python/postprocess_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..9415bd6f0dbc1d98f51120c8d0073d01570bbf28 --- /dev/null +++ b/scripts/python/postprocess_metrics.py @@ -0,0 +1,271 @@ +import argparse +import os +import pickle +import sys +from collections import Counter, defaultdict +from pathlib import Path + +import numpy as np +import pandas as pd +from rdkit import Chem +from scipy.stats import wasserstein_distance +from scipy.spatial.distance import jensenshannon +from tqdm import tqdm + +basedir = Path(__file__).resolve().parent.parent.parent +sys.path.append(str(basedir)) + +from src.data.data_utils import atom_encoder, bond_encoder, encode_atom +from src.sbdd_metrics.evaluation import VALIDITY_METRIC_NAME, aggregated_metrics, collection_metrics, get_data_type +from src.sbdd_metrics.metrics import FullEvaluator + + +DATA_TYPES = data_types = FullEvaluator().dtypes + +MEDCHEM_PROPS = [ + 'medchem.qed', + 'medchem.sa', + 'medchem.logp', + 'medchem.lipinski', + 'medchem.size', + 'medchem.n_rotatable_bonds', + 'energy.energy', +] + +DOCKING_PROPS = [ + 'gnina.vina_score', + 'gnina.gnina_score', + 'gnina.vina_efficiency', + 'gnina.gnina_efficiency', +] + +RELEVANT_INTERACTIONS = [ + 'interactions.HBAcceptor', + 'interactions.HBDonor', + 'interactions.HB', + 'interactions.PiStacking', + 'interactions.Hydrophobic', + # + 'interactions.HBAcceptor.normalized', + 'interactions.HBDonor.normalized', + 'interactions.HB.normalized', + 'interactions.PiStacking.normalized', + 'interactions.Hydrophobic.normalized' +] + + +def compute_discrete_distributions(smiles, name): + atom_counter = Counter() + bond_counter = Counter() + + for smi in tqdm(smiles, desc=name): + mol = Chem.MolFromSmiles(smi) + mol = Chem.RemoveAllHs(mol, sanitize=False) + for atom in mol.GetAtoms(): + try: + encoded_atom = encode_atom(atom, atom_encoder=atom_encoder) + except KeyError: + continue + atom_counter[encoded_atom] += 1 + for bond in mol.GetBonds(): + bond_counter[bond_encoder[str(bond.GetBondType())]] += 1 + + atom_distribution = np.zeros(len(atom_encoder)) + bond_distribution = np.zeros(len(bond_encoder)) + + for k, v in atom_counter.items(): + atom_distribution[k] = v + for k, v in bond_counter.items(): + bond_distribution[k] = v + + atom_distribution = atom_distribution / atom_distribution.sum() + bond_distribution = bond_distribution / bond_distribution.sum() + + return atom_distribution, bond_distribution + + +def flatten_distribution(data, name, table): + aux = ['sample', 'sdf_file', 'pdb_file'] + method_distributions = defaultdict(list) + + sdf2sample2size = defaultdict(dict) + for _, row in table.iterrows(): + sdf2sample2size[row['sdf_file']][int(row['sample'])] = row['medchem.size'] + + for item in tqdm(data, desc=name): + if item['medchem.valid'] is not True: + continue + + if 'interactions.HBAcceptor' in item and 'interactions.HBDonor' in item: + item['interactions.HB'] = item['interactions.HBAcceptor'] + item['interactions.HBDonor'] + + new_entries = {} + for key, value in item.items(): + if key.startswith('interactions'): + size = sdf2sample2size.get(item['sdf_file'], dict()).get(int(item['sample'])) + if size is not None: + new_entries[key + '.normalized'] = value / size + item.update(new_entries) + + for key, value in item.items(): + if value is None: + continue + if key in aux: + continue + if key == 'energy.energy' and abs(value) > 1000: + continue + + if get_data_type(key, DATA_TYPES, default=type(value)) == list: + method_distributions[key] += value + else: + method_distributions[key].append(value) + + return method_distributions + + +def prepare_baseline_data(root_path, baseline_name): + metrics_detailed = pd.read_csv(f'{root_path}/metrics_detailed.csv') + metrics_detailed = metrics_detailed[metrics_detailed['medchem.valid']] + distributions = pickle.load(open(f'{root_path}/metrics_data.pkl', 'rb')) + distributions = flatten_distribution(distributions, name=baseline_name, table=metrics_detailed) + distributions['energy.energy'] = [v for v in distributions['energy.energy'] if -1000 <= v <= 1000] + for prop in MEDCHEM_PROPS + DOCKING_PROPS: + distributions[prop] = metrics_detailed[prop].dropna().values.tolist() + + smiles = metrics_detailed['representation.smiles'] + atom_distribution, bond_distribution = compute_discrete_distributions(smiles, name=baseline_name) + discrete_distributions = { + 'atom_types': atom_distribution, + 'bond_types': bond_distribution, + } + + return distributions, discrete_distributions + + +if __name__ == '__main__': + p = argparse.ArgumentParser() + p.add_argument('--in_dir', type=Path, required=True, help='Directory with samples') + p.add_argument('--out_dir', type=str, required=True, help='Output directory') + p.add_argument('--n_samples', type=int, required=False, default=None, help='N samples per target') + p.add_argument('--reference_smiles', type=str, default=None, help='Path to the .npy file with reference SMILES (optional)') + p.add_argument('--crossdocked_dir', type=str, required=False, default=None, help='Crossdocked data dir for computing distances between distributions') + args = p.parse_args() + + Path(args.out_dir).mkdir(parents=True, exist_ok=True) + + print('Combining data') + data = [] + for file_path in tqdm(Path(args.in_dir).glob('metrics_data_*.pkl')): + with open(file_path, 'rb') as f: + d = pickle.load(f) + if args.n_samples is not None: + d = d[:args.n_samples] + data += d + with open(Path(args.out_dir, 'metrics_data.pkl'), 'wb') as f: + pickle.dump(data, f) + + print('Combining detailed metrics') + tables = [] + for file_path in tqdm(Path(args.in_dir).glob('metrics_detailed_*.csv')): + table = pd.read_csv(file_path) + if args.n_samples is not None: + table = table.head(args.n_samples) + tables.append(table) + + table_detailed = pd.concat(tables) + table_detailed.to_csv(Path(args.out_dir, 'metrics_detailed.csv'), index=False) + + print('Computing aggregated metrics') + evaluator = FullEvaluator(gnina='gnina', reduce='reduce') + table_aggregated = aggregated_metrics( + table_detailed, + data_types=evaluator.dtypes, + validity_metric_name=VALIDITY_METRIC_NAME + ) + + if args.reference_smiles is not None: + reference_smiles = np.load(args.reference_smiles) + col_metrics = collection_metrics( + table=table_detailed, + reference_smiles=reference_smiles, + validity_metric_name=VALIDITY_METRIC_NAME, + exclude_evaluators=[], + ) + table_aggregated = pd.concat([table_aggregated, col_metrics]) + + table_aggregated.to_csv(Path(args.out_dir, 'metrics_aggregated.csv'), index=False) + + # Computing distributions + if args.crossdocked_dir is not None: + + # Loading training data distributions + crossdocked_distributions = None + crossdocked_discrete_distributions = None + precomputed_distr_path = f'{args.crossdocked_dir}/crossdocked_distributions.pkl' + precomputed_discrete_distr_path = f'{args.crossdocked_dir}/crossdocked_discrete_distributions.pkl' + if os.path.exists(precomputed_distr_path) and os.path.exists(precomputed_discrete_distr_path): + # Use precomputed distributions in case they exist + with open(precomputed_distr_path, 'rb') as f: + crossdocked_distributions = pickle.load(f) + with open(precomputed_discrete_distr_path, 'rb') as f: + crossdocked_discrete_distributions = pickle.load(f) + else: + assert os.path.exists(f'{args.crossdocked_dir}/metrics_detailed.csv') + assert os.path.exists(f'{args.crossdocked_dir}/metrics_data.pkl') + crossdocked_distributions, crossdocked_discrete_distributions = prepare_baseline_data( + root_path=args.crossdocked_dir, + baseline_name='crossdocked' + ) + # Save precomputed distributions for faster next runs + with open(precomputed_distr_path, 'wb') as f: + pickle.dump(crossdocked_distributions, f) + with open(precomputed_discrete_distr_path, 'wb') as f: + pickle.dump(crossdocked_discrete_distributions, f) + + # Selecting top-5 most frequent atom types, bond types, angles and torsions + bonds = sorted([ + (k, len(v)) for k, v in crossdocked_distributions.items() + if k.startswith('geometry.') and sum(s.isalpha() for s in k.split('.')[1]) == 2 + ], key=lambda t: t[1], reverse=True)[:5] + top_5_bonds = [t[0] for t in bonds] + + angles = sorted([ + (k, len(v)) for k, v in crossdocked_distributions.items() + if k.startswith('geometry.') and sum(s.isalpha() for s in k.split('.')[1]) == 3 + ], key=lambda t: t[1], reverse=True)[:5] + top_5_angles = [t[0] for t in angles] + + # Loading distributions of samples + distributions, discrete_distributions = prepare_baseline_data(args.out_dir, 'samples') + + # Computing distances between distributions + distances = {'method': 'method',} + relevant_columns = MEDCHEM_PROPS + DOCKING_PROPS + RELEVANT_INTERACTIONS + top_5_bonds + top_5_angles + for metric in distributions.keys(): + if metric not in relevant_columns: + continue + + ref = crossdocked_distributions.get(metric) + # cur = distributions.get(metric) + cur = [x for x in distributions.get(metric) if not pd.isna(x)] + + if ref is not None and cur is not None and len(cur) > 0: + try: + distance = wasserstein_distance(ref, cur) + except: + from pdb import set_trace; set_trace() + num_ref = len(ref) + num_cur = len(cur) + distances[f'WD.{metric}'] = distance + + for metric in crossdocked_discrete_distributions.keys(): + ref = crossdocked_discrete_distributions.get(metric) + cur = discrete_distributions.get(metric) + if ref is not None and cur is not None: + distance = jensenshannon(p=ref, q=cur) + num_ref = len(ref) + num_cur = len(cur) + distances[f'JS.{metric}'] = distance + + dist_table = pd.DataFrame([distances]) + dist_table.to_csv(Path(args.out_dir, 'metrics_distances.csv'), index=False) \ No newline at end of file diff --git a/src/analysis/SA_Score/README.md b/src/analysis/SA_Score/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cbb4194781f62ea0b31895618e8e3f2d87f3505c --- /dev/null +++ b/src/analysis/SA_Score/README.md @@ -0,0 +1 @@ +Files taken from: https://github.com/rdkit/rdkit/tree/master/Contrib/SA_Score \ No newline at end of file diff --git a/src/analysis/SA_Score/fpscores.pkl.gz b/src/analysis/SA_Score/fpscores.pkl.gz new file mode 100644 index 0000000000000000000000000000000000000000..aa6f88c9c3fa56161b7df08e74ea6824f3071d08 --- /dev/null +++ b/src/analysis/SA_Score/fpscores.pkl.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:10dcef9340c873e7b987924461b0af5365eb8dd96be607203debe8ddf80c1e73 +size 3848394 diff --git a/src/analysis/SA_Score/sascorer.py b/src/analysis/SA_Score/sascorer.py new file mode 100644 index 0000000000000000000000000000000000000000..862d191032bb0b366260f9d6e306fb0ddf98ccf3 --- /dev/null +++ b/src/analysis/SA_Score/sascorer.py @@ -0,0 +1,173 @@ +# +# calculation of synthetic accessibility score as described in: +# +# Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions +# Peter Ertl and Ansgar Schuffenhauer +# Journal of Cheminformatics 1:8 (2009) +# http://www.jcheminf.com/content/1/1/8 +# +# several small modifications to the original paper are included +# particularly slightly different formula for marocyclic penalty +# and taking into account also molecule symmetry (fingerprint density) +# +# for a set of 10k diverse molecules the agreement between the original method +# as implemented in PipelinePilot and this implementation is r2 = 0.97 +# +# peter ertl & greg landrum, september 2013 +# + + +from rdkit import Chem +from rdkit.Chem import rdMolDescriptors +import pickle + +import math +from collections import defaultdict + +import os.path as op + +_fscores = None + + +def readFragmentScores(name='fpscores'): + import gzip + global _fscores + # generate the full path filename: + if name == "fpscores": + name = op.join(op.dirname(__file__), name) + data = pickle.load(gzip.open('%s.pkl.gz' % name)) + outDict = {} + for i in data: + for j in range(1, len(i)): + outDict[i[j]] = float(i[0]) + _fscores = outDict + + +def numBridgeheadsAndSpiro(mol, ri=None): + nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) + nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) + return nBridgehead, nSpiro + + +def calculateScore(m): + if _fscores is None: + readFragmentScores() + + # fragment score + fp = rdMolDescriptors.GetMorganFingerprint(m, + 2) # <- 2 is the *radius* of the circular fingerprint + fps = fp.GetNonzeroElements() + score1 = 0. + nf = 0 + for bitId, v in fps.items(): + nf += v + sfp = bitId + score1 += _fscores.get(sfp, -4) * v + score1 /= nf + + # features score + nAtoms = m.GetNumAtoms() + nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) + ri = m.GetRingInfo() + nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) + nMacrocycles = 0 + for x in ri.AtomRings(): + if len(x) > 8: + nMacrocycles += 1 + + sizePenalty = nAtoms**1.005 - nAtoms + stereoPenalty = math.log10(nChiralCenters + 1) + spiroPenalty = math.log10(nSpiro + 1) + bridgePenalty = math.log10(nBridgeheads + 1) + macrocyclePenalty = 0. + # --------------------------------------- + # This differs from the paper, which defines: + # macrocyclePenalty = math.log10(nMacrocycles+1) + # This form generates better results when 2 or more macrocycles are present + if nMacrocycles > 0: + macrocyclePenalty = math.log10(2) + + score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty + + # correction for the fingerprint density + # not in the original publication, added in version 1.1 + # to make highly symmetrical molecules easier to synthetise + score3 = 0. + if nAtoms > len(fps): + score3 = math.log(float(nAtoms) / len(fps)) * .5 + + sascore = score1 + score2 + score3 + + # need to transform "raw" value into scale between 1 and 10 + min = -4.0 + max = 2.5 + sascore = 11. - (sascore - min + 1) / (max - min) * 9. + # smooth the 10-end + if sascore > 8.: + sascore = 8. + math.log(sascore + 1. - 9.) + if sascore > 10.: + sascore = 10.0 + elif sascore < 1.: + sascore = 1.0 + + return sascore + + +def processMols(mols): + print('smiles\tName\tsa_score') + for i, m in enumerate(mols): + if m is None: + continue + + s = calculateScore(m) + + smiles = Chem.MolToSmiles(m) + print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) + + +if __name__ == '__main__': + import sys + import time + + t1 = time.time() + readFragmentScores("fpscores") + t2 = time.time() + + suppl = Chem.SmilesMolSupplier(sys.argv[1]) + t3 = time.time() + processMols(suppl) + t4 = time.time() + + print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)), + file=sys.stderr) + +# +# Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. +# * Neither the name of Novartis Institutes for BioMedical Research Inc. +# nor the names of its contributors may be used to endorse or promote +# products derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# \ No newline at end of file diff --git a/src/analysis/metrics.py b/src/analysis/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..b42014e06615514380c3e93ec1c6be15d0a6aa9a --- /dev/null +++ b/src/analysis/metrics.py @@ -0,0 +1,544 @@ +import subprocess + +import numpy as np +import tempfile +from pathlib import Path +from tqdm import tqdm +from rdkit import Chem, DataStructs +from rdkit.Chem import AllChem +from rdkit.Chem import Descriptors, Crippen, Lipinski, QED +from rdkit.Chem import AtomKekulizeException, AtomValenceException, \ + KekulizeException, MolSanitizeException +from src.analysis.SA_Score.sascorer import calculateScore +from src.utils import write_sdf_file + +from copy import deepcopy + +from pdb import set_trace + + +class CategoricalDistribution: + EPS = 1e-10 + + def __init__(self, histogram_dict, mapping): + histogram = np.zeros(len(mapping)) + for k, v in histogram_dict.items(): + histogram[mapping[k]] = v + + # Normalize histogram + self.p = histogram / histogram.sum() + self.mapping = deepcopy(mapping) + + def kl_divergence(self, other_sample): + sample_histogram = np.zeros(len(self.mapping)) + for x in other_sample: + # sample_histogram[self.mapping[x]] += 1 + sample_histogram[x] += 1 + + # Normalize + q = sample_histogram / sample_histogram.sum() + + return -np.sum(self.p * np.log(q / (self.p + self.EPS) + self.EPS)) + + +def check_mol(rdmol): + """ + See also: https://www.rdkit.org/docs/RDKit_Book.html#molecular-sanitization + """ + if rdmol is None: + return 'is_none' + + _rdmol = Chem.Mol(rdmol) + try: + Chem.SanitizeMol(_rdmol) + return 'valid' + except ValueError as e: + assert isinstance(e, MolSanitizeException) + return type(e).__name__ + + +def validity_analysis(rdmol_list): + """ + For explanations, see: https://www.rdkit.org/docs/RDKit_Book.html#molecular-sanitization + """ + + result = { + 'AtomValenceException': 0, # atoms in higher-than-allowed valence states + 'AtomKekulizeException': 0, + 'KekulizeException': 0, # ring cannot be kekulized or aromatic bonds found outside of rings + 'other': 0, + 'valid': 0 + } + + for rdmol in rdmol_list: + flag = check_mol(rdmol) + + try: + result[flag] += 1 + except KeyError: + result['other'] += 1 + + assert sum(result.values()) == len(rdmol_list) + + return result + + +class MoleculeValidity: + def __init__(self, connectivity_thresh=1.0): + self.connectivity_thresh = connectivity_thresh + + def compute_validity(self, generated): + """ generated: list of RDKit molecules. """ + if len(generated) < 1: + return [], 0.0 + + # Return copies of the valid molecules + valid = [Chem.Mol(mol) for mol in generated if check_mol(mol) == 'valid'] + return valid, len(valid) / len(generated) + + def compute_connectivity(self, valid): + """ + Consider molecule connected if its largest fragment contains at + least % of all atoms. + :param valid: list of valid RDKit molecules + """ + if len(valid) < 1: + return [], 0.0 + + for mol in valid: + Chem.SanitizeMol(mol) # all molecules should be valid + + connected = [] + for mol in valid: + + if mol.GetNumAtoms() < 1: + continue + + try: + mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True) + except MolSanitizeException as e: + print('Error while computing connectivity:', e) + continue + + largest_frag = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) + if largest_frag.GetNumAtoms() / mol.GetNumAtoms() >= self.connectivity_thresh: + connected.append(largest_frag) + + return connected, len(connected) / len(valid) + + def __call__(self, rdmols, verbose=False): + """ + :param rdmols: list of RDKit molecules + """ + + results = {} + results['n_total'] = len(rdmols) + + valid, validity = self.compute_validity(rdmols) + results['n_valid'] = len(valid) + results['validity'] = validity + + connected, connectivity = self.compute_connectivity(valid) + results['n_connected'] = len(connected) + results['connectivity'] = connectivity + results['valid_and_connected'] = results['n_connected'] / results['n_total'] + + if verbose: + print(f"Validity over {results['n_total']} molecules: {validity * 100 :.2f}%") + print(f"Connectivity over {results['n_valid']} valid molecules: {connectivity * 100 :.2f}%") + + return results + + +class MolecularMetrics: + def __init__(self, connectivity_thresh=1.0): + self.connectivity_thresh = connectivity_thresh + + @staticmethod + def is_valid(rdmol): + if rdmol.GetNumAtoms() < 1: + return False + + _mol = Chem.Mol(rdmol) + try: + Chem.SanitizeMol(_mol) + except ValueError: + return False + + return True + + def is_connected(self, rdmol): + + if rdmol.GetNumAtoms() < 1: + return False + + mol_frags = Chem.rdmolops.GetMolFrags(rdmol, asMols=True) + + largest_frag = max(mol_frags, default=rdmol, key=lambda m: m.GetNumAtoms()) + if largest_frag.GetNumAtoms() / rdmol.GetNumAtoms() >= self.connectivity_thresh: + return True + else: + return False + + @staticmethod + def calculate_qed(rdmol): + return QED.qed(rdmol) + + @staticmethod + def calculate_sa(rdmol): + sa = calculateScore(rdmol) + return sa + + @staticmethod + def calculate_logp(rdmol): + return Crippen.MolLogP(rdmol) + + @staticmethod + def calculate_lipinski(rdmol): + rule_1 = Descriptors.ExactMolWt(rdmol) < 500 + rule_2 = Lipinski.NumHDonors(rdmol) <= 5 + rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10 + rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5) + rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10 + return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]]) + + def __call__(self, rdmol): + valid = self.is_valid(rdmol) + + if valid: + Chem.SanitizeMol(rdmol) + + connected = None if not valid else self.is_connected(rdmol) + qed = None if not valid else self.calculate_qed(rdmol) + sa = None if not valid else self.calculate_sa(rdmol) + logp = None if not valid else self.calculate_logp(rdmol) + lipinski = None if not valid else self.calculate_lipinski(rdmol) + + return { + 'valid': valid, + 'connected': connected, + 'qed': qed, + 'sa': sa, + 'logp': logp, + 'lipinski': lipinski + } + + +class Diversity: + @staticmethod + def similarity(fp1, fp2): + return DataStructs.TanimotoSimilarity(fp1, fp2) + + def get_fingerprint(self, mol): + # fp = AllChem.GetMorganFingerprintAsBitVect( + # mol, 2, nBits=2048, useChirality=False) + fp = Chem.RDKFingerprint(mol) + return fp + + def __call__(self, pocket_mols): + + if len(pocket_mols) < 2: + return 0.0 + + pocket_fps = [self.get_fingerprint(m) for m in pocket_mols] + + div = 0 + total = 0 + for i in range(len(pocket_fps)): + for j in range(i + 1, len(pocket_fps)): + div += 1 - self.similarity(pocket_fps[i], pocket_fps[j]) + total += 1 + + return div / total + + +class MoleculeUniqueness: + def __call__(self, smiles_list): + """ smiles_list: list of SMILES strings. """ + if len(smiles_list) < 1: + return 0.0 + + return len(set(smiles_list)) / len(smiles_list) + + +class MoleculeNovelty: + def __init__(self, reference_smiles): + """ + :param reference_smiles: list of SMILES strings + """ + self.reference_smiles = set(reference_smiles) + + def __call__(self, smiles_list): + if len(smiles_list) < 1: + return 0.0 + + novel = [smi for smi in smiles_list if smi not in self.reference_smiles] + return len(novel) / len(smiles_list) + + +class MolecularProperties: + + @staticmethod + def calculate_qed(rdmol): + return QED.qed(rdmol) + + @staticmethod + def calculate_sa(rdmol): + sa = calculateScore(rdmol) + # return round((10 - sa) / 9, 2) # from pocket2mol + return sa + + @staticmethod + def calculate_logp(rdmol): + return Crippen.MolLogP(rdmol) + + @staticmethod + def calculate_lipinski(rdmol): + rule_1 = Descriptors.ExactMolWt(rdmol) < 500 + rule_2 = Lipinski.NumHDonors(rdmol) <= 5 + rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10 + rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5) + rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10 + return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]]) + + @classmethod + def calculate_diversity(cls, pocket_mols): + if len(pocket_mols) < 2: + return 0.0 + + div = 0 + total = 0 + for i in range(len(pocket_mols)): + for j in range(i + 1, len(pocket_mols)): + div += 1 - cls.similarity(pocket_mols[i], pocket_mols[j]) + total += 1 + return div / total + + @staticmethod + def similarity(mol_a, mol_b): + # fp1 = AllChem.GetMorganFingerprintAsBitVect( + # mol_a, 2, nBits=2048, useChirality=False) + # fp2 = AllChem.GetMorganFingerprintAsBitVect( + # mol_b, 2, nBits=2048, useChirality=False) + fp1 = Chem.RDKFingerprint(mol_a) + fp2 = Chem.RDKFingerprint(mol_b) + return DataStructs.TanimotoSimilarity(fp1, fp2) + + def evaluate_pockets(self, pocket_rdmols, verbose=False): + """ + Run full evaluation + Args: + pocket_rdmols: list of lists, the inner list contains all RDKit + molecules generated for a pocket + Returns: + QED, SA, LogP, Lipinski (per molecule), and Diversity (per pocket) + """ + + for pocket in pocket_rdmols: + for mol in pocket: + Chem.SanitizeMol(mol) # only evaluate valid molecules + + all_qed = [] + all_sa = [] + all_logp = [] + all_lipinski = [] + per_pocket_diversity = [] + for pocket in tqdm(pocket_rdmols): + all_qed.append([self.calculate_qed(mol) for mol in pocket]) + all_sa.append([self.calculate_sa(mol) for mol in pocket]) + all_logp.append([self.calculate_logp(mol) for mol in pocket]) + all_lipinski.append([self.calculate_lipinski(mol) for mol in pocket]) + per_pocket_diversity.append(self.calculate_diversity(pocket)) + + qed_flattened = [x for px in all_qed for x in px] + sa_flattened = [x for px in all_sa for x in px] + logp_flattened = [x for px in all_logp for x in px] + lipinski_flattened = [x for px in all_lipinski for x in px] + + if verbose: + print(f"{sum([len(p) for p in pocket_rdmols])} molecules from " + f"{len(pocket_rdmols)} pockets evaluated.") + print(f"QED: {np.mean(qed_flattened):.3f} \pm {np.std(qed_flattened):.2f}") + print(f"SA: {np.mean(sa_flattened):.3f} \pm {np.std(sa_flattened):.2f}") + print(f"LogP: {np.mean(logp_flattened):.3f} \pm {np.std(logp_flattened):.2f}") + print(f"Lipinski: {np.mean(lipinski_flattened):.3f} \pm {np.std(lipinski_flattened):.2f}") + print(f"Diversity: {np.mean(per_pocket_diversity):.3f} \pm {np.std(per_pocket_diversity):.2f}") + + return all_qed, all_sa, all_logp, all_lipinski, per_pocket_diversity + + def __call__(self, rdmols): + """ + Run full evaluation and return mean of each property + Args: + rdmols: list of RDKit molecules + Returns: + Dictionary with mean QED, SA, LogP, Lipinski, and Diversity values + """ + + if len(rdmols) < 1: + return {'QED': 0.0, 'SA': 0.0, 'LogP': 0.0, 'Lipinski': 0.0, + 'Diversity': 0.0} + + _rdmols = [] + for mol in rdmols: + try: + Chem.SanitizeMol(mol) # only evaluate valid molecules + _rdmols.append(mol) + except ValueError as e: + print("Tried to analyze invalid molecule") + rdmols = _rdmols + + qed = np.mean([self.calculate_qed(mol) for mol in rdmols]) + sa = np.mean([self.calculate_sa(mol) for mol in rdmols]) + logp = np.mean([self.calculate_logp(mol) for mol in rdmols]) + lipinski = np.mean([self.calculate_lipinski(mol) for mol in rdmols]) + diversity = self.calculate_diversity(rdmols) + + return {'QED': qed, 'SA': sa, 'LogP': logp, 'Lipinski': lipinski, + 'Diversity': diversity} + + +def compute_gnina_scores(ligands, receptors, gnina): + metrics = ['minimizedAffinity', 'minimizedRMSD', 'CNNscore', 'CNNaffinity', 'CNN_VS', 'CNNaffinity_variance'] + out = {m: [] for m in metrics} + with tempfile.TemporaryDirectory() as tmpdir: + for ligand, receptor in zip(tqdm(ligands, desc='Docking'), receptors): + in_ligand_path = Path(tmpdir, 'in_ligand.sdf') + out_ligand_path = Path(tmpdir, 'out_ligand.sdf') + receptor_path = Path(tmpdir, 'receptor.pdb') + write_sdf_file(in_ligand_path, [ligand], catch_errors=True) + Chem.MolToPDBFile(receptor, str(receptor_path)) + if ( + (not in_ligand_path.exists()) or + (not receptor_path.exists()) or + in_ligand_path.read_text() == '' or + receptor_path.read_text() == '' + ): + continue + + cmd = ( + f'{gnina} -r {receptor_path} -l {in_ligand_path} ' + f'--minimize --seed 42 -o {out_ligand_path} --no_gpu 1> /dev/null' + ) + subprocess.run(cmd, shell=True) + if not out_ligand_path.exists() or out_ligand_path.read_text() == '': + continue + + mol = Chem.SDMolSupplier(str(out_ligand_path), sanitize=False)[0] + for metric in metrics: + out[metric].append(float(mol.GetProp(metric))) + + for metric in metrics: + out[metric] = sum(out[metric]) / len(out[metric]) if len(out[metric]) > 0 else 0 + + return out + + +def legacy_clash_score(rdmol1, rdmol2=None, margin=0.75): + """ + Computes a clash score as the number of atoms that have at least one + clash divided by the number of atoms in the molecule. + + INTERMOLECULAR CLASH SCORE + If rdmol2 is provided, the score is the percentage of atoms in rdmol1 + that have at least one clash with rdmol2. + We define a clash if two atoms are closer than "margin times the sum of + their van der Waals radii". + + INTRAMOLECULAR CLASH SCORE + If rdmol2 is not provided, the score is the percentage of atoms in rdmol1 + that have at least one clash with other atoms in rdmol1. + In this case, a clash is defined by margin times the atoms' smallest + covalent radii (among single, double and triple bond radii). This is done + so that this function is applicable even if no connectivity information is + available. + """ + # source: https://en.wikipedia.org/wiki/Van_der_Waals_radius + vdw_radii = {'N': 1.55, 'O': 1.52, 'C': 1.70, 'H': 1.10, 'S': 1.80, 'P': 1.80, + 'Se': 1.90, 'K': 2.75, 'Na': 2.27, 'Mg': 1.73, 'Zn': 1.39, 'B': 1.92, + 'Br': 1.85, 'Cl': 1.75, 'I': 1.98, 'F': 1.47} + + # https://en.wikipedia.org/wiki/Covalent_radius#Radii_for_multiple_bonds + covalent_radii = {'H': 0.32, 'C': 0.60, 'N': 0.54, 'O': 0.53, 'F': 0.53, 'B': 0.73, + 'Al': 1.11, 'Si': 1.02, 'P': 0.94, 'S': 0.94, 'Cl': 0.93, 'As': 1.06, + 'Br': 1.09, 'I': 1.25, 'Hg': 1.33, 'Bi': 1.35} + + coord1 = rdmol1.GetConformer().GetPositions() + + if rdmol2 is None: + radii1 = np.array([covalent_radii[a.GetSymbol()] for a in rdmol1.GetAtoms()]) + assert coord1.shape[0] == radii1.shape[0] + + dist = np.sqrt(np.sum((coord1[:, None, :] - coord1[None, :, :]) ** 2, axis=-1)) + np.fill_diagonal(dist, np.inf) + clashes = dist < margin * (radii1[:, None] + radii1[None, :]) + + else: + coord2 = rdmol2.GetConformer().GetPositions() + + radii1 = np.array([vdw_radii[a.GetSymbol()] for a in rdmol1.GetAtoms()]) + assert coord1.shape[0] == radii1.shape[0] + radii2 = np.array([vdw_radii[a.GetSymbol()] for a in rdmol2.GetAtoms()]) + assert coord2.shape[0] == radii2.shape[0] + + dist = np.sqrt(np.sum((coord1[:, None, :] - coord2[None, :, :]) ** 2, axis=-1)) + clashes = dist < margin * (radii1[:, None] + radii2[None, :]) + + clashes = np.any(clashes, axis=1) + return np.mean(clashes) + + +def clash_score(rdmol1, rdmol2=None, margin=0.75, ignore={'H'}): + """ + Computes a clash score as the number of atoms that have at least one + clash divided by the number of atoms in the molecule. + + INTERMOLECULAR CLASH SCORE + If rdmol2 is provided, the score is the percentage of atoms in rdmol1 + that have at least one clash with rdmol2. + We define a clash if two atoms are closer than "margin times the sum of + their van der Waals radii". + + INTRAMOLECULAR CLASH SCORE + If rdmol2 is not provided, the score is the percentage of atoms in rdmol1 + that have at least one clash with other atoms in rdmol1. + In this case, a clash is defined by margin times the atoms' smallest + covalent radii (among single, double and triple bond radii). This is done + so that this function is applicable even if no connectivity information is + available. + """ + + intramolecular = rdmol2 is None + + _periodic_table = AllChem.GetPeriodicTable() + + def _coord_and_radii(rdmol): + coord = rdmol.GetConformer().GetPositions() + radii = np.array([_get_radius(a.GetSymbol()) for a in rdmol.GetAtoms()]) + + mask = np.array([a.GetSymbol() not in ignore for a in rdmol.GetAtoms()]) + coord = coord[mask] + radii = radii[mask] + + assert coord.shape[0] == radii.shape[0] + return coord, radii + + # INTRAMOLECULAR CLASH SCORE + if intramolecular: + rdmol2 = rdmol1 + _get_radius = _periodic_table.GetRcovalent # covalent radii + + # INTERMOLECULAR CLASH SCORE + else: + _get_radius = _periodic_table.GetRvdw # vdW radii + + coord1, radii1 = _coord_and_radii(rdmol1) + coord2, radii2 = _coord_and_radii(rdmol2) + + dist = np.sqrt(np.sum((coord1[:, None, :] - coord2[None, :, :]) ** 2, axis=-1)) + if intramolecular: + np.fill_diagonal(dist, np.inf) + + clashes = dist < margin * (radii1[:, None] + radii2[None, :]) + clashes = np.any(clashes, axis=1) + return np.mean(clashes) diff --git a/src/analysis/visualization_utils.py b/src/analysis/visualization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cfac11b2939eed0b3357a43c13bb382f9e6f85ce --- /dev/null +++ b/src/analysis/visualization_utils.py @@ -0,0 +1,192 @@ +import warnings + +import torch +from rdkit import Chem +from rdkit.Chem import Draw, AllChem +from rdkit.Chem import SanitizeFlags +from src.analysis.metrics import check_mol +from src import utils +from src.data.molecule_builder import build_molecule +from src.data.misc import protein_letters_1to3 + + +# def pocket_to_rdkit(pocket, pocket_representation, atom_encoder=None, +# atom_decoder=None, aa_decoder=None, residue_decoder=None, +# aa_atom_index=None): +# +# rdpockets = [] +# for i in torch.unique(pocket['mask']): +# +# node_coord = pocket['x'][pocket['mask'] == i] +# h = pocket['one_hot'][pocket['mask'] == i] +# +# if pocket_representation == 'side_chain_bead': +# coord = node_coord +# +# node_types = [residue_decoder[b] for b in h[:, -len(residue_decoder):].argmax(-1)] +# atom_types = ['C' if r == 'CA' else 'F' for r in node_types] +# +# elif pocket_representation == 'CA+': +# aa_types = [aa_decoder[b] for b in h.argmax(-1)] +# side_chain_vec = pocket['v'][pocket['mask'] == i] +# +# coord = [] +# atom_types = [] +# for xyz, aa, vec in zip(node_coord, aa_types, side_chain_vec): +# # C_alpha +# coord.append(xyz) +# atom_types.append('C') +# +# # all other atoms +# for atom_name, idx in aa_atom_index[aa].items(): +# coord.append(xyz + vec[idx]) +# atom_types.append(atom_name[0]) +# +# coord = torch.stack(coord, dim=0) +# +# else: +# raise NotImplementedError(f"{pocket_representation} residue representation not supported") +# +# atom_types = torch.tensor([atom_encoder[a] for a in atom_types]) +# rdpockets.append(build_molecule(coord, atom_types, atom_decoder=atom_decoder)) +# +# return rdpockets +def pocket_to_rdkit(pocket, pocket_representation, atom_encoder=None, + atom_decoder=None, aa_decoder=None, residue_decoder=None, + aa_atom_index=None): + + rdpockets = [] + for i in torch.unique(pocket['mask']): + + node_coord = pocket['x'][pocket['mask'] == i] + h = pocket['one_hot'][pocket['mask'] == i] + atom_mask = pocket['atom_mask'][pocket['mask'] == i] + + pdb_infos = [] + + if pocket_representation == 'side_chain_bead': + coord = node_coord + + node_types = [residue_decoder[b] for b in h[:, -len(residue_decoder):].argmax(-1)] + atom_types = ['C' if r == 'CA' else 'F' for r in node_types] + + elif pocket_representation == 'CA+': + aa_types = [aa_decoder[b] for b in h.argmax(-1)] + side_chain_vec = pocket['v'][pocket['mask'] == i] + + coord = [] + atom_types = [] + for resi, (xyz, aa, vec, am) in enumerate(zip(node_coord, aa_types, side_chain_vec, atom_mask)): + + # CA not treated differently with updated atom dictionary + for atom_name, idx in aa_atom_index[aa].items(): + + if ~am[idx]: + warnings.warn(f"Missing atom {atom_name} in {aa}:{resi}") + continue + + coord.append(xyz + vec[idx]) + atom_types.append(atom_name[0]) + + info = Chem.AtomPDBResidueInfo() + # info.SetChainId('A') + info.SetResidueName(protein_letters_1to3[aa]) + info.SetResidueNumber(resi + 1) + info.SetOccupancy(1.0) + info.SetTempFactor(0.0) + info.SetName(f' {atom_name:<3}') + pdb_infos.append(info) + + coord = torch.stack(coord, dim=0) + + else: + raise NotImplementedError(f"{pocket_representation} residue representation not supported") + + atom_types = torch.tensor([atom_encoder[a] for a in atom_types]) + rdmol = build_molecule(coord, atom_types, atom_decoder=atom_decoder) + + if len(pdb_infos) == len(rdmol.GetAtoms()): + for a, info in zip(rdmol.GetAtoms(), pdb_infos): + a.SetPDBResidueInfo(info) + + rdpockets.append(rdmol) + + return rdpockets + + +def mols_to_pdbfile(rdmols, filename, flavor=0): + pdb_str = "" + for i, mol in enumerate(rdmols): + pdb_str += f"MODEL{i + 1:>9}\n" + block = Chem.MolToPDBBlock(mol, flavor=flavor) + block = "\n".join(block.split("\n")[:-2]) # remove END + pdb_str += block + "\n" + pdb_str += f"ENDMDL\n" + pdb_str += f"END\n" + + with open(filename, 'w') as f: + f.write(pdb_str) + + return pdb_str + + +def mol_as_pdb(rdmol, filename=None, bfactor=None): + + _rdmol = Chem.Mol(rdmol) # copy + for a in _rdmol.GetAtoms(): + a.SetIsAromatic(False) + for b in _rdmol.GetBonds(): + b.SetIsAromatic(False) + + if bfactor is not None: + for a in _rdmol.GetAtoms(): + val = a.GetPropsAsDict()[bfactor] + + info = Chem.AtomPDBResidueInfo() + info.SetResidueName('UNL') + info.SetResidueNumber(1) + info.SetName(f' {a.GetSymbol():<3}') + info.SetIsHeteroAtom(True) + info.SetOccupancy(1.0) + info.SetTempFactor(val) + a.SetPDBResidueInfo(info) + + pdb_str = Chem.MolToPDBBlock(_rdmol) + + if filename is not None: + with open(filename, 'w') as f: + f.write(pdb_str) + + return pdb_str + + +def draw_grid(molecules, mols_per_row=5, fig_size=(200, 200), + label=check_mol, + highlight_atom=lambda atom: False, + highlight_bond=lambda bond: False): + + draw_mols = [] + marked_atoms = [] + marked_bonds = [] + for mol in molecules: + draw_mol = Chem.Mol(mol) # copy + Chem.SanitizeMol(draw_mol, sanitizeOps=SanitizeFlags.SANITIZE_NONE) + AllChem.Compute2DCoords(draw_mol) + draw_mol = Draw.rdMolDraw2D.PrepareMolForDrawing(draw_mol, + kekulize=False) + draw_mols.append(draw_mol) + marked_atoms.append([a.GetIdx() for a in draw_mol.GetAtoms() if highlight_atom(a)]) + marked_bonds.append([b.GetIdx() for b in draw_mol.GetBonds() if highlight_bond(b)]) + + drawOptions = Draw.rdMolDraw2D.MolDrawOptions() + drawOptions.prepareMolsBeforeDrawing = False + drawOptions.highlightBondWidthMultiplier = 20 + + return Draw.MolsToGridImage(draw_mols, + molsPerRow=mols_per_row, + subImgSize=fig_size, + drawOptions=drawOptions, + highlightAtomLists=marked_atoms, + highlightBondLists=marked_bonds, + legends=[f'[{i}] {label(mol)}' for + i, mol in enumerate(draw_mols)]) diff --git a/src/constants.py b/src/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..ac5b810da5cfb0a24ddd7991fa88d172f0f3ed55 --- /dev/null +++ b/src/constants.py @@ -0,0 +1,256 @@ +import os +from rdkit import Chem +import torch +import numpy as np + +# ------------------------------------------------------------------------------ +# Computational +# ------------------------------------------------------------------------------ +FLOAT_TYPE = torch.float32 +INT_TYPE = torch.int64 + + +# ------------------------------------------------------------------------------ +# Type encoding/decoding +# ------------------------------------------------------------------------------ + +atom_dict = os.environ.get('ATOM_DICT') +if atom_dict == 'simple': + atom_encoder = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9, 'NOATOM': 10} + atom_decoder = ['C', 'N', 'O', 'S', 'B', 'Br', 'Cl', 'P', 'I', 'F', 'NOATOM'] + +else: + atom_encoder = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9, 'NH': 10, 'N+': 11, 'O-': 12, 'NOATOM': 13} + atom_decoder = ['C', 'N', 'O', 'S', 'B', 'Br', 'Cl', 'P', 'I', 'F', 'NH', 'N+', 'O-', 'NOATOM'] + +bond_encoder = {"NOBOND": 0, "SINGLE": 1, "DOUBLE": 2, "TRIPLE": 3, 'AROMATIC': 4} +bond_decoder = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC] + +aa_encoder = {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19} +aa_decoder = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'] + +residue_encoder = {'CA': 0, 'SS': 1} +residue_decoder = ['CA', 'SS'] + +residue_bond_encoder = {'CA-CA': 0, 'CA-SS': 1, 'NOBOND': 2} +residue_bond_decoder = ['CA-CA', 'CA-SS', None] + +# aa_atom_index = { +# 'A': {'N': 0, 'C': 1, 'O': 2, 'CB': 3}, +# 'C': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'SG': 4}, +# 'D': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'OD1': 5, 'OD2': 6}, +# 'E': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD': 5, 'OE1': 6, 'OE2': 7}, +# 'F': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD1': 5, 'CD2': 6, 'CE1': 7, 'CE2': 8, 'CZ': 9}, +# 'G': {'N': 0, 'C': 1, 'O': 2}, +# 'H': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'ND1': 5, 'CD2': 6, 'CE1': 7, 'NE2': 8}, +# 'I': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG1': 4, 'CG2': 5, 'CD1': 6}, +# 'K': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD': 5, 'CE': 6, 'NZ': 7}, +# 'L': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD1': 5, 'CD2': 6}, +# 'M': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'SD': 5, 'CE': 6}, +# 'N': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'OD1': 5, 'ND2': 6}, +# 'P': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD': 5}, +# 'Q': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD': 5, 'OE1': 6, 'NE2': 7}, +# 'R': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD': 5, 'NE': 6, 'CZ': 7, 'NH1': 8, 'NH2': 9}, +# 'S': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'OG': 4}, +# 'T': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'OG1': 4, 'CG2': 5}, +# 'V': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG1': 4, 'CG2': 5}, +# 'W': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD1': 5, 'CD2': 6, 'NE1': 7, 'CE2': 8, 'CE3': 9, 'CZ2': 10, 'CZ3': 11, 'CH2': 12}, +# 'Y': {'N': 0, 'C': 1, 'O': 2, 'CB': 3, 'CG': 4, 'CD1': 5, 'CD2': 6, 'CE1': 7, 'CE2': 8, 'CZ': 9, 'OH': 10}, +# } +aa_atom_index = { + 'A': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4}, + 'C': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'SG': 5}, + 'D': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'OD1': 6, 'OD2': 7}, + 'E': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD': 6, 'OE1': 7, 'OE2': 8}, + 'F': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD1': 6, 'CD2': 7, 'CE1': 8, 'CE2': 9, 'CZ': 10}, + 'G': {'N': 0, 'CA': 1, 'C': 2, 'O': 3}, + 'H': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'ND1': 6, 'CD2': 7, 'CE1': 8, 'NE2': 9}, + 'I': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG1': 5, 'CG2': 6, 'CD1': 7}, + 'K': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD': 6, 'CE': 7, 'NZ': 8}, + 'L': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD1': 6, 'CD2': 7}, + 'M': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'SD': 6, 'CE': 7}, + 'N': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'OD1': 6, 'ND2': 7}, + 'P': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD': 6}, + 'Q': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD': 6, 'OE1': 7, 'NE2': 8}, + 'R': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD': 6, 'NE': 7, 'CZ': 8, 'NH1': 9, 'NH2': 10}, + 'S': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'OG': 5}, + 'T': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'OG1': 5, 'CG2': 6}, + 'V': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG1': 5, 'CG2': 6}, + 'W': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD1': 6, 'CD2': 7, 'NE1': 8, 'CE2': 9, 'CE3': 10, 'CZ2': 11, 'CZ3': 12, 'CH2': 13}, + 'Y': {'N': 0, 'CA': 1, 'C': 2, 'O': 3, 'CB': 4, 'CG': 5, 'CD1': 6, 'CD2': 7, 'CE1': 8, 'CE2': 9, 'CZ': 10, 'OH': 11}, +} + +# ------------------------------------------------------------------------------ +# NERF +# ------------------------------------------------------------------------------ + +# indicates whether atom exists +aa_atom_mask = { + 'A': [True, True, True, True, True, False, False, False, False, False, False, False, False, False], + 'C': [True, True, True, True, True, True, False, False, False, False, False, False, False, False], + 'D': [True, True, True, True, True, True, True, True, False, False, False, False, False, False], + 'E': [True, True, True, True, True, True, True, True, True, False, False, False, False, False], + 'F': [True, True, True, True, True, True, True, True, True, True, True, False, False, False], + 'G': [True, True, True, True, False, False, False, False, False, False, False, False, False, False], + 'H': [True, True, True, True, True, True, True, True, True, True, False, False, False, False], + 'I': [True, True, True, True, True, True, True, True, False, False, False, False, False, False], + 'K': [True, True, True, True, True, True, True, True, True, False, False, False, False, False], + 'L': [True, True, True, True, True, True, True, True, False, False, False, False, False, False], + 'M': [True, True, True, True, True, True, True, True, False, False, False, False, False, False], + 'N': [True, True, True, True, True, True, True, True, False, False, False, False, False, False], + 'P': [True, True, True, True, True, True, True, False, False, False, False, False, False, False], + 'Q': [True, True, True, True, True, True, True, True, True, False, False, False, False, False], + 'R': [True, True, True, True, True, True, True, True, True, True, True, False, False, False], + 'S': [True, True, True, True, True, True, False, False, False, False, False, False, False, False], + 'T': [True, True, True, True, True, True, True, False, False, False, False, False, False, False], + 'V': [True, True, True, True, True, True, True, False, False, False, False, False, False, False], + 'W': [True, True, True, True, True, True, True, True, True, True, True, True, True, True], + 'Y': [True, True, True, True, True, True, True, True, True, True, True, True, False, False], +} + +# (14, 3) index tensor with atom indices of atoms a, b and c for NERF reconstruction +# in principle, columns 1 and 2 can be inferred from column one (immediate predecessor) alone +aa_nerf_indices = { + 'A': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'C': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'D': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'E': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [6, 5, 4], [6, 5, 4], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'F': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [6, 5, 4], [7, 5, 4], [8, 6, 5], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'G': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'H': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [6, 5, 4], [7, 5, 4], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'I': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [4, 1, 0], [5, 4, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'K': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [6, 5, 4], [7, 6, 5], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'L': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'M': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [6, 5, 4], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'N': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'P': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'Q': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [6, 5, 4], [6, 5, 4], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'R': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [6, 5, 4], [7, 6, 5], [8, 7, 6], [8, 7, 6], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'S': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'T': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [4, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'V': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [4, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + 'W': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [6, 5, 4], [7, 5, 4], [7, 5, 4], [9, 7, 5], [10, 7, 5], [11, 9, 7]], + 'Y': [[0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 1, 0], [1, 0, 0], [4, 1, 0], [5, 4, 1], [5, 4, 1], [6, 5, 4], [7, 5, 4], [8, 6, 5], [10, 8, 6], [0, 0, 0], [0, 0, 0]], +} + +# unique id for each rotatable bond (0=chi1, 1=chi, ...) +aa_bond_to_chi = { + 'A': {}, + 'C': {('CA', 'CB'): 0}, + 'D': {('CA', 'CB'): 0, ('CB', 'CG'): 1}, + 'E': {('CA', 'CB'): 0, ('CB', 'CG'): 1, ('CG', 'CD'): 2}, + 'F': {('CA', 'CB'): 0, ('CB', 'CG'): 1}, + 'G': {}, + 'H': {('CA', 'CB'): 0, ('CB', 'CG'): 1}, + 'I': {('CA', 'CB'): 0, ('CB', 'CG2'): 1}, + 'K': {('CA', 'CB'): 0, ('CB', 'CG'): 1, ('CG', 'CD'): 2, ('CD', 'CE'): 3}, + 'L': {('CA', 'CB'): 0, ('CB', 'CG'): 1}, + 'M': {('CA', 'CB'): 0, ('CB', 'CG'): 1, ('CG', 'SD'): 2}, + 'N': {('CA', 'CB'): 0, ('CB', 'CG'): 1}, + 'P': {}, + 'Q': {('CA', 'CB'): 0, ('CB', 'CG'): 1, ('CG', 'CD'): 2}, + 'R': {('CA', 'CB'): 0, ('CB', 'CG'): 1, ('CG', 'CD'): 2, ('CD', 'NE'): 3, ('NE', 'CZ'): 4}, + 'S': {('CA', 'CB'): 0}, + 'T': {('CA', 'CB'): 0}, + 'V': {('CA', 'CB'): 0}, + 'W': {('CA', 'CB'): 0, ('CB', 'CG'): 1}, + 'Y': {('CA', 'CB'): 0, ('CB', 'CG'): 1}, +} + +# index between 0 and 4 to retrieve chi angles, -1 means not a rotatable bond +aa_chi_indices = { + 'A': [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + 'C': [-1, -1, -1, -1, -1, 0, -1, -1, -1, -1, -1, -1, -1, -1], + 'D': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1], + 'E': [-1, -1, -1, -1, -1, 0, 1, 2, 2, -1, -1, -1, -1, -1], + 'F': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1], + 'G': [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + 'H': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1], + 'I': [-1, -1, -1, -1, -1, 0, 0, 1, -1, -1, -1, -1, -1, -1], + 'K': [-1, -1, -1, -1, -1, 0, 1, 2, 3, -1, -1, -1, -1, -1], + 'L': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1], + 'M': [-1, -1, -1, -1, -1, 0, 1, 2, -1, -1, -1, -1, -1, -1], + 'N': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1], + 'P': [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + 'Q': [-1, -1, -1, -1, -1, 0, 1, 2, 2, -1, -1, -1, -1, -1], + 'R': [-1, -1, -1, -1, -1, 0, 1, 2, 3, 4, 4, -1, -1, -1], + 'S': [-1, -1, -1, -1, -1, 0, -1, -1, -1, -1, -1, -1, -1, -1], + 'T': [-1, -1, -1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1], + 'V': [-1, -1, -1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1], + 'W': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1], + 'Y': [-1, -1, -1, -1, -1, 0, 1, 1, -1, -1, -1, -1, -1, -1], +} + +# key: chi index (0=chi1, 1=chi, ...); value: index of atom that defines the chi angle (together with its three predecessors) +aa_chi_anchor_atom = { + 'A': {}, + 'C': {0: 5}, + 'D': {0: 5, 1: 6}, + 'E': {0: 5, 1: 6, 2: 7}, + 'F': {0: 5, 1: 6}, + 'G': {}, + 'H': {0: 5, 1: 6}, + 'I': {0: 5, 1: 7}, + 'K': {0: 5, 1: 6, 2: 7, 3: 8}, + 'L': {0: 5, 1: 6}, + 'M': {0: 5, 1: 6, 2: 7}, + 'N': {0: 5, 1: 6}, + 'P': {}, + 'Q': {0: 5, 1: 6, 2: 7}, + 'R': {0: 5, 1: 6, 2: 7, 3: 8, 4: 9}, + 'S': {0: 5}, + 'T': {0: 5}, + 'V': {0: 5}, + 'W': {0: 5, 1: 6}, + 'Y': {0: 5, 1: 6}, +} + +# ------------------------------------------------------------------------------ +# Visualization +# ------------------------------------------------------------------------------ +# PyMOL colors, see: https://pymolwiki.org/index.php/Color_Values#Chemical_element_colours +colors_dic = ['#33ff33', '#3333ff', '#ff4d4d', '#e6c540', '#ffb5b5', '#A62929', '#1FF01F', '#ff8000', '#940094', '#B3FFFF', '#b3e3f5'] +radius_dic = [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3] + + +# ------------------------------------------------------------------------------ +# Backbone geometry +# Taken from: Bhagavan, N. V., and C. E. Ha. +# "Chapter 4-Three-dimensional structure of proteins and disorders of protein misfolding." +# Essentials of Medical Biochemistry (2015): 31-51. +# https://www.sciencedirect.com/science/article/pii/B978012416687500004X +# ------------------------------------------------------------------------------ +N_CA_DIST = 1.47 +CA_C_DIST = 1.53 +N_CA_C_ANGLE = 110 * np.pi / 180 + +# ------------------------------------------------------------------------------ +# Atom radii +# ------------------------------------------------------------------------------ +# # https://en.wikipedia.org/wiki/Covalent_radius#Radii_for_multiple_bonds +# # (2023/04/14) +# covalent_radii = {'H': [32, None, None], +# 'C': [75, 67, 60], +# 'N': [71, 60, 54], +# 'O': [63, 57, 53], +# 'F': [64, 59, 53], +# 'B': [85, 78, 73], +# 'Al': [126, 113, 111], +# 'Si': [116, 107, 102], +# 'P': [111, 102, 94], +# 'S': [103, 94, 95], +# 'Cl': [99, 95, 93], +# 'As': [121, 114, 106], +# 'Br': [114, 109, 110], +# 'I': [133, 129, 125], +# 'Hg': [133, 142, None], +# 'Bi': [151, 141, 135]} + +# source: https://en.wikipedia.org/wiki/Van_der_Waals_radius +vdw_radii = {'N': 1.55, 'O': 1.52, 'C': 1.70, 'H': 1.10, 'S': 1.80, 'P': 1.80, + 'Se': 1.90, 'K': 2.75, 'Na': 2.27, 'Mg': 1.73, 'Zn': 1.39, 'B': 1.92, + 'Br': 1.85, 'Cl': 1.75, 'I': 1.98, 'F': 1.47} + + +WEBDATASET_SHARD_SIZE = 50000 +WEBDATASET_VAL_SIZE = 100 \ No newline at end of file diff --git a/src/data/data_utils.py b/src/data/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..44aa5895840bfd68dd1c2eb1a9c76f9e2a5eab17 --- /dev/null +++ b/src/data/data_utils.py @@ -0,0 +1,901 @@ +import io +from itertools import accumulate, chain +from copy import deepcopy +import random +import torch +import torch.nn.functional as F +import numpy as np +from rdkit import Chem +from torch_scatter import scatter_mean +from Bio.PDB import StructureBuilder, Chain, Model, Structure +from Bio.PDB.PICIO import read_PIC, write_PIC +from scipy.ndimage import gaussian_filter +from pdb import set_trace + +from src.constants import FLOAT_TYPE, INT_TYPE +from src.constants import atom_encoder, bond_encoder, aa_encoder, residue_encoder, residue_bond_encoder, aa_atom_index +from src import utils +from src.data.misc import protein_letters_3to1, is_aa +from src.data.normal_modes import pdb_to_normal_modes +from src.data.nerf import get_nerf_params, ic_to_coords +import src.data.so3_utils as so3 + + +class TensorDict(dict): + def __init__(self, **kwargs): + super(TensorDict, self).__init__(**kwargs) + + def _apply(self, func: str, *args, **kwargs): + """ Apply function to all tensors. """ + for k, v in self.items(): + if torch.is_tensor(v): + self[k] = getattr(v, func)(*args, **kwargs) + return self + + # def to(self, device): + # for k, v in self.items(): + # if torch.is_tensor(v): + # self[k] = v.to(device) + # return self + + def cuda(self): + return self.to('cuda') + + def cpu(self): + return self.to('cpu') + + def to(self, device): + return self._apply("to", device) + + def detach(self): + return self._apply("detach") + + def __repr__(self): + def val_to_str(val): + if isinstance(val, torch.Tensor): + # if val.isnan().any(): + # return "(!nan)" + return "%r" % list(val.size()) + if isinstance(val, list): + return "[%r,]" % len(val) + else: + return "?" + + return f"{type(self).__name__}({', '.join(f'{k}={val_to_str(v)}' for k, v in self.items())})" + + +def collate_entity(batch): + + out = {} + for prop in batch[0].keys(): + + if prop == 'name': + out[prop] = [x[prop] for x in batch] + + elif prop == 'size' or prop == 'n_bonds': + out[prop] = torch.tensor([x[prop] for x in batch]) + + elif prop == 'bonds': + # index offset + offset = list(accumulate([x['size'] for x in batch], initial=0)) + out[prop] = torch.cat([x[prop] + offset[i] for i, x in enumerate(batch)], dim=1) + + elif prop == 'residues': + out[prop] = list(chain.from_iterable(x[prop] for x in batch)) + + elif prop in {'mask', 'bond_mask'}: + pass # batch masks will be written later + + else: + out[prop] = torch.cat([x[prop] for x in batch], dim=0) + + # Create batch masks + # make sure indices in batch start at zero (needed for torch_scatter) + if prop == 'x': + out['mask'] = torch.cat([i * torch.ones(len(x[prop]), dtype=torch.int64, device=x[prop].device) + for i, x in enumerate(batch)], dim=0) + if prop == 'bond_one_hot': + # TODO: this is not necessary as it can be computed on-the-fly as bond_mask = mask[bonds[0]] or bond_mask = mask[bonds[1]] + out['bond_mask'] = torch.cat([i * torch.ones(len(x[prop]), dtype=torch.int64, device=x[prop].device) + for i, x in enumerate(batch)], dim=0) + + return out + + +def split_entity( + batch, + *, + index_types={'bonds'}, + edge_types={'bond_one_hot', 'bond_mask'}, + no_split={'name', 'size', 'n_bonds'}, + skip={'fragments'}, + batch_mask=None, + edge_mask=None + ): + """ Splits a batch into items and returns a list. """ + + batch_mask = batch["mask"] if batch_mask is None else batch_mask + edge_mask = batch["bond_mask"] if edge_mask is None else edge_mask + sizes = batch['size'] if 'size' in batch else torch.unique(batch_mask, return_counts=True)[1].tolist() + + batch_size = len(torch.unique(batch['mask'])) + out = {} + for prop in batch.keys(): + if prop in skip: + continue + if prop in no_split: + out[prop] = batch[prop] # already a list + + elif prop in index_types: + offsets = list(accumulate(sizes[:-1], initial=0)) + out[prop] = utils.batch_to_list_for_indices(batch[prop], edge_mask, offsets) + + elif prop in edge_types: + out[prop] = utils.batch_to_list(batch[prop], edge_mask) + + else: + out[prop] = utils.batch_to_list(batch[prop], batch_mask) + + out = [{k: v[i] for k, v in out.items()} for i in range(batch_size)] + return out + + +def repeat_items(batch, repeats): + batch_list = split_entity(batch) + out = collate_entity([x for _ in range(repeats) for x in batch_list]) + return type(batch)(**out) + + +def get_side_chain_bead_coord(biopython_residue): + """ + Places side chain bead at the location of the farthest side chain atom. + """ + if biopython_residue.get_resname() == 'GLY': + return None + if biopython_residue.get_resname() == 'ALA': + return biopython_residue['CB'].get_coord() + + ca_coord = biopython_residue['CA'].get_coord() + side_chain_atoms = [a for a in biopython_residue.get_atoms() if + a.id not in {'N', 'CA', 'C', 'O'} and a.element != 'H'] + side_chain_coords = np.stack([a.get_coord() for a in side_chain_atoms]) + + atom_idx = np.argmax(np.sum((side_chain_coords - ca_coord[None, :]) ** 2, axis=-1)) + + return side_chain_coords[atom_idx, :] + + +def get_side_chain_vectors(res, index_dict, size=None): + if size is None: + size = max([x for aa in index_dict.values() for x in aa.values()]) + 1 + + resname = protein_letters_3to1[res.get_resname()] + + out = np.zeros((size, 3)) + for atom in res.get_atoms(): + if atom.get_name() in index_dict[resname]: + idx = index_dict[resname][atom.get_name()] + out[idx] = atom.get_coord() - res['CA'].get_coord() + # else: + # if atom.get_name() != 'CA' and not atom.get_name().startswith('H'): + # print(resname, atom.get_name()) + + return out + + +def get_normal_modes(res, normal_mode_dict): + nm = normal_mode_dict[(res.get_parent().id, res.id[1], 'CA')] # (n_modes, 3) + return nm + + +def get_torsion_angles(res, device=None): + """ + Return the five chi angles. Missing angles are filled with zeros. + """ + ANGLES = ['chi1', 'chi2', 'chi3', 'chi4', 'chi5'] + + ic_res = res.internal_coord + chi_angles = [ic_res.get_angle(chi) for chi in ANGLES] + chi_angles = [chi if chi is not None else float('nan') for chi in chi_angles] + + return torch.tensor(chi_angles, device=device) * np.pi / 180 + + +def apply_torsion_angles(res, chi_angles): + """ + Set side chain torsion angles of a biopython residue object with + internal coordinates. + """ + ANGLES = ['chi1', 'chi2', 'chi3', 'chi4', 'chi5'] + + chi_angles = chi_angles * 180 / np.pi + + # res.parent.internal_coord.build_atomArray() # rebuild atom pointers + + ic_res = res.internal_coord + for chi, angle in zip(ANGLES, chi_angles): + if ic_res.pick_angle(chi) is None: + continue + ic_res.bond_set(chi, angle) + + res.parent.internal_to_atom_coordinates(verbose=False) + # res.parent.internal_coord.init_atom_coords() + # res.internal_coord.assemble() + + return res + + +def prepare_internal_coord(res): + + # Make new structure with a single residue + new_struct = Structure.Structure('X') + new_struct.header = {} + new_model = Model.Model(0) + new_struct.add(new_model) + new_chain = Chain.Chain('X') + new_model.add(new_chain) + new_chain.add(res) + res.set_parent(new_chain) # update pointer + + # Compute internal coordinates + new_chain.atom_to_internal_coordinates() + + pic_io = io.StringIO() + write_PIC(new_struct, pic_io) + return pic_io.getvalue() + + +def residue_from_internal_coord(ic_string): + pic_io = io.StringIO(ic_string) + struct = read_PIC(pic_io, quick=True) + res = struct.child_list[0].child_list[0].child_list[0] + res.parent.internal_to_atom_coordinates(verbose=False) + return res + + +def prepare_pocket(biopython_residues, amino_acid_encoder, residue_encoder, + residue_bond_encoder, pocket_representation='side_chain_bead', + compute_nerf_params=False, compute_bb_frames=False, + nma_input=None): + + assert nma_input is None or pocket_representation == 'CA+', \ + "vector features are only supported for CA+ pockets" + + # sort residues + biopython_residues = sorted(biopython_residues, key=lambda x: (x.parent.id, x.id[1])) + + if nma_input is not None: + # preprocessed normal mode eigenvectors + if isinstance(nma_input, dict): + nma_dict = nma_input + + # PDB file + else: + nma_dict = pdb_to_normal_modes(str(nma_input)) + + if pocket_representation == 'side_chain_bead': + ca_coords = np.zeros((len(biopython_residues), 3)) + ca_types = np.zeros(len(biopython_residues), dtype='int64') + side_chain_coords = [] + side_chain_aa_types = [] + edges = [] # CA-CA and CA-side_chain + edge_types = [] + last_res_id = None + for i, res in enumerate(biopython_residues): + aa = amino_acid_encoder[protein_letters_3to1[res.get_resname()]] + ca_coords[i, :] = res['CA'].get_coord() + ca_types[i] = aa + side_chain_coord = get_side_chain_bead_coord(res) + if side_chain_coord is not None: + side_chain_coords.append(side_chain_coord) + side_chain_aa_types.append(aa) + edges.append((i, len(ca_coords) + len(side_chain_coords) - 1)) + edge_types.append(residue_bond_encoder['CA-SS']) + + # add edges between contiguous CA atoms + if i > 0 and res.id[1] == last_res_id + 1: + edges.append((i - 1, i)) + edge_types.append(residue_bond_encoder['CA-CA']) + + last_res_id = res.id[1] + + # Coordinates + side_chain_coords = np.stack(side_chain_coords) + pocket_coords = np.concatenate([ca_coords, side_chain_coords], axis=0) + pocket_coords = torch.from_numpy(pocket_coords) + + # Features + amino_acid_onehot = F.one_hot( + torch.cat([torch.from_numpy(ca_types), torch.tensor(side_chain_aa_types, dtype=torch.int64)], dim=0), + num_classes=len(amino_acid_encoder) + ) + side_chain_onehot = np.concatenate([ + np.tile(np.eye(1, len(residue_encoder), residue_encoder['CA']), + [len(ca_coords), 1]), + np.tile(np.eye(1, len(residue_encoder), residue_encoder['SS']), + [len(side_chain_coords), 1]) + ], axis=0) + side_chain_onehot = torch.from_numpy(side_chain_onehot) + pocket_onehot = torch.cat([amino_acid_onehot, side_chain_onehot], dim=1) + + vector_features = None + nma_features = None + + # Bonds + edges = torch.tensor(edges).T + edge_types = F.one_hot(torch.tensor(edge_types), num_classes=len(residue_bond_encoder)) + + elif pocket_representation == 'CA+': + ca_coords = np.zeros((len(biopython_residues), 3)) + ca_types = np.zeros(len(biopython_residues), dtype='int64') + + v_dim = max([x for aa in aa_atom_index.values() for x in aa.values()]) + 1 + vec_feats = np.zeros((len(biopython_residues), v_dim, 3), dtype='float32') + nf_nma = 5 + nma_feats = np.zeros((len(biopython_residues), nf_nma, 3), dtype='float32') + + edges = [] # CA-CA and CA-side_chain + edge_types = [] + last_res_id = None + for i, res in enumerate(biopython_residues): + aa = amino_acid_encoder[protein_letters_3to1[res.get_resname()]] + ca_coords[i, :] = res['CA'].get_coord() + ca_types[i] = aa + + vec_feats[i] = get_side_chain_vectors(res, aa_atom_index, v_dim) + if nma_input is not None: + nma_feats[i] = get_normal_modes(res, nma_dict) + + # add edges between contiguous CA atoms + if i > 0 and res.id[1] == last_res_id + 1: + edges.append((i - 1, i)) + edge_types.append(residue_bond_encoder['CA-CA']) + + last_res_id = res.id[1] + + # Coordinates + pocket_coords = torch.from_numpy(ca_coords) + + # Features + pocket_onehot = F.one_hot(torch.from_numpy(ca_types), + num_classes=len(amino_acid_encoder)) + + vector_features = torch.from_numpy(vec_feats) + nma_features = torch.from_numpy(nma_feats) + + # Bonds + if len(edges) < 1: + edges = torch.empty(2, 0) + edge_types = torch.empty(0, len(residue_bond_encoder)) + else: + edges = torch.tensor(edges).T + edge_types = F.one_hot(torch.tensor(edge_types), + num_classes=len(residue_bond_encoder)) + + else: + raise NotImplementedError( + f"Pocket representation '{pocket_representation}' not implemented") + + # pocket_ids = [f'{res.parent.id}:{res.id[1]}' for res in biopython_residues] + + pocket = { + 'x': pocket_coords.to(dtype=FLOAT_TYPE), + 'one_hot': pocket_onehot.to(dtype=FLOAT_TYPE), + # 'ids': pocket_ids, + 'size': torch.tensor([len(pocket_coords)], dtype=INT_TYPE), + 'mask': torch.zeros(len(pocket_coords), dtype=INT_TYPE), + 'bonds': edges.to(INT_TYPE), + 'bond_one_hot': edge_types.to(FLOAT_TYPE), + 'bond_mask': torch.zeros(edges.size(1), dtype=INT_TYPE), + 'n_bonds': torch.tensor([len(edge_types)], dtype=INT_TYPE), + } + + if vector_features is not None: + pocket['v'] = vector_features.to(dtype=FLOAT_TYPE) + + if nma_input is not None: + pocket['nma_vec'] = nma_features.to(dtype=FLOAT_TYPE) + + if compute_nerf_params: + nerf_params = [get_nerf_params(r) for r in biopython_residues] + nerf_params = {k: torch.stack([x[k] for x in nerf_params], dim=0) + for k in nerf_params[0].keys()} + pocket.update(nerf_params) + + if compute_bb_frames: + n_xyz = torch.from_numpy(np.stack([r['N'].get_coord() for r in biopython_residues])) + ca_xyz = torch.from_numpy(np.stack([r['CA'].get_coord() for r in biopython_residues])) + c_xyz = torch.from_numpy(np.stack([r['C'].get_coord() for r in biopython_residues])) + pocket['axis_angle'], _ = get_bb_transform(n_xyz, ca_xyz, c_xyz) + + return pocket, biopython_residues + + +def encode_atom(rd_atom, atom_encoder): + element = rd_atom.GetSymbol().capitalize() + + explicitHs = rd_atom.GetNumExplicitHs() + if explicitHs == 1 and f'{element}H' in atom_encoder: + return atom_encoder[f'{element}H'] + + charge = rd_atom.GetFormalCharge() + if charge == 1 and f'{element}+' in atom_encoder: + return atom_encoder[f'{element}+'] + if charge == -1 and f'{element}-' in atom_encoder: + return atom_encoder[f'{element}-'] + + return atom_encoder[element] + + +def prepare_ligand(rdmol, atom_encoder, bond_encoder): + + # remove H atoms if not in atom_encoder + if 'H' not in atom_encoder: + rdmol = Chem.RemoveAllHs(rdmol, sanitize=False) + + # Coordinates + ligand_coord = rdmol.GetConformer().GetPositions() + ligand_coord = torch.from_numpy(ligand_coord) + + # Features + ligand_onehot = F.one_hot( + torch.tensor([encode_atom(a, atom_encoder) for a in rdmol.GetAtoms()]), + num_classes=len(atom_encoder) + ) + + # Bonds + adj = np.ones((rdmol.GetNumAtoms(), rdmol.GetNumAtoms())) * bond_encoder['NOBOND'] + for b in rdmol.GetBonds(): + i = b.GetBeginAtomIdx() + j = b.GetEndAtomIdx() + adj[i, j] = bond_encoder[str(b.GetBondType())] + adj[j, i] = adj[i, j] # undirected graph + + # molecular graph is undirected -> don't save redundant information + bonds = np.stack(np.triu_indices(len(ligand_coord), k=1), axis=0) + # bonds = np.stack(np.ones_like(adj).nonzero(), axis=0) + bond_types = adj[bonds[0], bonds[1]].astype('int64') + bonds = torch.from_numpy(bonds) + bond_types = F.one_hot(torch.from_numpy(bond_types), num_classes=len(bond_encoder)) + + ligand = { + 'x': ligand_coord.to(dtype=FLOAT_TYPE), + 'one_hot': ligand_onehot.to(dtype=FLOAT_TYPE), + 'mask': torch.zeros(len(ligand_coord), dtype=INT_TYPE), + 'bonds': bonds.to(INT_TYPE), + 'bond_one_hot': bond_types.to(FLOAT_TYPE), + 'bond_mask': torch.zeros(bonds.size(1), dtype=INT_TYPE), + 'size': torch.tensor([len(ligand_coord)], dtype=INT_TYPE), + 'n_bonds': torch.tensor([len(bond_types)], dtype=INT_TYPE), + } + + return ligand + + +def process_raw_molecule_with_empty_pocket(rdmol): + ligand = prepare_ligand(rdmol, atom_encoder, bond_encoder) + pocket = { + 'x': torch.tensor([], dtype=FLOAT_TYPE), + 'one_hot': torch.tensor([], dtype=FLOAT_TYPE), + 'size': torch.tensor([], dtype=INT_TYPE), + 'mask': torch.tensor([], dtype=INT_TYPE), + 'bonds': torch.tensor([], dtype=INT_TYPE), + 'bond_one_hot': torch.tensor([], dtype=FLOAT_TYPE), + 'bond_mask': torch.tensor([], dtype=INT_TYPE), + 'n_bonds': torch.tensor([], dtype=INT_TYPE), + } + return ligand, pocket + + +def process_raw_pair(biopython_model, rdmol, dist_cutoff=None, + pocket_representation='side_chain_bead', + compute_nerf_params=False, compute_bb_frames=False, + nma_input=None, return_pocket_pdb=False): + + # Process ligand + ligand = prepare_ligand(rdmol, atom_encoder, bond_encoder) + + # Find interacting pocket residues based on distance cutoff + pocket_residues = [] + for residue in biopython_model.get_residues(): + + # Remove non-standard amino acids and HETATMs + if not is_aa(residue.get_resname(), standard=True): + continue + + res_coords = torch.from_numpy(np.array([a.get_coord() for a in residue.get_atoms()])) + if dist_cutoff is None or (((res_coords[:, None, :] - ligand['x'][None, :, :]) ** 2).sum(-1) ** 0.5).min() < dist_cutoff: + pocket_residues.append(residue) + + pocket, pocket_residues = prepare_pocket( + pocket_residues, aa_encoder, residue_encoder, residue_bond_encoder, + pocket_representation, compute_nerf_params, compute_bb_frames, nma_input + ) + + if return_pocket_pdb: + builder = StructureBuilder.StructureBuilder() + builder.init_structure("") + builder.init_model(0) + pocket_struct = builder.get_structure() + for residue in pocket_residues: + chain = residue.get_parent().get_id() + + # init chain if necessary + if not pocket_struct[0].has_id(chain): + builder.init_chain(chain) + + # add residue + pocket_struct[0][chain].add(residue) + + pocket['pocket_pdb'] = pocket_struct + # if return_pocket_pdb: + # pocket['residues'] = [prepare_internal_coord(res) for res in pocket_residues] + + return ligand, pocket + + +class AppendVirtualNodes: + def __init__(self, atom_encoder, bond_encoder, max_ligand_size, scale=1.0): + self.max_size = max_ligand_size + self.atom_encoder = atom_encoder + self.bond_encoder = bond_encoder + self.vidx = atom_encoder['NOATOM'] + self.bidx = bond_encoder['NOBOND'] + self.scale = scale + + def __call__(self, ligand, max_size=None, eps=1e-6): + if max_size is None: + max_size = self.max_size + + n_virt = max_size - ligand['size'] + + C = torch.cov(ligand['x'].T) + L = torch.linalg.cholesky(C + torch.eye(3) * eps) + mu = ligand['x'].mean(0, keepdim=True) + virt_coords = mu + torch.randn(n_virt, 3) @ L.T * self.scale + + # insert virtual atom column + virt_one_hot = F.one_hot(torch.ones(n_virt, dtype=torch.int64) * self.vidx, num_classes=len(self.atom_encoder)) + virt_mask = torch.cat([torch.zeros(ligand['size'], dtype=bool), torch.ones(n_virt, dtype=bool)]) + + ligand['x'] = torch.cat([ligand['x'], virt_coords]) + ligand['one_hot'] = torch.cat(([ligand['one_hot'], virt_one_hot])) + ligand['virtual_mask'] = virt_mask + ligand['size'] = max_size + + # Bonds + new_bonds = torch.triu_indices(max_size, max_size, offset=1) + + bond_types = torch.ones(max_size, max_size, dtype=INT_TYPE) * self.bidx + row, col = ligand['bonds'] + bond_types[row, col] = ligand['bond_one_hot'].argmax(dim=1) + new_row, new_col = new_bonds + bond_types = bond_types[new_row, new_col] + + ligand['bonds'] = new_bonds + ligand['bond_one_hot'] = F.one_hot(bond_types, num_classes=len(self.bond_encoder)).to(ligand['bond_one_hot'].dtype) + ligand['n_bonds'] = len(ligand['bond_one_hot']) + + return ligand + + +class AppendVirtualNodesInCoM: + def __init__(self, atom_encoder, bond_encoder, add_min=0, add_max=10): + self.atom_encoder = atom_encoder + self.bond_encoder = bond_encoder + self.vidx = atom_encoder['NOATOM'] + self.bidx = bond_encoder['NOBOND'] + self.add_min = add_min + self.add_max = add_max + + def __call__(self, ligand): + + n_virt = random.randint(self.add_min, self.add_max) + + # all virtual coordinates in the CoM + virt_coords = ligand['x'].mean(0, keepdim=True).repeat(n_virt, 1) + + # insert virtual atom column + virt_one_hot = F.one_hot(torch.ones(n_virt, dtype=torch.int64) * self.vidx, num_classes=len(self.atom_encoder)) + virt_mask = torch.cat([torch.zeros(ligand['size'], dtype=bool), torch.ones(n_virt, dtype=bool)]) + + ligand['x'] = torch.cat([ligand['x'], virt_coords]) + ligand['one_hot'] = torch.cat(([ligand['one_hot'], virt_one_hot])) + ligand['virtual_mask'] = virt_mask + ligand['size'] = len(ligand['x']) + + # Bonds + new_bonds = torch.triu_indices(ligand['size'], ligand['size'], offset=1) + + bond_types = torch.ones(ligand['size'], ligand['size'], dtype=INT_TYPE) * self.bidx + row, col = ligand['bonds'] + bond_types[row, col] = ligand['bond_one_hot'].argmax(dim=1) + new_row, new_col = new_bonds + bond_types = bond_types[new_row, new_col] + + ligand['bonds'] = new_bonds + ligand['bond_one_hot'] = F.one_hot(bond_types, num_classes=len(self.bond_encoder)).to(ligand['bond_one_hot'].dtype) + ligand['n_bonds'] = len(ligand['bond_one_hot']) + + return ligand + + +def rdmol_to_smiles(rdmol): + mol = Chem.Mol(rdmol) + Chem.RemoveStereochemistry(mol) + mol = Chem.RemoveHs(mol) + return Chem.MolToSmiles(mol) + + +def get_n_nodes(lig_positions, pocket_positions, smooth_sigma=None): + # Joint distribution of ligand's and pocket's number of nodes + n_nodes_lig = [len(x) for x in lig_positions] + n_nodes_pocket = [len(x) for x in pocket_positions] + + joint_histogram = np.zeros((np.max(n_nodes_lig) + 1, + np.max(n_nodes_pocket) + 1)) + + for nlig, npocket in zip(n_nodes_lig, n_nodes_pocket): + joint_histogram[nlig, npocket] += 1 + + print(f'Original histogram: {np.count_nonzero(joint_histogram)}/' + f'{joint_histogram.shape[0] * joint_histogram.shape[1]} bins filled') + + # Smooth the histogram + if smooth_sigma is not None: + filtered_histogram = gaussian_filter( + joint_histogram, sigma=smooth_sigma, order=0, mode='constant', + cval=0.0, truncate=4.0) + + print(f'Smoothed histogram: {np.count_nonzero(filtered_histogram)}/' + f'{filtered_histogram.shape[0] * filtered_histogram.shape[1]} bins filled') + + joint_histogram = filtered_histogram + + return joint_histogram + + +# def get_type_histograms(lig_one_hot, pocket_one_hot, lig_encoder, pocket_encoder): +# +# lig_one_hot = np.concatenate(lig_one_hot, axis=0) +# pocket_one_hot = np.concatenate(pocket_one_hot, axis=0) +# +# atom_decoder = list(lig_encoder.keys()) +# lig_counts = {k: 0 for k in lig_encoder.keys()} +# for a in [atom_decoder[x] for x in lig_one_hot.argmax(1)]: +# lig_counts[a] += 1 +# +# aa_decoder = list(pocket_encoder.keys()) +# pocket_counts = {k: 0 for k in pocket_encoder.keys()} +# for r in [aa_decoder[x] for x in pocket_one_hot.argmax(1)]: +# pocket_counts[r] += 1 +# +# return lig_counts, pocket_counts + + +def get_type_histogram(one_hot, type_encoder): + + one_hot = np.concatenate(one_hot, axis=0) + + decoder = list(type_encoder.keys()) + counts = {k: 0 for k in type_encoder.keys()} + for a in [decoder[x] for x in one_hot.argmax(1)]: + counts[a] += 1 + + return counts + + +def get_residue_with_resi(pdb_chain, resi): + res = [x for x in pdb_chain.get_residues() if x.id[1] == resi] + assert len(res) == 1 + return res[0] + + +def get_pocket_from_ligand(pdb_model, ligand, dist_cutoff=8.0): + + if ligand.endswith(".sdf"): + # ligand as sdf file + rdmol = Chem.SDMolSupplier(str(ligand))[0] + ligand_coords = torch.from_numpy(rdmol.GetConformer().GetPositions()).float() + resi = None + else: + # ligand contained in PDB; given in : format + chain, resi = ligand.split(':') + ligand = get_residue_with_resi(pdb_model[chain], int(resi)) + ligand_coords = torch.from_numpy( + np.array([a.get_coord() for a in ligand.get_atoms()])) + + pocket_residues = [] + for residue in pdb_model.get_residues(): + if residue.id[1] == resi: + continue # skip ligand itself + + res_coords = torch.from_numpy( + np.array([a.get_coord() for a in residue.get_atoms()])) + if is_aa(residue.get_resname(), standard=True) \ + and torch.cdist(res_coords, ligand_coords).min() < dist_cutoff: + pocket_residues.append(residue) + + return pocket_residues + + +def encode_residues(biopython_residues, type_encoder, level='atom', + remove_H=True): + assert level in {'atom', 'residue'} + + if level == 'atom': + entities = [a for res in biopython_residues for a in res.get_atoms() + if (a.element != 'H' or not remove_H)] + types = [a.element.capitalize() for a in entities] + else: + entities = [res['CA'] for res in biopython_residues] + types = [protein_letters_3to1[res.get_resname()] for res in biopython_residues] + + coord = torch.tensor(np.stack([e.get_coord() for e in entities])) + one_hot = F.one_hot(torch.tensor([type_encoder[t] for t in types]), + num_classes=len(type_encoder)) + + return coord, one_hot + + +def center_data(ligand, pocket): + if pocket['x'].numel() > 0: + pocket_com = pocket.center() + else: + pocket_com = scatter_mean(ligand['x'], ligand['mask'], dim=0) + + ligand['x'] = ligand['x'] - pocket_com[ligand['mask']] + return ligand, pocket + + +def get_bb_transform(n_xyz, ca_xyz, c_xyz): + """ + Compute translation and rotation of the canoncical backbone frame (triangle N-Ca-C) from a position with + Ca at the origin, N on the x-axis and C in the xy-plane to the global position of the backbone frame + + Args: + n_xyz: (n, 3) + ca_xyz: (n, 3) + c_xyz: (n, 3) + + Returns: + axis-angle representation of the rotation, shape (n, 3) # rotation matrix of shape (n, 3, 3) + translation vector of shape (n, 3) + """ + + def rotation_matrix(angle, axis): + axis_mapping = {'x': 0, 'y': 1, 'z': 2} + axis = axis_mapping[axis] + vector = torch.zeros(len(angle), 3) + vector[:, axis] = 1 + # return axis_angle_to_matrix(angle * vector) + return so3.matrix_from_rotation_vector(angle.view(-1, 1) * vector) + + translation = ca_xyz + n_xyz = n_xyz - translation + c_xyz = c_xyz - translation + + # Find rotation matrix that aligns the coordinate systems + + # rotate around y-axis to move N into the xy-plane + theta_y = torch.arctan2(n_xyz[:, 2], -n_xyz[:, 0]) + Ry = rotation_matrix(theta_y, 'y') + Ry = Ry.transpose(2, 1) + n_xyz = torch.einsum('noi,ni->no', Ry, n_xyz) + + # rotate around z-axis to move N onto the x-axis + theta_z = torch.arctan2(n_xyz[:, 1], n_xyz[:, 0]) + Rz = rotation_matrix(theta_z, 'z') + Rz = Rz.transpose(2, 1) + # print(torch.einsum('noi,ni->no', Rz, n_xyz)) + + # n_xyz = torch.einsum('noi,ni->no', Rz.transpose(0, 2, 1), n_xyz) + + # rotate around x-axis to move C into the xy-plane + c_xyz = torch.einsum('noj,nji,ni->no', Rz, Ry, c_xyz) + theta_x = torch.arctan2(c_xyz[:, 2], c_xyz[:, 1]) + Rx = rotation_matrix(theta_x, 'x') + Rx = Rx.transpose(2, 1) + # print(torch.einsum('noi,ni->no', Rx, c_xyz)) + + # Final rotation matrix + Ry = Ry.transpose(2, 1) + Rz = Rz.transpose(2, 1) + Rx = Rx.transpose(2, 1) + R = torch.einsum('nok,nkj,nji->noi', Ry, Rz, Rx) + + # return R, translation + # return matrix_to_axis_angle(R), translation + return so3.rotation_vector_from_matrix(R), translation + + +class Residues(TensorDict): + """ + Dictionary-like container for residues that supports some basic transformations. + """ + + # all keys + KEYS = {'x', 'one_hot', 'bonds', 'bond_one_hot', 'v', 'nma_vec', 'fixed_coord', + 'atom_mask', 'nerf_indices', 'length', 'theta', 'chi', 'ddihedral', + 'chi_indices', 'axis_angle', 'mask', 'bond_mask'} + + # coordinate-type values, shape (..., 3) + COORD_KEYS = {'x', 'fixed_coord'} + + # vector-type values, shape (n_residues, n_feat, 3) + VECTOR_KEYS = {'v', 'nma_vec'} + + # properties that change if the side chains and/or backbones are updated + MUTABLE_PROPS_SS_AND_BB = {'v'} + + # properties that only change if the side chains are updated + MUTABLE_PROPS_SS = {'chi'} + + # properties that only change if the backbones are updated + MUTABLE_PROPS_BB = {'x', 'fixed_coord', 'axis_angle', 'nma_vec'} + + # properties that remain fixed in all cases + IMMUTABLE_PROPS = {'mask', 'one_hot', 'bonds', 'bond_one_hot', 'bond_mask', + 'atom_mask', 'nerf_indices', 'length', 'theta', + 'ddihedral', 'chi_indices', 'name', 'size', 'n_bonds'} + + def copy(self): + data = super().copy() + return Residues(**data) + + def deepcopy(self): + data = {k: v.clone() if torch.is_tensor(v) else deepcopy(v) + for k, v in self.items()} + return Residues(**data) + + def center(self): + com = scatter_mean(self['x'], self['mask'], dim=0) + self['x'] = self['x'] - com[self['mask']] + self['fixed_coord'] = self['fixed_coord'] - com[self['mask']].unsqueeze(1) + return com + + def set_empty_v(self): + self['v'] = torch.tensor([], device=self['x'].device) + + @torch.no_grad() + def set_chi(self, chi_angles): + self['chi'][:, :5] = chi_angles + nerf_params = {k: self[k] for k in ['fixed_coord', 'atom_mask', + 'nerf_indices', 'length', 'theta', + 'chi', 'ddihedral', 'chi_indices']} + self['v'] = ic_to_coords(**nerf_params) - self['x'].unsqueeze(1) + + @torch.no_grad() + def set_frame(self, new_ca_coord, new_axis_angle): + bb_coord = self['fixed_coord'] + bb_coord = bb_coord - self['x'].unsqueeze(1) + rotmat_before = so3.matrix_from_rotation_vector(self['axis_angle']) + rotmat_after = so3.matrix_from_rotation_vector(new_axis_angle) + rotmat_diff = rotmat_after @ rotmat_before.transpose(-1, -2) + bb_coord = torch.einsum('boi,bai->bao', rotmat_diff, bb_coord) + bb_coord = bb_coord + new_ca_coord.unsqueeze(1) + + self['x'] = new_ca_coord + self['axis_angle'] = new_axis_angle + self['fixed_coord'] = bb_coord + self['v'] = torch.einsum('boi,bai->bao', rotmat_diff, self['v']) + + @staticmethod + def empty(device): + return Residues( + x=torch.zeros(1, 3, device=device).float(), + mask=torch.zeros(1, 1, device=device).long(), + size=torch.zeros(1, device=device).long(), + ) + + +def randomize_tensors(tensor_dict, exclude_keys=None): + """Replace tensors with random tensors with the same shape.""" + exclude_keys = set() if exclude_keys is None else set(exclude_keys) + for k, v in tensor_dict.items(): + if isinstance(v, torch.Tensor) and k not in exclude_keys: + if torch.is_floating_point(v): + tensor_dict[k] = torch.randn_like(v) + else: + tensor_dict[k] = torch.randint_like(v, low=-42, high=42) + return tensor_dict diff --git a/src/data/dataset.py b/src/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..73d19bf738c125be2929a18b6af24e99c3eb8a45 --- /dev/null +++ b/src/data/dataset.py @@ -0,0 +1,208 @@ +import io +import random +import warnings +import torch +import webdataset as wds + +from pathlib import Path +from torch.utils.data import Dataset + +from src.data.data_utils import TensorDict, collate_entity +from src.constants import WEBDATASET_SHARD_SIZE, WEBDATASET_VAL_SIZE + + +class ProcessedLigandPocketDataset(Dataset): + def __init__(self, pt_path, ligand_transform=None, pocket_transform=None, + catch_errors=False): + + self.ligand_transform = ligand_transform + self.pocket_transform = pocket_transform + self.catch_errors = catch_errors + self.pt_path = pt_path + + self.data = torch.load(pt_path) + + # add number of nodes for convenience + for entity in ['ligands', 'pockets']: + self.data[entity]['size'] = torch.tensor([len(x) for x in self.data[entity]['x']]) + self.data[entity]['n_bonds'] = torch.tensor([len(x) for x in self.data[entity]['bond_one_hot']]) + + def __len__(self): + return len(self.data['ligands']['name']) + + def __getitem__(self, idx): + data = {} + data['ligand'] = {key: val[idx] for key, val in self.data['ligands'].items()} + data['pocket'] = {key: val[idx] for key, val in self.data['pockets'].items()} + try: + if self.ligand_transform is not None: + data['ligand'] = self.ligand_transform(data['ligand']) + if self.pocket_transform is not None: + data['pocket'] = self.pocket_transform(data['pocket']) + except (RuntimeError, ValueError) as e: + if self.catch_errors: + warnings.warn(f"{type(e).__name__}('{e}') in data transform. " + f"Returning random item instead") + # replace bad item with a random one + rand_idx = random.randint(0, len(self) - 1) + return self[rand_idx] + else: + raise e + return data + + @staticmethod + def collate_fn(batch_pairs, ligand_transform=None): + + out = {} + for entity in ['ligand', 'pocket']: + batch = [x[entity] for x in batch_pairs] + + if entity == 'ligand' and ligand_transform is not None: + max_size = max(x['size'].item() for x in batch) + # TODO: might have to remove elements from batch if processing fails, warn user in that case + batch = [ligand_transform(x, max_size=max_size) for x in batch] + + out[entity] = TensorDict(**collate_entity(batch)) + + return out + + +class ClusteredDataset(ProcessedLigandPocketDataset): + def __init__(self, pt_path, ligand_transform=None, pocket_transform=None, + catch_errors=False): + super().__init__(pt_path, ligand_transform, pocket_transform, catch_errors) + self.clusters = list(self.data['clusters'].values()) + + def __len__(self): + return len(self.clusters) + + def __getitem__(self, cidx): + cluster_inds = self.clusters[cidx] + # idx = cluster_inds[random.randint(0, len(cluster_inds) - 1)] + idx = random.choice(cluster_inds) + return super().__getitem__(idx) + +class DPODataset(ProcessedLigandPocketDataset): + def __init__(self, pt_path, ligand_transform=None, pocket_transform=None, + catch_errors=False): + self.ligand_transform = ligand_transform + self.pocket_transform = pocket_transform + self.catch_errors = catch_errors + self.pt_path = pt_path + + self.data = torch.load(pt_path) + + if not 'pockets' in self.data: + self.data['pockets'] = self.data['pockets_w'] + if not 'ligands' in self.data: + self.data['ligands'] = self.data['ligands_w'] + + if ( + len(self.data["ligands"]["name"]) + != len(self.data["ligands_l"]["name"]) + != len(self.data["pockets"]["name"]) + ): + raise ValueError( + "Error while importing DPO Dataset: Number of ligands winning, ligands losing and pockets must be the same" + ) + + # add number of nodes for convenience + for entity in ['ligands', 'ligands_l', 'pockets']: + self.data[entity]['size'] = torch.tensor([len(x) for x in self.data[entity]['x']]) + self.data[entity]['n_bonds'] = torch.tensor([len(x) for x in self.data[entity]['bond_one_hot']]) + + def __len__(self): + return len(self.data["ligands"]["name"]) + + def __getitem__(self, idx): + data = {} + data['ligand'] = {key: val[idx] for key, val in self.data['ligands'].items()} + data['ligand_l'] = {key: val[idx] for key, val in self.data['ligands_l'].items()} + data['pocket'] = {key: val[idx] for key, val in self.data['pockets'].items()} + try: + if self.ligand_transform is not None: + data['ligand'] = self.ligand_transform(data['ligand']) + data['ligand_l'] = self.ligand_transform(data['ligand_l']) + if self.pocket_transform is not None: + data['pocket'] = self.pocket_transform(data['pocket']) + except (RuntimeError, ValueError) as e: + if self.catch_errors: + warnings.warn(f"{type(e).__name__}('{e}') in data transform. " + f"Returning random item instead") + # replace bad item with a random one + rand_idx = random.randint(0, len(self) - 1) + return self[rand_idx] + else: + raise e + return data + + @staticmethod + def collate_fn(batch_pairs, ligand_transform=None): + + out = {} + for entity in ['ligand', 'ligand_l', 'pocket']: + batch = [x[entity] for x in batch_pairs] + + if entity in ['ligand', 'ligand_l'] and ligand_transform is not None: + max_size = max(x['size'].item() for x in batch) + batch = [ligand_transform(x, max_size=max_size) for x in batch] + + out[entity] = TensorDict(**collate_entity(batch)) + + return out + +########################################## +############### WebDatasets ############## +########################################## + +class ProteinLigandWebDataset(wds.WebDataset): + @staticmethod + def collate_fn(batch_pairs, ligand_transform=None): + return ProcessedLigandPocketDataset.collate_fn(batch_pairs, ligand_transform) + + +def wds_decoder(key, value): + return torch.load(io.BytesIO(value)) + + +def preprocess_wds_item(data): + out = {} + for entity in ['ligand', 'pocket']: + out[entity] = data['pt'][entity] + for attr in ['size', 'n_bonds']: + if torch.is_tensor(out[entity][attr]): + assert len(out[entity][attr]) == 0 + out[entity][attr] = 0 + + return out + + +def get_wds(data_path, stage, ligand_transform=None, pocket_transform=None): + current_data_dir = Path(data_path, stage) + shards = sorted(current_data_dir.glob('shard-?????.tar'), key=lambda s: int(s.name.split('-')[-1].split('.')[0])) + min_shard = min(shards).name.split('-')[-1].split('.')[0] + max_shard = max(shards).name.split('-')[-1].split('.')[0] + total_size = (int(max_shard) - int(min_shard) + 1) * WEBDATASET_SHARD_SIZE if stage == 'train' else WEBDATASET_VAL_SIZE + + url = f'{data_path}/{stage}/shard-{{{min_shard}..{max_shard}}}.tar' + ligand_transform_wrapper = lambda _data: _data + pocket_transform_wrapper = lambda _data: _data + + if ligand_transform is not None: + def ligand_transform_wrapper(_data): + _data['pt']['ligand'] = ligand_transform(_data['pt']['ligand']) + return _data + + if pocket_transform is not None: + def pocket_transform_wrapper(_data): + _data['pt']['pocket'] = pocket_transform(_data['pt']['pocket']) + return _data + + return ( + ProteinLigandWebDataset(url, nodesplitter=wds.split_by_node) + .decode(wds_decoder) + .map(ligand_transform_wrapper) + .map(pocket_transform_wrapper) + .map(preprocess_wds_item) + .with_length(total_size) + ) diff --git a/src/data/misc.py b/src/data/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..bb37458d36b41d41084be52665fb1e8bf33e9fb4 --- /dev/null +++ b/src/data/misc.py @@ -0,0 +1,19 @@ +# From: https://github.com/biopython/biopython/blob/master/Bio/PDB/Polypeptide.py#L128 + +protein_letters_1to3 = {'A': 'ALA', 'C': 'CYS', 'D': 'ASP', 'E': 'GLU', 'F': 'PHE', 'G': 'GLY', 'H': 'HIS', 'I': 'ILE', 'K': 'LYS', 'L': 'LEU', 'M': 'MET', 'N': 'ASN', 'P': 'PRO', 'Q': 'GLN', 'R': 'ARG', 'S': 'SER', 'T': 'THR', 'V': 'VAL', 'W': 'TRP', 'Y': 'TYR'} + + +protein_letters_3to1 = {'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'} + + +protein_letters_3to1_extended = {'A5N': 'N', 'A8E': 'V', 'A9D': 'S', 'AA3': 'A', 'AA4': 'A', 'AAR': 'R', 'ABA': 'A', 'ACL': 'R', 'AEA': 'C', 'AEI': 'D', 'AFA': 'N', 'AGM': 'R', 'AGQ': 'Y', 'AGT': 'C', 'AHB': 'N', 'AHL': 'R', 'AHO': 'A', 'AHP': 'A', 'AIB': 'A', 'AKL': 'D', 'AKZ': 'D', 'ALA': 'A', 'ALC': 'A', 'ALM': 'A', 'ALN': 'A', 'ALO': 'T', 'ALS': 'A', 'ALT': 'A', 'ALV': 'A', 'ALY': 'K', 'AME': 'M', 'AN6': 'L', 'AN8': 'A', 'API': 'K', 'APK': 'K', 'AR2': 'R', 'AR4': 'E', 'AR7': 'R', 'ARG': 'R', 'ARM': 'R', 'ARO': 'R', 'AS7': 'N', 'ASA': 'D', 'ASB': 'D', 'ASI': 'D', 'ASK': 'D', 'ASL': 'D', 'ASN': 'N', 'ASP': 'D', 'ASQ': 'D', 'AYA': 'A', 'AZH': 'A', 'AZK': 'K', 'AZS': 'S', 'AZY': 'Y', 'AVJ': 'H', 'A30': 'Y', 'A3U': 'F', 'ECC': 'Q', 'ECX': 'C', 'EFC': 'C', 'EHP': 'F', 'ELY': 'K', 'EME': 'E', 'EPM': 'M', 'EPQ': 'Q', 'ESB': 'Y', 'ESC': 'M', 'EXY': 'L', 'EXA': 'K', 'E0Y': 'P', 'E9V': 'H', 'E9M': 'W', 'EJA': 'C', 'EUP': 'T', 'EZY': 'G', 'E9C': 'Y', 'EW6': 'S', 'EXL': 'W', 'I2M': 'I', 'I4G': 'G', 'I58': 'K', 'IAM': 'A', 'IAR': 'R', 'ICY': 'C', 'IEL': 'K', 'IGL': 'G', 'IIL': 'I', 'ILE': 'I', 'ILG': 'E', 'ILM': 'I', 'ILX': 'I', 'ILY': 'K', 'IML': 'I', 'IOR': 'R', 'IPG': 'G', 'IT1': 'K', 'IYR': 'Y', 'IZO': 'M', 'IC0': 'G', 'M0H': 'C', 'M2L': 'K', 'M2S': 'M', 'M30': 'G', 'M3L': 'K', 'M3R': 'K', 'MA ': 'A', 'MAA': 'A', 'MAI': 'R', 'MBQ': 'Y', 'MC1': 'S', 'MCL': 'K', 'MCS': 'C', 'MD3': 'C', 'MD5': 'C', 'MD6': 'G', 'MDF': 'Y', 'ME0': 'M', 'MEA': 'F', 'MEG': 'E', 'MEN': 'N', 'MEQ': 'Q', 'MET': 'M', 'MEU': 'G', 'MFN': 'E', 'MGG': 'R', 'MGN': 'Q', 'MGY': 'G', 'MH1': 'H', 'MH6': 'S', 'MHL': 'L', 'MHO': 'M', 'MHS': 'H', 'MHU': 'F', 'MIR': 'S', 'MIS': 'S', 'MK8': 'L', 'ML3': 'K', 'MLE': 'L', 'MLL': 'L', 'MLY': 'K', 'MLZ': 'K', 'MME': 'M', 'MMO': 'R', 'MNL': 'L', 'MNV': 'V', 'MP8': 'P', 'MPQ': 'G', 'MSA': 'G', 'MSE': 'M', 'MSL': 'M', 'MSO': 'M', 'MT2': 'M', 'MTY': 'Y', 'MVA': 'V', 'MYK': 'K', 'MYN': 'R', 'QCS': 'C', 'QIL': 'I', 'QMM': 'Q', 'QPA': 'C', 'QPH': 'F', 'Q3P': 'K', 'QVA': 'C', 'QX7': 'A', 'Q2E': 'W', 'Q75': 'M', 'Q78': 'F', 'QM8': 'L', 'QMB': 'A', 'QNQ': 'C', 'QNT': 'C', 'QNW': 'C', 'QO2': 'C', 'QO5': 'C', 'QO8': 'C', 'QQ8': 'Q', 'U2X': 'Y', 'U3X': 'F', 'UF0': 'S', 'UGY': 'G', 'UM1': 'A', 'UM2': 'A', 'UMA': 'A', 'UQK': 'A', 'UX8': 'W', 'UXQ': 'F', 'YCM': 'C', 'YOF': 'Y', 'YPR': 'P', 'YPZ': 'Y', 'YTH': 'T', 'Y1V': 'L', 'Y57': 'K', 'YHA': 'K', '200': 'F', '23F': 'F', '23P': 'A', '26B': 'T', '28X': 'T', '2AG': 'A', '2CO': 'C', '2FM': 'M', '2GX': 'F', '2HF': 'H', '2JG': 'S', '2KK': 'K', '2KP': 'K', '2LT': 'Y', '2LU': 'L', '2ML': 'L', '2MR': 'R', '2MT': 'P', '2OR': 'R', '2P0': 'P', '2QZ': 'T', '2R3': 'Y', '2RA': 'A', '2RX': 'S', '2SO': 'H', '2TY': 'Y', '2VA': 'V', '2XA': 'C', '2ZC': 'S', '6CL': 'K', '6CW': 'W', '6GL': 'A', '6HN': 'K', '60F': 'C', '66D': 'I', '6CV': 'A', '6M6': 'C', '6V1': 'C', '6WK': 'C', '6Y9': 'P', '6DN': 'K', 'DA2': 'R', 'DAB': 'A', 'DAH': 'F', 'DBS': 'S', 'DBU': 'T', 'DBY': 'Y', 'DBZ': 'A', 'DC2': 'C', 'DDE': 'H', 'DDZ': 'A', 'DI7': 'Y', 'DHA': 'S', 'DHN': 'V', 'DIR': 'R', 'DLS': 'K', 'DM0': 'K', 'DMH': 'N', 'DMK': 'D', 'DNL': 'K', 'DNP': 'A', 'DNS': 'K', 'DNW': 'A', 'DOH': 'D', 'DON': 'L', 'DP1': 'R', 'DPL': 'P', 'DPP': 'A', 'DPQ': 'Y', 'DYS': 'C', 'D2T': 'D', 'DYA': 'D', 'DJD': 'F', 'DYJ': 'P', 'DV9': 'E', 'H14': 'F', 'H1D': 'M', 'H5M': 'P', 'HAC': 'A', 'HAR': 'R', 'HBN': 'H', 'HCM': 'C', 'HGY': 'G', 'HHI': 'H', 'HIA': 'H', 'HIC': 'H', 'HIP': 'H', 'HIQ': 'H', 'HIS': 'H', 'HL2': 'L', 'HLU': 'L', 'HMR': 'R', 'HNC': 'C', 'HOX': 'F', 'HPC': 'F', 'HPE': 'F', 'HPH': 'F', 'HPQ': 'F', 'HQA': 'A', 'HR7': 'R', 'HRG': 'R', 'HRP': 'W', 'HS8': 'H', 'HS9': 'H', 'HSE': 'S', 'HSK': 'H', 'HSL': 'S', 'HSO': 'H', 'HT7': 'W', 'HTI': 'C', 'HTR': 'W', 'HV5': 'A', 'HVA': 'V', 'HY3': 'P', 'HYI': 'M', 'HYP': 'P', 'HZP': 'P', 'HIX': 'A', 'HSV': 'H', 'HLY': 'K', 'HOO': 'H', 'H7V': 'A', 'L5P': 'K', 'LRK': 'K', 'L3O': 'L', 'LA2': 'K', 'LAA': 'D', 'LAL': 'A', 'LBY': 'K', 'LCK': 'K', 'LCX': 'K', 'LDH': 'K', 'LE1': 'V', 'LED': 'L', 'LEF': 'L', 'LEH': 'L', 'LEM': 'L', 'LEN': 'L', 'LET': 'K', 'LEU': 'L', 'LEX': 'L', 'LGY': 'K', 'LLO': 'K', 'LLP': 'K', 'LLY': 'K', 'LLZ': 'K', 'LME': 'E', 'LMF': 'K', 'LMQ': 'Q', 'LNE': 'L', 'LNM': 'L', 'LP6': 'K', 'LPD': 'P', 'LPG': 'G', 'LPS': 'S', 'LSO': 'K', 'LTR': 'W', 'LVG': 'G', 'LVN': 'V', 'LWY': 'P', 'LYF': 'K', 'LYK': 'K', 'LYM': 'K', 'LYN': 'K', 'LYO': 'K', 'LYP': 'K', 'LYR': 'K', 'LYS': 'K', 'LYU': 'K', 'LYX': 'K', 'LYZ': 'K', 'LAY': 'L', 'LWI': 'F', 'LBZ': 'K', 'P1L': 'C', 'P2Q': 'Y', 'P2Y': 'P', 'P3Q': 'Y', 'PAQ': 'Y', 'PAS': 'D', 'PAT': 'W', 'PBB': 'C', 'PBF': 'F', 'PCA': 'Q', 'PCC': 'P', 'PCS': 'F', 'PE1': 'K', 'PEC': 'C', 'PF5': 'F', 'PFF': 'F', 'PG1': 'S', 'PGY': 'G', 'PHA': 'F', 'PHD': 'D', 'PHE': 'F', 'PHI': 'F', 'PHL': 'F', 'PHM': 'F', 'PKR': 'P', 'PLJ': 'P', 'PM3': 'F', 'POM': 'P', 'PPN': 'F', 'PR3': 'C', 'PR4': 'P', 'PR7': 'P', 'PR9': 'P', 'PRJ': 'P', 'PRK': 'K', 'PRO': 'P', 'PRS': 'P', 'PRV': 'G', 'PSA': 'F', 'PSH': 'H', 'PTH': 'Y', 'PTM': 'Y', 'PTR': 'Y', 'PVH': 'H', 'PXU': 'P', 'PYA': 'A', 'PYH': 'K', 'PYX': 'C', 'PH6': 'P', 'P9S': 'C', 'P5U': 'S', 'POK': 'R', 'T0I': 'Y', 'T11': 'F', 'TAV': 'D', 'TBG': 'V', 'TBM': 'T', 'TCQ': 'Y', 'TCR': 'W', 'TEF': 'F', 'TFQ': 'F', 'TH5': 'T', 'TH6': 'T', 'THC': 'T', 'THR': 'T', 'THZ': 'R', 'TIH': 'A', 'TIS': 'S', 'TLY': 'K', 'TMB': 'T', 'TMD': 'T', 'TNB': 'C', 'TNR': 'S', 'TNY': 'T', 'TOQ': 'W', 'TOX': 'W', 'TPJ': 'P', 'TPK': 'P', 'TPL': 'W', 'TPO': 'T', 'TPQ': 'Y', 'TQI': 'W', 'TQQ': 'W', 'TQZ': 'C', 'TRF': 'W', 'TRG': 'K', 'TRN': 'W', 'TRO': 'W', 'TRP': 'W', 'TRQ': 'W', 'TRW': 'W', 'TRX': 'W', 'TRY': 'W', 'TS9': 'I', 'TSY': 'C', 'TTQ': 'W', 'TTS': 'Y', 'TXY': 'Y', 'TY1': 'Y', 'TY2': 'Y', 'TY3': 'Y', 'TY5': 'Y', 'TY8': 'Y', 'TY9': 'Y', 'TYB': 'Y', 'TYC': 'Y', 'TYE': 'Y', 'TYI': 'Y', 'TYJ': 'Y', 'TYN': 'Y', 'TYO': 'Y', 'TYQ': 'Y', 'TYR': 'Y', 'TYS': 'Y', 'TYT': 'Y', 'TYW': 'Y', 'TYY': 'Y', 'T8L': 'T', 'T9E': 'T', 'TNQ': 'W', 'TSQ': 'F', 'TGH': 'W', 'X2W': 'E', 'XCN': 'C', 'XPR': 'P', 'XSN': 'N', 'XW1': 'A', 'XX1': 'K', 'XYC': 'A', 'XA6': 'F', '11Q': 'P', '11W': 'E', '12L': 'P', '12X': 'P', '12Y': 'P', '143': 'C', '1AC': 'A', '1L1': 'A', '1OP': 'Y', '1PA': 'F', '1PI': 'A', '1TQ': 'W', '1TY': 'Y', '1X6': 'S', '56A': 'H', '5AB': 'A', '5CS': 'C', '5CW': 'W', '5HP': 'E', '5OH': 'A', '5PG': 'G', '51T': 'Y', '54C': 'W', '5CR': 'F', '5CT': 'K', '5FQ': 'A', '5GM': 'I', '5JP': 'S', '5T3': 'K', '5MW': 'K', '5OW': 'K', '5R5': 'S', '5VV': 'N', '5XU': 'A', '55I': 'F', '999': 'D', '9DN': 'N', '9NE': 'E', '9NF': 'F', '9NR': 'R', '9NV': 'V', '9E7': 'K', '9KP': 'K', '9WV': 'A', '9TR': 'K', '9TU': 'K', '9TX': 'K', '9U0': 'K', '9IJ': 'F', 'B1F': 'F', 'B27': 'T', 'B2A': 'A', 'B2F': 'F', 'B2I': 'I', 'B2V': 'V', 'B3A': 'A', 'B3D': 'D', 'B3E': 'E', 'B3K': 'K', 'B3U': 'H', 'B3X': 'N', 'B3Y': 'Y', 'BB6': 'C', 'BB7': 'C', 'BB8': 'F', 'BB9': 'C', 'BBC': 'C', 'BCS': 'C', 'BCX': 'C', 'BFD': 'D', 'BG1': 'S', 'BH2': 'D', 'BHD': 'D', 'BIF': 'F', 'BIU': 'I', 'BL2': 'L', 'BLE': 'L', 'BLY': 'K', 'BMT': 'T', 'BNN': 'F', 'BOR': 'R', 'BP5': 'A', 'BPE': 'C', 'BSE': 'S', 'BTA': 'L', 'BTC': 'C', 'BTK': 'K', 'BTR': 'W', 'BUC': 'C', 'BUG': 'V', 'BYR': 'Y', 'BWV': 'R', 'BWB': 'S', 'BXT': 'S', 'F2F': 'F', 'F2Y': 'Y', 'FAK': 'K', 'FB5': 'A', 'FB6': 'A', 'FC0': 'F', 'FCL': 'F', 'FDL': 'K', 'FFM': 'C', 'FGL': 'G', 'FGP': 'S', 'FH7': 'K', 'FHL': 'K', 'FHO': 'K', 'FIO': 'R', 'FLA': 'A', 'FLE': 'L', 'FLT': 'Y', 'FME': 'M', 'FOE': 'C', 'FP9': 'P', 'FPK': 'P', 'FT6': 'W', 'FTR': 'W', 'FTY': 'Y', 'FVA': 'V', 'FZN': 'K', 'FY3': 'Y', 'F7W': 'W', 'FY2': 'Y', 'FQA': 'K', 'F7Q': 'Y', 'FF9': 'K', 'FL6': 'D', 'JJJ': 'C', 'JJK': 'C', 'JJL': 'C', 'JLP': 'K', 'J3D': 'C', 'J9Y': 'R', 'J8W': 'S', 'JKH': 'P', 'N10': 'S', 'N7P': 'P', 'NA8': 'A', 'NAL': 'A', 'NAM': 'A', 'NBQ': 'Y', 'NC1': 'S', 'NCB': 'A', 'NEM': 'H', 'NEP': 'H', 'NFA': 'F', 'NIY': 'Y', 'NLB': 'L', 'NLE': 'L', 'NLN': 'L', 'NLO': 'L', 'NLP': 'L', 'NLQ': 'Q', 'NLY': 'G', 'NMC': 'G', 'NMM': 'R', 'NNH': 'R', 'NOT': 'L', 'NPH': 'C', 'NPI': 'A', 'NTR': 'Y', 'NTY': 'Y', 'NVA': 'V', 'NWD': 'A', 'NYB': 'C', 'NYS': 'C', 'NZH': 'H', 'N80': 'P', 'NZC': 'T', 'NLW': 'L', 'N0A': 'F', 'N9P': 'A', 'N65': 'K', 'R1A': 'C', 'R4K': 'W', 'RE0': 'W', 'RE3': 'W', 'RGL': 'R', 'RGP': 'E', 'RT0': 'P', 'RVX': 'S', 'RZ4': 'S', 'RPI': 'R', 'RVJ': 'A', 'VAD': 'V', 'VAF': 'V', 'VAH': 'V', 'VAI': 'V', 'VAL': 'V', 'VB1': 'K', 'VH0': 'P', 'VR0': 'R', 'V44': 'C', 'V61': 'F', 'VPV': 'K', 'V5N': 'H', 'V7T': 'K', 'Z01': 'A', 'Z3E': 'T', 'Z70': 'H', 'ZBZ': 'C', 'ZCL': 'F', 'ZU0': 'T', 'ZYJ': 'P', 'ZYK': 'P', 'ZZD': 'C', 'ZZJ': 'A', 'ZIQ': 'W', 'ZPO': 'P', 'ZDJ': 'Y', 'ZT1': 'K', '30V': 'C', '31Q': 'C', '33S': 'F', '33W': 'A', '34E': 'V', '3AH': 'H', '3BY': 'P', '3CF': 'F', '3CT': 'Y', '3GA': 'A', '3GL': 'E', '3MD': 'D', '3MY': 'Y', '3NF': 'Y', '3O3': 'E', '3PX': 'P', '3QN': 'K', '3TT': 'P', '3XH': 'G', '3YM': 'Y', '3WS': 'A', '3WX': 'P', '3X9': 'C', '3ZH': 'H', '7JA': 'I', '73C': 'S', '73N': 'R', '73O': 'Y', '73P': 'K', '74P': 'K', '7N8': 'F', '7O5': 'A', '7XC': 'F', '7ID': 'D', '7OZ': 'A', 'C1S': 'C', 'C1T': 'C', 'C1X': 'K', 'C22': 'A', 'C3Y': 'C', 'C4R': 'C', 'C5C': 'C', 'C6C': 'C', 'CAF': 'C', 'CAS': 'C', 'CAY': 'C', 'CCS': 'C', 'CEA': 'C', 'CGA': 'E', 'CGU': 'E', 'CGV': 'C', 'CHP': 'G', 'CIR': 'R', 'CLE': 'L', 'CLG': 'K', 'CLH': 'K', 'CME': 'C', 'CMH': 'C', 'CML': 'C', 'CMT': 'C', 'CR5': 'G', 'CS0': 'C', 'CS1': 'C', 'CS3': 'C', 'CS4': 'C', 'CSA': 'C', 'CSB': 'C', 'CSD': 'C', 'CSE': 'C', 'CSJ': 'C', 'CSO': 'C', 'CSP': 'C', 'CSR': 'C', 'CSS': 'C', 'CSU': 'C', 'CSW': 'C', 'CSX': 'C', 'CSZ': 'C', 'CTE': 'W', 'CTH': 'T', 'CWD': 'A', 'CWR': 'S', 'CXM': 'M', 'CY0': 'C', 'CY1': 'C', 'CY3': 'C', 'CY4': 'C', 'CYA': 'C', 'CYD': 'C', 'CYF': 'C', 'CYG': 'C', 'CYJ': 'K', 'CYM': 'C', 'CYQ': 'C', 'CYR': 'C', 'CYS': 'C', 'CYW': 'C', 'CZ2': 'C', 'CZZ': 'C', 'CG6': 'C', 'C1J': 'R', 'C4G': 'R', 'C67': 'R', 'C6D': 'R', 'CE7': 'N', 'CZS': 'A', 'G01': 'E', 'G8M': 'E', 'GAU': 'E', 'GEE': 'G', 'GFT': 'S', 'GHC': 'E', 'GHG': 'Q', 'GHW': 'E', 'GL3': 'G', 'GLH': 'Q', 'GLJ': 'E', 'GLK': 'E', 'GLN': 'Q', 'GLQ': 'E', 'GLU': 'E', 'GLY': 'G', 'GLZ': 'G', 'GMA': 'E', 'GME': 'E', 'GNC': 'Q', 'GPL': 'K', 'GSC': 'G', 'GSU': 'E', 'GT9': 'C', 'GVL': 'S', 'G3M': 'R', 'G5G': 'L', 'G1X': 'Y', 'G8X': 'P', 'K1R': 'C', 'KBE': 'K', 'KCX': 'K', 'KFP': 'K', 'KGC': 'K', 'KNB': 'A', 'KOR': 'M', 'KPI': 'K', 'KPY': 'K', 'KST': 'K', 'KYN': 'W', 'KYQ': 'K', 'KCR': 'K', 'KPF': 'K', 'K5L': 'S', 'KEO': 'K', 'KHB': 'K', 'KKD': 'D', 'K5H': 'C', 'K7K': 'S', 'OAR': 'R', 'OAS': 'S', 'OBS': 'K', 'OCS': 'C', 'OCY': 'C', 'OHI': 'H', 'OHS': 'D', 'OLD': 'H', 'OLT': 'T', 'OLZ': 'S', 'OMH': 'S', 'OMT': 'M', 'OMX': 'Y', 'OMY': 'Y', 'ONH': 'A', 'ORN': 'A', 'ORQ': 'R', 'OSE': 'S', 'OTH': 'T', 'OXX': 'D', 'OYL': 'H', 'O7A': 'T', 'O7D': 'W', 'O7G': 'V', 'O2E': 'S', 'O6H': 'W', 'OZW': 'F', 'S12': 'S', 'S1H': 'S', 'S2C': 'C', 'S2P': 'A', 'SAC': 'S', 'SAH': 'C', 'SAR': 'G', 'SBG': 'S', 'SBL': 'S', 'SCH': 'C', 'SCS': 'C', 'SCY': 'C', 'SD4': 'N', 'SDB': 'S', 'SDP': 'S', 'SEB': 'S', 'SEE': 'S', 'SEG': 'A', 'SEL': 'S', 'SEM': 'S', 'SEN': 'S', 'SEP': 'S', 'SER': 'S', 'SET': 'S', 'SGB': 'S', 'SHC': 'C', 'SHP': 'G', 'SHR': 'K', 'SIB': 'C', 'SLL': 'K', 'SLZ': 'K', 'SMC': 'C', 'SME': 'M', 'SMF': 'F', 'SNC': 'C', 'SNN': 'N', 'SOY': 'S', 'SRZ': 'S', 'STY': 'Y', 'SUN': 'S', 'SVA': 'S', 'SVV': 'S', 'SVW': 'S', 'SVX': 'S', 'SVY': 'S', 'SVZ': 'S', 'SXE': 'S', 'SKH': 'K', 'SNM': 'S', 'SNK': 'H', 'SWW': 'S', 'WFP': 'F', 'WLU': 'L', 'WPA': 'F', 'WRP': 'W', 'WVL': 'V', '02K': 'A', '02L': 'N', '02O': 'A', '02Y': 'A', '033': 'V', '037': 'P', '03Y': 'C', '04U': 'P', '04V': 'P', '05N': 'P', '07O': 'C', '0A0': 'D', '0A1': 'Y', '0A2': 'K', '0A8': 'C', '0A9': 'F', '0AA': 'V', '0AB': 'V', '0AC': 'G', '0AF': 'W', '0AG': 'L', '0AH': 'S', '0AK': 'D', '0AR': 'R', '0BN': 'F', '0CS': 'A', '0E5': 'T', '0EA': 'Y', '0FL': 'A', '0LF': 'P', '0NC': 'A', '0PR': 'Y', '0QL': 'C', '0TD': 'D', '0UO': 'W', '0WZ': 'Y', '0X9': 'R', '0Y8': 'P', '4AF': 'F', '4AR': 'R', '4AW': 'W', '4BF': 'F', '4CF': 'F', '4CY': 'M', '4DP': 'W', '4FB': 'P', '4FW': 'W', '4HL': 'Y', '4HT': 'W', '4IN': 'W', '4MM': 'M', '4PH': 'F', '4U7': 'A', '41H': 'F', '41Q': 'N', '42Y': 'S', '432': 'S', '45F': 'P', '4AK': 'K', '4D4': 'R', '4GJ': 'C', '4KY': 'P', '4L0': 'P', '4LZ': 'Y', '4N7': 'P', '4N8': 'P', '4N9': 'P', '4OG': 'W', '4OU': 'F', '4OV': 'S', '4OZ': 'S', '4PQ': 'W', '4SJ': 'F', '4WQ': 'A', '4HH': 'S', '4HJ': 'S', '4J4': 'C', '4J5': 'R', '4II': 'F', '4VI': 'R', '823': 'N', '8SP': 'S', '8AY': 'A'} + + +def is_aa(residue, standard=False): + if not isinstance(residue, str): + residue = f"{residue.get_resname():<3s}" + residue = residue.upper() + if standard: + return residue in protein_letters_3to1 + else: + return residue in protein_letters_3to1_extended \ No newline at end of file diff --git a/src/data/molecule_builder.py b/src/data/molecule_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..87aa8ec6b4e4524182ece926e4596ab9ae82d127 --- /dev/null +++ b/src/data/molecule_builder.py @@ -0,0 +1,107 @@ +from rdkit import Chem + +from src import constants + + +def remove_dummy_atoms(rdmol, sanitize=False): + # find exit atoms to be removed + dummy_inds = [] + for a in rdmol.GetAtoms(): + if a.GetSymbol() == '*': + dummy_inds.append(a.GetIdx()) + + dummy_inds = sorted(dummy_inds, reverse=True) + new_mol = Chem.EditableMol(rdmol) + for idx in dummy_inds: + new_mol.RemoveAtom(idx) + new_mol = new_mol.GetMol() + if sanitize: + Chem.SanitizeMol(new_mol) + return new_mol + + +def build_molecule(coords, atom_types, bonds=None, bond_types=None, + atom_props=None, atom_decoder=None, bond_decoder=None): + """ + Build RDKit molecule with given bonds + :param coords: N x 3 + :param atom_types: N + :param bonds: 2 x N_bonds + :param bond_types: N_bonds + :param atom_props: Dict, key: property name, value: list of float values (N,) + :param atom_decoder: list + :param bond_decoder: list + :return: RDKit molecule + """ + if atom_decoder is None: + atom_decoder = constants.atom_decoder + if bond_decoder is None: + bond_decoder = constants.bond_decoder + assert len(coords) == len(atom_types) + assert bonds is None or bonds.size(1) == len(bond_types) + + mol = Chem.RWMol() + for i, atom in enumerate(atom_types): + element = atom_decoder[atom.item()] + charge = None + explicitHs = None + + if len(element) > 1 and element.endswith('H'): + explicitHs = 1 + element = element[:-1] + elif element.endswith('+'): + charge = 1 + element = element[:-1] + elif element.endswith('-'): + charge = -1 + element = element[:-1] + + if element == 'NOATOM': + # element = 'Xe' # debug + element = '*' + + a = Chem.Atom(element) + + if explicitHs is not None: + a.SetNumExplicitHs(explicitHs) + if charge is not None: + a.SetFormalCharge(charge) + + if atom_props is not None: + for k, vals in atom_props.items(): + a.SetDoubleProp(k, vals[i].item()) + + mol.AddAtom(a) + + # add coordinates + conf = Chem.Conformer(mol.GetNumAtoms()) + for i in range(mol.GetNumAtoms()): + conf.SetAtomPosition(i, (coords[i, 0].item(), + coords[i, 1].item(), + coords[i, 2].item())) + mol.AddConformer(conf) + + # add bonds + if bonds is not None: + for bond, bond_type in zip(bonds.T, bond_types): + bond_type = bond_decoder[bond_type] + src = bond[0].item() + dst = bond[1].item() + + # try: + if bond_type == 'NOBOND' or mol.GetAtomWithIdx(src).GetSymbol() == '*' or mol.GetAtomWithIdx(dst).GetSymbol() == '*': + continue + # except RuntimeError: + # from pdb import set_trace; set_trace() + + if mol.GetBondBetweenAtoms(src, dst) is not None: + assert mol.GetBondBetweenAtoms(src, dst).GetBondType() == bond_type, \ + "Trying to assign two different types to the same bond." + continue + + if bond_type is None or src == dst: + continue + mol.AddBond(src, dst, bond_type) + + mol = remove_dummy_atoms(mol, sanitize=False) + return mol diff --git a/src/data/nerf.py b/src/data/nerf.py new file mode 100644 index 0000000000000000000000000000000000000000..fdcab1716dec671cc2b60020ee2bd56d00f6d887 --- /dev/null +++ b/src/data/nerf.py @@ -0,0 +1,250 @@ +""" +Natural Extension Reference Frame (NERF) + +Inspiration for parallel reconstruction: +https://github.com/EleutherAI/mp_nerf and references therein + +For atom names, see also: +https://www.ccpn.ac.uk/manual/v3/NEFAtomNames.html + +References: +- https://onlinelibrary.wiley.com/doi/10.1002/jcc.20237 (NERF) +- https://onlinelibrary.wiley.com/doi/10.1002/jcc.26768 (for code) +""" + +import warnings +import torch +import numpy as np + +from src.data.misc import protein_letters_3to1 +from src.constants import aa_atom_index, aa_atom_mask, aa_nerf_indices, aa_chi_indices, aa_chi_anchor_atom + + +# https://github.com/EleutherAI/mp_nerf/blob/master/mp_nerf/utils.py +def get_dihedral(c1, c2, c3, c4): + """ Returns the dihedral angle in radians. + Will use atan2 formula from: + https://en.wikipedia.org/wiki/Dihedral_angle#In_polymer_physics + Inputs: + * c1: (batch, 3) or (3,) + * c2: (batch, 3) or (3,) + * c3: (batch, 3) or (3,) + * c4: (batch, 3) or (3,) + """ + u1 = c2 - c1 + u2 = c3 - c2 + u3 = c4 - c3 + + return torch.atan2( ( (torch.norm(u2, dim=-1, keepdim=True) * u1) * torch.cross(u2,u3, dim=-1) ).sum(dim=-1) , + ( torch.cross(u1,u2, dim=-1) * torch.cross(u2, u3, dim=-1) ).sum(dim=-1) ) + + +# https://github.com/EleutherAI/mp_nerf/blob/master/mp_nerf/utils.py +def get_angle(c1, c2, c3): + """ Returns the angle in radians. + Inputs: + * c1: (batch, 3) or (3,) + * c2: (batch, 3) or (3,) + * c3: (batch, 3) or (3,) + """ + u1 = c2 - c1 + u2 = c3 - c2 + + # dont use acos since norms involved. + # better use atan2 formula: atan2(cross, dot) from here: + # https://johnblackburne.blogspot.com/2012/05/angle-between-two-3d-vectors.html + + # add a minus since we want the angle in reversed order - sidechainnet issues + return torch.atan2( torch.norm(torch.cross(u1,u2, dim=-1), dim=-1), + -(u1*u2).sum(dim=-1) ) + + +def get_nerf_params(biopython_residue): + aa = protein_letters_3to1[biopython_residue.get_resname()] + + # Basic mask and index tensors + atom_mask = torch.tensor(aa_atom_mask[aa], dtype=bool) + nerf_indices = torch.tensor(aa_nerf_indices[aa], dtype=int) + chi_indices = torch.tensor(aa_chi_indices[aa], dtype=int) + + fixed_coord = torch.zeros((5, 3)) + residue_coords = torch.zeros((14, 3)) # only required to compute internal coordinates during pre-processing + atom_found = torch.zeros_like(atom_mask) + for atom in biopython_residue.get_atoms(): + try: + idx = aa_atom_index[aa][atom.get_name()] + atom_found[idx] = True + except KeyError: + warnings.warn(f"{atom.get_name()} not found") + continue + + residue_coords[idx, :] = torch.from_numpy(atom.get_coord()) + + if atom.get_name() in ['N', 'CA', 'C', 'O', 'CB']: + fixed_coord[idx, :] = torch.from_numpy(atom.get_coord()) + + # Determine chi angles + chi = torch.zeros(6) # the last chi angle is a dummy and should always be zero + for chi_idx, anchor in aa_chi_anchor_atom[aa].items(): + idx_a = nerf_indices[anchor, 2] + idx_b = nerf_indices[anchor, 1] + idx_c = nerf_indices[anchor, 0] + + coords_a = residue_coords[idx_a, :] + coords_b = residue_coords[idx_b, :] + coords_c = residue_coords[idx_c, :] + coords_d = residue_coords[anchor, :] + + chi[chi_idx] = get_dihedral(coords_a, coords_b, coords_c, coords_d) + + # Compute remaining internal coordinates + # (parallel version) + idx_a = nerf_indices[:, 2] + idx_b = nerf_indices[:, 1] + idx_c = nerf_indices[:, 0] + + # update atom mask + # remove atoms for which one or several parameters are missing/incorrect + _atom_mask = atom_mask & atom_found & atom_found[idx_a] & atom_found[idx_b] & atom_found[idx_c] + if not torch.all(_atom_mask == atom_mask): + warnings.warn("Some atoms are missing for NERF reconstruction") + atom_mask = _atom_mask + + coords_a = residue_coords[idx_a] + coords_b = residue_coords[idx_b] + coords_c = residue_coords[idx_c] + coords_d = residue_coords + + length = torch.norm(coords_d - coords_c, dim=-1) + theta = get_angle(coords_b, coords_c, coords_d) + ddihedral = get_dihedral(coords_a, coords_b, coords_c, coords_d) + + # subtract chi angles from dihedrals + ddihedral = ddihedral - chi[chi_indices] + + # # (serial version) + # length = torch.zeros(14) + # theta = torch.zeros(14) + # ddihedral = torch.zeros(14) + # for i in range(5, 14): + # if not atom_mask[i]: # atom doesn't exist + # continue + + # idx_a = nerf_indices[i, 2] + # idx_b = nerf_indices[i, 1] + # idx_c = nerf_indices[i, 0] + + # coords_a = residue_coords[idx_a] + # coords_b = residue_coords[idx_b] + # coords_c = residue_coords[idx_c] + # coords_d = residue_coords[i] + + # length[i] = torch.norm(coords_d - coords_c, dim=-1) + # theta[i] = get_angle(coords_b, coords_c, coords_d) + # ddihedral[i] = get_dihedral(coords_a, coords_b, coords_c, coords_d) + + # # subtract chi angles from dihedrals + # ddihedral[i] = ddihedral[i] - chi[chi_indices[i]] + + return { + 'fixed_coord': fixed_coord, + 'atom_mask': atom_mask, + 'nerf_indices': nerf_indices, + 'length': length, + 'theta': theta, + 'chi': chi, + 'ddihedral': ddihedral, + 'chi_indices': chi_indices, + } + + +# https://github.com/EleutherAI/mp_nerf/blob/master/mp_nerf/massive_pnerf.py#L38C1-L65C67 +def mp_nerf_torch(a, b, c, l, theta, chi): + """ Custom Natural extension of Reference Frame. + Inputs: + * a: (batch, 3) or (3,). point(s) of the plane, not connected to d + * b: (batch, 3) or (3,). point(s) of the plane, not connected to d + * c: (batch, 3) or (3,). point(s) of the plane, connected to d + * theta: (batch,) or (float). angle(s) between b-c-d + * chi: (batch,) or float. dihedral angle(s) between the a-b-c and b-c-d planes + Outputs: d (batch, 3) or (float). the next point in the sequence, linked to c + """ + # safety check + if not ( (-np.pi <= theta) * (theta <= np.pi) ).all().item(): + raise ValueError(f"theta(s) must be in radians and in [-pi, pi]. theta(s) = {theta}") + # calc vecs + ba = b-a + cb = c-b + # calc rotation matrix. based on plane normals and normalized + n_plane = torch.cross(ba, cb, dim=-1) + n_plane_ = torch.cross(n_plane, cb, dim=-1) + rotate = torch.stack([cb, n_plane_, n_plane], dim=-1) + rotate /= torch.norm(rotate, dim=-2, keepdim=True) + # calc proto point, rotate. add (-1 for sidechainnet convention) + # https://github.com/jonathanking/sidechainnet/issues/14 + d = torch.stack([-torch.cos(theta), + torch.sin(theta) * torch.cos(chi), + torch.sin(theta) * torch.sin(chi)], dim=-1).unsqueeze(-1) + # extend base point, set length + return c + l.unsqueeze(-1) * torch.matmul(rotate, d).squeeze() + + +# inspired by: https://github.com/EleutherAI/mp_nerf/blob/master/mp_nerf/proteins.py#L323C5-L344C65 +def ic_to_coords(fixed_coord, atom_mask, nerf_indices, length, theta, chi, ddihedral, chi_indices): + """ + Run NERF in parallel for all residues. + + :param fixed_coord: (L, 5, 3) coordinates of (N, CA, C, O, CB) atoms, they don't depend on chi angles + :param atom_mask: (L, 14) indicates whether atom exists in this residue + :param nerf_indices: (L, 14, 3) indices of the three previous atoms ({c, b, a} for the NERF algorithm) + :param length: (L, 14) bond length between this and previous atom + :param theta: (L, 14) angle between this and previous two atoms + :param chi: (L, 6) values of the 5 rotatable bonds, plus zero in last column + :param ddihedral: (L, 14) angle offset to which chi is added + :param chi_indices: (L, 14) indexes into the chi array + :returns: (L, 14, 3) tensor with all coordinates, non-existing atoms are assigned CA coords + """ + + if not torch.all(chi[:, 5] == 0): + chi[:, 5] = 0.0 + warnings.warn("Last column of 'chi' tensor should be zero. Overriding values.") + assert torch.all(chi[:, 5] == 0) + + L, device = fixed_coord.size(0), fixed_coord.device + coords = torch.zeros((L, 14, 3), device=device) + coords[:, :5, :] = fixed_coord + + for i in range(5, 14): + level_mask = atom_mask[:, i] + # level_mask = torch.ones(len(atom_mask), dtype=bool) + + length_i = length[level_mask, i] + theta_i = theta[level_mask, i] + + # dihedral_i = dihedral[level_mask, i] + dihedral_i = chi[level_mask, chi_indices[level_mask, i]] + ddihedral[level_mask, i] + + idx_a = nerf_indices[level_mask, i, 2] + idx_b = nerf_indices[level_mask, i, 1] + idx_c = nerf_indices[level_mask, i, 0] + + coords[level_mask, i] = mp_nerf_torch(coords[level_mask, idx_a], + coords[level_mask, idx_b], + coords[level_mask, idx_c], + length_i, + theta_i, + dihedral_i) + + if coords.isnan().any(): + warnings.warn("Side chain reconstruction error. Removing affected atoms...") + + # mask out affected atoms + m, n, _ = torch.where(coords.isnan()) + atom_mask[m, n] = False + coords[m, n, :] = 0.0 + + # replace non-existing atom coords with CA coords (TODO: don't hard-code CA index) + coords = atom_mask.unsqueeze(-1) * coords + \ + (~atom_mask.unsqueeze(2)) * coords[:, 1, :].unsqueeze(1) + + return coords diff --git a/src/data/normal_modes.py b/src/data/normal_modes.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec57a380d8955911b91a7a1e94f84f99119fc53 --- /dev/null +++ b/src/data/normal_modes.py @@ -0,0 +1,69 @@ +import warnings +import numpy as np +import prody +prody.confProDy(verbosity='none') +from prody import parsePDB, ANM + + +def pdb_to_normal_modes(pdb_file, num_modes=5, nmax=5000): + """ + Compute normal modes for a PDB file using an Anisotropic Network Model (ANM) + http://prody.csb.pitt.edu/tutorials/enm_analysis/anm.html (accessed 01/11/2023) + """ + protein = parsePDB(pdb_file, model=1).select('calpha') + + if len(protein) > nmax: + warnings.warn("Protein is too big. Returning zeros...") + eig_vecs = np.zeros((len(protein), 3, num_modes)) + + else: + # build Hessian + anm = ANM('ANM analysis') + anm.buildHessian(protein, cutoff=15.0, gamma=1.0) + + # calculate normal modes + anm.calcModes(num_modes, zeros=False) + + # only use slowest modes + eig_vecs = anm.getEigvecs() # shape: (num_atoms * 3, num_modes) + eig_vecs = eig_vecs.reshape(len(protein), 3, num_modes) + # eig_vals = anm.getEigvals() # shape: (num_modes,) + + nm_dict = {} + for atom, nm_vec in zip(protein, eig_vecs): + chain = atom.getChid() + resi = atom.getResnum() + name = atom.getName() + nm_dict[(chain, resi, name)] = nm_vec.T + + return nm_dict + + +if __name__ == "__main__": + import argparse + from pathlib import Path + import torch + from tqdm import tqdm + + parser = argparse.ArgumentParser() + parser.add_argument('basedir', type=Path) + parser.add_argument('--outfile', type=Path, default=None) + args = parser.parse_args() + + # Read data split + split_path = Path(args.basedir, 'split_by_name.pt') + data_split = torch.load(split_path) + + pockets = [x[0] for split in data_split.values() for x in split] + + all_normal_modes = {} + for p in tqdm(pockets): + pdb_file = Path(args.basedir, 'crossdocked_pocket10', p) + + try: + nm_dict = pdb_to_normal_modes(str(pdb_file)) + all_normal_modes[p] = nm_dict + except AttributeError as e: + warnings.warn(str(e)) + + np.save(args.outfile, all_normal_modes) diff --git a/src/data/postprocessing.py b/src/data/postprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..c02b2d232e867b28555b90b982cba171ae0d87e3 --- /dev/null +++ b/src/data/postprocessing.py @@ -0,0 +1,93 @@ +import warnings + +from rdkit import Chem +from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule, UFFHasAllMoleculeParams + +from src.data import sanifix + + +def uff_relax(mol, max_iter=200): + """ + Uses RDKit's universal force field (UFF) implementation to optimize a + molecule. + """ + if not UFFHasAllMoleculeParams(mol): + warnings.warn('UFF parameters not available for all atoms. ' + 'Returning None.') + return None + + try: + more_iterations_required = UFFOptimizeMolecule(mol, maxIters=max_iter) + if more_iterations_required: + warnings.warn(f'Maximum number of FF iterations reached. ' + f'Returning molecule after {max_iter} relaxation steps.') + + except RuntimeError: + return None + + return mol + + +def add_hydrogens(rdmol): + return Chem.AddHs(rdmol, addCoords=(len(rdmol.GetConformers()) > 0)) + + +def get_largest_fragment(rdmol): + mol_frags = Chem.GetMolFrags(rdmol, asMols=True, sanitizeFrags=False) + largest_frag = max(mol_frags, default=rdmol, key=lambda m: m.GetNumAtoms()) + + # try: + # Chem.SanitizeMol(largest_frag) + # except ValueError: + # return None + + return largest_frag + + +def process_all(rdmol, largest_frag=True, adjust_aromatic_Ns=True, relax_iter=0): + """ + Apply all filters and post-processing steps. Returns a new molecule. + + Returns: + RDKit molecule or None if it does not pass the filters or processing + fails + """ + + # Only consider non-trivial molecules + if rdmol.GetNumAtoms() < 1: + return None + + # Create a copy + mol = Chem.Mol(rdmol) + + # try: + # Chem.SanitizeMol(mol) + # except ValueError: + # warnings.warn('Sanitization failed. Returning None.') + # return None + + if largest_frag: + mol = get_largest_fragment(mol) + # if mol is None: + # return None + + if adjust_aromatic_Ns: + mol = sanifix.fix_mol(mol) + if mol is None: + return None + + # if add_hydrogens: + # mol = add_hydrogens(mol) + + if relax_iter > 0: + mol = uff_relax(mol, relax_iter) + if mol is None: + return None + + try: + Chem.SanitizeMol(mol) + except ValueError: + warnings.warn('Sanitization failed. Returning None.') + return None + + return mol diff --git a/src/data/process_crossdocked.py b/src/data/process_crossdocked.py new file mode 100644 index 0000000000000000000000000000000000000000..ed4af05d8c5369403f7ca7c1780342eaa8c93b71 --- /dev/null +++ b/src/data/process_crossdocked.py @@ -0,0 +1,176 @@ +from pathlib import Path +from time import time +import argparse +import shutil +import random +import yaml +from collections import defaultdict + +import torch +from tqdm import tqdm +import numpy as np +from Bio.PDB import PDBParser +from rdkit import Chem + +import sys +basedir = Path(__file__).resolve().parent.parent.parent +sys.path.append(str(basedir)) + +from src.data.data_utils import process_raw_pair, get_n_nodes, get_type_histogram +from src.data.data_utils import rdmol_to_smiles +from src.constants import atom_encoder, bond_encoder + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('basedir', type=Path) + parser.add_argument('--outdir', type=Path, default=None) + parser.add_argument('--split_path', type=Path, default=None) + parser.add_argument('--pocket', type=str, default='CA+', + choices=['side_chain_bead', 'CA+']) + parser.add_argument('--random_seed', type=int, default=42) + parser.add_argument('--val_size', type=int, default=100) + parser.add_argument('--normal_modes', action='store_true') + parser.add_argument('--flex', action='store_true') + parser.add_argument('--toy', action='store_true') + args = parser.parse_args() + + random.seed(args.random_seed) + + datadir = args.basedir / 'crossdocked_pocket10/' + + # Make output directory + dirname = f"processed_crossdocked_{args.pocket}" + if args.flex: + dirname += '_flex' + if args.normal_modes: + dirname += '_nma' + if args.toy: + dirname += '_toy' + processed_dir = Path(args.basedir, dirname) if args.outdir is None else args.outdir + processed_dir.mkdir(parents=True) + + # Read data split + split_path = Path(args.basedir, 'split_by_name.pt') if args.split_path is None else args.split_path + data_split = torch.load(split_path) + + # If there is no validation set, copy training examples (the validation set + # is not very important in this application) + if 'val' not in data_split: + random.shuffle(data_split['train']) + data_split['val'] = data_split['train'][-args.val_size:] + data_split['train'] = data_split['train'][:-args.val_size] + + if args.toy: + data_split['train'] = random.sample(data_split['train'], 100) + + failed = {} + train_smiles = [] + + n_samples_after = {} + for split in data_split.keys(): + + print(f"Processing {split} dataset...") + + ligands = defaultdict(list) + pockets = defaultdict(list) + + tic = time() + pbar = tqdm(data_split[split]) + for pocket_fn, ligand_fn in pbar: + + pbar.set_description(f'#failed: {len(failed)}') + + sdffile = datadir / f'{ligand_fn}' + pdbfile = datadir / f'{pocket_fn}' + + try: + pdb_model = PDBParser(QUIET=True).get_structure('', pdbfile)[0] + + rdmol = Chem.SDMolSupplier(str(sdffile))[0] + + ligand, pocket = process_raw_pair( + pdb_model, rdmol, pocket_representation=args.pocket, + compute_nerf_params=args.flex, compute_bb_frames=args.flex, + nma_input=pdbfile if args.normal_modes else None) + + except (KeyError, AssertionError, FileNotFoundError, IndexError, + ValueError, AttributeError) as e: + failed[(split, sdffile, pdbfile)] = (type(e).__name__, str(e)) + continue + + nerf_keys = ['fixed_coord', 'atom_mask', 'nerf_indices', 'length', 'theta', 'chi', 'ddihedral', 'chi_indices'] + for k in ['x', 'one_hot', 'bonds', 'bond_one_hot', 'v', 'nma_vec'] + nerf_keys + ['axis_angle']: + if k in ligand: + ligands[k].append(ligand[k]) + if k in pocket: + pockets[k].append(pocket[k]) + + pocket_file = pdbfile.name.replace('_', '-') + ligand_file = Path(pocket_file).stem + '_' + Path(sdffile).name.replace('_', '-') + ligands['name'].append(ligand_file) + pockets['name'].append(pocket_file) + train_smiles.append(rdmol_to_smiles(rdmol)) + + if split in {'val', 'test'}: + pdb_sdf_dir = processed_dir / split + pdb_sdf_dir.mkdir(exist_ok=True) + + # Copy PDB file + pdb_file_out = Path(pdb_sdf_dir, pocket_file) + shutil.copy(pdbfile, pdb_file_out) + + # Copy SDF file + sdf_file_out = Path(pdb_sdf_dir, ligand_file) + shutil.copy(sdffile, sdf_file_out) + + data = {'ligands': ligands, 'pockets': pockets} + torch.save(data, Path(processed_dir, f'{split}.pt')) + + if split == 'train': + np.save(Path(processed_dir, 'train_smiles.npy'), train_smiles) + + print(f"Processing {split} set took {(time() - tic) / 60.0:.2f} minutes") + + + # -------------------------------------------------------------------------- + # Compute statistics & additional information + # -------------------------------------------------------------------------- + train_data = torch.load(Path(processed_dir, f'train.pt')) + + # Maximum molecule size + max_ligand_size = max([len(x) for x in train_data['ligands']['x']]) + + # Joint histogram of number of ligand and pocket nodes + pocket_coords = train_data['pockets']['x'] + ligand_coords = train_data['ligands']['x'] + n_nodes = get_n_nodes(ligand_coords, pocket_coords) + np.save(Path(processed_dir, 'size_distribution.npy'), n_nodes) + + # Get histograms of ligand node types + lig_one_hot = [x.numpy() for x in train_data['ligands']['one_hot']] + ligand_hist = get_type_histogram(lig_one_hot, atom_encoder) + np.save(Path(processed_dir, 'ligand_type_histogram.npy'), ligand_hist) + + # Get histograms of ligand edge types + lig_bond_one_hot = [x.numpy() for x in train_data['ligands']['bond_one_hot']] + ligand_bond_hist = get_type_histogram(lig_bond_one_hot, bond_encoder) + np.save(Path(processed_dir, 'ligand_bond_type_histogram.npy'), ligand_bond_hist) + + # Write error report + error_str = "" + for k, v in failed.items(): + error_str += f"{'Split':<15}: {k[0]}\n" + error_str += f"{'Ligand':<15}: {k[1]}\n" + error_str += f"{'Pocket':<15}: {k[2]}\n" + error_str += f"{'Error type':<15}: {v[0]}\n" + error_str += f"{'Error msg':<15}: {v[1]}\n\n" + + with open(Path(processed_dir, 'errors.txt'), 'w') as f: + f.write(error_str) + + metadata = { + 'max_ligand_size': max_ligand_size + } + with open(Path(processed_dir, 'metadata.yml'), 'w') as f: + yaml.dump(metadata, f, default_flow_style=False) diff --git a/src/data/process_dpo_dataset.py b/src/data/process_dpo_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f05b1c31385af8db10d8c7c95d4dd6905f2153b2 --- /dev/null +++ b/src/data/process_dpo_dataset.py @@ -0,0 +1,406 @@ +import argparse +from pathlib import Path +import numpy as np +import random +import shutil +from time import time +from collections import defaultdict +from Bio.PDB import PDBParser +from rdkit import Chem +import torch +from tqdm import tqdm +import pandas as pd +from itertools import combinations + +import sys +basedir = Path(__file__).resolve().parent.parent.parent +sys.path.append(str(basedir)) + +from src.sbdd_metrics.metrics import REOSEvaluator, MedChemEvaluator, PoseBustersEvaluator, GninaEvalulator +from src.data.data_utils import process_raw_pair, rdmol_to_smiles + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--smplsdir', type=Path, required=True) + parser.add_argument('--metrics-detailed', type=Path, required=False) + parser.add_argument('--ignore-missing-scores', action='store_true') + parser.add_argument('--datadir', type=Path, required=True) + parser.add_argument('--dpo-criterion', type=str, default='reos.all', + choices=['reos.all', 'medchem.sa', 'medchem.qed', 'gnina.vina_efficiency','combined']) + parser.add_argument('--basedir', type=Path, default=None) + parser.add_argument('--pocket', type=str, default='CA+', + choices=['side_chain_bead', 'CA+']) + parser.add_argument('--gnina', type=Path, default='gnina') + parser.add_argument('--random_seed', type=int, default=42) + parser.add_argument('--normal_modes', action='store_true') + parser.add_argument('--flex', action='store_true') + parser.add_argument('--toy', action='store_true') + parser.add_argument('--toy_size', type=int, default=100) + parser.add_argument('--n_pairs', type=int, default=5) + args = parser.parse_args() + return args + +def scan_smpl_dir(samples_dir): + samples_dir = Path(samples_dir) + subdirs = [] + for subdir in tqdm(samples_dir.iterdir(), desc='Scanning samples'): + if not subdir.is_dir(): + continue + if not sample_dir_valid(subdir): + continue + subdirs.append(subdir) + return subdirs + +def sample_dir_valid(samples_dir): + pocket = samples_dir / '0_pocket.pdb' + if not pocket.exists(): + return False + ligands = list(samples_dir.glob('*_ligand.sdf')) + if len(ligands) < 2: + return False + for ligand in ligands: + if ligand.stat().st_size == 0: + return False + return True + +def return_winning_losing_smpl(score_1, score_2, criterion): + if criterion == 'reos.all': + if score_1 == score_2: + return None + return score_1 > score_2 + elif criterion == 'medchem.sa': + if np.abs(score_1 - score_2) < 0.5: + return None + return score_1 < score_2 + elif criterion == 'medchem.qed': + if np.abs(score_1 - score_2) < 0.1: + return None + return score_1 > score_2 + elif criterion == 'gnina.vina_efficiency': + if np.abs(score_1 - score_2) < 0.1: + return None + return score_1 < score_2 + elif criterion == 'combined': + score_reos_1, score_reos_2 = score_1['reos.all'], score_2['reos.all'] + score_sa_1, score_sa_2 = score_1['medchem.sa'], score_2['medchem.sa'] + score_qed_1, score_qed_2 = score_1['medchem.qed'], score_2['medchem.qed'] + score_vina_1, score_vina_2 = score_1['gnina.vina_efficiency'], score_2['gnina.vina_efficiency'] + if score_reos_1 == score_reos_2: return None + # checking consistency + reos_sign = score_reos_1 > score_reos_2 + sa_sign = score_sa_1 < score_sa_2 + qed_sign = score_qed_1 > score_qed_2 + vina_sign = score_vina_1 < score_vina_2 + signs = [reos_sign, sa_sign, qed_sign, vina_sign] + if all(signs) or not any(signs): return signs[0] + return None + +def compute_scores(sample_dirs, evaluator, criterion, n_pairs=5, toy=False, toy_size=100, + precomp_scores=None, ignore_missing_scores=False): + samples = [] + pose_evaluator = PoseBustersEvaluator() + pbar = tqdm(sample_dirs, desc='Computing scores for samples') + + for dir in pbar: + pocket = dir / '0_pocket.pdb' + ligands = list(dir.glob('*_ligand.sdf')) + + target_samples = [] + for lig_path in ligands: + try: + mol = Chem.SDMolSupplier(str(lig_path))[0] + if mol is None: + continue + smiles = rdmol_to_smiles(mol) + except Exception as e: + print('Failed to read ligand:', lig_path) + continue + + if precomp_scores is not None and str(lig_path) in precomp_scores.index: + mol_props = precomp_scores.loc[str(lig_path)].to_dict() + if criterion == 'combined': + if not 'reos.all' in mol_props or not 'medchem.sa' in mol_props or not 'medchem.qed' in mol_props or not 'gnina.vina_efficiency' in mol_props: + print(f'Missing combined scores for ligand:', lig_path) + continue + mol_props['combined'] = { + 'reos.all': mol_props['reos.all'], + 'medchem.sa': mol_props['medchem.sa'], + 'medchem.qed': mol_props['medchem.qed'], + 'gnina.vina_efficiency': mol_props['gnina.vina_efficiency'], + 'combined': mol_props['gnina.vina_efficiency'] + } + else: + mol_props = {} + if criterion not in mol_props: + if ignore_missing_scores: + print(f'Missing {criterion} for ligand:', lig_path) + continue + print(f'Recomputing {criterion} for ligand:', lig_path) + try: + eval_res = evaluator.evaluate(mol) + criterion_cat = criterion.split('.')[0] + eval_res = {f'{criterion_cat}.{k}': v for k, v in eval_res.items()} + score = eval_res[criterion] + except: + continue + else: + score = mol_props[criterion] + + if 'posebusters.all' not in mol_props: + if ignore_missing_scores: + print('Missing PoseBusters for ligand:', lig_path) + continue + print('Recomputing PoseBusters for ligand:', lig_path) + try: + pose_eval_res = pose_evaluator.evaluate(lig_path, pocket) + except: + continue + if 'all' not in pose_eval_res or not pose_eval_res['all']: + continue + else: + pose_eval_res = mol_props['posebusters.all'] + if not pose_eval_res: + continue + + target_samples.append({ + 'smiles': smiles, + 'score': score, + 'ligand_path': lig_path, + 'pocket_path': pocket + }) + + # Deduplicate by SMILES + unique_samples = {} + for sample in target_samples: + if sample['smiles'] not in unique_samples: + unique_samples[sample['smiles']] = sample + unique_samples = list(unique_samples.values()) + if len(unique_samples) < 2: + continue + + # Generate all possible pairs + all_pairs = list(combinations(unique_samples, 2)) + + # Calculate score differences and filter valid pairs + valid_pairs = [] + for s1, s2 in all_pairs: + sign = return_winning_losing_smpl(s1['score'], s2['score'], criterion) + if sign is None: + continue + score_diff = abs(s1['score'] - s2['score']) if not criterion == 'combined' else \ + abs(s1['score']['combined'] - s2['score']['combined']) + if sign: + valid_pairs.append((s1, s2, score_diff)) + elif sign is False: + valid_pairs.append((s2, s1, score_diff)) + + # Sort pairs by score difference (descending) and select top N pairs + valid_pairs.sort(key=lambda x: x[2], reverse=True) + used_ligand_paths = set() + selected_pairs = [] + for winning, losing, score_diff in valid_pairs: + if winning['ligand_path'] in used_ligand_paths or losing['ligand_path'] in used_ligand_paths: + continue + + selected_pairs.append((winning, losing, score_diff)) + used_ligand_paths.add(winning['ligand_path']) + used_ligand_paths.add(losing['ligand_path']) + + if len(selected_pairs) == n_pairs: + break + for winning, losing, _ in selected_pairs: + d = { + 'score_w': winning['score'], + 'score_l': losing['score'], + 'pocket_p': winning['pocket_path'], + 'ligand_p_w': winning['ligand_path'], + 'ligand_p_l': losing['ligand_path'] + } + if isinstance(winning['score'], dict): + for k, v in winning['score'].items(): + d[f'{k}_w'] = v + d['score_w'] = winning['score']['combined'] + if isinstance(losing['score'], dict): + for k, v in losing['score'].items(): + d[f'{k}_l'] = v + d['score_l'] = losing['score']['combined'] + samples.append(d) + + pbar.set_postfix({'samples': len(samples)}) + + if toy and len(samples) >= toy_size: + break + + return samples + +def main(): + args = parse_args() + + if 'reos' in args.dpo_criterion: + evaluator = REOSEvaluator() + elif 'medchem' in args.dpo_criterion: + evaluator = MedChemEvaluator() + elif 'gnina' in args.dpo_criterion: + evaluator = GninaEvalulator(gnina=args.gnina) + elif 'combined' in args.dpo_criterion: + evaluator = None # for combined criterion, metrics have to be computed separately + if args.metrics_detailed is None: + raise ValueError('For combined criterion, detailed metrics file has to be provided') + if not args.ignore_missing_scores: + raise ValueError('For combined criterion, --ignore-missing-scores flag has to be set') + else: + raise ValueError(f"Unknown DPO criterion: {args.dpo_criterion}") + + # Make output directory + dirname = f"dpo_{args.dpo_criterion.replace('.','_')}_{args.pocket}" + if args.flex: + dirname += '_flex' + if args.normal_modes: + dirname += '_nma' + if args.toy: + dirname += '_toy' + processed_dir = Path(args.basedir, dirname) + processed_dir.mkdir(parents=True, exist_ok=True) + + if (processed_dir / f'samples_{args.dpo_criterion}.csv').exists(): + print(f"Samples already computed for criterion {args.dpo_criterion}, loading from file") + samples = pd.read_csv(processed_dir / f'samples_{args.dpo_criterion}.csv') + samples = [dict(row) for _, row in samples.iterrows()] + print(f"Found {len(samples)} winning/losing samples") + else: + print('Scanning sample directory...') + samples_dir = Path(args.smplsdir) + # scan dir + sample_dirs = scan_smpl_dir(samples_dir) + if args.metrics_detailed: + print(f'Loading precomputed scores from {args.metrics_detailed}') + precomp_scores = pd.read_csv(args.metrics_detailed) + precomp_scores = precomp_scores.set_index('sdf_file') + else: + precomp_scores = None + print(f'Found {len(sample_dirs)} valid sample directories') + print('Computing scores...') + samples = compute_scores(sample_dirs, evaluator, args.dpo_criterion, + n_pairs=args.n_pairs, toy=args.toy, toy_size=args.toy_size, + precomp_scores=precomp_scores, + ignore_missing_scores=args.ignore_missing_scores) + print(f'Found {len(samples)} winning/losing samples, saving to file') + pd.DataFrame(samples).to_csv(Path(processed_dir, f'samples_{args.dpo_criterion}.csv'), index=False) + + data_split = {} + data_split['train'] = samples + if args.toy: + data_split['train'] = random.sample(samples, min(args.toy_size, len(data_split['train']))) + + failed = {} + train_smiles = [] + + for split in data_split.keys(): + + print(f"Processing {split} dataset...") + + ligands_w = defaultdict(list) + ligands_l = defaultdict(list) + pockets = defaultdict(list) + + tic = time() + pbar = tqdm(data_split[split]) + for entry in pbar: + + pbar.set_description(f'#failed: {len(failed)}') + + pdbfile = Path(entry['pocket_p']) + entry['ligand_p_w'] = Path(entry['ligand_p_w']) + entry['ligand_p_l'] = Path(entry['ligand_p_l']) + entry['ligand_w'] = Chem.SDMolSupplier(str(entry['ligand_p_w']))[0] + entry['ligand_l'] = Chem.SDMolSupplier(str(entry['ligand_p_l']))[0] + + try: + pdb_model = PDBParser(QUIET=True).get_structure('', pdbfile)[0] + + ligand_w, pocket = process_raw_pair( + pdb_model, entry['ligand_w'], pocket_representation=args.pocket, + compute_nerf_params=args.flex, compute_bb_frames=args.flex, + nma_input=pdbfile if args.normal_modes else None) + ligand_l, _ = process_raw_pair( + pdb_model, entry['ligand_l'], pocket_representation=args.pocket, + compute_nerf_params=args.flex, compute_bb_frames=args.flex, + nma_input=pdbfile if args.normal_modes else None) + + except (KeyError, AssertionError, FileNotFoundError, IndexError, + ValueError, AttributeError) as e: + failed[(split, entry['ligand_p_w'], entry['ligand_p_l'], pdbfile)] \ + = (type(e).__name__, str(e)) + continue + + nerf_keys = ['fixed_coord', 'atom_mask', 'nerf_indices', 'length', 'theta', 'chi', 'ddihedral', 'chi_indices'] + for k in ['x', 'one_hot', 'bonds', 'bond_one_hot', 'v', 'nma_vec'] + nerf_keys + ['axis_angle']: + if k in ligand_w: + ligands_w[k].append(ligand_w[k]) + ligands_l[k].append(ligand_l[k]) + if k in pocket: + pockets[k].append(pocket[k]) + + smpl_n = pdbfile.parent.name + pocket_file = f'{smpl_n}__{pdbfile.stem}.pdb' + ligand_file_w = f'{smpl_n}__{entry["ligand_p_w"].stem}.sdf' + ligand_file_l = f'{smpl_n}__{entry["ligand_p_l"].stem}.sdf' + ligands_w['name'].append(ligand_file_w) + ligands_l['name'].append(ligand_file_l) + pockets['name'].append(pocket_file) + train_smiles.append(rdmol_to_smiles(entry['ligand_w'])) + train_smiles.append(rdmol_to_smiles(entry['ligand_l'])) + + data = {'ligands_w': ligands_w, + 'ligands_l': ligands_l, + 'pockets': pockets} + torch.save(data, Path(processed_dir, f'{split}.pt')) + + if split == 'train': + np.save(Path(processed_dir, 'train_smiles.npy'), train_smiles) + + print(f"Processing {split} set took {(time() - tic) / 60.0:.2f} minutes") + + # cp stats from original dataset + size_distr_p = Path(args.datadir, 'size_distribution.npy') + type_histo_p = Path(args.datadir, 'ligand_type_histogram.npy') + bond_histo_p = Path(args.datadir, 'ligand_bond_type_histogram.npy') + metadata_p = Path(args.datadir, 'metadata.yml') + shutil.copy(size_distr_p, processed_dir) + shutil.copy(type_histo_p, processed_dir) + shutil.copy(bond_histo_p, processed_dir) + shutil.copy(metadata_p, processed_dir) + + # cp val and test .pt and dirs + val_dir = Path(args.datadir, 'val') + test_dir = Path(args.datadir, 'test') + val_pt = Path(args.datadir, 'val.pt') + test_pt = Path(args.datadir, 'test.pt') + assert val_dir.exists() and test_dir.exists() and val_pt.exists() and test_pt.exists() + if (processed_dir / 'val').exists(): + shutil.rmtree(processed_dir / 'val') + if (processed_dir / 'test').exists(): + shutil.rmtree(processed_dir / 'test') + shutil.copytree(val_dir, processed_dir / 'val') + shutil.copytree(test_dir, processed_dir / 'test') + shutil.copy(val_pt, processed_dir) + shutil.copy(test_pt, processed_dir) + + # Write error report + error_str = "" + for k, v in failed.items(): + error_str += f"{'Split':<15}: {k[0]}\n" + error_str += f"{'Ligand W':<15}: {k[1]}\n" + error_str += f"{'Ligand L':<15}: {k[2]}\n" + error_str += f"{'Pocket':<15}: {k[3]}\n" + error_str += f"{'Error type':<15}: {v[0]}\n" + error_str += f"{'Error msg':<15}: {v[1]}\n\n" + + with open(Path(processed_dir, 'errors.txt'), 'w') as f: + f.write(error_str) + + with open(Path(processed_dir, 'dataset_config.txt'), 'w') as f: + f.write(str(args)) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/src/data/sanifix.py b/src/data/sanifix.py new file mode 100644 index 0000000000000000000000000000000000000000..148542500db7694d8b7ed4b3ad43be7d76e69d36 --- /dev/null +++ b/src/data/sanifix.py @@ -0,0 +1,159 @@ +""" sanifix4.py + + Contribution from James Davidson + adapted from: https://github.com/abradle/rdkitserver/blob/master/MYSITE/src/testproject/mol_parsing/sanifix.py +""" +from rdkit import Chem +from rdkit.Chem import AllChem +import warnings + +def _FragIndicesToMol(oMol,indices): + em = Chem.EditableMol(Chem.Mol()) + + newIndices={} + for i,idx in enumerate(indices): + em.AddAtom(oMol.GetAtomWithIdx(idx)) + newIndices[idx]=i + + for i,idx in enumerate(indices): + at = oMol.GetAtomWithIdx(idx) + for bond in at.GetBonds(): + if bond.GetBeginAtomIdx()==idx: + oidx = bond.GetEndAtomIdx() + else: + oidx = bond.GetBeginAtomIdx() + # make sure every bond only gets added once: + if oidx math.pi + + # angle in [0, pi) & invert + theta_wrapped[inv_mask] = -1 * (2 * math.pi - theta_wrapped[inv_mask]) + + # apply + theta = torch.clamp(theta, min=eps) + point = point * (theta_wrapped / theta).unsqueeze(-1) + assert not point.isnan().any() + return point + + +def random_uniform(n_samples, device=None): + """ + Follow geomstats implementation: + https://geomstats.github.io/_modules/geomstats/geometry/special_orthogonal.html + + Args: + n_samples: int + Returns: + rotation vectors, (n, 3) + """ + random_point = (torch.rand(n_samples, 3, device=device) * 2 - 1) * math.pi + random_point = regularize(random_point) + + return random_point + + +def hat(rot_vec): + """ + Maps R^3 vector to a skew-symmetric matrix r (i.e. r \in R^{3x3} and r^T = -r). + Since we have the identity rv = rot_vec x v for all v \in R^3, this is + identical to a cross-product-matrix representation of rot_vec. + rot_vec x v = hat(rot_vec)^T v + See also: + https://en.wikipedia.org/wiki/Cross_product#Conversion_to_matrix_multiplication + https://en.wikipedia.org/wiki/Hat_notation#Cross_product + Args: + rot_vec: (n, 3) + Returns: + skew-symmetric matrices (n, 3, 3) + """ + basis = torch.tensor([ + [[0., 0., 0.], [0., 0., -1.], [0., 1., 0.]], + [[0., 0., 1.], [0., 0., 0.], [-1., 0., 0.]], + [[0., -1., 0.], [1., 0., 0.], [0., 0., 0.]] + ], device=rot_vec.device) + # basis = torch.tensor([ + # [[0., 0., 0.], [0., 0., 1.], [0., -1., 0.]], + # [[0., 0., -1.], [0., 0., 0.], [1., 0., 0.]], + # [[0., 1., 0.], [-1., 0., 0.], [0., 0., 0.]] + # ], device=rot_vec.device) + + return torch.einsum('...i,ijk->...jk', rot_vec, basis) + + +def inv_hat(skew_mat): + """ + Inverse of hat operation + Args: + skew_mat: skew-symmetric matrices (n, 3, 3) + Returns: + rotation vectors, (n, 3) + """ + + assert torch.allclose(-skew_mat, skew_mat.transpose(-2, -1), atol=1e-4), \ + f"Input not skew-symmetric (err={(-skew_mat - skew_mat.transpose(-2, -1)).abs().max():.4g})" + + # vec = torch.stack([ + # skew_mat[:, 1, 2], + # skew_mat[:, 2, 1], + # skew_mat[:, 0, 1] + # ], dim=1) + + vec = torch.stack([ + skew_mat[:, 2, 1], + skew_mat[:, 0, 2], + skew_mat[:, 1, 0] + ], dim=1) + + return vec + + +def matrix_from_rotation_vector(axis_angle, eps=1e-6): + """ + Args: + axis_angle: (n, 3) + Returns: + rotation matrices, (n, 3, 3) + """ + + axis_angle = regularize(axis_angle) + angle = axis_angle.norm(dim=-1) + _norm = torch.clamp(angle, min=eps).unsqueeze(-1) + skew_mat = hat(axis_angle / _norm) + + # https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula#Matrix_notation + _id = torch.eye(3, device=axis_angle.device).unsqueeze(0) + rot_mat = _id + \ + torch.sin(angle)[:, None, None] * skew_mat + \ + (1 - torch.cos(angle))[:, None, None] * torch.bmm(skew_mat, skew_mat) + + return rot_mat + + +class safe_acos(torch.autograd.Function): + """ + Implementation of arccos that avoids NaN in backward pass. + https://github.com/pytorch/pytorch/issues/8069#issuecomment-2041223872 + """ + EPS = 1e-4 + @classmethod + def d_acos_dx(cls, x): + x = torch.clamp(x, min=-1. + cls.EPS, max=1. - cls.EPS) + return -1.0 / (1 - x**2).sqrt() + + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + return input.acos() + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return grad_output * safe_acos.d_acos_dx(input) + + +def rotation_vector_from_matrix(rot_mat, approx=1e-4): + """ + Args: + rot_mat: (n, 3, 3) + approx: float, minimum angle below which an approximation will be used + for numerical stability + Returns: + rotation vector, (n, 3) + """ + + # https://en.wikipedia.org/wiki/Rotation_matrix#Conversion_from_rotation_matrix_to_axis%E2%80%93angle + # https://en.wikipedia.org/wiki/Axis%E2%80%93angle_representation#Log_map_from_SO(3)_to_%F0%9D%94%B0%F0%9D%94%AC(3) + + # determine axis + skew_mat = rot_mat - rot_mat.transpose(-2, -1) + + # determine the angle + cos_angle = 0.5 * (_batch_trace(rot_mat) - 1) + # arccos is only defined between -1 and 1 + assert torch.all(cos_angle.abs() <= 1 + 1e-6) + cos_angle = torch.clamp(cos_angle, min=-1., max=1.) + # abs_angle = torch.arccos(cos_angle) + abs_angle = safe_acos.apply(cos_angle) + + # avoid numerical instability; use sin(x) \approx x for small x + close_to_0 = abs_angle < approx + _fac = torch.empty_like(abs_angle) + _fac[close_to_0] = 0.5 + _fac[~close_to_0] = 0.5 * abs_angle[~close_to_0] / torch.sin(abs_angle[~close_to_0]) + + axis_angle = inv_hat(_fac[:, None, None] * skew_mat) + return regularize(axis_angle) + + +def get_jacobian(point, left=True, inverse=False, eps=1e-4): + + # # From Geomstats: https://geomstats.github.io/_modules/geomstats/geometry/special_orthogonal.html + # jacobian = so3_vector.jacobian_translation(point, left) + # + # if inverse: + # jacobian = torch.linalg.inv(jacobian) + + # Right Jacobian defined as J_r(theta) = \partial exp([theta]_x) / \partial theta + # https://math.stackexchange.com/questions/301533/jacobian-involving-so3-exponential-map-logr-expm + # Source: + # Chirikjian, Gregory S. Stochastic models, information theory, and Lie + # groups, volume 2: Analytic methods and modern applications. Vol. 2. + # Springer Science & Business Media, 2011. (page 40) + # NOTE: the definitions of 'inverse' and 'left' in the book are the opposite + # of their meanings in Geomstats, whose functionality we're mimicking here. + # This explains the differences in the equations. + angle_squared = point.square().sum(-1) + angle = angle_squared.sqrt() + skew_mat = hat(point) + + assert torch.all(angle <= math.pi) + close_to_0 = angle < eps + close_to_pi = (math.pi - angle) < eps + + angle = angle[:, None, None] + angle_squared = angle_squared[:, None, None] + + if inverse: + # _jacobian = torch.eye(3, device=point.device).unsqueeze(0) + \ + # (1 - torch.cos(angle)) / angle_squared * skew_mat + \ + # (angle - torch.sin(angle)) / angle ** 3 * (skew_mat @ skew_mat) + + _term1 = torch.empty_like(angle) + _term1[close_to_0] = 0.5 # approximate with value at zero + _term1[~close_to_0] = (1 - torch.cos(angle)) / angle_squared + + _term2 = torch.empty_like(angle) + _term2[close_to_0] = 1 / 6 # approximate with value at zero + _term2[~close_to_0] = (angle - torch.sin(angle)) / angle ** 3 + + jacobian = torch.eye(3, device=point.device).unsqueeze(0) + \ + _term1 * skew_mat + _term2 * (skew_mat @ skew_mat) + # assert torch.allclose(jacobian, _jacobian, atol=1e-4) + else: + # _jacobian = torch.eye(3, device=point.device).unsqueeze(0) - 0.5 * skew_mat + \ + # (1 / angle_squared - (1 + torch.cos(angle)) / (2 * angle * torch.sin(angle))) * (skew_mat @ skew_mat) + + _term1 = torch.empty_like(angle) + _term1[close_to_0] = 1 / 12 # approximate with value at zero + _term1[close_to_pi] = 1 / math.pi**2 # approximate with value at pi + default = ~close_to_0 & ~close_to_pi + _term1[default] = 1 / angle_squared[default] - \ + (1 + torch.cos(angle[default])) / (2 * angle[default] * torch.sin(angle[default])) + + jacobian = torch.eye(3, device=point.device).unsqueeze(0) - \ + 0.5 * skew_mat + _term1 * (skew_mat @ skew_mat) + # assert torch.allclose(jacobian, _jacobian, atol=1e-4) + + if left: + jacobian = jacobian.transpose(-2, -1) + + return jacobian + + +def compose_rotations(rot_vec_1, rot_vec_2): + rot_mat_1 = matrix_from_rotation_vector(rot_vec_1) + rot_mat_2 = matrix_from_rotation_vector(rot_vec_2) + rot_mat_out = torch.bmm(rot_mat_1, rot_mat_2) + return rotation_vector_from_matrix(rot_mat_out) + + +def exp(tangent): + """ + Exponential map at identity. + Args: + tangent: vector on the tangent space, (n, 3) + Returns: + rotation vector on the manifold, (n, 3) + """ + # rotations are already represented by rotation vectors + exp_from_identity = regularize(tangent) + return exp_from_identity + + +def exp_not_from_identity(tangent_vec, base_point): + """ + Exponential map at base point. + Args: + tangent_vec: vector on the tangent plane, (n, 3) + base_point: base point on the manifold, (n, 3) + Returns: + new point on the manifold, (n, 3) + """ + + tangent_vec = regularize(tangent_vec) + base_point = regularize(base_point) + + # Lie algebra is the tangent space at the identity element of a Lie group + # -> to identity + jacobian = get_jacobian(base_point, left=True, inverse=True) + tangent_vec_at_id = torch.einsum("...ij,...j->...i", jacobian, tangent_vec) + + # exponential map from identity + exp_from_identity = exp(tangent_vec_at_id) + + # -> back to base point + return compose_rotations(base_point, exp_from_identity) + + +def log(rot_vec, as_skew=False): + """ + Logarithm map from tangent space at the identity. + Args: + rot_vec: point on the manifold, (n, 3) + Returns: + vector on the tangent space, (n, 3) + """ + # rotations are already represented by rotation vectors + # log_from_id = regularize(rot_vec) + log_from_id = rot_vec + if as_skew: + log_from_id = hat(log_from_id) + return log_from_id + + +def log_not_from_identity(point, base_point): + """ + Logarithm map of point from base point. + Args: + point: point on the manifold, (n, 3) + base_point: base point on the manifold, (n, 3) + Returns: + vector on the tangent plane, (n, 3) + """ + point = regularize(point) + base_point = regularize(base_point) + + inv_base_point = -1 * base_point + + point_near_id = compose_rotations(inv_base_point, point) + + # logarithm map from identity + log_from_id = log(point_near_id) + + jacobian = get_jacobian(base_point, inverse=False) + tangent_vec_at_id = torch.einsum("...ij,...j->...i", jacobian, log_from_id) + + return tangent_vec_at_id + + +if __name__ == "__main__": + + import os + os.environ['GEOMSTATS_BACKEND'] = "pytorch" + import scipy.optimize # does not seem to be imported correctly when just loading geomstats + default_dtype = torch.get_default_dtype() + from geomstats.geometry.special_orthogonal import SpecialOrthogonal + torch.set_default_dtype(default_dtype) # Geomstats changes default type when imported + + so3_vector = SpecialOrthogonal(n=3, point_type="vector") + + # decorator + if torch.__version__ >= '2.0.0': + GEOMSTATS_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' + + def geomstats_tensor_type(func): + def inner(*args, **kwargs): + with torch.device(GEOMSTATS_DEVICE): + out = func(*args, **kwargs) + return out + + return inner + else: + GEOMSTATS_TENSOR_TYPE = 'torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatTensor' + + # GEOMSTATS_TENSOR_TYPE = 'torch.cuda.DoubleTensor' if torch.cuda.is_available() else 'torch.DoubleTensor' + def geomstats_tensor_type(func): + def inner(*args, **kwargs): + # tensor_type_before = TODO + torch.set_default_tensor_type(GEOMSTATS_TENSOR_TYPE) + out = func(*args, **kwargs) + # torch.set_default_tensor_type(tensor_type_before) + torch.set_default_tensor_type('torch.FloatTensor') + return out + + return inner + + @geomstats_tensor_type + def gs_matrix_from_rotation_vector(*args, **kwargs): + return so3_vector.matrix_from_rotation_vector(*args, **kwargs) + + @geomstats_tensor_type + def gs_rotation_vector_from_matrix(*args, **kwargs): + return so3_vector.rotation_vector_from_matrix(*args, **kwargs) + + @geomstats_tensor_type + def gs_exp_not_from_identity(*args, **kwargs): + return so3_vector.exp_not_from_identity(*args, **kwargs) + + @geomstats_tensor_type + def gs_log_not_from_identity(*args, **kwargs): + # norm of the rotation vector will be between 0 and pi + return so3_vector.log_not_from_identity(*args, **kwargs) + + @geomstats_tensor_type + def compose(*args, **kwargs): + return so3_vector.compose(*args, **kwargs) + + @geomstats_tensor_type + def inverse(*args, **kwargs): + return so3_vector.inverse(*args, **kwargs) + + @geomstats_tensor_type + def gs_random_uniform(*args, **kwargs): + return so3_vector.random_uniform(*args, **kwargs) + + + ############# + # RUN TESTS # + ############# + + n = 16 + device = 'cuda' if torch.cuda.is_available() else None + + ### regularize ### + + # vec = (torch.rand(n, 3) * 2 - 1) * math.pi + vec = (torch.rand(n, 3) * 4 - 2) * math.pi + axis_angle = regularize(vec) + assert torch.all(torch.cross(vec, axis_angle).norm(dim=-1) < 1e-5), "not all vectors collinear" + assert torch.all(axis_angle.norm(dim=-1) < math.pi) & torch.all(axis_angle.norm(dim=-1) >= 0), "norm not between 0 and pi" + + + ### matrix_from_rotation_vector ### + + rot_vec = random_uniform(16, device=device) + assert torch.allclose(matrix_from_rotation_vector(rot_vec), + gs_matrix_from_rotation_vector(rot_vec), atol=1e-06) + + + ### rotation_vector_from_matrix ### + + rot_vec = random_uniform(16, device=device) + rot_mat = matrix_from_rotation_vector(rot_vec) + assert torch.allclose(rotation_vector_from_matrix(rot_mat), + gs_rotation_vector_from_matrix(rot_mat), atol=1e-05) + + + ### exp_not_from_identity ### + + tangent_vec = random_uniform(16, device=device) + base_pt = random_uniform(16, device=device) + my_val = exp_not_from_identity(tangent_vec, base_pt) + gs_val = gs_exp_not_from_identity(tangent_vec, base_pt) + assert torch.allclose(my_val, gs_val, atol=1e-03), (my_val - gs_val).abs().max() + + + ### log_not_from_identity ### + + pt = random_uniform(16, device=device) + base_pt = random_uniform(16, device=device) + my_val = log_not_from_identity(pt, base_pt) + gs_val = gs_log_not_from_identity(pt, base_pt) + assert torch.allclose(my_val, gs_val, atol=1e-03), (my_val - gs_val).abs().max() + + + print("All tests successful!") diff --git a/src/default/size_distribution.npy b/src/default/size_distribution.npy new file mode 100644 index 0000000000000000000000000000000000000000..3eb0b0553019e3c4d873420352a9cc050f53cbea --- /dev/null +++ b/src/default/size_distribution.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d4e677a30c4b972051499bb5577a0de773e4f92ec54c282d432f94873406ec7e +size 158488 diff --git a/src/generate.py b/src/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..9bfa7c3ee82ff28aadb68a460b14e9e45ccdafd3 --- /dev/null +++ b/src/generate.py @@ -0,0 +1,204 @@ +import argparse +import sys +import os +import warnings +import tempfile +import pandas as pd + +from Bio.PDB import PDBParser +from pathlib import Path +from rdkit import Chem +from torch.utils.data import DataLoader +from functools import partial + +basedir = Path(__file__).resolve().parent.parent +sys.path.append(str(basedir)) +warnings.filterwarnings("ignore") + +from src import utils +from src.data.dataset import ProcessedLigandPocketDataset +from src.data.data_utils import TensorDict, process_raw_pair +from src.model.lightning import DrugFlow +from src.sbdd_metrics.metrics import FullEvaluator + +from tqdm import tqdm +from pdb import set_trace + + +def aggregate_metrics(table): + agg_col = 'posebusters' + total = 0 + table[agg_col] = 0 + for column in table.columns: + if column.startswith(agg_col) and column != agg_col: + table[agg_col] += table[column].fillna(0).astype(float) + total += 1 + table[agg_col] = table[agg_col] / total + + agg_col = 'reos' + total = 0 + table[agg_col] = 0 + for column in table.columns: + if column.startswith(agg_col) and column != agg_col: + table[agg_col] += table[column].fillna(0).astype(float) + total += 1 + table[agg_col] = table[agg_col] / total + + agg_col = 'chembl_ring_systems' + total = 0 + table[agg_col] = 0 + for column in table.columns: + if column.startswith(agg_col) and column != agg_col and not column.endswith('smi'): + table[agg_col] += table[column].fillna(0).astype(float) + total += 1 + table[agg_col] = table[agg_col] / total + return table + + +if __name__ == "__main__": + p = argparse.ArgumentParser() + p.add_argument('--protein', type=str, required=True, help="Input PDB file.") + p.add_argument('--ref_ligand', type=str, required=True, help="SDF file with reference ligand used to define the pocket.") + p.add_argument('--checkpoint', type=str, required=True, help="Model checkpoint file.") + p.add_argument('--molecule_size', type=str, required=False, default=None, help="Maximum number of atoms in the sampled molecules. Can be a single number or a range, e.g. '15,20'. If None, size will be sampled.") + p.add_argument('--output', type=str, required=False, default='samples.sdf', help="Output file.") + p.add_argument('--n_samples', type=int, required=False, default=10, help="Number of sampled molecules.") + p.add_argument('--batch_size', type=int, required=False, default=32, help="Batch size.") + p.add_argument('--pocket_distance_cutoff', type=float, required=False, default=8.0, help="Distance cutoff to define the pocket around the reference ligand.") + p.add_argument('--n_steps', type=int, required=False, default=None, help="Number of denoising steps.") + p.add_argument('--device', type=str, required=False, default='cuda:0', help="Device to use.") + p.add_argument('--datadir', type=Path, required=False, default=Path(basedir, 'src', 'default'), help="Needs to be specified to sample molecule sizes.") + p.add_argument('--seed', type=int, required=False, default=42, help="Random seed.") + p.add_argument('--filter', action='store_true', required=False, default=False, help="Apply basic filters and keep sampling until `n_samples` molecules passing these filters are found.") + p.add_argument('--metrics_output', type=str, required=False, default=None, help="If provided, metrics will be computed and saved in csv format at this location.") + p.add_argument('--gnina', type=str, required=False, default=None, help="Path to a gnina executable. Required for computing docking scores.") + p.add_argument('--reduce', type=str, required=False, default=None, help="Path to a reduce executable. Required for computing interactions.") + args = p.parse_args() + + utils.set_deterministic(seed=args.seed) + utils.disable_rdkit_logging() + + if args.molecule_size is None and (args.datadir is None or not args.datadir.exists()): + raise NotImplementedError( + "Please provide a path to the processed dataset (using `--datadir`) "\ + "to infer the number of nodes. It contains the size distribution histogram." + ) + + if not args.filter: + args.batch_size = min(args.batch_size, args.n_samples) + + # Loading model + chkpt_path = Path(args.checkpoint) + chkpt_name = chkpt_path.parts[-1].split('.')[0] + model = DrugFlow.load_from_checkpoint(args.checkpoint, map_location=args.device, strict=False) + if args.datadir is not None: + model.datadir = args.datadir + + model.setup(stage='generation') + model.batch_size = model.eval_batch_size = args.batch_size + model.eval().to(args.device) + if args.n_steps is not None: + model.T = args.n_steps + + # Loading size model + size_model = None + molecule_size = None + molecule_size_boundaries = None + if args.molecule_size is not None: + if args.molecule_size.isdigit(): + molecule_size = int(args.molecule_size) + print(f'Will generate molecules of size {molecule_size}') + else: + boundaries = [x.strip() for x in args.molecule_size.split(',')] + assert len(boundaries) == 2 and boundaries[0].isdigit() and boundaries[1].isdigit() + left = int(boundaries[0]) + right = int(boundaries[1]) + molecule_size = f"uniform_{left}_{right}" + print(f'Will generate molecules with numbers of atoms sampled from U({left}, {right})') + + # Preparing input + pdb_model = PDBParser(QUIET=True).get_structure('', args.protein)[0] + rdmol = Chem.SDMolSupplier(str(args.ref_ligand))[0] + + ligand, pocket = process_raw_pair( + pdb_model, rdmol, + dist_cutoff=args.pocket_distance_cutoff, + pocket_representation=model.pocket_representation, + compute_nerf_params=True, + nma_input=args.protein if model.dynamics.add_nma_feat else None + ) + ligand['name'] = 'ligand' + dataset = [{'ligand': ligand, 'pocket': pocket} for _ in range(args.batch_size)] + dataloader = DataLoader( + dataset=dataset, + batch_size=args.batch_size, + collate_fn=partial(ProcessedLigandPocketDataset.collate_fn, ligand_transform=None), + pin_memory=True + ) + + # Start sampling + smiles = set() + sampled_molecules = [] + metrics = [] + Path(args.output).parent.absolute().mkdir(parents=True, exist_ok=True) + print(f'Will generate {args.n_samples} samples') + + evaluator = FullEvaluator(gnina=args.gnina, reduce=args.reduce) + + with tqdm(total=args.n_samples) as pbar: + while len(sampled_molecules) < args.n_samples: + for i, data in enumerate(dataloader): + new_data = { + 'ligand': TensorDict(**data['ligand']).to(args.device), + 'pocket': TensorDict(**data['pocket']).to(args.device), + } + rdmols, rdpockets, _ = model.sample( + new_data, + n_samples=1, + timesteps=args.n_steps, + num_nodes=molecule_size, + ) + + if args.filter or (args.metrics_output is not None): + results = [] + with tempfile.TemporaryDirectory() as tmpdir: + for mol, receptor in zip(rdmols, rdpockets): + receptor_path = Path(tmpdir, 'receptor.pdb') + Chem.MolToPDBFile(receptor, str(receptor_path)) + results.append(evaluator(mol, receptor_path)) + + table = pd.DataFrame(results) + table['novel'] = ~table['representation.smiles'].isin(smiles) + table = aggregate_metrics(table) + + added_molecules = 0 + if args.filter: + table['passed_filters'] = ( + (table['posebusters'] == 1) & + # (table['reos'] == 1) & + (table['chembl_ring_systems'] == 1) & + (table['novel'] == 1) + ) + for i, (passed, smi) in enumerate(table[['passed_filters', 'representation.smiles']].values): + if passed: + sampled_molecules.append(rdmols[i]) + smiles.add(smi) + added_molecules += 1 + + if args.metrics_output is not None: + metrics.append(table[table['passed_filters']]) + + else: + sampled_molecules.extend(rdmols) + added_molecules = len(rdmols) + if args.metrics_output is not None: + metrics.append(table) + + pbar.update(added_molecules) + + # Write results + utils.write_sdf_file(args.output, sampled_molecules) + + if args.metrics_output is not None: + metrics = pd.concat(metrics) + metrics.to_csv(args.metrics_output, index=False) diff --git a/src/model/diffusion_utils.py b/src/model/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d5ddd3af0da5606118931628e0bc9ae5ddfa5700 --- /dev/null +++ b/src/model/diffusion_utils.py @@ -0,0 +1,206 @@ +import math +import torch +import torch.nn.functional as F +import numpy as np + + +class DistributionNodes: + def __init__(self, histogram): + + histogram = torch.tensor(histogram).float() + histogram = histogram + 1e-3 # for numerical stability + + prob = histogram / histogram.sum() + + self.idx_to_n_nodes = torch.tensor( + [[(i, j) for j in range(prob.shape[1])] for i in range(prob.shape[0])] + ).view(-1, 2) + + self.n_nodes_to_idx = {tuple(x.tolist()): i + for i, x in enumerate(self.idx_to_n_nodes)} + + self.prob = prob + self.m = torch.distributions.Categorical(self.prob.view(-1), + validate_args=True) + + self.n1_given_n2 = \ + [torch.distributions.Categorical(prob[:, j], validate_args=True) + for j in range(prob.shape[1])] + self.n2_given_n1 = \ + [torch.distributions.Categorical(prob[i, :], validate_args=True) + for i in range(prob.shape[0])] + + # entropy = -torch.sum(self.prob.view(-1) * torch.log(self.prob.view(-1) + 1e-30)) + # entropy = self.m.entropy() + # print("Entropy of n_nodes: H[N]", entropy.item()) + + def sample(self, n_samples=1): + idx = self.m.sample((n_samples,)) + num_nodes_lig, num_nodes_pocket = self.idx_to_n_nodes[idx].T + return num_nodes_lig, num_nodes_pocket + + def sample_conditional(self, n1=None, n2=None): + assert (n1 is None) ^ (n2 is None), \ + "Exactly one input argument must be None" + + m = self.n1_given_n2 if n2 is not None else self.n2_given_n1 + c = n2 if n2 is not None else n1 + + return torch.tensor([m[i].sample() for i in c], device=c.device) + + def log_prob(self, batch_n_nodes_1, batch_n_nodes_2): + assert len(batch_n_nodes_1.size()) == 1 + assert len(batch_n_nodes_2.size()) == 1 + + idx = torch.tensor( + [self.n_nodes_to_idx[(n1, n2)] + for n1, n2 in zip(batch_n_nodes_1.tolist(), batch_n_nodes_2.tolist())] + ) + + # log_probs = torch.log(self.prob.view(-1)[idx] + 1e-30) + log_probs = self.m.log_prob(idx) + + return log_probs.to(batch_n_nodes_1.device) + + def log_prob_n1_given_n2(self, n1, n2): + assert len(n1.size()) == 1 + assert len(n2.size()) == 1 + log_probs = torch.stack([self.n1_given_n2[c].log_prob(i.cpu()) + for i, c in zip(n1, n2)]) + return log_probs.to(n1.device) + + def log_prob_n2_given_n1(self, n2, n1): + assert len(n2.size()) == 1 + assert len(n1.size()) == 1 + log_probs = torch.stack([self.n2_given_n1[c].log_prob(i.cpu()) + for i, c in zip(n2, n1)]) + return log_probs.to(n2.device) + + +def cosine_beta_schedule_midi(timesteps, s=0.008, nu=1.0, clip=False): + """ + Modified cosine schedule as proposed in https://arxiv.org/abs/2302.09048. + Note: we use (t/T)^\nu not (t/T + s)^\nu as written in the MiDi paper + We also divide by alphas_cumprod[0] as the original cosine schedule from + https://arxiv.org/abs/2102.09672 + """ + device = nu.device if torch.is_tensor(nu) else None + x = torch.linspace(0, timesteps, timesteps + 1, device=device) + alphas_cumprod = torch.cos(0.5 * np.pi * ((x / timesteps)**nu + s) / (1 + s)) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + + if clip: + alphas_cumprod = torch.cat([torch.tensor([1.0], device=alphas_cumprod.device), alphas_cumprod]) + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + betas = torch.clip(betas, min=0, max=0.999) + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + return alphas_cumprod + + +class CosineSchedule(torch.nn.Module): + """ + nu=1.0 corresponds to the standard cosine schedule + """ + + def __init__(self, timesteps, nu=1.0, trainable=False, clip_alpha2_step=0.001): + super(CosineSchedule, self).__init__() + self.timesteps = timesteps + self.trainable = trainable + self.nu = nu + assert 0.0 <= clip_alpha2_step < 1.0 + self.clip = clip_alpha2_step + + if self.trainable: + self.nu = torch.nn.Parameter(torch.Tensor([nu]), requires_grad=True) + else: + self._alpha2 = self.alphas2 + self._gamma = torch.nn.Parameter(self.gammas, requires_grad=False) + + @property + def alphas2(self): + """ + Cumulative alpha squared. + Called alpha_bar in: Nichol, Alexander Quinn, and Prafulla Dhariwal. + "Improved denoising diffusion probabilistic models." PMLR, 2021. + """ + if hasattr(self, '_alpha2'): + return self._alpha2 + + assert isinstance(self.nu, float) or ~self.nu.isnan() + + # our alpha is eqivalent to sqrt(alpha) from https://arxiv.org/abs/2102.09672, where the cosine schedule was introduced + alphas2 = cosine_beta_schedule_midi(self.timesteps, nu=self.nu, clip=False) + + # avoid singularities near t=T + alphas2 = torch.cat([torch.tensor([1.0], device=alphas2.device), alphas2]) + alphas2_step = alphas2[1:] / alphas2[:-1] + alphas2_step = torch.clip(alphas2_step, min=self.clip, max=1.0) + alphas2 = torch.cumprod(alphas2_step, dim=0) + + return alphas2 + + @property + def alphas2_t_given_tminus1(self): + """ + Alphas for a single transition + """ + alphas2 = torch.cat([torch.tensor([1.0]), self.alphas2]) + return alphas2[1:] / alphas2[:-1] + + @property + def gammas(self): + """ + Gammas as defined in appendix B of the EDM paper + gamma_t = -(log alpha_t^2 - log sigma_t^2) + """ + if hasattr(self, '_gamma'): + return self._gamma + + alphas2 = self.alphas2 + sigmas2 = 1 - alphas2 + + gammas = -(torch.log(alphas2) - torch.log(sigmas2)) + + return gammas.float() + + def forward(self, t): + t_int = torch.round(t * self.timesteps).long() + return self.gammas[t_int] + + @staticmethod + def alpha(gamma): + """ Computes alpha given gamma. """ + return torch.sqrt(torch.sigmoid(-gamma)) + + @staticmethod + def sigma(gamma): + """ Computes sigma given gamma. """ + return torch.sqrt(torch.sigmoid(gamma)) + + @staticmethod + def SNR(gamma): + """ Computes signal to noise ratio (alpha^2/sigma^2) given gamma. """ + return torch.exp(-gamma) + + def sigma_and_alpha_t_given_s(self, gamma_t: torch.Tensor, gamma_s: torch.Tensor): + """ + Computes sigma_t_given_s, using gamma_t and gamma_s. Used during sampling. + These are defined as: + alpha_t_given_s = alpha_t / alpha_s, + sigma_t_given_s = sqrt(1 - (alpha_t_given_s)^2 ). + """ + sigma2_t_given_s = -torch.expm1( + F.softplus(gamma_s) - F.softplus(gamma_t)) + + # alpha_t_given_s = alpha_t / alpha_s + log_alpha2_t = F.logsigmoid(-gamma_t) + log_alpha2_s = F.logsigmoid(-gamma_s) + log_alpha2_t_given_s = log_alpha2_t - log_alpha2_s + + alpha_t_given_s = torch.exp(0.5 * log_alpha2_t_given_s) + alpha_t_given_s = torch.clip(alpha_t_given_s, min=self.clip ** 0.5, max=1.0) + + sigma_t_given_s = torch.sqrt(sigma2_t_given_s) + + return sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s diff --git a/src/model/dpo.py b/src/model/dpo.py new file mode 100644 index 0000000000000000000000000000000000000000..5a184ba59f7a535e1aa2f5d2035ea6d3bf2c92ac --- /dev/null +++ b/src/model/dpo.py @@ -0,0 +1,252 @@ +from typing import Optional +from pathlib import Path +from contextlib import nullcontext + +import torch +import torch.nn.functional as F +from torch_scatter import scatter_mean + +from src.constants import atom_encoder, bond_encoder +from src.model.lightning import DrugFlow, set_default +from src.data.dataset import ProcessedLigandPocketDataset, DPODataset +from src.data.data_utils import AppendVirtualNodesInCoM, Residues, center_data + +class DPO(DrugFlow): + def __init__(self, dpo_mode, ref_checkpoint_p, **kwargs): + super(DPO, self).__init__(**kwargs) + self.dpo_mode = dpo_mode + self.dpo_beta = kwargs['loss_params'].dpo_beta if 'dpo_beta' in kwargs['loss_params'] else 100.0 + self.dpo_beta_schedule = kwargs['loss_params'].dpo_beta_schedule if 'dpo_beta_schedule' in kwargs['loss_params'] else 't' + self.clamp_dpo = kwargs['loss_params'].clamp_dpo if 'clamp_dpo' in kwargs['loss_params'] else True + self.dpo_lambda_dpo = kwargs['loss_params'].dpo_lambda_dpo if 'dpo_lambda_dpo' in kwargs['loss_params'] else 1 + self.dpo_lambda_w = kwargs['loss_params'].dpo_lambda_w if 'dpo_lambda_w' in kwargs['loss_params'] else 1 + self.dpo_lambda_l = kwargs['loss_params'].dpo_lambda_l if 'dpo_lambda_l' in kwargs['loss_params'] else 0.2 + self.dpo_lambda_h = kwargs['loss_params'].dpo_lambda_h if 'dpo_lambda_h' in kwargs['loss_params'] else kwargs['loss_params'].lambda_h + self.dpo_lambda_e = kwargs['loss_params'].dpo_lambda_e if 'dpo_lambda_e' in kwargs['loss_params'] else kwargs['loss_params'].lambda_e + self.ref_dynamics = self.init_model(kwargs['predictor_params']) + state_dict = torch.load(ref_checkpoint_p)['state_dict'] + self.ref_dynamics.load_state_dict({k.replace('dynamics.',''): v for k, v in state_dict.items() if k.startswith('dynamics.')}) + print(f'Loaded reference model from {ref_checkpoint_p}') + # initializing model params with ref model params + self.dynamics.load_state_dict(self.ref_dynamics.state_dict()) + + def get_dataset(self, stage, pocket_transform=None): + + # when sampling we don't append virtual nodes as we might need access to the ground truth size + if self.virtual_nodes and stage == 'train': + ligand_transform = AppendVirtualNodesInCoM( + atom_encoder, bond_encoder, add_min=self.add_virtual_min, add_max=self.add_virtual_max) + else: + ligand_transform = None + + # we want to know if something goes wrong on the validation or test set + catch_errors = stage == 'train' + + if self.sharded_dataset: + raise NotImplementedError('Sharded dataset not implemented for DPO') + + if self.sample_from_clusters and stage == 'train': # val/test should be deterministic + raise NotImplementedError('Sampling from clusters not implemented for DPO') + + if stage == 'train': + return DPODataset( + Path(self.datadir, 'train.pt'), + ligand_transform=None, + pocket_transform=pocket_transform, + catch_errors=True, + ) + else: + return ProcessedLigandPocketDataset( + pt_path=Path(self.datadir, 'val.pt' if self.debug else f'{stage}.pt'), + ligand_transform=ligand_transform, + pocket_transform=pocket_transform, + catch_errors=catch_errors, + ) + + + def training_step(self, data, *args): + ligand_w, ligand_l, pocket = data['ligand'], data['ligand_l'], data['pocket'] + loss, info = self.compute_dpo_loss(pocket, ligand_w=ligand_w, ligand_l=ligand_l, return_info=True) + + if torch.isnan(loss): + print(f'For ligand pair , loss is NaN at epoch {self.current_epoch}. Info: {info}') + + log_dict = {k: v for k, v in info.items() if isinstance(v, float) or torch.numel(v) <= 1} + self.log_metrics({'loss': loss, **log_dict}, 'train', batch_size=len(ligand_w['size'])) + + out = {'loss': loss, **info} + self.training_step_outputs.append(out) + return out + + def validation_step(self, data, *args): + return super().validation_step(data, *args) + + def compute_dpo_loss(self, pocket, ligand_w, ligand_l, return_info=False): + t = torch.rand(ligand_w['size'].size(0), device=ligand_w['x'].device).unsqueeze(-1) + + if self.dpo_beta_schedule == 't': + # from https://arxiv.org/pdf/2407.13981 + beta_t = (self.dpo_beta * t).squeeze() + elif self.dpo_beta_schedule == 'const': + beta_t = self.dpo_beta + else: + raise ValueError(f'Unknown DPO beta schedule: {self.dpo_beta_schedule}') + + loss_dict_w = self.compute_loss_single_pair(ligand_w, pocket, t) + loss_dict_l = self.compute_loss_single_pair(ligand_l, pocket, t) + info = { + 'loss_x_w': loss_dict_w['theta']['x'].mean().item(), + 'loss_h_w': loss_dict_w['theta']['h'].mean().item(), + 'loss_e_w': loss_dict_w['theta']['e'].mean().item(), + 'loss_x_l': loss_dict_l['theta']['x'].mean().item(), + 'loss_h_l': loss_dict_l['theta']['h'].mean().item(), + 'loss_e_l': loss_dict_l['theta']['e'].mean().item(), + } + if self.dpo_mode == 'single_dpo_comp': + loss_w_theta = ( + loss_dict_w['theta']['x'] + + self.dpo_lambda_h * loss_dict_w['theta']['h'] + + self.dpo_lambda_e * loss_dict_w['theta']['e'] + ) + loss_w_ref = ( + loss_dict_w['ref']['x'] + + self.dpo_lambda_h * loss_dict_w['ref']['h'] + + self.dpo_lambda_e * loss_dict_w['ref']['e'] + ) + loss_l_theta = ( + loss_dict_l['theta']['x'] + + self.dpo_lambda_h * loss_dict_l['theta']['h'] + + self.dpo_lambda_e * loss_dict_l['theta']['e'] + ) + loss_l_ref = ( + loss_dict_l['ref']['x'] + + self.dpo_lambda_h * loss_dict_l['ref']['h'] + + self.dpo_lambda_e * loss_dict_l['ref']['e'] + ) + diff_w = loss_w_theta - loss_w_ref + diff_l = loss_l_theta - loss_l_ref + info['diff_w'] = diff_w.mean().item() + info['diff_l'] = diff_l.mean().item() + # print(diff) + diff = -1 * beta_t * (diff_w - diff_l) + loss = -1 * F.logsigmoid(diff) + elif self.dpo_mode == 'single_dpo_comp_v3': + diff_w_x = loss_dict_w['theta']['x'] - loss_dict_w['ref']['x'] + diff_w_h = loss_dict_w['theta']['h'] - loss_dict_w['ref']['h'] + diff_w_e = loss_dict_w['theta']['e'] - loss_dict_w['ref']['e'] + diff_l_x = loss_dict_l['theta']['x'] - loss_dict_l['ref']['x'] + diff_l_h = loss_dict_l['theta']['h'] - loss_dict_l['ref']['h'] + diff_l_e = loss_dict_l['theta']['e'] - loss_dict_l['ref']['e'] + info['diff_w_x'] = diff_w_x.mean().item() + info['diff_w_h'] = diff_w_h.mean().item() + info['diff_w_e'] = diff_w_e.mean().item() + info['diff_l_x'] = diff_l_x.mean().item() + info['diff_l_h'] = diff_l_h.mean().item() + info['diff_l_e'] = diff_l_e.mean().item() + + # not used, just for logging + _diff_w = diff_w_x + self.dpo_lambda_h * diff_w_h + self.dpo_lambda_e * diff_w_e + _diff_l = diff_l_x + self.dpo_lambda_h * diff_l_h + self.dpo_lambda_e * diff_l_e + info['diff_w'] = _diff_w.mean().item() + info['diff_l'] = _diff_l.mean().item() + + diff_x = diff_w_x - diff_l_x + diff_h = diff_w_h - diff_l_h + diff_e = diff_w_e - diff_l_e + info['diff_x'] = diff_x.mean().item() + info['diff_h'] = diff_h.mean().item() + info['diff_e'] = diff_e.mean().item() + + diff = -1 * beta_t * (diff_x + self.dpo_lambda_h * diff_h + self.dpo_lambda_e * diff_e) + if self.clamp_dpo: + diff = diff.clamp(-10, 10) + info['dpo_arg_min'] = diff.min().item() + info['dpo_arg_max'] = diff.max().item() + info['dpo_arg_mean'] = diff.mean().item() + dpo_loss = -1 * self.dpo_lambda_dpo * F.logsigmoid(diff) + info['dpo_loss'] = dpo_loss.mean().item() + + loss_w_theta_reg = ( + loss_dict_w['theta']['x'] + + self.lambda_h * loss_dict_w['theta']['h'] + + self.lambda_e * loss_dict_w['theta']['e'] + ) + info['loss_w_theta_reg'] = loss_w_theta_reg.mean().item() + loss_l_theta_reg = ( + loss_dict_l['theta']['x'] + + self.lambda_h * loss_dict_l['theta']['h'] + + self.lambda_e * loss_dict_l['theta']['e'] + ) + info['loss_l_theta_reg'] = loss_l_theta_reg.mean().item() + dpo_reg = self.dpo_lambda_w * loss_w_theta_reg + \ + self.dpo_lambda_l * loss_l_theta_reg + info['dpo_reg'] = dpo_reg.mean().item() + loss = dpo_loss + dpo_reg + else: + raise ValueError(f'Unknown DPO mode: {self.dpo_mode}') + + if self.timestep_weights is not None: + w_t = self.timestep_weights(t).squeeze() + loss = w_t * loss + + loss = loss.mean(0) + + print(f'Loss is {loss}, info is {info}') + + return (loss, info) if return_info else loss + + def compute_loss_single_pair(self, ligand, pocket, t): + pocket = Residues(**pocket) + + # Center sample + ligand, pocket = center_data(ligand, pocket) + pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0) + + # Noise + z0_x = self.module_x.sample_z0(pocket_com, ligand['mask']) + z0_h = self.module_h.sample_z0(ligand['mask']) + z0_e = self.module_e.sample_z0(ligand['bond_mask']) + zt_x = self.module_x.sample_zt(z0_x, ligand['x'], t, ligand['mask']) + zt_h = self.module_h.sample_zt(z0_h, ligand['one_hot'], t, ligand['mask']) + zt_e = self.module_e.sample_zt(z0_e, ligand['bond_one_hot'], t, ligand['bond_mask']) + + # Predict denoising + sc_transform = self.get_sc_transform_fn(None, zt_x, t, None, ligand['mask'], pocket) + + pred_ligand, _ = self.dynamics( + zt_x, zt_h, ligand['mask'], pocket, t, + bonds_ligand=(ligand['bonds'], zt_e), + sc_transform=sc_transform + ) + + # Reference model + with torch.no_grad(): + ref_pred_ligand, _ = self.ref_dynamics( + zt_x, zt_h, ligand['mask'], pocket, t, + bonds_ligand=(ligand['bonds'], zt_e), + sc_transform=sc_transform + ) + + # Compute L2 loss + loss_x = self.module_x.compute_loss(pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce=self.loss_reduce) + ref_loss_x = self.module_x.compute_loss(ref_pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce=self.loss_reduce) + + t_next = torch.clamp(t + self.train_step_size, max=1.0) + + loss_h = self.module_h.compute_loss(pred_ligand['logits_h'], zt_h, ligand['one_hot'], ligand['mask'], t, t_next, reduce=self.loss_reduce) + ref_loss_h = self.module_h.compute_loss(ref_pred_ligand['logits_h'], zt_h, ligand['one_hot'], ligand['mask'], t, t_next, reduce=self.loss_reduce) + loss_e = self.module_e.compute_loss(pred_ligand['logits_e'], zt_e, ligand['bond_one_hot'], ligand['bond_mask'], t, t_next, reduce=self.loss_reduce) + ref_loss_e = self.module_e.compute_loss(ref_pred_ligand['logits_e'], zt_e, ligand['bond_one_hot'], ligand['bond_mask'], t, t_next, reduce=self.loss_reduce) + + return { + 'theta': { + 'x': loss_x, + 'h': loss_h, + 'e': loss_e, + }, + 'ref': { + 'x': ref_loss_x, + 'h': ref_loss_h, + 'e': ref_loss_e, + } + } diff --git a/src/model/dynamics.py b/src/model/dynamics.py new file mode 100644 index 0000000000000000000000000000000000000000..76afbeb90ac2ecce1b26ea329a1c71c0558b25ba --- /dev/null +++ b/src/model/dynamics.py @@ -0,0 +1,791 @@ +from collections.abc import Iterable +from abc import abstractmethod +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from src.constants import INT_TYPE +from src.model.gvp import GVPModel, GVP, LayerNorm +from src.model.gvp_transformer import GVPTransformerModel +from src.constants import FLOAT_TYPE + +from pdb import set_trace + + +def binomial_coefficient(n, k): + # source: https://discuss.pytorch.org/t/n-choose-k-function/121974 + return ((n + 1).lgamma() - (k + 1).lgamma() - ((n - k) + 1).lgamma()).exp() + + +def cycle_counts(adj): + assert (adj.diag() == 0).all() + assert (adj == adj.T).all() + + A = adj.float() + d = A.sum(dim=-1) + + # Compute powers + A2 = A @ A + A3 = A2 @ A + A4 = A3 @ A + A5 = A4 @ A + + x3 = A3.diag() / 2 + x4 = (A4.diag() - d * (d - 1) - A @ d) / 2 + + """ New (different from DiGress) + case where correction is relevant: + 2 o + | + 1,3 o--o 4 + | / + 0,5 o + """ + # Triangle count matrix (indicates for each node i how many triangles it shares with node j) + T = adj * A2 + x5 = (A5.diag() - 2 * T @ d - 4 * d * x3 - 2 * A @ x3 + 10 * x3) / 2 + + # # TODO + # A6 = A5 @ A + # + # # 4-cycle count matrix (indicates in how many shared 4-cycles i and j are 2 hops apart) + # Q2 = binomial_coefficient(n=A2 - d.diag(), k=torch.tensor(2)) + # + # # 4-cycle count matrix (indicates in how many shared 4-cycles i and j are 1 (and 3) hop(s) apart) + # Q1 = A * (A3 - (d.view(-1, 1) + d.view(1, -1)) + 1) # "+1" because link between i and j is subtracted twice + # + # x6 = ... + # return torch.stack([x3, x4, x5, x6], dim=-1) + + return torch.stack([x3, x4, x5], dim=-1) + + +# TODO: also consider directional aggregation as in: +# Beaini, Dominique, et al. "Directional graph networks." +# International Conference on Machine Learning. PMLR, 2021. +def eigenfeatures(A, batch_mask, k=5): + # TODO, see: + # - https://github.com/cvignac/DiGress/blob/main/src/diffusion/extra_features.py + # - https://arxiv.org/pdf/2209.14734.pdf (Appendix B.2) + + # split adjacency matrix + batch = [] + for i in torch.unique(batch_mask, sorted=True): # TODO: optimize (try to avoid loop) + batch_inds = torch.where(batch_mask == i)[0] + batch.append(A[torch.meshgrid(batch_inds, batch_inds, indexing='ij')]) + + eigenfeats = [get_nontrivial_eigenvectors(adj)[:, :k] for adj in batch] + # if there are less than k non-trivial eigenvectors + eigenfeats = [torch.cat([ + x, torch.zeros(x.size(0), max(k - x.size(1), 0), device=x.device)], dim=-1) + for x in eigenfeats] + return torch.cat(eigenfeats, dim=0) + + +def get_nontrivial_eigenvectors(A, normalize_l=True, thresh=1e-5, + norm_eps=1e-12): + """ + Compute eigenvectors of the graph Laplacian corresponding to non-zero + eigenvalues. + """ + assert (A == A.T).all(), "undirected graph" + + # Compute laplacian + d = A.sum(-1) + D = d.diag() + L = D - A + + if normalize_l: + D_inv_sqrt = (1 / (d.sqrt() + norm_eps)).diag() + L = D_inv_sqrt @ L @ D_inv_sqrt + + # Eigendecomposition + # eigenvalues are sorted in ascending order + # eigvecs matrix contains eigenvectors as its columns + eigvals, eigvecs = torch.linalg.eigh(L) + + # index of first non-trivial eigenvector + try: + idx = torch.nonzero(eigvals > thresh)[0].item() + except IndexError: + # recover if no non-trivial eigenvectors are found + idx = eigvecs.size(1) + + return eigvecs[:, idx:] + + +class DynamicsBase(nn.Module): + """ + Implements self-conditioning logic and basic functions + """ + def __init__( + self, + predict_angles=False, + predict_frames=False, + add_cycle_counts=False, + add_spectral_feat=False, + self_conditioning=False, + augment_residue_sc=False, + augment_ligand_sc=False + ): + super().__init__() + + if not hasattr(self, 'predict_angles'): + self.predict_angles = predict_angles + + if not hasattr(self, 'predict_frames'): + self.predict_frames = predict_frames + + if not hasattr(self, 'add_cycle_counts'): + self.add_cycle_counts = add_cycle_counts + + if not hasattr(self, 'add_spectral_feat'): + self.add_spectral_feat = add_spectral_feat + + if not hasattr(self, 'self_conditioning'): + self.self_conditioning = self_conditioning + + if not hasattr(self, 'augment_residue_sc'): + self.augment_residue_sc = augment_residue_sc + + if not hasattr(self, 'augment_ligand_sc'): + self.augment_ligand_sc = augment_ligand_sc + + if self.self_conditioning: + self.prev_ligand = None + self.prev_residues = None + + @abstractmethod + def _forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None, + h_atoms_sc=None, e_atoms_sc=None, h_residues_sc=None): + """ + Implement forward pass. + Returns: + - vel + - h_final_atoms + - edge_final_atoms + - residue_angles + - residue_trans + - residue_rot + """ + pass + + def make_sc_input(self, pred_ligand, pred_residues, sc_transform): + + if self.predict_confidence: + h_atoms_sc = (torch.cat([pred_ligand['logits_h'], pred_ligand['uncertainty_vel'].unsqueeze(1)], dim=-1), + pred_ligand['vel'].unsqueeze(1)) + else: + h_atoms_sc = (pred_ligand['logits_h'], pred_ligand['vel'].unsqueeze(1)) + e_atoms_sc = pred_ligand['logits_e'] + + if self.predict_frames: + h_residues_sc = (torch.cat([pred_residues['chi'], pred_residues['rot']], dim=-1), + pred_residues['trans'].unsqueeze(1)) + elif self.predict_angles: + h_residues_sc = pred_residues['chi'] + else: + h_residues_sc = None + + if self.augment_residue_sc and h_residues_sc is not None: + if self.predict_frames: + h_residues_sc = (h_residues_sc[0], torch.cat( + [h_residues_sc[1], sc_transform['residues'](pred_residues['chi'], pred_residues['trans'].squeeze(1), pred_residues['rot'])], dim=1)) + + else: + h_residues_sc = (h_residues_sc, sc_transform['residues'](pred_residues['chi'])) + + if self.augment_ligand_sc: + h_atoms_sc = (h_atoms_sc[0], torch.cat( + [h_atoms_sc[1], sc_transform['atoms'](pred_ligand['vel'].unsqueeze(1))], dim=1)) + + return h_atoms_sc, e_atoms_sc, h_residues_sc + + def forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None, sc_transform=None): + """ + Implements self-conditioning as in https://arxiv.org/abs/2208.04202 + """ + + h_atoms_sc, e_atoms_sc = None, None + h_residues_sc = None + + if self.self_conditioning: + + # Sampling: use previous prediction in all but the first time step + if not self.training and t.min() > 0.0: + assert t.min() == t.max(), "currently only supports sampling at same time steps" + assert self.prev_ligand is not None + assert self.prev_residues is not None or not self.predict_frames + + else: + # Create zero tensors + zeros_ligand = {'logits_h': torch.zeros_like(h_atoms), + 'vel': torch.zeros_like(x_atoms), + 'logits_e': torch.zeros_like(bonds_ligand[1])} + if self.predict_confidence: + zeros_ligand['uncertainty_vel'] = torch.zeros( + len(x_atoms), dtype=x_atoms.dtype, device=x_atoms.device) + + zeros_residues = {} + if self.predict_angles: + zeros_residues['chi'] = torch.zeros((pocket['one_hot'].size(0), 5), device=pocket['one_hot'].device) + if self.predict_frames: + zeros_residues['trans'] = torch.zeros((pocket['one_hot'].size(0), 3), device=pocket['one_hot'].device) + zeros_residues['rot'] = torch.zeros((pocket['one_hot'].size(0), 3), device=pocket['one_hot'].device) + + # Training: use 50% zeros and 50% predictions with detached gradients + if self.training and random.random() > 0.5: + with torch.no_grad(): + h_atoms_sc, e_atoms_sc, h_residues_sc = self.make_sc_input( + zeros_ligand, zeros_residues, sc_transform) + + self.prev_ligand, self.prev_residues = self._forward( + x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand, + h_atoms_sc, e_atoms_sc, h_residues_sc) + + # use zeros for first sampling step and 50% of training + else: + self.prev_ligand = zeros_ligand + self.prev_residues = zeros_residues + + h_atoms_sc, e_atoms_sc, h_residues_sc = self.make_sc_input( + self.prev_ligand, self.prev_residues, sc_transform) + + pred_ligand, pred_residues = self._forward( + x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand, + h_atoms_sc, e_atoms_sc, h_residues_sc + ) + + if self.self_conditioning and not self.training: + self.prev_ligand = pred_ligand.copy() + self.prev_residues = pred_residues.copy() + + return pred_ligand, pred_residues + + def compute_extra_features(self, batch_mask, edge_indices, edge_types): + + feat = torch.zeros(len(batch_mask), 0, device=batch_mask.device) + + if not (self.add_cycle_counts or self.add_spectral_feat): + return feat + + adj = batch_mask[:, None] == batch_mask[None, :] + + E = torch.zeros_like(adj, dtype=INT_TYPE) + E[edge_indices[0], edge_indices[1]] = edge_types + + A = (E > 0).float() + + if self.add_cycle_counts: + cycle_features = cycle_counts(A) + cycle_features[cycle_features > 10] = 10 # avoid large values + + feat = torch.cat([feat, cycle_features], dim=-1) + + if self.add_spectral_feat: + feat = torch.cat([feat, eigenfeatures(A, batch_mask)], dim=-1) + + return feat + + +class Dynamics(DynamicsBase): + def __init__(self, atom_nf, residue_nf, joint_nf, bond_dict, pocket_bond_dict, + edge_nf, hidden_nf, act_fn=torch.nn.SiLU(), condition_time=True, + model='egnn', model_params=None, + edge_cutoff_ligand=None, edge_cutoff_pocket=None, + edge_cutoff_interaction=None, + predict_angles=False, predict_frames=False, + add_cycle_counts=False, add_spectral_feat=False, + add_nma_feat=False, self_conditioning=False, + augment_residue_sc=False, augment_ligand_sc=False, + add_chi_as_feature=False, angle_act_fn=False): + super().__init__() + self.model = model + self.edge_cutoff_l = edge_cutoff_ligand + self.edge_cutoff_p = edge_cutoff_pocket + self.edge_cutoff_i = edge_cutoff_interaction + self.hidden_nf = hidden_nf + self.predict_angles = predict_angles + self.predict_frames = predict_frames + self.bond_dict = bond_dict + self.pocket_bond_dict = pocket_bond_dict + self.bond_nf = len(bond_dict) + self.pocket_bond_nf = len(pocket_bond_dict) + self.edge_nf = edge_nf + self.add_cycle_counts = add_cycle_counts + self.add_spectral_feat = add_spectral_feat + self.add_nma_feat = add_nma_feat + self.self_conditioning = self_conditioning + self.augment_residue_sc = augment_residue_sc + self.augment_ligand_sc = augment_ligand_sc + self.add_chi_as_feature = add_chi_as_feature + self.predict_confidence = False + + if self.self_conditioning: + self.prev_vel = None + self.prev_h = None + self.prev_e = None + self.prev_a = None + self.prev_ca = None + self.prev_rot = None + + lig_nf = atom_nf + if self.add_cycle_counts: + lig_nf = lig_nf + 3 + if self.add_spectral_feat: + lig_nf = lig_nf + 5 + + + if not isinstance(joint_nf, Iterable): + # joint_nf contains only scalars + joint_nf = (joint_nf, 0) + + + if isinstance(residue_nf, Iterable): + _atom_in_nf = (lig_nf, 0) + _residue_atom_dim = residue_nf[1] + + if self.add_nma_feat: + residue_nf = (residue_nf[0], residue_nf[1] + 5) + + if self.self_conditioning: + _atom_in_nf = (_atom_in_nf[0] + atom_nf, 1) + + if self.augment_ligand_sc: + _atom_in_nf = (_atom_in_nf[0], _atom_in_nf[1] + 1) + + if self.predict_angles: + residue_nf = (residue_nf[0] + 5, residue_nf[1]) + + if self.predict_frames: + residue_nf = (residue_nf[0], residue_nf[1] + 2) + + if self.augment_residue_sc: + assert self.predict_angles + residue_nf = (residue_nf[0], residue_nf[1] + _residue_atom_dim) + + if self.add_chi_as_feature: + residue_nf = (residue_nf[0] + 5, residue_nf[1]) + + self.atom_encoder = nn.Sequential( + GVP(_atom_in_nf, joint_nf, activations=(act_fn, torch.sigmoid)), + LayerNorm(joint_nf, learnable_vector_weight=True), + GVP(joint_nf, joint_nf, activations=(None, None)), + ) + + self.residue_encoder = nn.Sequential( + GVP(residue_nf, joint_nf, activations=(act_fn, torch.sigmoid)), + LayerNorm(joint_nf, learnable_vector_weight=True), + GVP(joint_nf, joint_nf, activations=(None, None)), + ) + + else: + # No vector-valued input features + assert joint_nf[1] == 0 + + # self-conditioning not yet supported + assert not self.self_conditioning + + # Normal mode features are vectors + assert not self.add_nma_feat + + if self.add_chi_as_feature: + residue_nf += 5 + + self.atom_encoder = nn.Sequential( + nn.Linear(lig_nf, 2 * atom_nf), + act_fn, + nn.Linear(2 * atom_nf, joint_nf[0]) + ) + + self.residue_encoder = nn.Sequential( + nn.Linear(residue_nf, 2 * residue_nf), + act_fn, + nn.Linear(2 * residue_nf, joint_nf[0]) + ) + + self.atom_decoder = nn.Sequential( + nn.Linear(joint_nf[0], 2 * atom_nf), + act_fn, + nn.Linear(2 * atom_nf, atom_nf) + ) + + self.edge_decoder = nn.Sequential( + nn.Linear(hidden_nf, hidden_nf), + act_fn, + nn.Linear(hidden_nf, self.bond_nf) + ) + + _atom_bond_nf = 2 * self.bond_nf if self.self_conditioning else self.bond_nf + self.ligand_bond_encoder = nn.Sequential( + nn.Linear(_atom_bond_nf, hidden_nf), + act_fn, + nn.Linear(hidden_nf, self.edge_nf) + ) + + self.pocket_bond_encoder = nn.Sequential( + nn.Linear(self.pocket_bond_nf, hidden_nf), + act_fn, + nn.Linear(hidden_nf, self.edge_nf) + ) + + out_nf = (joint_nf[0], 1) + res_out_nf = (0, 0) + if self.predict_angles: + res_out_nf = (res_out_nf[0] + 5, res_out_nf[1]) + if self.predict_frames: + res_out_nf = (res_out_nf[0], res_out_nf[1] + 2) + self.residue_decoder = nn.Sequential( + GVP(out_nf, out_nf, activations=(act_fn, torch.sigmoid)), + LayerNorm(out_nf, learnable_vector_weight=True), + GVP(out_nf, res_out_nf, activations=(None, None)), + ) if res_out_nf != (0, 0) else None + + if angle_act_fn is None: + self.angle_act_fn = None + elif angle_act_fn == 'tanh': + self.angle_act_fn = lambda x: np.pi * F.tanh(x) + else: + raise NotImplementedError(f"Angle activation {angle_act_fn} not available") + + # self.ligand_nobond_emb = nn.Parameter(torch.zeros(self.edge_nf)) + # self.pocket_nobond_emb = nn.Parameter(torch.zeros(self.edge_nf)) + self.cross_emb = nn.Parameter(torch.zeros(self.edge_nf), + requires_grad=True) + + if condition_time: + dynamics_node_nf = (joint_nf[0] + 1, joint_nf[1]) + else: + print('Warning: dynamics model is NOT conditioned on time.') + dynamics_node_nf = (joint_nf[0], joint_nf[1]) + + if model == 'egnn': + raise NotImplementedError + # self.net = EGNN( + # in_node_nf=dynamics_node_nf[0], in_edge_nf=self.edge_nf, + # hidden_nf=hidden_nf, out_node_nf=joint_nf[0], + # device=model_params.device, act_fn=act_fn, + # n_layers=model_params.n_layers, + # attention=model_params.attention, + # tanh=model_params.tanh, + # norm_constant=model_params.norm_constant, + # inv_sublayers=model_params.inv_sublayers, + # sin_embedding=model_params.sin_embedding, + # normalization_factor=model_params.normalization_factor, + # aggregation_method=model_params.aggregation_method, + # reflection_equiv=model_params.reflection_equivariant, + # update_edge_attr=True + # ) + # self.node_nf = dynamics_node_nf[0] + + elif model == 'gvp': + self.net = GVPModel( + node_in_dim=dynamics_node_nf, node_h_dim=model_params.node_h_dim, + node_out_nf=joint_nf[0], edge_in_nf=self.edge_nf, + edge_h_dim=model_params.edge_h_dim, edge_out_nf=hidden_nf, + num_layers=model_params.n_layers, + drop_rate=model_params.dropout, + vector_gate=model_params.vector_gate, + reflection_equiv=model_params.reflection_equivariant, + d_max=model_params.d_max, + num_rbf=model_params.num_rbf, + update_edge_attr=True + ) + + elif model == 'gvp_transformer': + self.net = GVPTransformerModel( + node_in_dim=dynamics_node_nf, + node_h_dim=model_params.node_h_dim, + node_out_nf=joint_nf[0], + edge_in_nf=self.edge_nf, + edge_h_dim=model_params.edge_h_dim, + edge_out_nf=hidden_nf, + num_layers=model_params.n_layers, + dk=model_params.dk, + dv=model_params.dv, + de=model_params.de, + db=model_params.db, + dy=model_params.dy, + attn_heads=model_params.attn_heads, + n_feedforward=model_params.n_feedforward, + drop_rate=model_params.dropout, + reflection_equiv=model_params.reflection_equivariant, + d_max=model_params.d_max, + num_rbf=model_params.num_rbf, + vector_gate=model_params.vector_gate, + attention=model_params.attention, + ) + + elif model == 'gnn': + raise NotImplementedError + # n_dims = 3 + # self.net = GNN( + # in_node_nf=dynamics_node_nf + n_dims, in_edge_nf=self.edge_emb_dim, + # hidden_nf=hidden_nf, out_node_nf=n_dims + dynamics_node_nf, + # device=model_params.device, act_fn=act_fn, n_layers=model_params.n_layers, + # attention=model_params.attention, normalization_factor=model_params.normalization_factor, + # aggregation_method=model_params.aggregation_method) + + else: + raise NotImplementedError(f"{model} is not available") + + # self.device = device + # self.n_dims = n_dims + self.condition_time = condition_time + + def _forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None, + h_atoms_sc=None, e_atoms_sc=None, h_residues_sc=None): + """ + :param x_atoms: + :param h_atoms: + :param mask_atoms: + :param pocket: must contain keys: 'x', 'one_hot', 'mask', 'bonds' and 'bond_one_hot' + :param t: + :param bonds_ligand: tuple - bond indices (2, n_bonds) & bond types (n_bonds, bond_nf) + :param h_atoms_sc: additional node feature for self-conditioning, (s, V) + :param e_atoms_sc: additional edge feature for self-conditioning, only scalar + :param h_residues_sc: additional node feature for self-conditioning, tensor or tuple + :return: + """ + x_residues, h_residues, mask_residues = pocket['x'], pocket['one_hot'], pocket['mask'] + if 'bonds' in pocket: + bonds_pocket = (pocket['bonds'], pocket['bond_one_hot']) + else: + bonds_pocket = None + + if self.add_chi_as_feature: + h_residues = torch.cat([h_residues, pocket['chi'][:, :5]], dim=-1) + + if 'v' in pocket: + v_residues = pocket['v'] + if self.add_nma_feat: + v_residues = torch.cat([v_residues, pocket['nma_vec']], dim=1) + h_residues = (h_residues, v_residues) + + if h_residues_sc is not None: + # if self.augment_residue_sc: + if isinstance(h_residues_sc, tuple): + h_residues = (torch.cat([h_residues[0], h_residues_sc[0]], dim=-1), + torch.cat([h_residues[1], h_residues_sc[1]], dim=1)) + else: + h_residues = (torch.cat([h_residues[0], h_residues_sc], dim=-1), + h_residues[1]) + + # get graph edges and edge attributes + if bonds_ligand is not None: + # NOTE: 'bond' denotes one-directional edges and 'edge' means bi-directional + ligand_bond_indices = bonds_ligand[0] + + # make sure messages are passed both ways + ligand_edge_indices = torch.cat( + [bonds_ligand[0], bonds_ligand[0].flip(dims=[0])], dim=1) + ligand_edge_types = torch.cat([bonds_ligand[1], bonds_ligand[1]], dim=0) + # edges_ligand = (ligand_edge_indices, ligand_edge_types) + + # add auxiliary features to ligand nodes + extra_features = self.compute_extra_features( + mask_atoms, ligand_edge_indices, ligand_edge_types.argmax(-1)) + h_atoms = torch.cat([h_atoms, extra_features], dim=-1) + + if bonds_pocket is not None: + # make sure messages are passed both ways + pocket_edge_indices = torch.cat( + [bonds_pocket[0], bonds_pocket[0].flip(dims=[0])], dim=1) + pocket_edge_types = torch.cat([bonds_pocket[1], bonds_pocket[1]], dim=0) + # edges_pocket = (pocket_edge_indices, pocket_edge_types) + + if h_atoms_sc is not None: + h_atoms = (torch.cat([h_atoms, h_atoms_sc[0]], dim=-1), + h_atoms_sc[1]) + + if e_atoms_sc is not None: + e_atoms_sc = torch.cat([e_atoms_sc, e_atoms_sc], dim=0) + ligand_edge_types = torch.cat([ligand_edge_types, e_atoms_sc], dim=-1) + + # embed atom features and residue features in a shared space + h_atoms = self.atom_encoder(h_atoms) + e_ligand = self.ligand_bond_encoder(ligand_edge_types) + + if len(h_residues) > 0: + h_residues = self.residue_encoder(h_residues) + e_pocket = self.pocket_bond_encoder(pocket_edge_types) + else: + e_pocket = pocket_edge_types + h_residues = (h_residues, h_residues) + pocket_edge_indices = torch.tensor([[], []], dtype=torch.long, device=h_residues[0].device) + pocket_edge_types = torch.tensor([[], []], dtype=torch.long, device=h_residues[0].device) + + if isinstance(h_atoms, tuple): + h_atoms, v_atoms = h_atoms + h_residues, v_residues = h_residues + v = torch.cat((v_atoms, v_residues), dim=0) + else: + v = None + + edges, edge_feat = self.get_edges( + mask_atoms, mask_residues, x_atoms, x_residues, + bond_inds_ligand=ligand_edge_indices, bond_inds_pocket=pocket_edge_indices, + bond_feat_ligand=e_ligand, bond_feat_pocket=e_pocket) + + # combine the two node types + x = torch.cat((x_atoms, x_residues), dim=0) + h = torch.cat((h_atoms, h_residues), dim=0) + mask = torch.cat([mask_atoms, mask_residues]) + + if self.condition_time: + if np.prod(t.size()) == 1: + # t is the same for all elements in batch. + h_time = torch.empty_like(h[:, 0:1]).fill_(t.item()) + else: + # t is different over the batch dimension. + h_time = t[mask] + h = torch.cat([h, h_time], dim=1) + + assert torch.all(mask[edges[0]] == mask[edges[1]]) + + if self.model == 'egnn': + # Don't update pocket coordinates + update_coords_mask = torch.cat((torch.ones_like(mask_atoms), + torch.zeros_like(mask_residues))).unsqueeze(1) + h_final, vel, edge_final = self.net( + h, x, edges, batch_mask=mask, edge_attr=edge_feat, + update_coords_mask=update_coords_mask) + # vel = (x_final - x) + + elif self.model == 'gvp' or self.model == 'gvp_transformer': + h_final, vel, edge_final = self.net( + h, x, edges, v=v, batch_mask=mask, edge_attr=edge_feat) + + elif self.model == 'gnn': + xh = torch.cat([x, h], dim=1) + output = self.net(xh, edges, node_mask=None, edge_attr=edge_feat) + vel = output[:, :3] + h_final = output[:, 3:] + + else: + raise NotImplementedError(f"Wrong model ({self.model})") + + # if self.condition_time: + # # Slice off last dimension which represented time. + # h_final = h_final[:, :-1] + + # decode atom and residue features + h_final_atoms = self.atom_decoder(h_final[:len(mask_atoms)]) + + if torch.any(torch.isnan(vel)) or torch.any(torch.isnan(h_final_atoms)): + if self.training: + vel[torch.isnan(vel)] = 0.0 + h_final_atoms[torch.isnan(h_final_atoms)] = 0.0 + else: + raise ValueError("NaN detected in network output") + + # predict edge type + ligand_edge_mask = (edges[0] < len(mask_atoms)) & (edges[1] < len(mask_atoms)) + edge_final = edge_final[ligand_edge_mask] + edges = edges[:, ligand_edge_mask] + + # Symmetrize + edge_logits = torch.zeros( + (len(mask_atoms), len(mask_atoms), self.hidden_nf), + device=mask_atoms.device) + edge_logits[edges[0], edges[1]] = edge_final + edge_logits = (edge_logits + edge_logits.transpose(0, 1)) * 0.5 + # edge_logits = edge_logits[lig_edge_indices[0], lig_edge_indices[1]] + + # return upper triangular elements only (matching the input) + edge_logits = edge_logits[ligand_bond_indices[0], ligand_bond_indices[1]] + # assert (edge_logits == 0).sum() == 0 + + edge_final_atoms = self.edge_decoder(edge_logits) + + # Predict torsion angles + residue_angles = None + residue_trans, residue_rot = None, None + if self.residue_decoder is not None: + h_residues = h_final[len(mask_atoms):] + vec_residues = vel[len(mask_atoms):].unsqueeze(1) + residue_angles = self.residue_decoder((h_residues, vec_residues)) + if self.predict_frames: + residue_angles, residue_frames = residue_angles + residue_trans = residue_frames[:, 0, :].squeeze(1) + residue_rot = residue_frames[:, 1, :].squeeze(1) + if self.angle_act_fn is not None: + residue_angles = self.angle_act_fn(residue_angles) + + # return vel[:len(mask_atoms)], h_final_atoms, edge_final_atoms, residue_angles, residue_trans, residue_rot + pred_ligand = {'vel': vel[:len(mask_atoms)], 'logits_h': h_final_atoms, 'logits_e': edge_final_atoms} + pred_residues = {'chi': residue_angles, 'trans': residue_trans, 'rot': residue_rot} + return pred_ligand, pred_residues + + def get_edges(self, batch_mask_ligand, batch_mask_pocket, x_ligand, + x_pocket, bond_inds_ligand=None, bond_inds_pocket=None, + bond_feat_ligand=None, bond_feat_pocket=None, self_edges=False): + + # Adjacency matrix + adj_ligand = batch_mask_ligand[:, None] == batch_mask_ligand[None, :] + adj_pocket = batch_mask_pocket[:, None] == batch_mask_pocket[None, :] + adj_cross = batch_mask_ligand[:, None] == batch_mask_pocket[None, :] + + if self.edge_cutoff_l is not None: + adj_ligand = adj_ligand & (torch.cdist(x_ligand, x_ligand) <= self.edge_cutoff_l) + + # Add missing bonds if they got removed + adj_ligand[bond_inds_ligand[0], bond_inds_ligand[1]] = True + + if self.edge_cutoff_p is not None and len(x_pocket) > 0: + adj_pocket = adj_pocket & (torch.cdist(x_pocket, x_pocket) <= self.edge_cutoff_p) + + # Add missing bonds if they got removed + adj_pocket[bond_inds_pocket[0], bond_inds_pocket[1]] = True + + if self.edge_cutoff_i is not None and len(x_pocket) > 0: + adj_cross = adj_cross & (torch.cdist(x_ligand, x_pocket) <= self.edge_cutoff_i) + + adj = torch.cat((torch.cat((adj_ligand, adj_cross), dim=1), + torch.cat((adj_cross.T, adj_pocket), dim=1)), dim=0) + + if not self_edges: + adj = adj ^ torch.eye(*adj.size(), out=torch.empty_like(adj)) + + # # ensure that edge definition is consistent if bonds are provided (for loss computation) + # if bond_inds_ligand is not None: + # # remove ligand edges + # adj[:adj_ligand.size(0), :adj_ligand.size(1)] = False + # edges = torch.stack(torch.where(adj), dim=0) + # # add ligand edges back with original definition + # edges = torch.cat([bond_inds_ligand, edges], dim=-1) + # else: + # edges = torch.stack(torch.where(adj), dim=0) + + # Feature matrix + ligand_nobond_onehot = F.one_hot(torch.tensor( + self.bond_dict['NOBOND'], device=bond_feat_ligand.device), + num_classes=self.ligand_bond_encoder[0].in_features) + ligand_nobond_emb = self.ligand_bond_encoder( + ligand_nobond_onehot.to(FLOAT_TYPE)) + feat_ligand = ligand_nobond_emb.repeat(*adj_ligand.shape, 1) + feat_ligand[bond_inds_ligand[0], bond_inds_ligand[1]] = bond_feat_ligand + + if len(adj_pocket) > 0: + pocket_nobond_onehot = F.one_hot(torch.tensor( + self.pocket_bond_dict['NOBOND'], device=bond_feat_pocket.device), + num_classes=self.pocket_bond_nf) + pocket_nobond_emb = self.pocket_bond_encoder( + pocket_nobond_onehot.to(FLOAT_TYPE)) + feat_pocket = pocket_nobond_emb.repeat(*adj_pocket.shape, 1) + feat_pocket[bond_inds_pocket[0], bond_inds_pocket[1]] = bond_feat_pocket + + feat_cross = self.cross_emb.repeat(*adj_cross.shape, 1) + + feats = torch.cat((torch.cat((feat_ligand, feat_cross), dim=1), + torch.cat((feat_cross.transpose(0, 1), feat_pocket), dim=1)), dim=0) + else: + feats = feat_ligand + + # Return results + edges = torch.stack(torch.where(adj), dim=0) + edge_feat = feats[edges[0], edges[1]] + + return edges, edge_feat diff --git a/src/model/dynamics_hetero.py b/src/model/dynamics_hetero.py new file mode 100644 index 0000000000000000000000000000000000000000..22631e95d3d1141a56c896101b54a28dc562d95c --- /dev/null +++ b/src/model/dynamics_hetero.py @@ -0,0 +1,1008 @@ +from collections.abc import Iterable +from collections import defaultdict +from functools import partial +import functools +import warnings +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch_scatter import scatter_mean +from torch_geometric.nn import MessagePassing +from torch_geometric.nn.module_dict import ModuleDict +from torch_geometric.utils.hetero import check_add_self_loops +try: + from torch_geometric.nn.conv.hgt_conv import group +except ImportError as e: + from torch_geometric.nn.conv.hetero_conv import group + +from src.model.dynamics import DynamicsBase +from src.model import gvp +from src.model.gvp import GVP, _rbf, _normalize, tuple_index, tuple_sum, _split, tuple_cat, _merge + + +class MyModuleDict(nn.ModuleDict): + def __init__(self, modules): + # a mapping (dictionary) of (string: module) or an iterable of key-value pairs of type (string, module) + if isinstance(modules, dict): + super().__init__({str(k): v for k, v in modules.items()}) + else: + raise NotImplementedError + + def __getitem__(self, key): + return super().__getitem__(str(key)) + + def __setitem__(self, key, value): + super().__setitem__(str(key), value) + + def __delitem__(self, key): + super().__delitem__(str(key)) + + +class MyHeteroConv(nn.Module): + """ + Implementation from PyG 2.2.0 with minor changes. + Override forward pass to control the final aggregation + Ref.: https://pytorch-geometric.readthedocs.io/en/2.2.0/_modules/torch_geometric/nn/conv/hetero_conv.html + """ + def __init__(self, convs, aggr="sum"): + self.vo = {} + for k, module in convs.items(): + dst = k[-1] + if dst not in self.vo: + self.vo[dst] = module.vo + else: + assert self.vo[dst] == module.vo + + # from the original implementation in PyTorch Geometric + super().__init__() + + for edge_type, module in convs.items(): + check_add_self_loops(module, [edge_type]) + + src_node_types = set([key[0] for key in convs.keys()]) + dst_node_types = set([key[-1] for key in convs.keys()]) + if len(src_node_types - dst_node_types) > 0: + warnings.warn( + f"There exist node types ({src_node_types - dst_node_types}) " + f"whose representations do not get updated during message " + f"passing as they do not occur as destination type in any " + f"edge type. This may lead to unexpected behaviour.") + + self.convs = ModuleDict({'__'.join(k): v for k, v in convs.items()}) + self.aggr = aggr + + def reset_parameters(self): + for conv in self.convs.values(): + conv.reset_parameters() + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(num_relations={len(self.convs)})' + + def forward( + self, + x_dict, + edge_index_dict, + *args_dict, + **kwargs_dict, + ): + r""" + Args: + x_dict (Dict[str, Tensor]): A dictionary holding node feature + information for each individual node type. + edge_index_dict (Dict[Tuple[str, str, str], Tensor]): A dictionary + holding graph connectivity information for each individual + edge type. + *args_dict (optional): Additional forward arguments of invididual + :class:`torch_geometric.nn.conv.MessagePassing` layers. + **kwargs_dict (optional): Additional forward arguments of + individual :class:`torch_geometric.nn.conv.MessagePassing` + layers. + For example, if a specific GNN layer at edge type + :obj:`edge_type` expects edge attributes :obj:`edge_attr` as a + forward argument, then you can pass them to + :meth:`~torch_geometric.nn.conv.HeteroConv.forward` via + :obj:`edge_attr_dict = { edge_type: edge_attr }`. + """ + out_dict = defaultdict(list) + out_dict_edge = {} + for edge_type, edge_index in edge_index_dict.items(): + src, rel, dst = edge_type + + str_edge_type = '__'.join(edge_type) + if str_edge_type not in self.convs: + continue + + args = [] + for value_dict in args_dict: + if edge_type in value_dict: + args.append(value_dict[edge_type]) + elif src == dst and src in value_dict: + args.append(value_dict[src]) + elif src in value_dict or dst in value_dict: + args.append( + (value_dict.get(src, None), value_dict.get(dst, None))) + + kwargs = {} + for arg, value_dict in kwargs_dict.items(): + arg = arg[:-5] # `{*}_dict` + if edge_type in value_dict: + kwargs[arg] = value_dict[edge_type] + elif src == dst and src in value_dict: + kwargs[arg] = value_dict[src] + elif src in value_dict or dst in value_dict: + kwargs[arg] = (value_dict.get(src, None), + value_dict.get(dst, None)) + + conv = self.convs[str_edge_type] + + if src == dst: + out = conv(x_dict[src], edge_index, *args, **kwargs) + else: + out = conv((x_dict[src], x_dict[dst]), edge_index, *args, + **kwargs) + + if isinstance(out, (tuple, list)): + out, out_edge = out + out_dict_edge[edge_type] = out_edge + + out_dict[dst].append(out) + + for key, value in out_dict.items(): + out_dict[key] = group(value, self.aggr) + out_dict[key] = _split(out_dict[key], self.vo[key]) + + return out_dict if len(out_dict_edge) <= 0 else out_dict, out_dict_edge + + +class GVPHeteroConv(MessagePassing): + ''' + Graph convolution / message passing with Geometric Vector Perceptrons. + Takes in a graph with node and edge embeddings, + and returns new node embeddings. + + This does NOT do residual updates and pointwise feedforward layers + ---see `GVPConvLayer`. + + :param in_dims: input node embedding dimensions (n_scalar, n_vector) + :param out_dims: output node embedding dimensions (n_scalar, n_vector) + :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) + :param n_layers: number of GVPs in the message function + :param module_list: preconstructed message function, overrides n_layers + :param aggr: should be "add" if some incoming edges are masked, as in + a masked autoregressive decoder architecture, otherwise "mean" + :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs + :param vector_gate: whether to use vector gating. + (vector_act will be used as sigma^+ in vector gating if `True`) + :param update_edge_attr: whether to compute an updated edge representation + ''' + + def __init__(self, in_dims, out_dims, edge_dims, in_dims_other=None, + n_layers=3, module_list=None, aggr="mean", + activations=(F.relu, torch.sigmoid), vector_gate=False, + update_edge_attr=False): + super(GVPHeteroConv, self).__init__(aggr=aggr) + + if in_dims_other is None: + in_dims_other = in_dims + + self.si, self.vi = in_dims + self.si_other, self.vi_other = in_dims_other + self.so, self.vo = out_dims + self.se, self.ve = edge_dims + self.update_edge_attr = update_edge_attr + + GVP_ = functools.partial(GVP, + activations=activations, + vector_gate=vector_gate) + + def get_modules(module_list, out_dims): + module_list = module_list or [] + if not module_list: + if n_layers == 1: + module_list.append( + GVP_((self.si + self.si_other + self.se, self.vi + self.vi_other + self.ve), + (self.so, self.vo), activations=(None, None))) + else: + module_list.append( + GVP_((self.si + self.si_other + self.se, self.vi + self.vi_other + self.ve), + out_dims) + ) + for i in range(n_layers - 2): + module_list.append(GVP_(out_dims, out_dims)) + module_list.append(GVP_(out_dims, out_dims, + activations=(None, None))) + return nn.Sequential(*module_list) + + self.message_func = get_modules(module_list, out_dims) + self.edge_func = get_modules(module_list, edge_dims) if self.update_edge_attr else None + + def forward(self, x, edge_index, edge_attr): + ''' + :param x: tuple (s, V) of `torch.Tensor` + :param edge_index: array of shape [2, n_edges] + :param edge_attr: tuple (s, V) of `torch.Tensor` + ''' + elem_0, elem_1 = x + if isinstance(elem_0, (tuple, list)): + assert isinstance(elem_1, (tuple, list)) + x_s = (elem_0[0], elem_1[0]) + x_v = (elem_0[1].reshape(elem_0[1].shape[0], 3 * elem_0[1].shape[1]), + elem_1[1].reshape(elem_1[1].shape[0], 3 * elem_1[1].shape[1])) + else: + x_s, x_v = elem_0, elem_1 + x_v = x_v.reshape(x_v.shape[0], 3 * x_v.shape[1]) + + message = self.propagate(edge_index, s=x_s, v=x_v, edge_attr=edge_attr) + + if self.update_edge_attr: + if isinstance(x_s, (tuple, list)): + s_i, s_j = x_s[1][edge_index[1]], x_s[0][edge_index[0]] + else: + s_i, s_j = x_s[edge_index[1]], x_s[edge_index[0]] + + if isinstance(x_v, (tuple, list)): + v_i, v_j = x_v[1][edge_index[1]], x_v[0][edge_index[0]] + else: + v_i, v_j = x_v[edge_index[1]], x_v[edge_index[0]] + + edge_out = self.edge_attr(s_i, v_i, s_j, v_j, edge_attr) + # return _split(message, self.vo), edge_out + return message, edge_out + else: + # return _split(message, self.vo) + return message + + def message(self, s_i, v_i, s_j, v_j, edge_attr): + v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3) + v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3) + message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i)) + message = self.message_func(message) + return _merge(*message) + + def edge_attr(self, s_i, v_i, s_j, v_j, edge_attr): + v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3) + v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3) + message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i)) + return self.edge_func(message) + + +class GVPHeteroConvLayer(nn.Module): + """ + Full graph convolution / message passing layer with + Geometric Vector Perceptrons. Residually updates node embeddings with + aggregated incoming messages, applies a pointwise feedforward + network to node embeddings, and returns updated node embeddings. + + To only compute the aggregated messages, see `GVPConv`. + + :param conv_dims: dictionary defining (src_dim, dst_dim, edge_dim) for each edge type + """ + def __init__(self, conv_dims, + n_message=3, n_feedforward=2, drop_rate=.1, + activations=(F.relu, torch.sigmoid), vector_gate=False, + update_edge_attr=False, ln_vector_weight=False): + + super(GVPHeteroConvLayer, self).__init__() + self.update_edge_attr = update_edge_attr + + gvp_conv = partial(GVPHeteroConv, + n_layers=n_message, + aggr="sum", + activations=activations, + vector_gate=vector_gate, + update_edge_attr=update_edge_attr) + + def get_feedforward(n_dims): + GVP_ = partial(GVP, activations=activations, vector_gate=vector_gate) + + ff_func = [] + if n_feedforward == 1: + ff_func.append(GVP_(n_dims, n_dims, activations=(None, None))) + else: + hid_dims = 4 * n_dims[0], 2 * n_dims[1] + ff_func.append(GVP_(n_dims, hid_dims)) + for i in range(n_feedforward - 2): + ff_func.append(GVP_(hid_dims, hid_dims)) + ff_func.append(GVP_(hid_dims, n_dims, activations=(None, None))) + return nn.Sequential(*ff_func) + + # self.conv = HeteroConv({k: gvp_conv(*dims) for k, dims in conv_dims.items()}, aggr='sum') + self.conv = MyHeteroConv({k: gvp_conv(*dims) for k, dims in conv_dims.items()}, aggr='sum') + + node_dims = {k[-1]: dims[1] for k, dims in conv_dims.items()} + self.norm0 = MyModuleDict({k: gvp.LayerNorm(dims, ln_vector_weight) for k, dims in node_dims.items()}) + self.dropout0 = MyModuleDict({k: gvp.Dropout(drop_rate) for k, dims in node_dims.items()}) + self.ff_func = MyModuleDict({k: get_feedforward(dims) for k, dims in node_dims.items()}) + self.norm1 = MyModuleDict({k: gvp.LayerNorm(dims, ln_vector_weight) for k, dims in node_dims.items()}) + self.dropout1 = MyModuleDict({k: gvp.Dropout(drop_rate) for k, dims in node_dims.items()}) + + if self.update_edge_attr: + self.edge_norm0 = MyModuleDict({k: gvp.LayerNorm(dims[2], ln_vector_weight) for k, dims in conv_dims.items()}) + self.edge_dropout0 = MyModuleDict({k: gvp.Dropout(drop_rate) for k, dims in conv_dims.items()}) + self.edge_ff = MyModuleDict({k: get_feedforward(dims[2]) for k, dims in conv_dims.items()}) + self.edge_norm1 = MyModuleDict({k: gvp.LayerNorm(dims[2], ln_vector_weight) for k, dims in conv_dims.items()}) + self.edge_dropout1 = MyModuleDict({k: gvp.Dropout(drop_rate) for k, dims in conv_dims.items()}) + + def forward(self, x_dict, edge_index_dict, edge_attr_dict, node_mask_dict=None): + ''' + :param x: tuple (s, V) of `torch.Tensor` + :param edge_index: array of shape [2, n_edges] + :param edge_attr: tuple (s, V) of `torch.Tensor` + :param node_mask: array of type `bool` to index into the first + dim of node embeddings (s, V). If not `None`, only + these nodes will be updated. + ''' + + dh_dict = self.conv(x_dict, edge_index_dict, edge_attr_dict) + + if self.update_edge_attr: + dh_dict, de_dict = dh_dict + + for k, edge_attr in edge_attr_dict.items(): + de = de_dict[k] + + edge_attr = self.edge_norm0[k](tuple_sum(edge_attr, self.edge_dropout0[k](de))) + de = self.edge_ff[k](edge_attr) + edge_attr = self.edge_norm1[k](tuple_sum(edge_attr, self.edge_dropout1[k](de))) + + edge_attr_dict[k] = edge_attr + + for k, x in x_dict.items(): + dh = dh_dict[k] + node_mask = None if node_mask_dict is None else node_mask_dict[k] + + if node_mask is not None: + x_ = x + x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask) + + x = self.norm0[k](tuple_sum(x, self.dropout0[k](dh))) + + dh = self.ff_func[k](x) + x = self.norm1[k](tuple_sum(x, self.dropout1[k](dh))) + + if node_mask is not None: + x_[0][node_mask], x_[1][node_mask] = x[0], x[1] + x = x_ + + x_dict[k] = x + + return (x_dict, edge_attr_dict) if self.update_edge_attr else x_dict + + +class GVPModel(torch.nn.Module): + """ + GVP-GNN model + inspired by: https://github.com/drorlab/gvp-pytorch/blob/main/gvp/models.py + and: https://github.com/drorlab/gvp-pytorch/blob/82af6b22eaf8311c15733117b0071408d24ed877/gvp/atom3d.py#L115 + """ + def __init__(self, + node_in_dim_ligand, node_in_dim_pocket, + edge_in_dim_ligand, edge_in_dim_pocket, edge_in_dim_interaction, + node_h_dim_ligand, node_h_dim_pocket, + edge_h_dim_ligand, edge_h_dim_pocket, edge_h_dim_interaction, + node_out_dim_ligand=None, node_out_dim_pocket=None, + edge_out_dim_ligand=None, edge_out_dim_pocket=None, edge_out_dim_interaction=None, + num_layers=3, drop_rate=0.1, vector_gate=False, update_edge_attr=False): + + super(GVPModel, self).__init__() + + self.update_edge_attr = update_edge_attr + + self.node_in = nn.ModuleDict({ + 'ligand': GVP(node_in_dim_ligand, node_h_dim_ligand, activations=(None, None), vector_gate=vector_gate), + 'pocket': GVP(node_in_dim_pocket, node_h_dim_pocket, activations=(None, None), vector_gate=vector_gate), + }) + # self.edge_in = MyModuleDict({ + # ('ligand', 'ligand'): GVP(edge_in_dim_ligand, edge_h_dim_ligand, activations=(None, None), vector_gate=vector_gate), + # ('pocket', 'pocket'): GVP(edge_in_dim_pocket, edge_h_dim_pocket, activations=(None, None), vector_gate=vector_gate), + # ('ligand', 'pocket'): GVP(edge_in_dim_interaction, edge_h_dim_interaction, activations=(None, None), vector_gate=vector_gate), + # ('pocket', 'ligand'): GVP(edge_in_dim_interaction, edge_h_dim_interaction, activations=(None, None), vector_gate=vector_gate), + # }) + self.edge_in = MyModuleDict({ + ('ligand', '', 'ligand'): GVP(edge_in_dim_ligand, edge_h_dim_ligand, activations=(None, None), vector_gate=vector_gate), + ('pocket', '', 'pocket'): GVP(edge_in_dim_pocket, edge_h_dim_pocket, activations=(None, None), vector_gate=vector_gate), + ('ligand', '', 'pocket'): GVP(edge_in_dim_interaction, edge_h_dim_interaction, activations=(None, None), vector_gate=vector_gate), + ('pocket', '', 'ligand'): GVP(edge_in_dim_interaction, edge_h_dim_interaction, activations=(None, None), vector_gate=vector_gate), + }) + + # conv_dims = { + # ('ligand', 'ligand'): (node_h_dim_ligand, node_h_dim_ligand, edge_h_dim_ligand), + # ('pocket', 'pocket'): (node_h_dim_pocket, node_h_dim_pocket, edge_h_dim_pocket), + # ('ligand', 'pocket'): (node_h_dim_ligand, node_h_dim_pocket, edge_h_dim_interaction), + # ('pocket', 'ligand'): (node_h_dim_pocket, node_h_dim_ligand, edge_h_dim_interaction), + # } + conv_dims = { + ('ligand', '', 'ligand'): (node_h_dim_ligand, node_h_dim_ligand, edge_h_dim_ligand), + ('pocket', '', 'pocket'): (node_h_dim_pocket, node_h_dim_pocket, edge_h_dim_pocket), + ('ligand', '', 'pocket'): (node_h_dim_ligand, node_h_dim_pocket, edge_h_dim_interaction, node_h_dim_pocket), + ('pocket', '', 'ligand'): (node_h_dim_pocket, node_h_dim_ligand, edge_h_dim_interaction, node_h_dim_ligand), + } + + self.layers = nn.ModuleList( + GVPHeteroConvLayer(conv_dims, + drop_rate=drop_rate, + update_edge_attr=self.update_edge_attr, + activations=(F.relu, None), + vector_gate=vector_gate, + ln_vector_weight=True) + for _ in range(num_layers)) + + self.node_out = nn.ModuleDict({ + 'ligand': GVP(node_h_dim_ligand, node_out_dim_ligand, activations=(None, None), vector_gate=vector_gate), + 'pocket': GVP(node_h_dim_pocket, node_out_dim_pocket, activations=(None, None), vector_gate=vector_gate) if node_out_dim_pocket is not None else None, + }) + # self.edge_out = MyModuleDict({ + # ('ligand', 'ligand'): GVP(edge_h_dim_ligand, edge_out_dim_ligand, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_ligand is not None else None, + # ('pocket', 'pocket'): GVP(edge_h_dim_pocket, edge_out_dim_pocket, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_pocket is not None else None, + # ('ligand', 'pocket'): GVP(edge_h_dim_interaction, edge_out_dim_interaction, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_interaction is not None else None, + # ('pocket', 'ligand'): GVP(edge_h_dim_interaction, edge_out_dim_interaction, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_interaction is not None else None, + # }) + self.edge_out = MyModuleDict({ + ('ligand', '', 'ligand'): GVP(edge_h_dim_ligand, edge_out_dim_ligand, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_ligand is not None else None, + ('pocket', '', 'pocket'): GVP(edge_h_dim_pocket, edge_out_dim_pocket, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_pocket is not None else None, + ('ligand', '', 'pocket'): GVP(edge_h_dim_interaction, edge_out_dim_interaction, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_interaction is not None else None, + ('pocket', '', 'ligand'): GVP(edge_h_dim_interaction, edge_out_dim_interaction, activations=(None, None), vector_gate=vector_gate) if edge_out_dim_interaction is not None else None, + }) + + def forward(self, node_attr, batch_mask, edge_index, edge_attr): + + # to hidden dimension + for k in node_attr.keys(): + node_attr[k] = self.node_in[k](node_attr[k]) + + for k in edge_attr.keys(): + edge_attr[k] = self.edge_in[k](edge_attr[k]) + + # convolutions + for layer in self.layers: + out = layer(node_attr, edge_index, edge_attr) + if self.update_edge_attr: + node_attr, edge_attr = out + else: + node_attr = out + + # to output dimension + for k in node_attr.keys(): + node_attr[k] = self.node_out[k](node_attr[k]) \ + if self.node_out[k] is not None else None + + if self.update_edge_attr: + for k in edge_attr.keys(): + if self.edge_out[k] is not None: + edge_attr[k] = self.edge_out[k](edge_attr[k]) + + return node_attr, edge_attr + + +class DynamicsHetero(DynamicsBase): + def __init__(self, atom_nf, residue_nf, bond_dict, pocket_bond_dict, + condition_time=True, + num_rbf_time=None, + model='gvp', + model_params=None, + edge_cutoff_ligand=None, + edge_cutoff_pocket=None, + edge_cutoff_interaction=None, + predict_angles=False, + predict_frames=False, + add_cycle_counts=False, + add_spectral_feat=False, + add_nma_feat=False, + reflection_equiv=False, + d_max=15.0, + num_rbf_dist=16, + self_conditioning=False, + augment_residue_sc=False, + augment_ligand_sc=False, + add_chi_as_feature=False, + angle_act_fn=False, + add_all_atom_diff=False, + predict_confidence=False): + + super().__init__( + predict_angles=predict_angles, + predict_frames=predict_frames, + add_cycle_counts=add_cycle_counts, + add_spectral_feat=add_spectral_feat, + self_conditioning=self_conditioning, + augment_residue_sc=augment_residue_sc, + augment_ligand_sc=augment_ligand_sc + ) + + self.model = model + self.edge_cutoff_l = edge_cutoff_ligand + self.edge_cutoff_p = edge_cutoff_pocket + self.edge_cutoff_i = edge_cutoff_interaction + self.bond_dict = bond_dict + self.pocket_bond_dict = pocket_bond_dict + self.bond_nf = len(bond_dict) + self.pocket_bond_nf = len(pocket_bond_dict) + # self.edge_dim = edge_dim + self.add_nma_feat = add_nma_feat + self.add_chi_as_feature = add_chi_as_feature + self.add_all_atom_diff = add_all_atom_diff + self.condition_time = condition_time + self.predict_confidence = predict_confidence + + # edge encoding params + self.reflection_equiv = reflection_equiv + self.d_max = d_max + self.num_rbf = num_rbf_dist + + + # Output dimensions dimensions, always tuple (scalar, vector) + _atom_out = (atom_nf[0], 1) if isinstance(atom_nf, Iterable) else (atom_nf, 1) + _residue_out = (0, 0) + + if self.predict_confidence: + _atom_out = tuple_sum(_atom_out, (1, 0)) + + if self.predict_angles: + _residue_out = tuple_sum(_residue_out, (5, 0)) + + if self.predict_frames: + _residue_out = tuple_sum(_residue_out, (3, 1)) + + + # Input dimensions dimensions, always tuple (scalar, vector) + assert isinstance(atom_nf, int), "expected: element onehot" + _atom_in = (atom_nf, 0) + assert isinstance(residue_nf, Iterable), "expected: (AA-onehot, vectors to atoms)" + _residue_in = tuple(residue_nf) + _residue_atom_dim = residue_nf[1] + + if self.add_cycle_counts: + _atom_in = tuple_sum(_atom_in, (3, 0)) + if self.add_spectral_feat: + _atom_in = tuple_sum(_atom_in, (5, 0)) + + if self.add_nma_feat: + _residue_in = tuple_sum(_residue_in, (0, 5)) + + if self.add_chi_as_feature: + _residue_in = tuple_sum(_residue_in, (5, 0)) + + if self.condition_time: + self.embed_time = num_rbf_time is not None + self.time_dim = num_rbf_time if self.embed_time else 1 + + _atom_in = tuple_sum(_atom_in, (self.time_dim, 0)) + _residue_in = tuple_sum(_residue_in, (self.time_dim, 0)) + else: + print('Warning: dynamics model is NOT conditioned on time.') + + if self.self_conditioning: + _atom_in = tuple_sum(_atom_in, _atom_out) + _residue_in = tuple_sum(_residue_in, _residue_out) + + if self.augment_ligand_sc: + _atom_in = tuple_sum(_atom_in, (0, 1)) + + if self.augment_residue_sc: + assert self.predict_angles + _residue_in = tuple_sum(_residue_in, (0, _residue_atom_dim)) + + + # Edge output dimensions, always tuple (scalar, vector) + _edge_ligand_out = (self.bond_nf, 0) + _edge_ligand_before_symmetrization = (model_params.edge_h_dim[0], 0) + + + # Edge input dimensions dimensions, always tuple (scalar, vector) + _edge_ligand_in = (self.bond_nf + self.num_rbf, 1 if self.reflection_equiv else 2) + _edge_ligand_in = tuple_sum(_edge_ligand_in, _atom_in) # src node + _edge_ligand_in = tuple_sum(_edge_ligand_in, _atom_in) # dst node + + if self_conditioning: + _edge_ligand_in = tuple_sum(_edge_ligand_in, _edge_ligand_out) + + _n_dist_residue = _residue_atom_dim ** 2 if self.add_all_atom_diff else 1 + _edge_pocket_in = (_n_dist_residue * self.num_rbf + self.pocket_bond_nf, _n_dist_residue) + _edge_pocket_in = tuple_sum(_edge_pocket_in, _residue_in) # src node + _edge_pocket_in = tuple_sum(_edge_pocket_in, _residue_in) # dst node + + _n_dist_interaction = _residue_atom_dim if self.add_all_atom_diff else 1 + _edge_interaction_in = (_n_dist_interaction * self.num_rbf, _n_dist_interaction) + _edge_interaction_in = tuple_sum(_edge_interaction_in, _atom_in) # atom node + _edge_interaction_in = tuple_sum(_edge_interaction_in, _residue_in) # residue node + + + # Embeddings for newly added edges + _ligand_nobond_nf = self.bond_nf + _edge_ligand_out[0] if self.self_conditioning else self.bond_nf + self.ligand_nobond_emb = nn.Parameter(torch.zeros(_ligand_nobond_nf), requires_grad=True) + self.pocket_nobond_emb = nn.Parameter(torch.zeros(self.pocket_bond_nf), requires_grad=True) + + # for access in self-conditioning + self.atom_out_dim = _atom_out + self.residue_out_dim = _residue_out + self.edge_out_dim = _edge_ligand_out + + if model == 'gvp': + + self.net = GVPModel( + node_in_dim_ligand=_atom_in, + node_in_dim_pocket=_residue_in, + edge_in_dim_ligand=_edge_ligand_in, + edge_in_dim_pocket=_edge_pocket_in, + edge_in_dim_interaction=_edge_interaction_in, + node_h_dim_ligand=model_params.node_h_dim, + node_h_dim_pocket=model_params.node_h_dim, + edge_h_dim_ligand=model_params.edge_h_dim, + edge_h_dim_pocket=model_params.edge_h_dim, + edge_h_dim_interaction=model_params.edge_h_dim, + node_out_dim_ligand=_atom_out, + node_out_dim_pocket=_residue_out, + edge_out_dim_ligand=_edge_ligand_before_symmetrization, + edge_out_dim_pocket=None, + edge_out_dim_interaction=None, + num_layers=model_params.n_layers, + drop_rate=model_params.dropout, + vector_gate=model_params.vector_gate, + update_edge_attr=True + ) + + else: + raise NotImplementedError(f"{model} is not available") + + assert _edge_ligand_out[1] == 0 + assert _edge_ligand_before_symmetrization[1] == 0 + self.edge_decoder = nn.Sequential( + nn.Linear(_edge_ligand_before_symmetrization[0], _edge_ligand_before_symmetrization[0]), + torch.nn.SiLU(), + nn.Linear(_edge_ligand_before_symmetrization[0], _edge_ligand_out[0]) + ) + + if angle_act_fn is None: + self.angle_act_fn = None + elif angle_act_fn == 'tanh': + self.angle_act_fn = lambda x: np.pi * F.tanh(x) + else: + raise NotImplementedError(f"Angle activation {angle_act_fn} not available") + + def _forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None, + h_atoms_sc=None, e_atoms_sc=None, h_residues_sc=None): + """ + :param x_atoms: + :param h_atoms: + :param mask_atoms: + :param pocket: must contain keys: 'x', 'one_hot', 'mask', 'bonds' and 'bond_one_hot' + :param t: + :param bonds_ligand: tuple - bond indices (2, n_bonds) & bond types (n_bonds, bond_nf) + :param h_atoms_sc: additional node feature for self-conditioning, (s, V) + :param e_atoms_sc: additional edge feature for self-conditioning, only scalar + :param h_residues_sc: additional node feature for self-conditioning, tensor or tuple + :return: + """ + x_residues, h_residues, mask_residues = pocket['x'], pocket['one_hot'], pocket['mask'] + if 'bonds' in pocket: + bonds_pocket = (pocket['bonds'], pocket['bond_one_hot']) + else: + bonds_pocket = None + + if self.add_chi_as_feature: + h_residues = torch.cat([h_residues, pocket['chi'][:, :5]], dim=-1) + + if 'v' in pocket: + v_residues = pocket['v'] + if self.add_nma_feat: + v_residues = torch.cat([v_residues, pocket['nma_vec']], dim=1) + h_residues = (h_residues, v_residues) + + # NOTE: 'bond' denotes one-directional edges and 'edge' means bi-directional + # get graph edges and edge attributes + if bonds_ligand is not None: + + ligand_bond_indices = bonds_ligand[0] + + # make sure messages are passed both ways + ligand_edge_indices = torch.cat( + [bonds_ligand[0], bonds_ligand[0].flip(dims=[0])], dim=1) + ligand_edge_types = torch.cat([bonds_ligand[1], bonds_ligand[1]], dim=0) + if e_atoms_sc is not None: + e_atoms_sc = torch.cat([e_atoms_sc, e_atoms_sc], dim=0) + + # add auxiliary features to ligand nodes + extra_features = self.compute_extra_features( + mask_atoms, ligand_edge_indices, ligand_edge_types.argmax(-1)) + h_atoms = torch.cat([h_atoms, extra_features], dim=-1) + + if bonds_pocket is not None: + # make sure messages are passed both ways + pocket_edge_indices = torch.cat( + [bonds_pocket[0], bonds_pocket[0].flip(dims=[0])], dim=1) + pocket_edge_types = torch.cat([bonds_pocket[1], bonds_pocket[1]], dim=0) + + + # Self-conditioning + if h_atoms_sc is not None: + h_atoms = (torch.cat([h_atoms, h_atoms_sc[0]], dim=-1), h_atoms_sc[1]) + + if e_atoms_sc is not None: + ligand_edge_types = torch.cat([ligand_edge_types, e_atoms_sc], dim=-1) + + if h_residues_sc is not None: + # if self.augment_residue_sc: + if isinstance(h_residues_sc, tuple): + h_residues = (torch.cat([h_residues[0], h_residues_sc[0]], dim=-1), + torch.cat([h_residues[1], h_residues_sc[1]], dim=1)) + else: + h_residues = (torch.cat([h_residues[0], h_residues_sc], dim=-1), + h_residues[1]) + + if self.condition_time: + if self.embed_time: + t = _rbf(t.squeeze(-1), D_min=0.0, D_max=1.0, D_count=self.time_dim, device=t.device) + if isinstance(h_atoms, tuple) : + h_atoms = (torch.cat([h_atoms[0], t[mask_atoms]], dim=1), h_atoms[1]) + else: + h_atoms = torch.cat([h_atoms, t[mask_atoms]], dim=1) + h_residues = (torch.cat([h_residues[0], t[mask_residues]], dim=1), h_residues[1]) + + empty_pocket = (len(pocket['x']) == 0) + + # Process edges and encode in shared feature space + edge_index_dict, edge_attr_dict = self.get_edges( + x_atoms, h_atoms, mask_atoms, ligand_edge_indices, ligand_edge_types, + x_residues, h_residues, mask_residues, pocket['v'], pocket_edge_indices, pocket_edge_types, + empty_pocket=empty_pocket + ) + + if not empty_pocket: + node_attr_dict = { + 'ligand': h_atoms, + 'pocket': h_residues, + } + batch_mask_dict = { + 'ligand': mask_atoms, + 'pocket': mask_residues, + } + else: + node_attr_dict = {'ligand': h_atoms} + batch_mask_dict = {'ligand': mask_atoms} + + if self.model == 'gvp' or self.model == 'gvp_transformer': + out_node_attr, out_edge_attr = self.net( + node_attr_dict, batch_mask_dict, edge_index_dict, edge_attr_dict) + + else: + raise NotImplementedError(f"Wrong model ({self.model})") + + h_final_atoms = out_node_attr['ligand'][0] + vel = out_node_attr['ligand'][1].squeeze(-2) + + if torch.any(torch.isnan(vel)) or torch.any(torch.isnan(h_final_atoms)): + if self.training: + vel[torch.isnan(vel)] = 0.0 + h_final_atoms[torch.isnan(h_final_atoms)] = 0.0 + else: + raise ValueError("NaN detected in network output") + + # predict edge type + edge_final = out_edge_attr[('ligand', '', 'ligand')] + edges = edge_index_dict[('ligand', '', 'ligand')] + + # Symmetrize + edge_logits = torch.zeros( + (len(mask_atoms), len(mask_atoms), edge_final.size(-1)), + device=mask_atoms.device) + edge_logits[edges[0], edges[1]] = edge_final + edge_logits = (edge_logits + edge_logits.transpose(0, 1)) * 0.5 + + # return upper triangular elements only (matching the input) + edge_logits = edge_logits[ligand_bond_indices[0], ligand_bond_indices[1]] + # assert (edge_logits == 0).sum() == 0 + + edge_final_atoms = self.edge_decoder(edge_logits) + + pred_ligand = {'vel': vel, 'logits_e': edge_final_atoms} + + if self.predict_confidence: + pred_ligand['logits_h'] = h_final_atoms[:, :-1] + pred_ligand['uncertainty_vel'] = F.softplus(h_final_atoms[:, -1]) + else: + pred_ligand['logits_h'] = h_final_atoms + + pred_residues = {} + + # Predict torsion angles + if self.predict_angles and self.predict_frames: + residue_s, residue_v = out_node_attr['pocket'] + pred_residues['chi'] = residue_s[:, :5] + pred_residues['rot'] = residue_s[:, 5:] + pred_residues['trans'] = residue_v.squeeze(1) + + elif self.predict_frames: + pred_residues['rot'], pred_residues['trans'] = out_node_attr['pocket'] + pred_residues['trans'] = pred_residues['trans'].squeeze(1) + + elif self.predict_angles: + pred_residues['chi'] = out_node_attr['pocket'] + + if self.angle_act_fn is not None and 'chi' in pred_residues: + pred_residues['chi'] = self.angle_act_fn(pred_residues['chi']) + + return pred_ligand, pred_residues + + def get_edges(self, x_ligand, h_ligand, batch_mask_ligand, edges_ligand, edge_feat_ligand, + x_pocket, h_pocket, batch_mask_pocket, atom_vectors_pocket, edges_pocket, edge_feat_pocket, + self_edges=False, empty_pocket=False): + + # Adjacency matrix + adj_ligand = batch_mask_ligand[:, None] == batch_mask_ligand[None, :] + adj_pocket = batch_mask_pocket[:, None] == batch_mask_pocket[None, :] + adj_cross = batch_mask_ligand[:, None] == batch_mask_pocket[None, :] + + if self.edge_cutoff_l is not None: + adj_ligand = adj_ligand & (torch.cdist(x_ligand, x_ligand) <= self.edge_cutoff_l) + + # Add missing bonds if they got removed + adj_ligand[edges_ligand[0], edges_ligand[1]] = True + + if not self_edges: + adj_ligand = adj_ligand ^ torch.eye(*adj_ligand.size(), out=torch.empty_like(adj_ligand)) + + if self.edge_cutoff_p is not None and not empty_pocket: + adj_pocket = adj_pocket & (torch.cdist(x_pocket, x_pocket) <= self.edge_cutoff_p) + + # Add missing bonds if they got removed + adj_pocket[edges_pocket[0], edges_pocket[1]] = True + + if not self_edges: + adj_pocket = adj_pocket ^ torch.eye(*adj_pocket.size(), out=torch.empty_like(adj_pocket)) + + if self.edge_cutoff_i is not None and not empty_pocket: + adj_cross = adj_cross & (torch.cdist(x_ligand, x_pocket) <= self.edge_cutoff_i) + + # ligand-ligand edge features + edges_ligand_updated = torch.stack(torch.where(adj_ligand), dim=0) + feat_ligand = self.ligand_nobond_emb.repeat(*adj_ligand.shape, 1) + feat_ligand[edges_ligand[0], edges_ligand[1]] = edge_feat_ligand + feat_ligand = feat_ligand[edges_ligand_updated[0], edges_ligand_updated[1]] + feat_ligand = self.ligand_edge_features(h_ligand, x_ligand, edges_ligand_updated, batch_mask_ligand, edge_attr=feat_ligand) + + if not empty_pocket: + # residue-residue edge features + edges_pocket_updated = torch.stack(torch.where(adj_pocket), dim=0) + feat_pocket = self.pocket_nobond_emb.repeat(*adj_pocket.shape, 1) + feat_pocket[edges_pocket[0], edges_pocket[1]] = edge_feat_pocket + feat_pocket = feat_pocket[edges_pocket_updated[0], edges_pocket_updated[1]] + feat_pocket = self.pocket_edge_features(h_pocket, x_pocket, atom_vectors_pocket, edges_pocket_updated, edge_attr=feat_pocket) + + # ligand-residue edge features + edges_cross = torch.stack(torch.where(adj_cross), dim=0) + feat_cross = self.cross_edge_features(h_ligand, x_ligand, h_pocket, x_pocket, atom_vectors_pocket, edges_cross) + + edge_index = { + ('ligand', '', 'ligand'): edges_ligand_updated, + ('pocket', '', 'pocket'): edges_pocket_updated, + ('ligand', '', 'pocket'): edges_cross, + ('pocket', '', 'ligand'): edges_cross.flip(dims=[0]), + } + + edge_attr = { + ('ligand', '', 'ligand'): feat_ligand, + ('pocket', '', 'pocket'): feat_pocket, + ('ligand', '', 'pocket'): feat_cross, + ('pocket', '', 'ligand'): feat_cross, + } + else: + edge_index = {('ligand', '', 'ligand'): edges_ligand_updated} + edge_attr = {('ligand', '', 'ligand'): feat_ligand} + + return edge_index, edge_attr + + def ligand_edge_features(self, h, x, edge_index, batch_mask=None, edge_attr=None): + """ + :param h: (s, V) + :param x: + :param edge_index: + :param batch_mask: + :param edge_attr: + :return: scalar and vector-valued edge features + """ + row, col = edge_index + coord_diff = x[row] - x[col] + dist = coord_diff.norm(dim=-1) + rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf, + device=x.device) + + if isinstance(h, tuple): + edge_s = torch.cat([h[0][row], h[0][col], rbf], dim=1) + edge_v = torch.cat([h[1][row], h[1][col], _normalize(coord_diff).unsqueeze(-2)], dim=1) + else: + edge_s = torch.cat([h[row], h[col], rbf], dim=1) + edge_v = _normalize(coord_diff).unsqueeze(-2) + + # edge_s = rbf + # edge_v = _normalize(coord_diff).unsqueeze(-2) + + if edge_attr is not None: + edge_s = torch.cat([edge_s, edge_attr], dim=1) + + # self.reflection_equiv: bool, use reflection-sensitive feature based on + # the cross product if False + if not self.reflection_equiv: + mean = scatter_mean(x, batch_mask, dim=0, + dim_size=batch_mask.max() + 1) + row, col = edge_index + cross = torch.cross(x[row] - mean[batch_mask[row]], + x[col] - mean[batch_mask[col]], dim=1) + cross = _normalize(cross).unsqueeze(-2) + + edge_v = torch.cat([edge_v, cross], dim=-2) + + return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v) + + def pocket_edge_features(self, h, x, v, edge_index, edge_attr=None): + """ + :param h: (s, V) + :param x: + :param v: + :param edge_index: + :param edge_attr: + :return: scalar and vector-valued edge features + """ + row, col = edge_index + + if self.add_all_atom_diff: + all_coord = v + x.unsqueeze(1) # (nR, nA, 3) + coord_diff = all_coord[row, :, None, :] - all_coord[col, None, :, :] # (nB, nA, nA, 3) + coord_diff = coord_diff.flatten(1, 2) + dist = coord_diff.norm(dim=-1) # (nB, nA^2) + rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf, device=x.device) # (nB, nA^2, rdb_dim) + rbf = rbf.flatten(1, 2) + coord_diff = _normalize(coord_diff) + else: + coord_diff = x[row] - x[col] + dist = coord_diff.norm(dim=-1) + rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf, device=x.device) + coord_diff = _normalize(coord_diff).unsqueeze(-2) + + edge_s = torch.cat([h[0][row], h[0][col], rbf], dim=1) + edge_v = torch.cat([h[1][row], h[1][col], coord_diff], dim=1) + # edge_s = rbf + # edge_v = coord_diff + + if edge_attr is not None: + edge_s = torch.cat([edge_s, edge_attr], dim=1) + + return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v) + + def cross_edge_features(self, h_ligand, x_ligand, h_pocket, x_pocket, v_pocket, edge_index): + """ + :param h_ligand: (s, V) + :param x_ligand: + :param h_pocket: (s, V) + :param x_pocket: + :param v_pocket: + :param edge_index: first row indexes into the ligand tensors, second row into the pocket tensors + + :return: scalar and vector-valued edge features + """ + ligand_idx, pocket_idx = edge_index + + if self.add_all_atom_diff: + all_coord_pocket = v_pocket + x_pocket.unsqueeze(1) # (nR, nA, 3) + coord_diff = x_ligand[ligand_idx, None, :] - all_coord_pocket[pocket_idx] # (nB, nA, 3) + dist = coord_diff.norm(dim=-1) # (nB, nA) + rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf, device=x_ligand.device) # (nB, nA, rdb_dim) + rbf = rbf.flatten(1, 2) + coord_diff = _normalize(coord_diff) + else: + coord_diff = x_ligand[ligand_idx] - x_pocket[pocket_idx] + dist = coord_diff.norm(dim=-1) # (nB, nA) + rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf, device=x_ligand.device) + coord_diff = _normalize(coord_diff).unsqueeze(-2) + + if isinstance(h_ligand, tuple): + edge_s = torch.cat([h_ligand[0][ligand_idx], h_pocket[0][pocket_idx], rbf], dim=1) + edge_v = torch.cat([h_ligand[1][ligand_idx], h_pocket[1][pocket_idx], coord_diff], dim=1) + else: + edge_s = torch.cat([h_ligand[ligand_idx], h_pocket[0][pocket_idx], rbf], dim=1) + edge_v = torch.cat([h_pocket[1][pocket_idx], coord_diff], dim=1) + + # edge_s = rbf + # edge_v = coord_diff + + return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v) diff --git a/src/model/flows.py b/src/model/flows.py new file mode 100644 index 0000000000000000000000000000000000000000..5ba293d07b8c8ecdb296c21fdbf08ecb06b59ec8 --- /dev/null +++ b/src/model/flows.py @@ -0,0 +1,448 @@ +from abc import ABC +from abc import abstractmethod +import math +import torch +from torch_scatter import scatter_mean, scatter_add + +import src.data.so3_utils as so3 + + +class ICFM(ABC): + """ + Abstract base class for all Independent-coupling CFM classes. + Defines a common interface. + Notation: + - zt is the intermediate representation at time step t \in [0, 1] + - zs is the noised representation at time step s < t + + # TODO: add interpolation schedule (not necessrily linear) + """ + def __init__(self, sigma): + self.sigma = sigma + + @abstractmethod + def sample_zt(self, z0, z1, t, *args, **kwargs): + """ TODO. """ + pass + + @abstractmethod + def sample_zt_given_zs(self, *args, **kwargs): + """ Perform update, typically using an explicit Euler step. """ + pass + + @abstractmethod + def sample_z0(self, *args, **kwargs): + """ Prior. """ + pass + + @abstractmethod + def compute_loss(self, pred, z0, z1, *args, **kwargs): + """ Compute loss per sample. """ + pass + + +class CoordICFM(ICFM): + def __init__(self, sigma): + self.dim = 3 + self.scale = 2.7 + super().__init__(sigma) + + def sample_zt(self, z0, z1, t, batch_mask): + zt = t[batch_mask] * z1 + (1 - t)[batch_mask] * z0 + # zt = self.sigma * z0 + t[batch_mask] * z1 + (1 - t)[batch_mask] * z0 # TODO: do we have to compute Psi? + return zt + + def sample_zt_given_zs(self, zs, pred, s, t, batch_mask): + """ Perform an explicit Euler step. """ + step_size = t - s + zt = zs + step_size[batch_mask] * self.scale * pred + return zt + + def sample_z0(self, com, batch_mask): + """ Prior. """ + z0 = torch.randn((len(batch_mask), self.dim), device=batch_mask.device) + + # Move center of mass + z0 = z0 + com[batch_mask] + + return z0 + + def reduce_loss(self, loss, batch_mask, reduce): + assert reduce in {'mean', 'sum', 'none'} + + if reduce == 'mean': + loss = scatter_mean(loss / self.dim, batch_mask, dim=0) + elif reduce == 'sum': + loss = scatter_add(loss, batch_mask, dim=0) + + return loss + + def compute_loss(self, pred, z0, z1, t, batch_mask, reduce='mean'): + """ Compute loss per sample. """ + + loss = torch.sum((pred - (z1 - z0) / self.scale) ** 2, dim=-1) + + return self.reduce_loss(loss, batch_mask, reduce) + + def get_z1_given_zt_and_pred(self, zt, pred, z0, t, batch_mask): + """ Make a best guess on the final state z1 given the current state and + the network prediction. """ + # z1 = z0 + pred + z1 = zt + (1 - t)[batch_mask] * pred + return z1 + + +class TorusICFM(ICFM): + """ + Following: + Chen, Ricky TQ, and Yaron Lipman. + "Riemannian flow matching on general geometries." + arXiv preprint arXiv:2302.03660 (2023). + """ + def __init__(self, sigma, dim, scheduler_args=None): + super().__init__(sigma) + self.dim = dim + + # Scheduler that determines the rate at which the geodesic distance decreases + scheduler_args = scheduler_args or {} + scheduler_args["type"] = scheduler_args.get("type", "linear") # default + scheduler_args["learn_scaled"] = scheduler_args.get("learn_scaled", False) # default + + # linear scheduler: kappa(t) = 1-t (default) + if scheduler_args["type"] == "linear": + # equivalent to: 1 - kappa(t) + self.flow_scaling = lambda t: t + + # equivalent to: -1 * d/dt kappa(t) + self.velocity_scaling = lambda t: torch.ones_like(t) + + # exponential scheduler: kappa(t) = exp(-c*t) + elif scheduler_args["type"] == "exponential": + + self.c = scheduler_args["c"] + assert self.c > 0 + + # equivalent to: 1 - kappa(t) + self.flow_scaling = lambda t: 1 - torch.exp(-self.c * t) + + # equivalent to: -1 * d/dt kappa(t) + self.velocity_scaling = lambda t: self.c * torch.exp(-self.c * t) + + # polynomial scheduler: kappa(t) = (1-t)^k + elif scheduler_args["type"] == "polynomial": + self.k = scheduler_args["k"] + assert self.k > 0 + + # equivalent to: 1 - kappa(t) + self.flow_scaling = lambda t: 1 - (1 - t)**self.k + + # equivalent to: -1 * d/dt kappa(t) + self.velocity_scaling = lambda t: self.k * (1 - t)**(self.k - 1) + + else: + raise NotImplementedError(f"Scheduler {scheduler_args['type']} not implemented.") + + kappa_interval = self.flow_scaling(torch.tensor([0.0, 1.0])) + if kappa_interval[0] != 0.0 or kappa_interval[1] != 1.0: + print(f"Scheduler should satisfy kappa(0)=1 and kappa(1)=0. Found " + f"interval {kappa_interval.tolist()} instead.") + + # determines whether the scaled vector field is learned or the scheduler + # is post-multiplied + self.learn_scaled = scheduler_args["learn_scaled"] + + @staticmethod + def wrap(angle): + """ Maps angles to range [-\pi, \pi). """ + return ((angle + math.pi) % (2 * math.pi)) - math.pi + + def exponential_map(self, x, u): + """ + :param x: point on the manifold + :param u: point on the tangent space + """ + return self.wrap(x + u) + + @staticmethod + def logarithm_map(x, y): + """ + :param x, y: points on the manifold + """ + return torch.atan2(torch.sin(y - x), torch.cos(y - x)) + + def sample_zt(self, z0, z1, t, batch_mask): + """ expressed in terms of exponential and logarithm maps """ + + # apply logarithm map + # zt_tangent = t[batch_mask] * self.logarithm_map(z0, z1) + zt_tangent = self.flow_scaling(t)[batch_mask] * self.logarithm_map(z0, z1) + + # apply exponential map + return self.exponential_map(z0, zt_tangent) + + def get_z1_given_zt_and_pred(self, zt, pred, z0, t, batch_mask): + """ Make a best guess on the final state z1 given the current state and + the network prediction. """ + + # estimate z1_tangent based on zt and pred only + if self.learn_scaled: + pred = pred / torch.clamp(self.velocity_scaling(t), min=1e-3)[batch_mask] + + z1_tangent = (1 - t)[batch_mask] * pred + + # exponential map + return self.exponential_map(zt, z1_tangent) + + def sample_zt_given_zs(self, zs, pred, s, t, batch_mask): + """ Perform update, typically using an explicit Euler step. """ + + step_size = t - s + zt_tangent = step_size[batch_mask] * pred + + if not self.learn_scaled: + zt_tangent = self.velocity_scaling(t)[batch_mask] * zt_tangent + + # exponential map + return self.exponential_map(zs, zt_tangent) + + def sample_z0(self, batch_mask): + """ Prior. """ + + # Uniform distribution + z0 = torch.rand((len(batch_mask), self.dim), device=batch_mask.device) + + return 2 * math.pi * z0 - math.pi + + def compute_loss(self, pred, z0, z1, zt, t, batch_mask, reduce='mean'): + """ Compute loss per sample. """ + assert reduce in {'mean', 'sum', 'none'} + mask = ~torch.isnan(z1) + z1 = torch.nan_to_num(z1, nan=0.0) + + zt_dot = self.logarithm_map(z0, z1) + if self.learn_scaled: + # NOTE: potentially requires output magnitude to vary substantially + zt_dot = self.velocity_scaling(t)[batch_mask] * zt_dot + loss = mask * (pred - zt_dot) ** 2 + loss = torch.sum(loss, dim=-1) + + if reduce == 'mean': + denom = mask.sum(dim=-1) + 1e-6 + loss = scatter_mean(loss / denom, batch_mask, dim=0) + elif reduce == 'sum': + loss = scatter_add(loss, batch_mask, dim=0) + return loss + + +class SO3ICFM(ICFM): + """ + All rotations are assumed to be in axis-angle format. + Mostly following descriptions from the FoldFlow paper: + https://openreview.net/forum?id=kJFIH23hXb + + See also: + https://geomstats.github.io/_modules/geomstats/geometry/special_orthogonal.html#SpecialOrthogonal + https://geomstats.github.io/_modules/geomstats/geometry/lie_group.html#LieGroup + """ + def __init__(self, sigma): + super().__init__(sigma) + + def exponential_map(self, base, tangent): + """ + Args: + base: base point (rotation vector) on the manifold + tangent: point in tangent space at identity + Returns: + rotation vector on the manifold + """ + # return so3.exp_not_from_identity(tangent, base_point=base) + return so3.compose_rotations(base, so3.exp(tangent)) + + def logarithm_map(self, base, r): + """ + Args: + base: base point (rotation vector) on the manifold + r: rotation vector on the manifold + Return: + point in tangent space at identity + """ + # return so3.log_not_from_identity(r, base_point=base) + return so3.log(so3.compose_rotations(-base, r)) + + def sample_zt(self, z0, z1, t, batch_mask): + """ + Expressed in terms of exponential and logarithm maps. + Corresponds to SLERP interpolation: R(t) = R1 exp( t * log(R1^T R2) ) + (see https://lucaballan.altervista.org/pdfs/IK.pdf, slide 16) + """ + + # apply logarithm map + zt_tangent = t[batch_mask] * self.logarithm_map(z0, z1) + + # apply exponential map + return self.exponential_map(z0, zt_tangent) + + def get_z1_given_zt_and_pred(self, zt, pred, z0, t, batch_mask): + """ Make a best guess on the final state z1 given the current state and + the network prediction. """ + + # estimate z1_tangent based on zt and pred only + z1_tangent = (1 - t)[batch_mask] * pred + + # exponential map + return self.exponential_map(zt, z1_tangent) + + def sample_zt_given_zs(self, zs, pred, s, t, batch_mask): + """ Perform update, typically using an explicit Euler step. """ + + # # parallel transport vector field to lie algebra so3 (at identity) + # # (FoldFlow paper, Algorithm 3, line 8) + # # TODO: is this correct? is it necessary? + # pred = so3.compose(so3.inverse(zs), pred) + + step_size = t - s + zt_tangent = step_size[batch_mask] * pred + + # exponential map + return self.exponential_map(zs, zt_tangent) + + def sample_z0(self, batch_mask): + """ Prior. """ + return so3.random_uniform(n_samples=len(batch_mask), device=batch_mask.device) + + @staticmethod + def d_R_squared_SO3(rot_vec_1, rot_vec_2): + """ + Squared Riemannian metric on SO(3). + Defined as d(R1, R2) = sqrt(0.5) ||log(R1^T R2)||_F + where R1, R2 are rotation matrices. + + The following is equivalent if the difference between the rotations is + expressed as a rotation vector \omega_diff: + d(r1, r2) = ||\omega_diff||_2 + ----- + With the definition of the Frobenius matrix norm ||A||_F^2 = trace(A^H A): + d^2(R1, R2) = 1/2 ||log(R1^T R2)||_F^2 + = 1/2 || hat(R_d) ||_F^2 + = 1/2 tr( hat(R_d)^T hat(R_d) ) + = 1/2 * 2 * ||\omega||_2^2 + """ + + # rot_mat_1 = so3.matrix_from_rotation_vector(rot_vec_1) + # rot_mat_2 = so3.matrix_from_rotation_vector(rot_vec_2) + # rot_mat_diff = rot_mat_1.transpose(-2, -1) @ rot_mat_2 + # return torch.norm(so3.log(rot_mat_diff, as_skew=True), p='fro', dim=(-2, -1)) + + diff_rot = so3.compose_rotations(-rot_vec_1, rot_vec_2) + return diff_rot.square().sum(dim=-1) + + def compute_loss(self, pred, z0, z1, zt, t, batch_mask, reduce='mean', eps=5e-2): + """ Compute loss per sample. """ + assert reduce in {'mean', 'sum', 'none'} + + zt_dot = self.logarithm_map(zt, z1) / torch.clamp(1 - t, min=eps)[batch_mask] + + # TODO: do I need this? + # pred_at_id = self.logarithm_map(zt, pred) / torch.clamp(1 - t, min=eps)[batch_mask] + + loss = torch.sum((pred - zt_dot)**2, dim=-1) # TODO: is this the right loss in SO3? + # loss = self.d_R_squared_SO3(zt_dot, pred) + + if reduce == 'mean': + loss = scatter_mean(loss, batch_mask, dim=0) + elif reduce == 'sum': + loss = scatter_add(loss, batch_mask, dim=0) + + return loss + + +################# +# Predicting z1 # +################# + +class CoordICFMPredictFinal(CoordICFM): + def __init__(self, sigma): + self.dim = 3 + super().__init__(sigma) + + def sample_zt_given_zs(self, zs, z1_minus_zs_pred, s, t, batch_mask): + """ Perform an explicit Euler step. """ + + # step_size = t - s + # zt = zs + step_size[batch_mask] * z1_minus_zs_pred / (1.0 - s)[batch_mask] + + # for numerical stability + step_size = (t - s) / (1.0 - s) + assert torch.all(step_size <= 1.0) + # step_size = torch.clamp(step_size, max=1.0) + zt = zs + step_size[batch_mask] * z1_minus_zs_pred + return zt + + def compute_loss(self, z1_minus_zt_pred, z0, z1, t, batch_mask, reduce='mean'): + """ Compute loss per sample. """ + assert reduce in {'mean', 'sum', 'none'} + t = torch.clamp(t, max=0.9) + zt = self.sample_zt(z0, z1, t, batch_mask) + loss = torch.sum((z1_minus_zt_pred + zt - z1) ** 2, dim=-1) / torch.square(1 - t)[batch_mask].squeeze() + + if reduce == 'mean': + loss = scatter_mean(loss / self.dim, batch_mask, dim=0) + elif reduce == 'sum': + loss = scatter_add(loss, batch_mask, dim=0) + + return loss + + def get_z1_given_zt_and_pred(self, zt, z1_minus_zt_pred, z0, t, batch_mask): + return z1_minus_zt_pred + zt + + +class TorusICFMPredictFinal(TorusICFM): + """ + Following: + Chen, Ricky TQ, and Yaron Lipman. + "Riemannian flow matching on general geometries." + arXiv preprint arXiv:2302.03660 (2023). + """ + def __init__(self, sigma, dim): + super().__init__(sigma, dim) + + def get_z1_given_zt_and_pred(self, zt, z1_tangent_pred, z0, t, batch_mask): + """ Make a best guess on the final state z1 given the current state and + the network prediction. """ + + # exponential map + return self.exponential_map(zt, z1_tangent_pred) + + def sample_zt_given_zs(self, zs, z1_tangent_pred, s, t, batch_mask): + """ Perform update, typically using an explicit Euler step. """ + + # step_size = t - s + # zt_tangent = step_size[batch_mask] * z1_tangent_pred / (1.0 - s)[batch_mask] + + # for numerical stability + step_size = (t - s) / (1.0 - s) + assert torch.all(step_size <= 1.0) + # step_size = torch.clamp(step_size, max=1.0) + zt_tangent = step_size[batch_mask] * z1_tangent_pred + + # exponential map + return self.exponential_map(zs, zt_tangent) + + def compute_loss(self, z1_tangent_pred, z0, z1, t, batch_mask, reduce='mean'): + """ Compute loss per sample. """ + assert reduce in {'mean', 'sum', 'none'} + zt = self.sample_zt(z0, z1, t, batch_mask) + t = torch.clamp(t, max=0.9) + + mask = ~torch.isnan(z1) + z1 = torch.nan_to_num(z1, nan=0.0) + loss = mask * (z1_tangent_pred - self.logarithm_map(zt, z1)) ** 2 + loss = torch.sum(loss, dim=-1) / torch.square(1 - t)[batch_mask].squeeze() + + if reduce == 'mean': + denom = mask.sum(dim=-1) + 1e-6 + loss = scatter_mean(loss / denom, batch_mask, dim=0) + elif reduce == 'sum': + loss = scatter_add(loss, batch_mask, dim=0) + + return loss diff --git a/src/model/gvp.py b/src/model/gvp.py new file mode 100644 index 0000000000000000000000000000000000000000..a174fc01f60ce9c022ecc0d01be4543dc0c3b24d --- /dev/null +++ b/src/model/gvp.py @@ -0,0 +1,650 @@ +""" +Geometric Vector Perceptron implementation taken from: +https://github.com/drorlab/gvp-pytorch/blob/main/gvp/__init__.py +""" +import copy +import warnings + +import torch, functools +from torch import nn +import torch.nn.functional as F +from torch_geometric.nn import MessagePassing +from torch_scatter import scatter_add, scatter_mean + + +def tuple_sum(*args): + ''' + Sums any number of tuples (s, V) elementwise. + ''' + return tuple(map(sum, zip(*args))) + + +def tuple_cat(*args, dim=-1): + ''' + Concatenates any number of tuples (s, V) elementwise. + + :param dim: dimension along which to concatenate when viewed + as the `dim` index for the scalar-channel tensors. + This means that `dim=-1` will be applied as + `dim=-2` for the vector-channel tensors. + ''' + dim %= len(args[0][0].shape) + s_args, v_args = list(zip(*args)) + return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim) + + +def tuple_index(x, idx): + ''' + Indexes into a tuple (s, V) along the first dimension. + + :param idx: any object which can be used to index into a `torch.Tensor` + ''' + return x[0][idx], x[1][idx] + + +def randn(n, dims, device="cpu"): + ''' + Returns random tuples (s, V) drawn elementwise from a normal distribution. + + :param n: number of data points + :param dims: tuple of dimensions (n_scalar, n_vector) + + :return: (s, V) with s.shape = (n, n_scalar) and + V.shape = (n, n_vector, 3) + ''' + return torch.randn(n, dims[0], device=device), \ + torch.randn(n, dims[1], 3, device=device) + + +def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True): + ''' + L2 norm of tensor clamped above a minimum value `eps`. + + :param sqrt: if `False`, returns the square of the L2 norm + ''' + out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps) + return torch.sqrt(out) if sqrt else out + + +def _split(x, nv): + ''' + Splits a merged representation of (s, V) back into a tuple. + Should be used only with `_merge(s, V)` and only if the tuple + representation cannot be used. + + :param x: the `torch.Tensor` returned from `_merge` + :param nv: the number of vector channels in the input to `_merge` + ''' + v = torch.reshape(x[..., -3 * nv:], x.shape[:-1] + (nv, 3)) + s = x[..., :-3 * nv] + return s, v + + +def _merge(s, v): + ''' + Merges a tuple (s, V) into a single `torch.Tensor`, where the + vector channels are flattened and appended to the scalar channels. + Should be used only if the tuple representation cannot be used. + Use `_split(x, nv)` to reverse. + ''' + v = torch.reshape(v, v.shape[:-2] + (3 * v.shape[-2],)) + return torch.cat([s, v], -1) + + +class GVP(nn.Module): + ''' + Geometric Vector Perceptron. See manuscript and README.md + for more details. + + :param in_dims: tuple (n_scalar, n_vector) + :param out_dims: tuple (n_scalar, n_vector) + :param h_dim: intermediate number of vector channels, optional + :param activations: tuple of functions (scalar_act, vector_act) + :param vector_gate: whether to use vector gating. + (vector_act will be used as sigma^+ in vector gating if `True`) + ''' + + def __init__(self, in_dims, out_dims, h_dim=None, + activations=(F.relu, torch.sigmoid), vector_gate=False): + super(GVP, self).__init__() + self.si, self.vi = in_dims + self.so, self.vo = out_dims + self.vector_gate = vector_gate + if self.vi: + self.h_dim = h_dim or max(self.vi, self.vo) + self.wh = nn.Linear(self.vi, self.h_dim, bias=False) + self.ws = nn.Linear(self.h_dim + self.si, self.so) + if self.vo: + self.wv = nn.Linear(self.h_dim, self.vo, bias=False) + if self.vector_gate: self.wsv = nn.Linear(self.so, self.vo) + else: + self.ws = nn.Linear(self.si, self.so) + + self.scalar_act, self.vector_act = activations + self.dummy_param = nn.Parameter(torch.empty(0)) + + def forward(self, x): + ''' + :param x: tuple (s, V) of `torch.Tensor`, + or (if vectors_in is 0), a single `torch.Tensor` + :return: tuple (s, V) of `torch.Tensor`, + or (if vectors_out is 0), a single `torch.Tensor` + ''' + if self.vi: + s, v = x + v = torch.transpose(v, -1, -2) + vh = self.wh(v) + vn = _norm_no_nan(vh, axis=-2) + s = self.ws(torch.cat([s, vn], -1)) + if self.vo: + v = self.wv(vh) + v = torch.transpose(v, -1, -2) + if self.vector_gate: + if self.vector_act: + gate = self.wsv(self.vector_act(s)) + else: + gate = self.wsv(s) + v = v * torch.sigmoid(gate).unsqueeze(-1) + elif self.vector_act: + v = v * self.vector_act( + _norm_no_nan(v, axis=-1, keepdims=True)) + else: + s = self.ws(x) + if self.vo: + v = torch.zeros(s.shape[0], self.vo, 3, + device=self.dummy_param.device) + if self.scalar_act: + s = self.scalar_act(s) + + return (s, v) if self.vo else s + + +class _VDropout(nn.Module): + ''' + Vector channel dropout where the elements of each + vector channel are dropped together. + ''' + + def __init__(self, drop_rate): + super(_VDropout, self).__init__() + self.drop_rate = drop_rate + self.dummy_param = nn.Parameter(torch.empty(0)) + + def forward(self, x): + ''' + :param x: `torch.Tensor` corresponding to vector channels + ''' + device = self.dummy_param.device + if not self.training: + return x + mask = torch.bernoulli( + (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device) + ).unsqueeze(-1) + x = mask * x / (1 - self.drop_rate) + return x + + +class Dropout(nn.Module): + ''' + Combined dropout for tuples (s, V). + Takes tuples (s, V) as input and as output. + ''' + + def __init__(self, drop_rate): + super(Dropout, self).__init__() + self.sdropout = nn.Dropout(drop_rate) + self.vdropout = _VDropout(drop_rate) + + def forward(self, x): + ''' + :param x: tuple (s, V) of `torch.Tensor`, + or single `torch.Tensor` + (will be assumed to be scalar channels) + ''' + if type(x) is torch.Tensor: + return self.sdropout(x) + s, v = x + return self.sdropout(s), self.vdropout(v) + + +class LayerNorm(nn.Module): + ''' + Combined LayerNorm for tuples (s, V). + Takes tuples (s, V) as input and as output. + ''' + + def __init__(self, dims, learnable_vector_weight=False): + super(LayerNorm, self).__init__() + self.s, self.v = dims + self.scalar_norm = nn.LayerNorm(self.s) + self.vector_norm = VectorLayerNorm(self.v, learnable_vector_weight) \ + if self.v > 0 else None + + def forward(self, x): + ''' + :param x: tuple (s, V) of `torch.Tensor`, + or single `torch.Tensor` + (will be assumed to be scalar channels) + ''' + if not self.v: + return self.scalar_norm(x) + s, v = x + # vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False) + # vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True)) + # return self.scalar_norm(s), v / vn + return self.scalar_norm(s), self.vector_norm(v) + + +class VectorLayerNorm(nn.Module): + """ + Equivariant normalization of vector-valued features inspired by: + Liao, Yi-Lun, and Tess Smidt. + "Equiformer: Equivariant graph attention transformer for 3d atomistic graphs." + arXiv preprint arXiv:2206.11990 (2022). + Section 4.1, "Layer Normalization" + """ + def __init__(self, n_channels, learnable_weight=True): + super(VectorLayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.ones(1, n_channels, 1)) \ + if learnable_weight else None # (1, c, 1) + + def forward(self, x): + """ + Computes LN(x) = ( x / RMS( L2-norm(x) ) ) * gamma + :param x: input tensor (n, c, 3) + :return: layer normalized vector feature + """ + norm2 = _norm_no_nan(x, axis=-1, keepdims=True, sqrt=False) # (n, c, 1) + rms = torch.sqrt(torch.mean(norm2, dim=-2, keepdim=True)) # (n, 1, 1) + x = x / rms # (n, c, 3) + if self.gamma is not None: + x = x * self.gamma + return x + + +class GVPConv(MessagePassing): + ''' + Graph convolution / message passing with Geometric Vector Perceptrons. + Takes in a graph with node and edge embeddings, + and returns new node embeddings. + + This does NOT do residual updates and pointwise feedforward layers + ---see `GVPConvLayer`. + + :param in_dims: input node embedding dimensions (n_scalar, n_vector) + :param out_dims: output node embedding dimensions (n_scalar, n_vector) + :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) + :param n_layers: number of GVPs in the message function + :param module_list: preconstructed message function, overrides n_layers + :param aggr: should be "add" if some incoming edges are masked, as in + a masked autoregressive decoder architecture, otherwise "mean" + :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs + :param vector_gate: whether to use vector gating. + (vector_act will be used as sigma^+ in vector gating if `True`) + :param update_edge_attr: whether to compute an updated edge representation + ''' + + def __init__(self, in_dims, out_dims, edge_dims, + n_layers=3, module_list=None, aggr="mean", + activations=(F.relu, torch.sigmoid), vector_gate=False, + update_edge_attr=False): + super(GVPConv, self).__init__(aggr=aggr) + self.si, self.vi = in_dims + self.so, self.vo = out_dims + self.se, self.ve = edge_dims + self.update_edge_attr = update_edge_attr + + GVP_ = functools.partial(GVP, + activations=activations, + vector_gate=vector_gate) + + module_list = module_list or [] + if not module_list: + if n_layers == 1: + module_list.append( + GVP_((2 * self.si + self.se, 2 * self.vi + self.ve), + (self.so, self.vo), activations=(None, None))) + else: + module_list.append( + GVP_((2 * self.si + self.se, 2 * self.vi + self.ve), + out_dims) + ) + for i in range(n_layers - 2): + module_list.append(GVP_(out_dims, out_dims)) + module_list.append(GVP_(out_dims, out_dims, + activations=(None, None))) + self.message_func = nn.Sequential(*module_list) + + self.edge_func = copy.deepcopy(self.message_func) \ + if self.update_edge_attr else None + + def forward(self, x, edge_index, edge_attr): + ''' + :param x: tuple (s, V) of `torch.Tensor` + :param edge_index: array of shape [2, n_edges] + :param edge_attr: tuple (s, V) of `torch.Tensor` + ''' + x_s, x_v = x + message = self.propagate(edge_index, + s=x_s, + v=x_v.reshape(x_v.shape[0], 3 * x_v.shape[1]), + edge_attr=edge_attr) + + if self.update_edge_attr: + s_i, s_j = x_s[edge_index[0]], x_s[edge_index[1]] + x_v = x_v.reshape(x_v.shape[0], 3 * x_v.shape[1]) + v_i, v_j = x_v[edge_index[0]], x_v[edge_index[1]] + + edge_out = self.edge_attr(s_i, v_i, s_j, v_j, edge_attr) + return _split(message, self.vo), edge_out + else: + return _split(message, self.vo) + + def message(self, s_i, v_i, s_j, v_j, edge_attr): + v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3) + v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3) + message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i)) + message = self.message_func(message) + return _merge(*message) + + def edge_attr(self, s_i, v_i, s_j, v_j, edge_attr): + v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3) + v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3) + message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i)) + return self.edge_func(message) + + +class GVPConvLayer(nn.Module): + ''' + Full graph convolution / message passing layer with + Geometric Vector Perceptrons. Residually updates node embeddings with + aggregated incoming messages, applies a pointwise feedforward + network to node embeddings, and returns updated node embeddings. + + To only compute the aggregated messages, see `GVPConv`. + + :param node_dims: node embedding dimensions (n_scalar, n_vector) + :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) + :param n_message: number of GVPs to use in message function + :param n_feedforward: number of GVPs to use in feedforward function + :param drop_rate: drop probability in all dropout layers + :param autoregressive: if `True`, this `GVPConvLayer` will be used + with a different set of input node embeddings for messages + where src >= dst + :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs + :param vector_gate: whether to use vector gating. + (vector_act will be used as sigma^+ in vector gating if `True`) + :param update_edge_attr: whether to compute an updated edge representation + :param ln_vector_weight: whether to include a learnable weight in the vector + layer norm + ''' + + def __init__(self, node_dims, edge_dims, + n_message=3, n_feedforward=2, drop_rate=.1, + autoregressive=False, + activations=(F.relu, torch.sigmoid), vector_gate=False, + update_edge_attr=False, ln_vector_weight=False): + + super(GVPConvLayer, self).__init__() + assert not (update_edge_attr and autoregressive), "Not implemented" + self.update_edge_attr = update_edge_attr + self.conv = GVPConv(node_dims, node_dims, edge_dims, n_message, + aggr="add" if autoregressive else "mean", + activations=activations, vector_gate=vector_gate, + update_edge_attr=update_edge_attr) + GVP_ = functools.partial(GVP, + activations=activations, + vector_gate=vector_gate) + self.norm = nn.ModuleList([LayerNorm(node_dims, ln_vector_weight) + for _ in range(2)]) + self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)]) + + def get_feedforward(n_dims): + ff_func = [] + if n_feedforward == 1: + ff_func.append(GVP_(n_dims, n_dims, activations=(None, None))) + else: + hid_dims = 4 * n_dims[0], 2 * n_dims[1] + ff_func.append(GVP_(n_dims, hid_dims)) + for i in range(n_feedforward - 2): + ff_func.append(GVP_(hid_dims, hid_dims)) + ff_func.append(GVP_(hid_dims, n_dims, activations=(None, None))) + return nn.Sequential(*ff_func) + + self.ff_func = get_feedforward(node_dims) + + if self.update_edge_attr: + self.edge_norm = nn.ModuleList([LayerNorm(edge_dims, ln_vector_weight) + for _ in range(2)]) + self.edge_dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)]) + self.edge_ff = get_feedforward(edge_dims) + + def forward(self, x, edge_index, edge_attr, + autoregressive_x=None, node_mask=None): + ''' + :param x: tuple (s, V) of `torch.Tensor` + :param edge_index: array of shape [2, n_edges] + :param edge_attr: tuple (s, V) of `torch.Tensor` + :param autoregressive_x: tuple (s, V) of `torch.Tensor`. + If not `None`, will be used as src node embeddings + for forming messages where src >= dst. The corrent node + embeddings `x` will still be the base of the update and the + pointwise feedforward. + :param node_mask: array of type `bool` to index into the first + dim of node embeddings (s, V). If not `None`, only + these nodes will be updated. + ''' + + if autoregressive_x is not None: + src, dst = edge_index + mask = src < dst + edge_index_forward = edge_index[:, mask] + edge_index_backward = edge_index[:, ~mask] + edge_attr_forward = tuple_index(edge_attr, mask) + edge_attr_backward = tuple_index(edge_attr, ~mask) + + dh = tuple_sum( + self.conv(x, edge_index_forward, edge_attr_forward), + self.conv(autoregressive_x, edge_index_backward, + edge_attr_backward) + ) + + count = scatter_add(torch.ones_like(dst), dst, + dim_size=dh[0].size(0)).clamp(min=1).unsqueeze( + -1) + + dh = dh[0] / count, dh[1] / count.unsqueeze(-1) + + else: + dh = self.conv(x, edge_index, edge_attr) + + if self.update_edge_attr: + dh, de = dh + edge_attr = self.edge_norm[0](tuple_sum(edge_attr, self.dropout[0](de))) + de = self.edge_ff(edge_attr) + edge_attr = self.edge_norm[1](tuple_sum(edge_attr, self.dropout[1](de))) + + if node_mask is not None: + x_ = x + x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask) + + x = self.norm[0](tuple_sum(x, self.dropout[0](dh))) + + dh = self.ff_func(x) + x = self.norm[1](tuple_sum(x, self.dropout[1](dh))) + + if node_mask is not None: + x_[0][node_mask], x_[1][node_mask] = x[0], x[1] + x = x_ + return (x, edge_attr) if self.update_edge_attr else x + + +################################################################################ +def _normalize(tensor, dim=-1, eps=1e-8): + ''' + Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. + ''' + return torch.nan_to_num( + torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True) + eps)) + + +def _rbf(D, D_min=0., D_max=20., D_count=16, device='cpu'): + ''' + From https://github.com/jingraham/neurips19-graph-protein-design + + Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1. + That is, if `D` has shape [...dims], then the returned tensor will have + shape [...dims, D_count]. + ''' + D_mu = torch.linspace(D_min, D_max, D_count, device=device) + D_mu = D_mu.view([1, -1]) + D_sigma = (D_max - D_min) / D_count + D_expand = torch.unsqueeze(D, -1) + + RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2) + return RBF + + +class GVPModel(torch.nn.Module): + """ + GVP-GNN model + inspired by: https://github.com/drorlab/gvp-pytorch/blob/main/gvp/models.py + and: https://github.com/drorlab/gvp-pytorch/blob/82af6b22eaf8311c15733117b0071408d24ed877/gvp/atom3d.py#L115 + + :param node_in_dim: node dimension in input graph, scalars or tuple (scalars, vectors) + :param node_h_dim: node dimensions to use in GVP-GNN layers, tuple (s, V) + :param node_out_nf: node dimensions in output graph, tuple (s, V) + :param edge_in_nf: edge dimension in input graph (scalars) + :param edge_h_dim: edge dimensions to embed to before use in GVP-GNN layers, + tuple (s, V) + :param edge_out_nf: edge dimensions in output graph, tuple (s, V) + :param num_layers: number of GVP-GNN layers + :param drop_rate: rate to use in all dropout layers + :param vector_gate: use vector gates in all GVPs + :param reflection_equiv: bool, use reflection-sensitive feature based on the + cross product if False + :param d_max: + :param num_rbf: + :param update_edge_attr: bool, update edge attributes at each layer in a + learnable way + """ + def __init__(self, node_in_dim, node_h_dim, node_out_nf, + edge_in_nf, edge_h_dim, edge_out_nf, + num_layers=3, drop_rate=0.1, vector_gate=False, + reflection_equiv=True, d_max=20.0, num_rbf=16, + update_edge_attr=False): + + super(GVPModel, self).__init__() + + self.reflection_equiv = reflection_equiv + self.update_edge_attr = update_edge_attr + self.d_max = d_max + self.num_rbf = num_rbf + + # node_in_dim = (node_in_dim, 1) + if not isinstance(node_in_dim, tuple): + node_in_dim = (node_in_dim, 0) + + edge_in_dim = (edge_in_nf + 2 * node_in_dim[0] + self.num_rbf, 1) + if not self.reflection_equiv: + edge_in_dim = (edge_in_dim[0], edge_in_dim[1] + 1) + + # self.W_v = nn.Sequential( + # GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=True), + # LayerNorm(node_h_dim) + # ) + self.W_v = nn.Sequential( + LayerNorm(node_in_dim, learnable_vector_weight=True), + GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=vector_gate), + ) + # self.W_e = nn.Sequential( + # GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=True), + # LayerNorm(edge_h_dim) + # ) + self.W_e = nn.Sequential( + LayerNorm(edge_in_dim, learnable_vector_weight=True), + GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=vector_gate), + ) + + self.layers = nn.ModuleList( + GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate, + update_edge_attr=self.update_edge_attr, + activations=(F.relu, None), vector_gate=vector_gate, + ln_vector_weight=True) + # activations=(F.relu, torch.sigmoid)) + # GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate, + # update_edge_attr=self.update_edge_attr, + # activations=(nn.SiLU(), nn.SiLU())) + for _ in range(num_layers)) + + # self.W_v_out = GVP(node_h_dim, (node_out_nf, 1), + # activations=(None, None), vector_gate=True) + self.W_v_out = nn.Sequential( + LayerNorm(node_h_dim, learnable_vector_weight=True), + GVP(node_h_dim, (node_out_nf, 1), activations=(None, None), vector_gate=vector_gate), + ) + # self.W_e_out = GVP(edge_h_dim, (edge_out_nf, 0), + # activations=(None, None), vector_gate=True) \ + # if self.update_edge_attr else None + self.W_e_out = nn.Sequential( + LayerNorm(edge_h_dim, learnable_vector_weight=True), + GVP(edge_h_dim, (edge_out_nf, 0), activations=(None, None), vector_gate=vector_gate) + ) if self.update_edge_attr else None + + def edge_features(self, h, x, edge_index, batch_mask=None, edge_attr=None): + """ + :param h: + :param x: + :param edge_index: + :param batch_mask: + :param edge_attr: + :return: scalar and vector-valued edge features + """ + row, col = edge_index + coord_diff = x[row] - x[col] + dist = coord_diff.norm(dim=-1) + rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf, + device=x.device) + + edge_s = torch.cat([h[row], h[col], rbf], dim=1) + edge_v = _normalize(coord_diff).unsqueeze(-2) + + if edge_attr is not None: + edge_s = torch.cat([edge_s, edge_attr], dim=1) + + if not self.reflection_equiv: + mean = scatter_mean(x, batch_mask, dim=0, + dim_size=batch_mask.max() + 1) + row, col = edge_index + cross = torch.cross(x[row] - mean[batch_mask[row]], + x[col] - mean[batch_mask[col]], dim=1) + cross = _normalize(cross).unsqueeze(-2) + + edge_v = torch.cat([edge_v, cross], dim=-2) + + return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v) + + def forward(self, h, x, edge_index, v=None, batch_mask=None, edge_attr=None): + + # h_v = (h, x.unsqueeze(-2)) + h_v = h if v is None else (h, v) + h_e = self.edge_features(h, x, edge_index, batch_mask, edge_attr) + + h_v = self.W_v(h_v) + h_e = self.W_e(h_e) + + for layer in self.layers: + h_v = layer(h_v, edge_index, edge_attr=h_e) + if self.update_edge_attr: + h_v, h_e = h_v + + # h, x = self.W_v_out(h_v) + # x = x.squeeze(-2) + h, vel = self.W_v_out(h_v) + # x = x + vel.squeeze(-2) + + if self.update_edge_attr: + edge_attr = self.W_e_out(h_e) + + # return h, x, edge_attr + return h, vel.squeeze(-2), edge_attr diff --git a/src/model/gvp_transformer.py b/src/model/gvp_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f36bdc9b3e31a44ac8ad045f443760ac26ddc08e --- /dev/null +++ b/src/model/gvp_transformer.py @@ -0,0 +1,471 @@ +import math +import functools +import torch +from torch import nn +import torch.nn.functional as F +from torch_scatter import scatter_mean, scatter_std, scatter_min, scatter_max, scatter_softmax + + +# ## debug +# import sys +# from pathlib import Path +# +# basedir = Path(__file__).resolve().parent.parent.parent +# sys.path.append(str(basedir)) +# ### + +from src.model.gvp import GVP, _norm_no_nan, tuple_sum, Dropout, LayerNorm, \ + tuple_cat, tuple_index, _rbf, _normalize + + +def tuple_mul(tup, val): + if isinstance(val, torch.Tensor): + return (tup[0] * val, tup[1] * val.unsqueeze(-1)) + return (tup[0] * val, tup[1] * val) + + +class GVPBlock(nn.Module): + def __init__(self, in_dims, out_dims, n_layers=1, + activations=(F.relu, torch.sigmoid), vector_gate=False, + dropout=0.0, skip=False, layernorm=False): + super(GVPBlock, self).__init__() + self.si, self.vi = in_dims + self.so, self.vo = out_dims + assert not skip or (self.si == self.so and self.vi == self.vo) + self.skip = skip + + GVP_ = functools.partial(GVP, activations=activations, vector_gate=vector_gate) + + module_list = [] + if n_layers == 1: + module_list.append(GVP_(in_dims, out_dims, activations=(None, None))) + else: + module_list.append(GVP_(in_dims, out_dims)) + for i in range(n_layers - 2): + module_list.append(GVP_(out_dims, out_dims)) + module_list.append(GVP_(out_dims, out_dims, activations=(None, None))) + + self.layers = nn.Sequential(*module_list) + + self.norm = LayerNorm(out_dims, learnable_vector_weight=True) if layernorm else None + self.dropout = Dropout(dropout) if dropout > 0 else None + + def forward(self, x): + """ + :param x: tuple (s, V) of `torch.Tensor` + :return: tuple (s, V) of `torch.Tensor` + """ + + dx = self.layers(x) + + if self.dropout is not None: + dx = self.dropout(dx) + + if self.skip: + x = tuple_sum(x, dx) + else: + x = dx + + if self.norm is not None: + x = self.norm(x) + + return x + + +class GeometricPNA(nn.Module): + def __init__(self, d_in, d_out): + """ Map features to global features """ + super().__init__() + si, vi = d_in + so, vo = d_out + self.gvp = GVPBlock((4 * si + 3 * vi, vi), d_out) + + def forward(self, x, batch_mask, batch_size=None): + """ x: tuple (s, V) """ + s, v = x + + sm = scatter_mean(s, batch_mask, dim=0, dim_size=batch_size) + smi = scatter_min(s, batch_mask, dim=0, dim_size=batch_size)[0] + sma = scatter_max(s, batch_mask, dim=0, dim_size=batch_size)[0] + sstd = scatter_std(s, batch_mask, dim=0, dim_size=batch_size) + + vnorm = _norm_no_nan(v) + vm = scatter_mean(v, batch_mask, dim=0, dim_size=batch_size) + vmi = scatter_min(vnorm, batch_mask, dim=0, dim_size=batch_size)[0] + vma = scatter_max(vnorm, batch_mask, dim=0, dim_size=batch_size)[0] + vstd = scatter_std(vnorm, batch_mask, dim=0, dim_size=batch_size) + + z = torch.hstack((sm, smi, sma, sstd, vmi, vma, vstd)) + out = self.gvp((z, vm)) + return out + + +class TupleLinear(nn.Module): + def __init__(self, in_dims, out_dims, bias=True): + super().__init__() + self.si, self.vi = in_dims + self.so, self.vo = out_dims + assert self.si and self.so + self.ws = nn.Linear(self.si, self.so, bias=bias) + self.wv = nn.Linear(self.vi, self.vo, bias=bias) if self.vi and self.vo else None + + def forward(self, x): + if self.vi: + s, v = x + + s = self.ws(s) + + if self.vo: + v = v.transpose(-1, -2) + v = self.wv(v) + v = v.transpose(-1, -2) + + else: + s = self.ws(x) + + if self.vo: + v = torch.zeros(s.size(0), self.vo, 3, device=s.device) + + return (s, v) if self.vo else s + + +class GVPTransformerLayer(nn.Module): + """ + Full graph transformer layer with Geometric Vector Perceptrons. + Inspired by + - GVP: Jing, Bowen, et al. "Learning from protein structure with geometric vector perceptrons." arXiv preprint arXiv:2009.01411 (2020). + - Transformer architecture: Vignac, Clement, et al. "Digress: Discrete denoising diffusion for graph generation." arXiv preprint arXiv:2209.14734 (2022). + - Invariant point attention: Jumper, John, et al. "Highly accurate protein structure prediction with AlphaFold." Nature 596.7873 (2021): 583-589. + + :param node_dims: node embedding dimensions (n_scalar, n_vector) + :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) + :param global_dims: global feature dimension (n_scalar, n_vector) + :param dk: key dimension, (n_scalar, n_vector) + :param dv: node value dimension, (n_scalar, n_vector) + :param de: edge value dimension, (n_scalar, n_vector) + :param db: dimension of edge contribution to attention, int + :param attn_heads: number of attention heads, int + :param n_feedforward: number of GVPs to use in feedforward function + :param drop_rate: drop probability in all dropout layers + :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs + :param vector_gate: whether to use vector gating. + (vector_act will be used as sigma^+ in vector gating if `True`) + :param attention: can be used to turn off the attention mechanism + """ + + def __init__(self, node_dims, edge_dims, global_dims, dk, dv, de, db, + attn_heads, n_feedforward=1, drop_rate=0.0, + activations=(F.relu, torch.sigmoid), vector_gate=False, + attention=True): + + super(GVPTransformerLayer, self).__init__() + + self.attention = attention + + dq = dk + self.dq = dq + self.dk = dk + self.dv = dv + self.de = de + self.db = db + + self.h = attn_heads + + self.q = TupleLinear(node_dims, tuple_mul(dq, self.h), bias=False) if self.attention else None + self.k = TupleLinear(node_dims, tuple_mul(dk, self.h), bias=False) if self.attention else None + self.vx = TupleLinear(node_dims, tuple_mul(dv, self.h), bias=False) + + self.ve = TupleLinear(edge_dims, tuple_mul(de, self.h), bias=False) + self.b = TupleLinear(edge_dims, (db * self.h, 0), bias=False) if self.attention else None + + m_dim = tuple_sum(tuple_mul(dv, self.h), tuple_mul(de, self.h)) + self.msg = GVPBlock(m_dim, m_dim, n_feedforward, + activations=activations, vector_gate=vector_gate) + + m_dim = tuple_sum(m_dim, global_dims) + self.x_out = GVPBlock(m_dim, node_dims, n_feedforward, + activations=activations, vector_gate=vector_gate) + self.x_norm = LayerNorm(node_dims, learnable_vector_weight=True) + self.x_dropout = Dropout(drop_rate) + + e_dim = tuple_sum(tuple_mul(node_dims, 2), edge_dims, global_dims) + if self.attention: + e_dim = (e_dim[0] + 3 * attn_heads, e_dim[1]) + self.e_out = GVPBlock(e_dim, edge_dims, n_feedforward, + activations=activations, vector_gate=vector_gate) + self.e_norm = LayerNorm(edge_dims, learnable_vector_weight=True) + self.e_dropout = Dropout(drop_rate) + + self.pna_x = GeometricPNA(node_dims, node_dims) + self.pna_e = GeometricPNA(edge_dims, edge_dims) + self.y = GVP(global_dims, global_dims, activations=(None, None), vector_gate=vector_gate) + _dim = tuple_sum(node_dims, edge_dims, global_dims) + self.y_out = GVPBlock(_dim, global_dims, n_feedforward, + activations=activations, vector_gate=vector_gate) + self.y_norm = LayerNorm(global_dims, learnable_vector_weight=True) + self.y_dropout = Dropout(drop_rate) + + def forward(self, x, edge_index, batch_mask, edge_attr, global_attr=None, + node_mask=None): + """ + :param x: tuple (s, V) of `torch.Tensor` + :param edge_index: array of shape [2, n_edges] + :param batch_mask: array indicating different graphs + :param edge_attr: tuple (s, V) of `torch.Tensor` + :param global_attr: tuple (s, V) of `torch.Tensor` + :param node_mask: array of type `bool` to index into the first + dim of node embeddings (s, V). If not `None`, only + these nodes will be updated. + """ + + row, col = edge_index + n = len(x[0]) + batch_size = len(torch.unique(batch_mask)) + + # Compute attention + if self.attention: + Q = self.q(x) + K = self.k(x) + b = self.b(edge_attr) + + qs, qv = Q # (n, dq * h), (n, dq * h, 3) + ks, kv = K # (n, dq * h), (n, dq * h, 3) + attn_s = (qs[row] * ks[col]).reshape(len(row), self.h, self.dq[0]).sum(dim=-1) # (m, h) + # NOTE: attn_v is the Frobenius inner product between vector-valued queries and keys of size [dq, 3] + # (generalizes the dot-product between queries and keys similar to Pocket2Mol) + # TODO: double-check if this is correctly implemented! + attn_v = (qv[row] * kv[col]).reshape(len(row), self.h, self.dq[1], 3).sum(dim=(-2, -1)) # (m, h) + attn_e = b.reshape(b.size(0), self.h, self.db).sum(dim=-1) # (m, h) + + attn = attn_s / math.sqrt(3 * self.dk[0]) + \ + attn_v / math.sqrt(9 * self.dk[1]) + \ + attn_e / math.sqrt(3 * self.db) + attn = scatter_softmax(attn, row, dim=0) # (m, h) + attn = attn.unsqueeze(-1) # (m, h, 1) + + # Compute new features + Vx = self.vx(x) + Ve = self.ve(edge_attr) + + mx = (Vx[0].reshape(Vx[0].size(0), self.h, self.dv[0]), # (n, h, dv) + Vx[1].reshape(Vx[1].size(0), self.h, self.dv[1], 3)) # (n, h, dv, 3) + me = (Ve[0].reshape(Ve[0].size(0), self.h, self.de[0]), + Ve[1].reshape(Ve[1].size(0), self.h, self.de[1], 3)) + + mx = tuple_index(mx, col) + if self.attention: + mx = tuple_mul(mx, attn) + me = tuple_mul(me, attn) + + _m = tuple_cat(mx, me) + _m = (_m[0].flatten(1), _m[1].flatten(1, 2)) + m = self.msg(_m) # (m, h * dv), (m, h * dv, 3) + m = (scatter_mean(m[0], row, dim=0, dim_size=n), # (n, h * dv) + scatter_mean(m[1], row, dim=0, dim_size=n)) # (n, h * dv, 3) + if global_attr is not None: + m = tuple_cat(m, tuple_index(global_attr, batch_mask)) + X_out = self.x_norm(tuple_sum(x, self.x_dropout(self.x_out(m)))) + + _e = tuple_cat(tuple_index(x, row), tuple_index(x, col), edge_attr) + if self.attention: + _e = (torch.cat([_e[0], attn_s, attn_v, attn_e], dim=-1), _e[1]) + if global_attr is not None: + _e = tuple_cat(_e, tuple_index(global_attr, batch_mask[row])) + E_out = self.e_norm(tuple_sum(edge_attr, self.e_dropout(self.e_out(_e)))) + + _y = tuple_cat(self.pna_x(x, batch_mask, batch_size), + self.pna_e(edge_attr, batch_mask[row], batch_size)) + if global_attr is not None: + _y = tuple_cat(_y, self.y(global_attr)) + y_out = self.y_norm(tuple_sum(global_attr, self.y_dropout(self.y_out(_y)))) + else: + y_out = self.y_norm(self.y_dropout(self.y_out(_y))) + + if node_mask is not None: + X_out[0][~node_mask], X_out[1][~node_mask] = tuple_index(x, ~node_mask) + + return X_out, E_out, y_out + + +class GVPTransformerModel(torch.nn.Module): + """ + GVP-Transformer model + + :param node_in_dim: node dimension in input graph, scalars or tuple (scalars, vectors) + :param node_h_dim: node dimensions to use in GVP-GNN layers, tuple (s, V) + :param node_out_nf: node dimensions in output graph, tuple (s, V) + :param edge_in_nf: edge dimension in input graph (scalars) + :param edge_h_dim: edge dimensions to embed to before use in GVP-GNN layers, + tuple (s, V) + :param edge_out_nf: edge dimensions in output graph, tuple (s, V) + :param num_layers: number of GVP-GNN layers + :param drop_rate: rate to use in all dropout layers + :param reflection_equiv: bool, use reflection-sensitive feature based on the + cross product if False + :param d_max: + :param num_rbf: + :param vector_gate: use vector gates in all GVPs + :param attention: can be used to turn off the attention mechanism + """ + def __init__(self, node_in_dim, node_h_dim, node_out_nf, edge_in_nf, + edge_h_dim, edge_out_nf, num_layers, dk, dv, de, db, dy, + attn_heads, n_feedforward, drop_rate, reflection_equiv=True, + d_max=20.0, num_rbf=16, vector_gate=False, attention=True): + + super(GVPTransformerModel, self).__init__() + + self.reflection_equiv = reflection_equiv + self.d_max = d_max + self.num_rbf = num_rbf + + # node_in_dim = (node_in_dim, 1) + if not isinstance(node_in_dim, tuple): + node_in_dim = (node_in_dim, 0) + + edge_in_dim = (edge_in_nf + 2 * node_in_dim[0] + self.num_rbf, 1) + if not self.reflection_equiv: + edge_in_dim = (edge_in_dim[0], edge_in_dim[1] + 1) + + self.W_v = GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=vector_gate) + self.W_e = GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=vector_gate) + # self.W_v = nn.Sequential( + # LayerNorm(node_in_dim, learnable_vector_weight=True), + # GVP(node_in_dim, node_h_dim, activations=(None, None)), + # ) + # self.W_e = nn.Sequential( + # LayerNorm(edge_in_dim, learnable_vector_weight=True), + # GVP(edge_in_dim, edge_h_dim, activations=(None, None)), + # ) + + self.dy = dy + self.layers = nn.ModuleList( + GVPTransformerLayer(node_h_dim, edge_h_dim, dy, dk, dv, de, db, + attn_heads, n_feedforward=n_feedforward, + drop_rate=drop_rate, vector_gate=vector_gate, + activations=(F.relu, None), attention=attention) + for _ in range(num_layers)) + + self.W_v_out = GVP(node_h_dim, (node_out_nf, 1), activations=(None, None), vector_gate=vector_gate) + self.W_e_out = GVP(edge_h_dim, (edge_out_nf, 0), activations=(None, None), vector_gate=vector_gate) + # self.W_v_out = nn.Sequential( + # LayerNorm(node_h_dim, learnable_vector_weight=True), + # GVP(node_h_dim, (node_out_nf, 1), activations=(None, None)), + # ) + # self.W_e_out = nn.Sequential( + # LayerNorm(edge_h_dim, learnable_vector_weight=True), + # GVP(edge_h_dim, (edge_out_nf, 0), activations=(None, None)) + # ) + + def edge_features(self, h, x, edge_index, batch_mask=None, edge_attr=None): + """ + :param h: + :param x: + :param edge_index: + :param batch_mask: + :param edge_attr: + :return: scalar and vector-valued edge features + """ + row, col = edge_index + coord_diff = x[row] - x[col] + dist = coord_diff.norm(dim=-1) + rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf, + device=x.device) + + edge_s = torch.cat([h[row], h[col], rbf], dim=1) + edge_v = _normalize(coord_diff).unsqueeze(-2) + + if edge_attr is not None: + edge_s = torch.cat([edge_s, edge_attr], dim=1) + + if not self.reflection_equiv: + mean = scatter_mean(x, batch_mask, dim=0, + dim_size=batch_mask.max() + 1) + row, col = edge_index + cross = torch.cross(x[row] - mean[batch_mask[row]], + x[col] - mean[batch_mask[col]], dim=1) + cross = _normalize(cross).unsqueeze(-2) + + edge_v = torch.cat([edge_v, cross], dim=-2) + + return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v) + + def forward(self, h, x, edge_index, v=None, batch_mask=None, edge_attr=None): + + bs = len(batch_mask.unique()) + + # h_v = (h, x.unsqueeze(-2)) + h_v = h if v is None else (h, v) + h_e = self.edge_features(h, x, edge_index, batch_mask, edge_attr) + + h_v = self.W_v(h_v) + h_e = self.W_e(h_e) + h_y = (torch.zeros(bs, self.dy[0], device=h.device), + torch.zeros(bs, self.dy[1], 3, device=h.device)) + + for layer in self.layers: + h_v, h_e, h_y = layer(h_v, edge_index, batch_mask, h_e, h_y) + + # h, x = self.W_v_out(h_v) + # x = x.squeeze(-2) + h, vel = self.W_v_out(h_v) + # x = x + vel.squeeze(-2) + + edge_attr = self.W_e_out(h_e) + + # return h, x, edge_attr + return h, vel.squeeze(-2), edge_attr + + +if __name__ == "__main__": + from src.model.gvp import randn + from scipy.spatial.transform import Rotation + + def test_equivariance(model, nodes, edges, glob_feat): + random = torch.as_tensor(Rotation.random().as_matrix(), + dtype=torch.float32, device=device) + + with torch.no_grad(): + X_out, E_out, y_out = model(nodes, edges, glob_feat) + n_v_rot, e_v_rot, y_v_rot = nodes[1] @ random, edges[1] @ random, glob_feat[1] @ random + X_out_v_rot = X_out[1] @ random + E_out_v_rot = E_out[1] @ random + y_out_v_rot = y_out[1] @ random + X_out_prime, E_out_prime, y_out_prime = model((nodes[0], n_v_rot), (edges[0], e_v_rot), (glob_feat[0], y_v_rot)) + + assert torch.allclose(X_out[0], X_out_prime[0], atol=1e-5, rtol=1e-4) + assert torch.allclose(X_out_v_rot, X_out_prime[1], atol=1e-5, rtol=1e-4) + assert torch.allclose(E_out[0], E_out_prime[0], atol=1e-5, rtol=1e-4) + assert torch.allclose(E_out_v_rot, E_out_prime[1], atol=1e-5, rtol=1e-4) + assert torch.allclose(y_out[0], y_out_prime[0], atol=1e-5, rtol=1e-4) + assert torch.allclose(y_out_v_rot, y_out_prime[1], atol=1e-5, rtol=1e-4) + print("SUCCESS") + + + n_nodes = 300 + n_edges = 10000 + batch_size = 6 + + node_dim = (16, 8) + edge_dim = (8, 4) + global_dim = (4, 2) + dk = (6, 3) + dv = (7, 4) + de = (5, 2) + db = 10 + attn_heads = 9 + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + + nodes = randn(n_nodes, node_dim, device=device) + edges = randn(n_edges, edge_dim, device=device) + glob_feat = randn(batch_size, global_dim, device=device) + edge_index = torch.randint(0, n_nodes, (2, n_edges), device=device) + batch_idx = torch.randint(0, batch_size, (n_nodes,), device=device) + + model = GVPTransformerLayer(node_dim, edge_dim, global_dim, dk, dv, de, db, + attn_heads, n_feedforward = 2, + drop_rate = 0.1).to(device).eval() + + model_fn = lambda h_V, h_E, h_y: model(h_V, edge_index, batch_idx, h_E, h_y) + test_equivariance(model_fn, nodes, edges, glob_feat) diff --git a/src/model/lightning.py b/src/model/lightning.py new file mode 100644 index 0000000000000000000000000000000000000000..4771b500038737713fbc090f00613a2d223195ec --- /dev/null +++ b/src/model/lightning.py @@ -0,0 +1,1426 @@ +import warnings +import tempfile +from typing import Optional, Union +from time import time +from pathlib import Path +from functools import partial +from itertools import accumulate +from argparse import Namespace + +import numpy as np +import pandas as pd +from rdkit import Chem +import torch +from torch.utils.data import DataLoader, SubsetRandomSampler +from torch.distributions.categorical import Categorical +import pytorch_lightning as pl +from torch_scatter import scatter_mean + +import src.utils as utils +from src.constants import atom_encoder, atom_decoder, aa_encoder, aa_decoder, \ + bond_encoder, bond_decoder, residue_encoder, residue_bond_encoder, \ + residue_decoder, residue_bond_decoder, aa_atom_index, aa_atom_mask +from src.data.dataset import ProcessedLigandPocketDataset, ClusteredDataset, get_wds +from src.data import data_utils +from src.data.data_utils import AppendVirtualNodesInCoM, center_data, Residues, TensorDict, randomize_tensors +from src.model.flows import CoordICFM, TorusICFM, CoordICFMPredictFinal, TorusICFMPredictFinal, SO3ICFM +from src.model.markov_bridge import UniformPriorMarkovBridge, MarginalPriorMarkovBridge +from src.model.dynamics import Dynamics +from src.model.dynamics_hetero import DynamicsHetero +from src.model.diffusion_utils import DistributionNodes +from src.model.loss_utils import TimestepWeights, clash_loss +from src.analysis.visualization_utils import pocket_to_rdkit, mols_to_pdbfile +from src.analysis.metrics import MoleculeValidity, CategoricalDistribution, MolecularProperties +from src.data.molecule_builder import build_molecule +from src.data.postprocessing import process_all +from src.sbdd_metrics.metrics import FullEvaluator +from src.sbdd_metrics.evaluation import VALIDITY_METRIC_NAME, aggregated_metrics, collection_metrics +from tqdm import tqdm + +# derive additional constants +aa_atom_mask_tensor = torch.tensor([aa_atom_mask[aa] for aa in aa_decoder]) +aa_atom_decoder = {aa: {v: k for k, v in aa_atom_index[aa].items()} for aa in aa_decoder} +aa_atom_type_tensor = torch.tensor([[atom_encoder.get(aa_atom_decoder[aa].get(i, '-')[0], -42) + for i in range(14)] for aa in aa_decoder]) + + +def set_default(namespace, key, default_val): + val = vars(namespace).get(key, default_val) + setattr(namespace, key, val) + + +class DrugFlow(pl.LightningModule): + def __init__( + self, + pocket_representation: str, + train_params: Namespace, + loss_params: Namespace, + eval_params: Namespace, + predictor_params: Namespace, + simulation_params: Namespace, + virtual_nodes: Union[list, None], + flexible: bool, + flexible_bb: bool = False, + debug: bool = False, + overfit: bool = False, + ): + super(DrugFlow, self).__init__() + self.save_hyperparameters() + + # Set default parameters + set_default(train_params, "sharded_dataset", False) + set_default(train_params, "sample_from_clusters", False) + set_default(train_params, "lr_step_size", None) + set_default(train_params, "lr_gamma", None) + set_default(train_params, "gnina", None) + set_default(loss_params, "lambda_x", 1.0) + set_default(loss_params, "lambda_clash", None) + set_default(loss_params, "reduce", "mean") + set_default(loss_params, "regularize_uncertainty", None) + set_default(eval_params, "n_loss_per_sample", 1) + set_default(eval_params, "n_sampling_steps", simulation_params.n_steps) + set_default(predictor_params, "transform_sc_pred", False) + set_default(predictor_params, "add_chi_as_feature", False) + set_default(predictor_params, "augment_residue_sc", False) + set_default(predictor_params, "augment_ligand_sc", False) + set_default(predictor_params, "add_all_atom_diff", False) + set_default(predictor_params, "angle_act_fn", None) + set_default(simulation_params, "predict_confidence", False) + set_default(simulation_params, "predict_final", False) + set_default(simulation_params, "scheduler_chi", None) + + # Check for invalid configurations + assert pocket_representation in {'side_chain_bead', 'CA+'} + self.pocket_representation = pocket_representation + + assert flexible or not predictor_params.augment_residue_sc + self.augment_residue_sc = predictor_params.augment_residue_sc \ + if 'augment_residue_sc' in predictor_params else False + self.augment_ligand_sc = predictor_params.augment_ligand_sc \ + if 'augment_ligand_sc' in predictor_params else False + + assert not (flexible_bb and predictor_params.normal_modes), \ + "Normal mode eigenvectors are only meaningful for fixed backbones" + assert (not flexible_bb) or flexible, \ + "Currently atom vectors aren't updated if flexible=False" + + assert not (simulation_params.predict_confidence and + (not predictor_params.heterogeneous_graph or simulation_params.predict_final)) + + # Set parameters + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + self.virtual_nodes = virtual_nodes + self.flexible = flexible + self.flexible_bb = flexible_bb + self.debug = debug + self.overfit = overfit + self.predict_confidence = simulation_params.predict_confidence + + if self.virtual_nodes: + self.add_virtual_min = virtual_nodes[0] + self.add_virtual_max = virtual_nodes[1] + + # Training parameters + self.datadir = train_params.datadir + self.receptor_dir = train_params.datadir + self.batch_size = train_params.batch_size + self.lr = train_params.lr + self.lr_step_size = train_params.lr_step_size + self.lr_gamma = train_params.lr_gamma + self.num_workers = train_params.num_workers + self.sample_from_clusters = train_params.sample_from_clusters + self.sharded_dataset = train_params.sharded_dataset + self.clip_grad = train_params.clip_grad + if self.clip_grad: + self.gradnorm_queue = utils.Queue() + # Add large value that will be flushed. + self.gradnorm_queue.add(3000) + + # Evaluation parameters + self.outdir = eval_params.outdir + self.eval_batch_size = eval_params.eval_batch_size + self.eval_epochs = eval_params.eval_epochs + # assert eval_params.visualize_sample_epoch % self.eval_epochs == 0 + self.visualize_sample_epoch = eval_params.visualize_sample_epoch + self.visualize_chain_epoch = eval_params.visualize_chain_epoch + self.sample_with_ground_truth_size = eval_params.sample_with_ground_truth_size + self.n_loss_per_sample = eval_params.n_loss_per_sample + self.n_eval_samples = eval_params.n_eval_samples + self.n_visualize_samples = eval_params.n_visualize_samples + self.keep_frames = eval_params.keep_frames + self.gnina = train_params.gnina + + # Feature encoders/decoders + self.atom_encoder = atom_encoder + self.atom_decoder = atom_decoder + self.bond_encoder = bond_encoder + self.bond_decoder = bond_decoder + self.aa_encoder = aa_encoder + self.aa_decoder = aa_decoder + self.residue_encoder = residue_encoder + self.residue_decoder = residue_decoder + self.residue_bond_encoder = residue_bond_encoder + self.residue_bond_decoder = residue_bond_decoder + + self.atom_nf = len(self.atom_decoder) + self.residue_nf = len(self.aa_decoder) + if self.pocket_representation == 'side_chain_bead': + self.residue_nf += len(self.residue_encoder) + if self.pocket_representation == 'CA+': + self.aa_atom_index = aa_atom_index + self.n_atom_aa = max([x for aa in aa_atom_index.values() for x in aa.values()]) + 1 + self.residue_nf = (self.residue_nf, self.n_atom_aa) # (s, V) + self.bond_nf = len(self.bond_decoder) + self.pocket_bond_nf = len(self.residue_bond_decoder) + self.x_dim = 3 + + # Set up the neural network + self.dynamics = self.init_model(predictor_params) + + # Initialize objects for each variable type + if simulation_params.predict_final: + self.module_x = CoordICFMPredictFinal(None) + self.module_chi = TorusICFMPredictFinal(None, 5) if self.flexible else None + if self.flexible_bb: + raise NotImplementedError() + else: + self.module_x = CoordICFM(None) + # self.module_chi = AngleICFM(None, 5) if self.flexible else None + scheduler_args = None if simulation_params.scheduler_chi is None else vars(simulation_params.scheduler_chi) + self.module_chi = TorusICFM(None, 5, scheduler_args) if self.flexible else None + self.module_trans = CoordICFM(None) if self.flexible_bb else None + self.module_rot = SO3ICFM(None) if self.flexible_bb else None + + if simulation_params.prior_h == 'uniform': + self.module_h = UniformPriorMarkovBridge(self.atom_nf, loss_type=loss_params.discrete_loss) + elif simulation_params.prior_h == 'marginal': + self.register_buffer('prior_h', self.get_categorical_prop('atom')) # add to module + self.module_h = MarginalPriorMarkovBridge(self.atom_nf, self.prior_h, loss_type=loss_params.discrete_loss) + + if simulation_params.prior_e == 'uniform': + self.module_e = UniformPriorMarkovBridge(self.bond_nf, loss_type=loss_params.discrete_loss) + elif simulation_params.prior_e == 'marginal': + self.register_buffer('prior_e', self.get_categorical_prop('bond')) # add to module + self.module_e = MarginalPriorMarkovBridge(self.bond_nf, self.prior_e, loss_type=loss_params.discrete_loss) + + + # Loss parameters + self.loss_reduce = loss_params.reduce + self.lambda_x = loss_params.lambda_x + self.lambda_h = loss_params.lambda_h + self.lambda_e = loss_params.lambda_e + self.lambda_chi = loss_params.lambda_chi if self.flexible else None + self.lambda_trans = loss_params.lambda_trans if self.flexible_bb else None + self.lambda_rot = loss_params.lambda_rot if self.flexible_bb else None + self.lambda_clash = loss_params.lambda_clash + self.regularize_uncertainty = loss_params.regularize_uncertainty + + if loss_params.timestep_weights is not None: + weight_type = loss_params.timestep_weights.split('_')[0] + kwargs = loss_params.timestep_weights.split('_')[1:] + kwargs = {x.split('=')[0]: float(x.split('=')[1]) for x in kwargs} + self.timestep_weights = TimestepWeights(weight_type, **kwargs) + else: + self.timestep_weights = None + + + # Sampling + self.T_sampling = eval_params.n_sampling_steps + self.train_step_size = 1 / simulation_params.n_steps + self.size_distribution = None # initialized only if needed + + + # Metrics, initialized only if needed + self.train_smiles = None + self.ligand_metrics = None + self.molecule_properties = None + self.evaluator = None + self.ligand_atom_type_distribution = None + self.ligand_bond_type_distribution = None + + # containers for metric aggregation + self.training_step_outputs = [] + self.validation_step_outputs = [] + + def on_load_checkpoint(self, checkpoint): + """ + This hook is only used for backward compatibility with checkpoints that + did not save prior_h and prior_e in state_dict in the past + """ + if hasattr(self, "prior_h") and "prior_h" not in checkpoint["state_dict"]: + checkpoint["state_dict"]["prior_h"] = self.get_categorical_prop('atom') + if hasattr(self, "prior_e") and "prior_e" not in checkpoint["state_dict"]: + checkpoint["state_dict"]["prior_e"] = self.get_categorical_prop('bond') + if "prior_e" in checkpoint["state_dict"] and not hasattr(self, "prior_e"): + # NOTE: a very exotic case that happened to one model. Potentially can be removed in the future + self.register_buffer("prior_e", self.get_categorical_prop('bond')) + + def init_model(self, predictor_params): + + model_type = predictor_params.backbone + + if 'heterogeneous_graph' in predictor_params and predictor_params.heterogeneous_graph: + return DynamicsHetero( + atom_nf=self.atom_nf, + residue_nf=self.residue_nf, + bond_dict=self.bond_encoder, + pocket_bond_dict=self.residue_bond_encoder, + model=model_type, + num_rbf_time=predictor_params.__dict__.get('num_rbf_time'), + model_params=getattr(predictor_params, model_type + '_params'), + edge_cutoff_ligand=predictor_params.edge_cutoff_ligand, + edge_cutoff_pocket=predictor_params.edge_cutoff_pocket, + edge_cutoff_interaction=predictor_params.edge_cutoff_interaction, + predict_angles=self.flexible, + predict_frames=self.flexible_bb, + add_cycle_counts=predictor_params.cycle_counts, + add_spectral_feat=predictor_params.spectral_feat, + add_nma_feat=predictor_params.normal_modes, + reflection_equiv=predictor_params.reflection_equivariant, + d_max=predictor_params.d_max, + num_rbf_dist=predictor_params.num_rbf, + self_conditioning=predictor_params.self_conditioning, + augment_residue_sc=self.augment_residue_sc, + augment_ligand_sc=self.augment_ligand_sc, + add_chi_as_feature=predictor_params.add_chi_as_feature, + angle_act_fn=predictor_params.angle_act_fn, + add_all_atom_diff=predictor_params.add_all_atom_diff, + predict_confidence=self.predict_confidence, + ) + + else: + if predictor_params.__dict__.get('num_rbf_time') is not None: + raise NotImplementedError("RBF time embedding not yet implemented") + + return Dynamics( + atom_nf=self.atom_nf, + residue_nf=self.residue_nf, + joint_nf=predictor_params.joint_nf, + bond_dict=self.bond_encoder, + pocket_bond_dict=self.residue_bond_encoder, + edge_nf=predictor_params.edge_nf, + hidden_nf=predictor_params.hidden_nf, + model=model_type, + model_params=getattr(predictor_params, model_type + '_params'), + edge_cutoff_ligand=predictor_params.edge_cutoff_ligand, + edge_cutoff_pocket=predictor_params.edge_cutoff_pocket, + edge_cutoff_interaction=predictor_params.edge_cutoff_interaction, + predict_angles=self.flexible, + predict_frames=self.flexible_bb, + add_cycle_counts=predictor_params.cycle_counts, + add_spectral_feat=predictor_params.spectral_feat, + add_nma_feat=predictor_params.normal_modes, + self_conditioning=predictor_params.self_conditioning, + augment_residue_sc=self.augment_residue_sc, + augment_ligand_sc=self.augment_ligand_sc, + add_chi_as_feature=predictor_params.add_chi_as_feature, + angle_act_fn=predictor_params.angle_act_fn, + ) + + def _load_histogram(self, type): + """ + Load empirical categorical distributions of atom or bond types from disk. + Returns None if the required file is not found. + """ + assert type in {"atom", "bond"} + filename = 'ligand_type_histogram.npy' if type == 'atom' else 'ligand_bond_type_histogram.npy' + encoder = self.atom_encoder if type == 'atom' else self.bond_encoder + hist_file = Path(self.datadir, filename) + if not hist_file.exists(): + return None + hist = np.load(hist_file, allow_pickle=True).item() + return CategoricalDistribution(hist, encoder) + + def get_categorical_prop(self, type): + hist = self._load_histogram(type) + encoder = self.atom_encoder if type == 'atom' else self.bond_encoder + # Note: default value ensures that code will crash if prior is not + # read from disk or loaded from checkpoint later on + return torch.zeros(len(encoder)) * float("nan") if hist is None else torch.tensor(hist.p) + + def configure_optimizers(self): + optimizers = [ + torch.optim.AdamW(self.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12), + ] + + if self.lr_step_size is None or self.lr_gamma is None: + lr_schedulers = [] + else: + lr_schedulers = [ + torch.optim.lr_scheduler.StepLR(optimizers[0], step_size=self.lr_step_size, gamma=self.lr_gamma), + ] + return optimizers, lr_schedulers + + def setup(self, stage: Optional[str] = None): + + self.setup_sampling() + + if stage == 'fit': + self.train_dataset = self.get_dataset(stage='train') + self.val_dataset = self.get_dataset(stage='val') + self.setup_metrics() + elif stage == 'val': + self.val_dataset = self.get_dataset(stage='val') + self.setup_metrics() + elif stage == 'test': + self.test_dataset = self.get_dataset(stage='test') + self.setup_metrics() + elif stage == 'generation': + pass + else: + raise NotImplementedError + + def get_dataset(self, stage, pocket_transform=None): + + # when sampling we don't append virtual nodes as we might need access to the ground truth size + if self.virtual_nodes and stage == "train": + ligand_transform = AppendVirtualNodesInCoM( + atom_encoder, bond_encoder, add_min=self.add_virtual_min, add_max=self.add_virtual_max) + else: + ligand_transform = None + + # we want to know if something goes wrong on the validation or test set + catch_errors = stage == "train" + + if self.sharded_dataset: + return get_wds( + data_path=self.datadir, + stage='val' if self.debug else stage, + ligand_transform=ligand_transform, + pocket_transform=pocket_transform, + ) + + if self.sample_from_clusters and stage == "train": # val/test should be deterministic + return ClusteredDataset( + pt_path=Path(self.datadir, 'val.pt' if self.debug else f'{stage}.pt'), + ligand_transform=ligand_transform, + pocket_transform=pocket_transform, + catch_errors=catch_errors + ) + + return ProcessedLigandPocketDataset( + pt_path=Path(self.datadir, 'val.pt' if self.debug else f'{stage}.pt'), + ligand_transform=ligand_transform, + pocket_transform=pocket_transform, + catch_errors=catch_errors + ) + + def setup_sampling(self): + # distribution of nodes + histogram_file = Path(self.datadir, 'size_distribution.npy') # TODO: store this in model checkpoint so that we can sample without this file + size_histogram = np.load(histogram_file).tolist() + self.size_distribution = DistributionNodes(size_histogram) + + def setup_metrics(self): + # For metrics + smiles_file = Path(self.datadir, 'train_smiles.npy') + self.train_smiles = None if not smiles_file.exists() else np.load(smiles_file) + + self.ligand_metrics = MoleculeValidity() + self.molecule_properties = MolecularProperties() + self.evaluator = FullEvaluator(gnina=self.gnina, exclude_evaluators=['geometry', 'ring_count']) + self.ligand_atom_type_distribution = self._load_histogram('atom') + self.ligand_bond_type_distribution = self._load_histogram('bond') + + def train_dataloader(self): + shuffle = None if self.overfit else False if self.sharded_dataset else True + return DataLoader(self.train_dataset, self.batch_size, shuffle=shuffle, + sampler=SubsetRandomSampler([0]) if self.overfit else None, + num_workers=self.num_workers, + collate_fn=self.train_dataset.collate_fn, + # collate_fn=partial(self.train_dataset.collate_fn, ligand_transform=batch_transform), + pin_memory=True) + + def val_dataloader(self): + if self.overfit: + return self.train_dataloader() + + return DataLoader(self.val_dataset, self.eval_batch_size, + shuffle=False, num_workers=self.num_workers, + collate_fn=self.val_dataset.collate_fn, + pin_memory=True) + + def test_dataloader(self): + return DataLoader(self.test_dataset, self.eval_batch_size, shuffle=False, + num_workers=self.num_workers, + collate_fn=self.test_dataset.collate_fn, + pin_memory=True) + + def log_metrics(self, metrics_dict, split, batch_size=None, **kwargs): + for m, value in metrics_dict.items(): + self.log(f'{m}/{split}', value, batch_size=batch_size, **kwargs) + + def aggregate_metrics(self, step_outputs, prefix): + if 'timestep' in step_outputs[0]: + timesteps = torch.cat([x['timestep'] for x in step_outputs]).squeeze() + + if 'loss_per_sample' in step_outputs[0]: + losses = torch.cat([x['loss_per_sample'] for x in step_outputs]) + pearson_corr = torch.corrcoef(torch.stack([timesteps, losses], dim=0))[0, 1] + self.log(f'corr_loss_timestep/{prefix}', pearson_corr, prog_bar=False) + + if 'eps_hat_norm' in step_outputs[0]: + eps_norm = torch.cat([x['eps_hat_norm'] for x in step_outputs]) + pearson_corr = torch.corrcoef(torch.stack([timesteps, eps_norm], dim=0))[0, 1] + self.log(f'corr_eps_timestep/{prefix}', pearson_corr, prog_bar=False) + + def on_train_epoch_end(self): + self.aggregate_metrics(self.training_step_outputs, 'train') + self.training_step_outputs.clear() + + # TODO: doesn't work in multi-GPU mode + # def on_before_batch_transfer(self, batch, dataloader_idx): + # """ + # Performs operations on data before it is transferred to the GPU. + # Hence, supports multiple dataloaders for speedup. + # """ + # batch['pocket'] = Residues(**batch['pocket']) + # return batch + + # # TODO: try if this is compatible with DDP + # def on_after_batch_transfer(self, batch, dataloader_idx): + # """ + # Performs operations on data after it is transferred to the GPU. + # """ + # batch['pocket'] = Residues(**batch['pocket']) + # batch['ligand'] = TensorDict(**batch['ligand']) + # return batch + + def get_sc_transform_fn(self, zt_chi, zt_x, t, z0_chi, ligand_mask, pocket): + sc_transform = {} + + if self.augment_residue_sc: + def pred_all_atom(pred_chi, pred_trans=None, pred_rot=None): + temp_pocket = pocket.deepcopy() + + if pred_trans is not None and pred_rot is not None: + zt_trans = pocket['x'] + zt_rot = pocket['axis_angle'] + z1_trans_pred = self.module_trans.get_z1_given_zt_and_pred( + zt_trans, pred_trans, None, t, pocket['mask']) + z1_rot_pred = self.module_rot.get_z1_given_zt_and_pred( + zt_rot, pred_rot, None, t, pocket['mask']) + temp_pocket.set_frame(z1_trans_pred, z1_rot_pred) + + z1_chi_pred = self.module_chi.get_z1_given_zt_and_pred( + zt_chi[..., :5], pred_chi, z0_chi, t, pocket['mask']) + temp_pocket.set_chi(z1_chi_pred) + + all_coord = temp_pocket['v'] + temp_pocket['x'].unsqueeze(1) + return all_coord - pocket['x'].unsqueeze(1) + + sc_transform['residues'] = pred_all_atom + + if self.augment_ligand_sc: + # sc_transform['atoms'] = partial(self.module_x.get_z1_given_zt_and_pred, zt=zs_x, z0=None, t=t, batch_mask=lig_mask) + sc_transform['atoms'] = lambda pred: (self.module_x.get_z1_given_zt_and_pred( + zt_x, pred.squeeze(1), None, t, ligand_mask) - zt_x).unsqueeze(1) + + return sc_transform + + def compute_loss(self, ligand, pocket, return_info=False): + """ + Samples time steps and computes network predictions + """ + # TODO: move somewhere else (like collate_fn) + pocket = Residues(**pocket) + + # Center sample + ligand, pocket = center_data(ligand, pocket) + if pocket['x'].numel() > 0: + pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0) + else: + pocket_com = scatter_mean(ligand['x'], ligand['mask'], dim=0) + + # # Normalize pocket coordinates + # pocket['x'] = self.module_x.normalize(pocket['x']) + + # Sample a timestep t for each example in batch + t = torch.rand(ligand['size'].size(0), device=ligand['x'].device).unsqueeze(-1) + + # Noise + z0_x = self.module_x.sample_z0(pocket_com, ligand['mask']) + z0_h = self.module_h.sample_z0(ligand['mask']) + z0_e = self.module_e.sample_z0(ligand['bond_mask']) + zt_x = self.module_x.sample_zt(z0_x, ligand['x'], t, ligand['mask']) + zt_h = self.module_h.sample_zt(z0_h, ligand['one_hot'], t, ligand['mask']) + zt_e = self.module_e.sample_zt(z0_e, ligand['bond_one_hot'], t, ligand['bond_mask']) + + if self.flexible_bb: + z0_trans = self.module_trans.sample_z0(pocket_com, pocket['mask']) + z1_trans = pocket['x'].detach().clone() + zt_trans = self.module_trans.sample_zt(z0_trans, z1_trans, t, pocket['mask']) + + z0_rot = self.module_rot.sample_z0(pocket['mask']) + z1_rot = pocket['axis_angle'].detach().clone() + zt_rot = self.module_rot.sample_zt(z0_rot, z1_rot, t, pocket['mask']) + + # update pocket + pocket.set_frame(zt_trans, zt_rot) + + z0_chi, zt_chi = None, None + if self.flexible: + # residues = [data_utils.residue_from_internal_coord(ic) for ic in pocket['residues']] + # residues = pocket['residues'] + # z1_chi = torch.stack([data_utils.get_torsion_angles(r, device=self.device) for r in residues], dim=0) + z1_chi = pocket['chi'][:, :5].detach().clone() + + z0_chi = self.module_chi.sample_z0(pocket['mask']) + zt_chi = self.module_chi.sample_zt(z0_chi, z1_chi, t, pocket['mask']) + + # internal to external coordinates + pocket.set_chi(zt_chi) + if pocket['x'].numel() == 0: + pocket.set_empty_v() + + # Predict denoising + sc_transform = self.get_sc_transform_fn(zt_chi, zt_x, t, z0_chi, ligand['mask'], pocket) + # sc_transform = None + pred_ligand, pred_residues = self.dynamics( + zt_x, zt_h, ligand['mask'], pocket, t, + bonds_ligand=(ligand['bonds'], zt_e), sc_transform=sc_transform + ) + + # Compute L2 loss + if self.predict_confidence: + loss_x = self.module_x.compute_loss(pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce='none') + + # compute confidence regularization + k = self.module_x.dim # pred.size(-1) + sigma = pred_ligand['uncertainty_vel'] + loss_x = loss_x / (2 * sigma ** 2) + k * torch.log(sigma) + + if self.regularize_uncertainty is not None: + loss_x = loss_x + self.regularize_uncertainty * (pred_ligand['uncertainty_vel'] - 1) ** 2 + + loss_x = self.module_x.reduce_loss(loss_x, ligand['mask'], reduce=self.loss_reduce) + else: + loss_x = self.module_x.compute_loss(pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce=self.loss_reduce) + + # Loss for categorical variables + t_next = torch.clamp(t + self.train_step_size, max=1.0) + loss_h = self.module_h.compute_loss(pred_ligand['logits_h'], zt_h, ligand['one_hot'], ligand['mask'], t, t_next, reduce=self.loss_reduce) + loss_e = self.module_e.compute_loss(pred_ligand['logits_e'], zt_e, ligand['bond_one_hot'], ligand['bond_mask'], t, t_next, reduce=self.loss_reduce) + + loss = self.lambda_x * loss_x + self.lambda_h * loss_h + self.lambda_e * loss_e + if self.flexible: + loss_chi = self.module_chi.compute_loss(pred_residues['chi'], z0_chi, z1_chi, zt_chi, t, pocket['mask'], reduce=self.loss_reduce) + loss = loss + self.lambda_chi * loss_chi + + if self.flexible_bb: + loss_trans = self.module_trans.compute_loss(pred_residues['trans'], z0_trans, z1_trans, t, pocket['mask'], reduce=self.loss_reduce) + loss_rot = self.module_rot.compute_loss(pred_residues['rot'], z0_rot, z1_rot, zt_rot, t, pocket['mask'], reduce=self.loss_reduce) + loss = loss + self.lambda_trans * loss_trans + self.lambda_rot * loss_rot + + if self.lambda_clash is not None and self.lambda_clash > 0: + + if self.flexible_bb: + pred_z1_trans = self.module_trans.get_z1_given_zt_and_pred(zt_trans, pred_residues['trans'], z0_trans, t, pocket['mask']) + pred_z1_rot = self.module_rot.get_z1_given_zt_and_pred(zt_rot, pred_residues['rot'], z0_rot, t, pocket['mask']) + pocket.set_frame(pred_z1_trans, pred_z1_rot) + + if self.flexible: + # internal to external coordinates + pred_z1_chi = self.module_chi.get_z1_given_zt_and_pred(zt_chi, pred_residues['chi'], z0_chi, t, pocket['mask']) + pocket.set_chi(pred_z1_chi) + + pocket_coord = pocket['x'].unsqueeze(1) + pocket['v'] + pocket_types = aa_atom_type_tensor[pocket['one_hot'].argmax(dim=-1)] + pocket_mask = pocket['mask'].unsqueeze(-1).repeat((1, pocket['v'].size(1))) + + # Extract only existing atoms + atom_mask = aa_atom_mask_tensor[pocket['one_hot'].argmax(dim=-1)] + pocket_coord = pocket_coord[atom_mask] + pocket_types = pocket_types[atom_mask] + pocket_mask = pocket_mask[atom_mask] + + # pred_z1_x = pred_x + z0_x + pred_z1_x = self.module_x.get_z1_given_zt_and_pred(zt_x, pred_ligand['vel'], z0_x, t, ligand['mask']) + pred_z1_h = pred_ligand['logits_h'].argmax(dim=-1) + loss_clash = clash_loss(pred_z1_x, pred_z1_h, ligand['mask'], + pocket_coord, pocket_types, pocket_mask) + loss = loss + self.lambda_clash * loss_clash + + if self.timestep_weights is not None: + w_t = self.timestep_weights(t).squeeze() + loss = w_t * loss + + loss = loss.mean(0) + + info = { + 'loss_x': loss_x.mean().item(), + 'loss_h': loss_h.mean().item(), + 'loss_e': loss_e.mean().item(), + } + if self.flexible: + info['loss_chi'] = loss_chi.mean().item() + if self.flexible_bb: + info['loss_trans'] = loss_trans.mean().item() + info['loss_rot'] = loss_rot.mean().item() + if self.lambda_clash is not None: + info['loss_clash'] = loss_clash.mean().item() + if self.predict_confidence: + sigma_x_mol = scatter_mean(pred_ligand['uncertainty_vel'], ligand['mask'], dim=0) + info['pearson_sigma_x'] = torch.corrcoef(torch.stack([sigma_x_mol.detach(), t.squeeze()]))[0, 1].item() + info['mean_sigma_x'] = sigma_x_mol.mean().item() + entropy_h = Categorical(logits=pred_ligand['logits_h']).entropy() + entropy_h_mol = scatter_mean(entropy_h, ligand['mask'], dim=0) + info['pearson_entropy_h'] = torch.corrcoef(torch.stack([entropy_h_mol.detach(), t.squeeze()]))[0, 1].item() + info['mean_entropy_h'] = entropy_h_mol.mean().item() + entropy_e = Categorical(logits=pred_ligand['logits_e']).entropy() + entropy_e_mol = scatter_mean(entropy_e, ligand['bond_mask'], dim=0) + info['pearson_entropy_e'] = torch.corrcoef(torch.stack([entropy_e_mol.detach(), t.squeeze()]))[0, 1].item() + info['mean_entropy_e'] = entropy_e_mol.mean().item() + + return (loss, info) if return_info else loss + + def training_step(self, data, *args): + ligand, pocket = data['ligand'], data['pocket'] + try: + loss, info = self.compute_loss(ligand, pocket, return_info=True) + except RuntimeError as e: + # this is not supported for multi-GPU + if self.trainer.num_devices < 2 and 'out of memory' in str(e): + print('WARNING: ran out of memory, skipping to the next batch') + return None + else: + raise e + + log_dict = {k: v for k, v in info.items() if isinstance(v, float) + or torch.numel(v) <= 1} + # if self.learn_nu: + # log_dict['nu_x'] = self.noise_schedules['x'].nu.item() + # log_dict['nu_h'] = self.noise_schedules['h'].nu.item() + # log_dict['nu_e'] = self.noise_schedules['e'].nu.item() + + self.log_metrics({'loss': loss, **log_dict}, 'train', + batch_size=len(ligand['size'])) + + out = {'loss': loss, **info} + self.training_step_outputs.append(out) + return out + + def validation_step(self, data, *args): + + # Compute the loss N times and average to get a better estimate + loss_list, info_list = [], [] + self.dynamics.train() # TODO: this is currently necessary to make self-conditioning work + for _ in range(self.n_loss_per_sample): + loss, info = self.compute_loss(data['ligand'].copy(), + data['pocket'].copy(), + return_info=True) + loss_list.append(loss.item()) + info_list.append(info) + self.dynamics.eval() + if len(loss_list) >= 1: + loss = np.mean(loss_list) + info = {k: np.mean([x[k] for x in info_list]) for k in info_list[0]} + self.log_metrics({'loss': loss, **info}, 'val', batch_size=len(data['ligand']['size'])) + + # Sample + rdmols, rdpockets, _ = self.sample( + data=data, + n_samples=self.n_eval_samples, + num_nodes="ground_truth" if self.sample_with_ground_truth_size else None, + ) + + out = { + 'ligands': rdmols, + 'pockets': rdpockets, + 'receptor_files': [Path(self.receptor_dir, 'val', x) for x in data['pocket']['name']] + } + self.validation_step_outputs.append(out) + return out + + # def test_step(self, data, *args): + # self._shared_eval(data, 'test', *args) + + def on_validation_epoch_end(self): + + outdir = Path(self.outdir, f'epoch_{self.current_epoch}') + + rdmols = [m for x in self.validation_step_outputs for m in x['ligands']] + rdpockets = [p for x in self.validation_step_outputs for p in x['pockets']] + receptors = [r for x in self.validation_step_outputs for r in x['receptor_files']] + self.validation_step_outputs.clear() + + ligand_atom_types = [atom_encoder[a.GetSymbol()] for m in rdmols for a in m.GetAtoms()] + ligand_bond_types = [] + for m in rdmols: + bonds = m.GetBonds() + no_bonds = m.GetNumAtoms() * (m.GetNumAtoms() - 1) // 2 - m.GetNumBonds() + ligand_bond_types += [bond_encoder['NOBOND']] * no_bonds + for b in bonds: + ligand_bond_types.append(bond_encoder[b.GetBondType().name]) + + tic = time() + results = self.analyze_sample( + rdmols, ligand_atom_types, ligand_bond_types, receptors=(rdpockets if len(rdpockets) != 0 else None) + ) + self.log_metrics(results, 'val') + print(f'Evaluation took {time() - tic:.2f} seconds') + + if (self.current_epoch + 1) % self.visualize_sample_epoch == 0: + tic = time() + + outdir.mkdir(exist_ok=True, parents=True) + + # center for better visualization + rdmols = rdmols[:self.n_visualize_samples] + rdpockets = rdpockets[:self.n_visualize_samples] + for m, p in zip(rdmols, rdpockets): + center = m.GetConformer().GetPositions().mean(axis=0) + for i in range(m.GetNumAtoms()): + x, y, z = m.GetConformer().GetPositions()[i] - center + m.GetConformer().SetAtomPosition(i, (x, y, z)) + for i in range(p.GetNumAtoms()): + x, y, z = p.GetConformer().GetPositions()[i] - center + p.GetConformer().SetAtomPosition(i, (x, y, z)) + + # save molecule + utils.write_sdf_file(Path(outdir, 'molecules.sdf'), rdmols) + + # save pocket + utils.write_sdf_file(Path(outdir, 'pockets.sdf'), rdpockets) + + print(f'Sample visualization took {time() - tic:.2f} seconds') + + if (self.current_epoch + 1) % self.visualize_chain_epoch == 0: + tic = time() + outdir.mkdir(exist_ok=True, parents=True) + + if self.sharded_dataset: + index = torch.randint(len(self.val_dataset), size=(1,)).item() + for i, x in enumerate(self.val_dataset): + if i == index: + break + batch = self.val_dataset.collate_fn([x]) + else: + batch = self.val_dataset.collate_fn([self.val_dataset[torch.randint(len(self.val_dataset), size=(1,))]]) + batch['pocket'] = Residues(**batch['pocket']).to(self.device) + pocket_copy = batch['pocket'].copy() + + if len(batch['pocket']['x']) > 0: + ligand_chain, pocket_chain, info = self.sample_chain(batch['pocket'], self.keep_frames) + else: + num_nodes, _ = self.size_distribution.sample() + ligand_chain, pocket_chain, info = self.sample_chain(batch['pocket'], self.keep_frames, num_nodes=num_nodes) + + # utils.write_sdf_file(Path(outdir, 'chain_pocket.sdf'), pocket_chain) + # utils.write_chain(Path(outdir, 'chain_pocket.xyz'), pocket_chain) + if self.flexible or self.flexible_bb: + # insert ground truth at the beginning so that it's used by PyMOL to determine the connectivity + ground_truth_pocket = pocket_to_rdkit( + pocket_copy, self.pocket_representation, + self.atom_encoder, self.atom_decoder, + self.aa_decoder, self.residue_decoder, + self.aa_atom_index + )[0] + ground_truth_ligand = build_molecule( + batch['ligand']['x'], batch['ligand']['one_hot'].argmax(1), + bonds=batch['ligand']['bonds'], + bond_types=batch['ligand']['bond_one_hot'].argmax(1), + atom_decoder=self.atom_decoder, + bond_decoder=self.bond_decoder + ) + pocket_chain.insert(0, ground_truth_pocket) + ligand_chain.insert(0, ground_truth_ligand) + # pocket_chain.insert(0, pocket_chain[-1]) + # ligand_chain.insert(0, ligand_chain[-1]) + + # save molecules + utils.write_sdf_file(Path(outdir, 'chain_ligand.sdf'), ligand_chain) + + # save pocket + mols_to_pdbfile(pocket_chain, Path(outdir, 'chain_pocket.pdb')) + + self.log_metrics(info, 'val') + print(f'Chain visualization took {time() - tic:.2f} seconds') + + + # NOTE: temporary fix of this Lightning bug: + # https://github.com/Lightning-AI/pytorch-lightning/discussions/18110 + # Without it resume training has a strange behavior and fails + @property + def total_batch_idx(self) -> int: + """Returns the current batch index (across epochs)""" + # use `ready` instead of `completed` in case this is accessed after `completed` has been increased + # but before the next `ready` increase + return max(0, self.batch_progress.total.ready - 1) + + @property + def batch_idx(self) -> int: + """Returns the current batch index (within this epoch)""" + # use `ready` instead of `completed` in case this is accessed after `completed` has been increased + # but before the next `ready` increase + return max(0, self.batch_progress.current.ready - 1) + + # def analyze_sample(self, rdmols, atom_types, bond_types, aa_types=None, receptors=None): + # out = {} + + # # Distribution of node types + # kl_div_atom = self.ligand_atom_type_distribution.kl_divergence(atom_types) \ + # if self.ligand_atom_type_distribution is not None else -1 + # out['kl_div_atom_types'] = kl_div_atom + + # # Distribution of edge types + # kl_div_bond = self.ligand_bond_type_distribution.kl_divergence(bond_types) \ + # if self.ligand_bond_type_distribution is not None else -1 + # out['kl_div_bond_types'] = kl_div_bond + + # if aa_types is not None: + # kl_div_aa = self.pocket_type_distribution.kl_divergence(aa_types) \ + # if self.pocket_type_distribution is not None else -1 + # out['kl_div_residue_types'] = kl_div_aa + + # # Post-process sample + # processed_mols = [process_all(m) for m in rdmols] + + # # Other basic metrics + # results = self.ligand_metrics(rdmols) + # out['n_samples'] = results['n_total'] + # out['Validity'] = results['validity'] + # out['Connectivity'] = results['connectivity'] + # out['valid_and_connected'] = results['valid_and_connected'] + + # # connected_mols = [get_largest_fragment(m) for m in rdmols] + # connected_mols = [process_all(m, largest_frag=True, adjust_aromatic_Ns=False, relax_iter=0) for m in rdmols] + # connected_mols = [m for m in connected_mols if m is not None] + # out.update(self.molecule_properties(connected_mols)) + + # # Repeat after post-processing + # results = self.ligand_metrics(processed_mols) + # out['validity_processed'] = results['validity'] + # out['connectivity_processed'] = results['connectivity'] + # out['valid_and_connected_processed'] = results['valid_and_connected'] + + # processed_mols = [m for m in processed_mols if m is not None] + # for k, v in self.molecule_properties(processed_mols).items(): + # out[f"{k}_processed"] = v + + # # Simple docking score + # if receptors is not None and self.gnina is not None: + # assert len(receptors) == len(rdmols) + # docking_results = compute_gnina_scores(rdmols, receptors, gnina=self.gnina) + # out.update(docking_results) + + # # Clash score + # if receptors is not None: + # assert len(receptors) == len(rdmols) + # clashes = { + # 'ligands': [legacy_clash_score(m) for m in rdmols], + # 'pockets': [legacy_clash_score(p) for p in receptors], + # 'between': [legacy_clash_score(m, p) for m, p in zip(rdmols, receptors)], + # 'v2_ligands': [clash_score(m) for m in rdmols], + # 'v2_pockets': [clash_score(p) for p in receptors], + # 'v2_between': [clash_score(m, p) for m, p in zip(rdmols, receptors)] + # } + # for k, v in clashes.items(): + # out[f'mean_clash_score_{k}'] = np.mean(v) + # out[f'frac_no_clashes_{k}'] = np.mean(np.array(v) <= 0.0) + + # return out + + def analyze_sample(self, rdmols, atom_types, bond_types, aa_types=None, receptors=None): + out = {} + + # Distribution of node types + kl_div_atom = self.ligand_atom_type_distribution.kl_divergence(atom_types) \ + if self.ligand_atom_type_distribution is not None else -1 + out['kl_div_atom_types'] = kl_div_atom + + # Distribution of edge types + kl_div_bond = self.ligand_bond_type_distribution.kl_divergence(bond_types) \ + if self.ligand_bond_type_distribution is not None else -1 + out['kl_div_bond_types'] = kl_div_bond + + if aa_types is not None: + kl_div_aa = self.pocket_type_distribution.kl_divergence(aa_types) \ + if self.pocket_type_distribution is not None else -1 + out['kl_div_residue_types'] = kl_div_aa + + # Evaluation + results = [] + if receptors is not None: + with tempfile.TemporaryDirectory() as tmpdir: + for mol, receptor in zip(tqdm(rdmols, desc='FullEvaluator'), receptors): + receptor_path = Path(tmpdir, 'receptor.pdb') + Chem.MolToPDBFile(receptor, str(receptor_path)) + results.append(self.evaluator(mol, receptor_path)) + else: + for mol in tqdm(rdmols, desc='FullEvaluator'): + self.evaluator = FullEvaluator(pb_conf='mol') + results.append(self.evaluator(mol)) + + results = pd.DataFrame(results) + agg_results = aggregated_metrics(results, self.evaluator.dtypes, VALIDITY_METRIC_NAME).fillna(0) + agg_results['metric'] = agg_results['metric'].str.replace('.', '/') + + col_results = collection_metrics(results, self.train_smiles, VALIDITY_METRIC_NAME, exclude_evaluators='fcd') + col_results['metric'] = 'collection/' + col_results['metric'] + + all_results = pd.concat([agg_results, col_results]) + out.update(**dict(all_results[['metric', 'value']].values)) + + return out + + def sample_zt_given_zs(self, zs_ligand, zs_pocket, s, t, delta_eps_x=None, uncertainty=None): + + sc_transform = self.get_sc_transform_fn(zs_pocket.get('chi'), zs_ligand['x'], s, None, zs_ligand['mask'], zs_pocket) + pred_ligand, pred_residues = self.dynamics( + zs_ligand['x'], zs_ligand['h'], zs_ligand['mask'], zs_pocket, s, bonds_ligand=(zs_ligand['bonds'], zs_ligand['e']), + sc_transform=sc_transform + ) + + if delta_eps_x is not None: + pred_ligand['vel'] = pred_ligand['vel'] + delta_eps_x + + zt_ligand = zs_ligand.copy() + zt_ligand['x'] = self.module_x.sample_zt_given_zs(zs_ligand['x'], pred_ligand['vel'], s, t, zs_ligand['mask']) + + zt_ligand['h'] = self.module_h.sample_zt_given_zs(zs_ligand['h'], pred_ligand['logits_h'], s, t, zs_ligand['mask']) + zt_ligand['e'] = self.module_e.sample_zt_given_zs(zs_ligand['e'], pred_ligand['logits_e'], s, t, zs_ligand['edge_mask']) + + zt_pocket = zs_pocket.copy() + if self.flexible_bb: + zt_trans_pocket = self.module_trans.sample_zt_given_zs(zs_pocket['x'], pred_residues['trans'], s, t, zs_pocket['mask']) + zt_rot_pocket = self.module_rot.sample_zt_given_zs(zs_pocket['axis_angle'], pred_residues['rot'], s, t, zs_pocket['mask']) + + # update pocket in-place + zt_pocket.set_frame(zt_trans_pocket, zt_rot_pocket) + + if self.flexible: + zt_chi_pocket = self.module_chi.sample_zt_given_zs(zs_pocket['chi'][..., :5], pred_residues['chi'], s, t, zs_pocket['mask']) + + # update pocket in-place + zt_pocket.set_chi(zt_chi_pocket) + + if self.predict_confidence: + assert uncertainty is not None + dt = (t - s).view(-1)[zt_ligand['mask']] + uncertainty['sigma_x_squared'] += (dt * pred_ligand['uncertainty_vel']**2) + uncertainty['entropy_h'] += (dt * Categorical(logits=pred_ligand['logits_h']).entropy()) + + return zt_ligand, zt_pocket + + def simulate(self, ligand, pocket, timesteps, t_start, t_end=1.0, + return_frames=1, guide_log_prob=None): + """ + Take a version of the ligand and pocket (at any time step t_start) and + simulate the generative process from t_start to t_end. + """ + + assert 0 < return_frames <= timesteps + assert timesteps % return_frames == 0 + assert 0.0 <= t_start < 1.0 + assert 0 < t_end <= 1.0 + assert t_start < t_end + + device = ligand['x'].device + n_samples = len(pocket['size']) + delta_t = (t_end - t_start) / timesteps + + # Initialize output tensors + out_ligand = { + 'x': torch.zeros((return_frames, len(ligand['mask']), self.x_dim), device=device), + 'h': torch.zeros((return_frames, len(ligand['mask']), self.atom_nf), device=device), + 'e': torch.zeros((return_frames, len(ligand['edge_mask']), self.bond_nf), device=device) + } + if self.predict_confidence: + out_ligand['sigma_x'] = torch.zeros((return_frames, len(ligand['mask'])), device=device) + out_ligand['entropy_h'] = torch.zeros((return_frames, len(ligand['mask'])), device=device) + out_pocket = { + 'x': torch.zeros((return_frames, len(pocket['mask']), 3), device=device), # CA-coord + 'v': torch.zeros((return_frames, len(pocket['mask']), self.n_atom_aa, 3), device=device) # difference vectors to all other atoms + } + + cumulative_uncertainty = { + 'sigma_x_squared': torch.zeros(len(ligand['mask']), device=device), + 'entropy_h': torch.zeros(len(ligand['mask']), device=device) + } if self.predict_confidence else None + + for i, t in enumerate(torch.linspace(t_start, t_end - delta_t, timesteps)): + t_array = torch.full((n_samples, 1), fill_value=t, device=device) + + if guide_log_prob is not None: + raise NotImplementedError('Not yet implemented for flow matching model') + alpha_t = self.diffusion_x.schedule.alpha(self.gamma_x(t_array)) + + with torch.enable_grad(): + zt_x_ligand.requires_grad = True + g = guide_log_prob(t_array, x=ligand['x'], h=ligand['h'], batch_mask=ligand['mask'], + bonds=ligand['bonds'], bond_types=ligand['e']) + + # Compute gradient w.r.t. coordinates + grad_x_lig = torch.autograd.grad(g.sum(), inputs=ligand['x'])[0] + + # clip gradients + g_max = 1.0 + clip_mask = (grad_x_lig.norm(dim=-1) > g_max) + grad_x_lig[clip_mask] = \ + grad_x_lig[clip_mask] / grad_x_lig[clip_mask].norm( + dim=-1, keepdim=True) * g_max + + delta_eps_lig = -1 * (1 - alpha_t[lig_mask]).sqrt() * grad_x_lig + else: + delta_eps_lig = None + + ligand, pocket = self.sample_zt_given_zs( + ligand, pocket, t_array, t_array + delta_t, delta_eps_lig, cumulative_uncertainty) + + # save frame + if (i + 1) % (timesteps // return_frames) == 0: + idx = (i + 1) // (timesteps // return_frames) + idx = idx - 1 + + out_ligand['x'][idx] = ligand['x'].detach() + out_ligand['h'][idx] = ligand['h'].detach() + out_ligand['e'][idx] = ligand['e'].detach() + if pocket['x'].numel() > 0: + out_pocket['x'][idx] = pocket['x'].detach() + out_pocket['v'][idx] = pocket['v'][:, :self.n_atom_aa, :].detach() + if self.predict_confidence: + out_ligand['sigma_x'][idx] = cumulative_uncertainty['sigma_x_squared'].sqrt().detach() + out_ligand['entropy_h'][idx] = cumulative_uncertainty['entropy_h'].detach() + + # remove frame dimension if only the final molecule is returned + out_ligand = {k: v.squeeze(0) for k, v in out_ligand.items()} + out_pocket = {k: v.squeeze(0) for k, v in out_pocket.items()} + + return out_ligand, out_pocket + + def init_ligand(self, num_nodes_lig, pocket): + device = pocket['x'].device + + n_samples = len(pocket['size']) + lig_mask = utils.num_nodes_to_batch_mask(n_samples, num_nodes_lig, device) + + # only consider upper triangular matrix for symmetry + lig_bonds = torch.stack(torch.where(torch.triu( + lig_mask[:, None] == lig_mask[None, :], diagonal=1)), dim=0) + lig_edge_mask = lig_mask[lig_bonds[0]] + + # Sample from Normal distribution in the pocket center + pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0) + z0_x = self.module_x.sample_z0(pocket_com, lig_mask) + z0_h = self.module_h.sample_z0(lig_mask) + z0_e = self.module_e.sample_z0(lig_edge_mask) + + return TensorDict(**{ + 'x': z0_x, 'h': z0_h, 'e': z0_e, 'mask': lig_mask, + 'bonds': lig_bonds, 'edge_mask': lig_edge_mask + }) + + def init_pocket(self, pocket): + + if self.flexible_bb: + pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0) + z0_trans = self.module_trans.sample_z0(pocket_com, pocket['mask']) + z0_rot = self.module_rot.sample_z0(pocket['mask']) + + # update pocket in-place + pocket.set_frame(z0_trans, z0_rot) + + if self.flexible: + z0_chi = self.module_chi.sample_z0(pocket['mask']) + + # # DEBUG ## + # z0_chi = torch.stack([data_utils.get_torsion_angles(r, device=self.device) for r in pocket['residues']], dim=0) + # #### + + # internal to external coordinates + pocket.set_chi(z0_chi) + + if pocket['x'].numel() == 0: + pocket.set_empty_v() + + return pocket + + def parse_num_nodes_spec(self, batch, spec=None, size_model=None): + + if spec == "2d_histogram" or spec is None: # default option + assert "pocket" in batch + num_nodes = self.size_distribution.sample_conditional( + n1=None, n2=batch['pocket']['size']) + + # make sure there is at least one potential bond + num_nodes[num_nodes < 2] = 2 + + elif isinstance(spec, (int, torch.Tensor)): + num_nodes = spec + + elif spec == "ground_truth": + assert "ligand" in batch + num_nodes = batch['ligand']['size'] + + elif spec == "nn_prediction": + assert size_model is not None + assert "pocket" in batch + predictions = size_model.forward(batch['pocket']) + predictions = torch.softmax(predictions, dim=-1) + predictions[:, :5] = 0.0 + probabilities = predictions / predictions.sum(dim=1, keepdims=True) + num_nodes = torch.distributions.Categorical(probabilities).sample() + + elif isinstance(spec, str) and spec.startswith("uniform"): + # expected format: uniform_low_high + assert "pocket" in batch + left, right = map(int, spec.split("_")[1:]) + shape = batch['pocket']['size'].shape + num_nodes = torch.randint(left, right + 1, shape, dtype=torch.long) + + else: + raise NotImplementedError(f"Invalid size specification {spec}") + + if self.virtual_nodes: + num_nodes += self.add_virtual_max + + return num_nodes + + @torch.no_grad() + def sample(self, data, n_samples, num_nodes=None, timesteps=None, + guide_log_prob=None, size_model=None, **kwargs): + + # TODO: move somewhere else (like collate_fn) + data['pocket'] = Residues(**data['pocket']) + + timesteps = self.T_sampling if timesteps is None else timesteps + + if len(data['pocket']['x']) > 0: + pocket = data_utils.repeat_items(data['pocket'], n_samples) + else: + pocket = Residues(**{key: value for key, value in data['pocket'].items()}) + pocket['name'] = pocket['name'] * n_samples + pocket['size'] = pocket['size'].repeat(n_samples) + pocket['n_bonds'] = pocket['n_bonds'].repeat(n_samples) + + _ligand = data_utils.repeat_items(data['ligand'], n_samples) + # _ligand = randomize_tensors(_ligand, exclude_keys=['size', 'name']) # avoid data leakage + + batch = {"ligand": _ligand, "pocket": pocket} + num_nodes = self.parse_num_nodes_spec(batch, spec=num_nodes, size_model=size_model) + + # Sample from prior + if pocket['x'].numel() > 0: + ligand = self.init_ligand(num_nodes, pocket) + else: + ligand = self.init_ligand(num_nodes, _ligand) + pocket = self.init_pocket(pocket) + + # return prior samples + if timesteps == 0: + # Convert into rdmols + rdmols = [build_molecule(coords=m['x'], + atom_types=m['h'].argmax(1), + bonds=m['bonds'], + bond_types=m['e'].argmax(1), + atom_decoder=self.atom_decoder, bond_decoder=self.bond_decoder) + for m in data_utils.split_entity(ligand.detach().cpu(), edge_types={"e", "edge_mask"}, edge_mask=ligand["edge_mask"])] + + rdpockets = pocket_to_rdkit(pocket, self.pocket_representation, + self.atom_encoder, self.atom_decoder, + self.aa_decoder, self.residue_decoder, + self.aa_atom_index) + + return rdmols, rdpockets, _ligand['name'] + + out_tensors_ligand, out_tensors_pocket = self.simulate( + ligand, pocket, timesteps, 0.0, 1.0, + guide_log_prob=guide_log_prob + ) + + # Build mol objects + x = out_tensors_ligand['x'].detach().cpu() + ligand_type = out_tensors_ligand['h'].argmax(1).detach().cpu() + edge_type = out_tensors_ligand['e'].argmax(1).detach().cpu() + lig_mask = ligand['mask'].detach().cpu() + lig_bonds = ligand['bonds'].detach().cpu() + lig_edge_mask = ligand['edge_mask'].detach().cpu() + sizes = torch.unique(ligand['mask'], return_counts=True)[1].tolist() + offsets = list(accumulate(sizes[:-1], initial=0)) + mol_kwargs = { + 'coords': utils.batch_to_list(x, lig_mask), + 'atom_types': utils.batch_to_list(ligand_type, lig_mask), + 'bonds': utils.batch_to_list_for_indices(lig_bonds, lig_edge_mask, offsets), + 'bond_types': utils.batch_to_list(edge_type, lig_edge_mask) + } + if self.predict_confidence: + sigma_x = out_tensors_ligand['sigma_x'].detach().cpu() + entropy_h = out_tensors_ligand['entropy_h'].detach().cpu() + mol_kwargs['atom_props'] = [ + {'sigma_x': x[0], 'entropy_h': x[1]} + for x in zip(utils.batch_to_list(sigma_x, lig_mask), + utils.batch_to_list(entropy_h, lig_mask)) + ] + mol_kwargs = [{k: v[i] for k, v in mol_kwargs.items()} + for i in range(len(mol_kwargs['coords']))] + + # Convert into rdmols + rdmols = [build_molecule( + **m, atom_decoder=self.atom_decoder, bond_decoder=self.bond_decoder) + for m in mol_kwargs + ] + + out_pocket = pocket.copy() + out_pocket['x'] = out_tensors_pocket['x'] + out_pocket['v'] = out_tensors_pocket['v'] + rdpockets = pocket_to_rdkit(out_pocket, self.pocket_representation, + self.atom_encoder, self.atom_decoder, + self.aa_decoder, self.residue_decoder, + self.aa_atom_index) + + return rdmols, rdpockets, _ligand['name'] + + @torch.no_grad() + def sample_chain(self, pocket, keep_frames, num_nodes=None, timesteps=None, + guide_log_prob=None, **kwargs): + + # TODO: move somewhere else (like collate_fn) + pocket = Residues(**pocket) + + info = {} + + timesteps = self.T_sampling if timesteps is None else timesteps + + # n_samples = 1 + # TODO: get batch_size differently + assert len(pocket['mask'].unique()) <= 1, "sample_chain only supports a single sample" + + # # Pocket's initial center of mass + # pocket_com_before = scatter_mean(pocket['x'], pocket['mask'], dim=0) + + num_nodes = self.parse_num_nodes_spec(batch={"pocket": pocket}, spec=num_nodes) + + # Sample from prior + if pocket['x'].numel() > 0: + ligand = self.init_ligand(num_nodes, pocket) + else: + dummy_pocket = Residues.empty(pocket['x'].device) + ligand = self.init_ligand(num_nodes, dummy_pocket) + + pocket = self.init_pocket(pocket) + + out_tensors_ligand, out_tensors_pocket = self.simulate( + ligand, pocket, timesteps, 0.0, 1.0, guide_log_prob=guide_log_prob, return_frames=keep_frames) + + # chain_lig = utils.reverse_tensor(chain_lig) + # chain_pocket = utils.reverse_tensor(chain_pocket) + # chain_bond = utils.reverse_tensor(chain_bond) + + info['traj_displacement_lig'] = torch.norm(out_tensors_ligand['x'][-1] - out_tensors_ligand['x'][0], dim=-1).mean() + info['traj_rms_lig'] = out_tensors_ligand['x'].std(dim=0).mean() + + # # Repeat last frame to see final sample better. + # chain_lig = torch.cat([chain_lig, chain_lig[-1:].repeat(10, 1, 1)], dim=0) + # chain_pocket = torch.cat([chain_pocket, chain_pocket[-1:].repeat(10, 1, 1)], dim=0) + # chain_bond = torch.cat([chain_bond, chain_bond[-1:].repeat(10, 1, 1)], dim=0) + + # Flatten + assert keep_frames == out_tensors_ligand['x'].size(0) == out_tensors_pocket['x'].size(0) + n_atoms = out_tensors_ligand['x'].size(1) + n_bonds = out_tensors_ligand['e'].size(1) + n_residues = out_tensors_pocket['x'].size(1) + device = out_tensors_ligand['x'].device + + def flatten_tensor(chain): + if len(chain.size()) == 3: # l=0 values + return chain.view(-1, chain.size(-1)) + elif len(chain.size()) == 4: # vectors + return chain.view(-1, chain.size(-2), chain.size(-1)) + else: + warnings.warn(f"Could not flatten frame dimension of tensor with shape {list(chain.size())}") + return chain + + out_tensors_ligand_flat = {k: flatten_tensor(chain) for k, chain in out_tensors_ligand.items()} + out_tensors_pocket_flat = {k: flatten_tensor(chain) for k, chain in out_tensors_pocket.items()} + # ligand_flat = chain_lig.view(-1, chain_lig.size(-1)) + # ligand_mask_flat = torch.arange(chain_lig.size(0)).repeat_interleave(chain_lig.size(1)).to(chain_lig.device) + ligand_mask_flat = torch.arange(keep_frames).repeat_interleave(n_atoms).to(device) + + # # pocket_flat = chain_pocket.view(-1, chain_pocket.size(-1)) + # # pocket_v_flat = pocket['v'].repeat(100, 1, 1) + # pocket_flat = chain_pocket.view(-1, chain_pocket.size(-2), chain_pocket.size(-1)) + # pocket_mask_flat = torch.arange(chain_pocket.size(0)).repeat_interleave(chain_pocket.size(1)).to(chain_pocket.device) + pocket_mask_flat = torch.arange(keep_frames).repeat_interleave(n_residues).to(device) + + # bond_flat = chain_bond.view(-1, chain_bond.size(-1)) + # bond_mask_flat = torch.arange(chain_bond.size(0)).repeat_interleave(chain_bond.size(1)).to(chain_bond.device) + bond_mask_flat = torch.arange(keep_frames).repeat_interleave(n_bonds).to(device) + edges_flat = ligand['bonds'].repeat(1, keep_frames) + + # # Move generated molecule back to the original pocket position + # pocket_com_after = scatter_mean(pocket_flat[:, 0, :], pocket_mask_flat, dim=0) + # ligand_flat[:, :self.x_dim] += (pocket_com_before - pocket_com_after)[ligand_mask_flat] + # + # # Move pocket back as well (for visualization purposes) + # pocket_flat[:, 0, :] += (pocket_com_before - pocket_com_after)[pocket_mask_flat] + + # Build ligands + x = out_tensors_ligand_flat['x'].detach().cpu() + ligand_type = out_tensors_ligand_flat['h'].argmax(1).detach().cpu() + ligand_mask_flat = ligand_mask_flat.detach().cpu() + bond_mask_flat = bond_mask_flat.detach().cpu() + edges_flat = edges_flat.detach().cpu() + edge_type = out_tensors_ligand_flat['e'].argmax(1).detach().cpu() + offsets = torch.zeros(keep_frames, dtype=int) # edges_flat is already zero-based + molecules = list( + zip(utils.batch_to_list(x, ligand_mask_flat), + utils.batch_to_list(ligand_type, ligand_mask_flat), + utils.batch_to_list_for_indices(edges_flat, bond_mask_flat, offsets), + utils.batch_to_list(edge_type, bond_mask_flat) + ) + ) + + # Convert into rdmols + ligand_chain = [build_molecule( + *graph, atom_decoder=self.atom_decoder, + bond_decoder=self.bond_decoder) for graph in molecules + ] + + # Build pockets + # as long as the pocket does not change during sampling, we can ust + # write it once + out_pocket = { + 'x': out_tensors_pocket_flat['x'], + 'one_hot': pocket['one_hot'].repeat(keep_frames, 1), + 'mask': pocket_mask_flat, + 'v': out_tensors_pocket_flat['v'], + 'atom_mask': pocket['atom_mask'].repeat(keep_frames, 1), + } if self.flexible else pocket + pocket_chain = pocket_to_rdkit(out_pocket, self.pocket_representation, + self.atom_encoder, self.atom_decoder, + self.aa_decoder, self.residue_decoder, + self.aa_atom_index) + + return ligand_chain, pocket_chain, info + + # def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm): + # def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm): + def configure_gradient_clipping(self, optimizer, *args, **kwargs): + + if not self.clip_grad: + return + + # Allow gradient norm to be 150% + 2 * stdev of the recent history. + max_grad_norm = 1.5 * self.gradnorm_queue.mean() + \ + 2 * self.gradnorm_queue.std() + + # hard upper limit + max_grad_norm = min(max_grad_norm, 10.0) + + # Get current grad_norm + params = [p for g in optimizer.param_groups for p in g['params']] + grad_norm = utils.get_grad_norm(params) + + # Lightning will handle the gradient clipping + self.clip_gradients(optimizer, gradient_clip_val=max_grad_norm, + gradient_clip_algorithm='norm') + + if float(grad_norm) > max_grad_norm: + print(f'Clipped gradient with value {grad_norm:.1f} ' + f'while allowed {max_grad_norm:.1f}') + grad_norm = max_grad_norm + + self.gradnorm_queue.add(float(grad_norm)) \ No newline at end of file diff --git a/src/model/loss_utils.py b/src/model/loss_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0c8d284fc604c33f6d8b70cd09808c7c4462add6 --- /dev/null +++ b/src/model/loss_utils.py @@ -0,0 +1,79 @@ +import torch +from torch_scatter import scatter_add, scatter_mean + +from src.constants import atom_decoder, vdw_radii +_vdw_radii = {**vdw_radii} +_vdw_radii['NH'] = vdw_radii['N'] +_vdw_radii['N+'] = vdw_radii['N'] +_vdw_radii['O-'] = vdw_radii['O'] +_vdw_radii['NOATOM'] = 0 +vdw_radii_array = torch.tensor([_vdw_radii[a] for a in atom_decoder]) + + +def clash_loss(ligand_coord, ligand_types, ligand_mask, pocket_coord, + pocket_types, pocket_mask): + """ + Computes a clash loss that penalizes interatomic distances smaller than the + sum of van der Waals radii between atoms. + """ + + ligand_radii = vdw_radii_array[ligand_types].to(ligand_coord.device) + pocket_radii = vdw_radii_array[pocket_types].to(pocket_coord.device) + + dist = torch.sqrt(torch.sum((ligand_coord[:, None, :] - pocket_coord[None, :, :]) ** 2, dim=-1)) + # dist[ligand_mask[:, None] != pocket_mask[None, :]] = float('inf') + + # compute linearly decreasing penalty + # penalty = max(1 - 1/sum_vdw * d, 0) + sum_vdw = ligand_radii[:, None] + pocket_radii[None, :] + loss = torch.clamp(1 - dist / sum_vdw, min=0.0) # (n_ligand, n_pocket) + + loss = scatter_add(loss, pocket_mask, dim=1) + loss = scatter_mean(loss, ligand_mask, dim=0) + loss = loss.diag() + + # # DEBUG (non-differentiable version) + # dist = torch.sqrt(torch.sum((ligand_coord[:, None, :] - pocket_coord[None, :, :]) ** 2, dim=-1)) + # dist[ligand_mask[:, None] != pocket_mask[None, :]] = float('inf') + # _loss = torch.clamp(1 - dist / sum_vdw, min=0.0) # (n_ligand, n_pocket) + # _loss = _loss.sum(dim=-1) + # _loss = scatter_mean(_loss, ligand_mask, dim=0) + # assert torch.allclose(loss, _loss) + + return loss + + +class TimestepSampler: + def __init__(self, type='uniform', lowest_t=1, highest_t=500): + assert type in {'uniform', 'sigmoid'} + self.type = type + self.lowest_t = lowest_t + self.highest_t = highest_t + + def __call__(self, n, device=None): + if self.type == 'uniform': + t_int = torch.randint(self.lowest_t, self.highest_t + 1, + size=(n, 1), device=device) + + elif self.type == 'sigmoid': + weight_fun = lambda t: 1.45 * torch.sigmoid(-t * 10 / self.highest_t + 5) + 0.05 + + possible_ts = torch.arange(self.lowest_t, self.highest_t + 1, device=device) + weights = weight_fun(possible_ts) + weights = weights / weights.sum() + t_int = possible_ts[torch.multinomial(weights, n, replacement=True)].unsqueeze(-1) + + return t_int.float() + + +class TimestepWeights: + def __init__(self, weight_type, a, b): + if weight_type != 'sigmoid': + raise NotImplementedError("Only sigmoidal loss weighting is available.") + # self.weight_fn = lambda t: a * torch.sigmoid((-t + 0.5) * b) + (1 - a / 2) + self.weight_fn = lambda t: a * torch.sigmoid((t - 0.5) * b) + (1 - a / 2) + + def __call__(self, t_array): + # normalized t \in [0, 1] + # return self.weight_fn(1 - t_array) + return self.weight_fn(t_array) diff --git a/src/model/markov_bridge.py b/src/model/markov_bridge.py new file mode 100644 index 0000000000000000000000000000000000000000..6685889e4f396de2b72a24d5ca3543feedab05ea --- /dev/null +++ b/src/model/markov_bridge.py @@ -0,0 +1,163 @@ +from functools import reduce +import torch +import torch.nn.functional as F +from torch_scatter import scatter_mean, scatter_add + +from src.utils import bvm + + +class LinearSchedule: + """ + We use the scheduling parameter \beta to linearly remove noise, i.e. + \bar{\beta}_t = 1 - h (h: step size) with + \bar{Q}_t = \bar{\beta}_t I + (1 - \bar{\beta}_t) 1_vec z1^T + + From this, it follows that for each step transition matrix, we have + \beta_t = \bar{\beta}_t / \bar{\beta}_{t-h} = \frac{1-t}{1-t+h} + """ + def __init__(self): + super().__init__() + + def beta_bar(self, t): + return 1 - t + + def beta(self, t, step_size): + return (1 - t) / (1 - t + step_size) + + +class UniformPriorMarkovBridge: + """ + Markov bridge model in which z0 is drawn from a uniform prior. + Transitions are defined as: + Q_t = \beta_t I + (1 - \beta_t) 1_vec z1^T + where z1 is a one-hot representation of the final state. + We follow the notation from [1] and multiply transition matrices from the + right to one-hot state vectors. + + We use the scheduling parameter \beta to linearly remove noise, i.e. + \bar{\beta}_t = 1 - h (h: step size) with + \bar{Q}_t = \bar{\beta}_t I + (1 - \bar{\beta}_t) 1_vec z1^T + + From this, it follows that for each step transition matrix, we have + \beta_t = \bar{\beta}_t / \bar{\beta}_{t-h} = \frac{1-t}{1-t+h} + + [1] Austin, Jacob, et al. + "Structured denoising diffusion models in discrete state-spaces." + Advances in Neural Information Processing Systems 34 (2021): 17981-17993. + """ + def __init__(self, dim, loss_type='CE', step_size=None): + assert loss_type in ['VLB', 'CE'] + self.dim = dim + self.step_size = step_size # required for VLB + self.schedule = LinearSchedule() + self.loss_type = loss_type + super(UniformPriorMarkovBridge, self).__init__() + + @staticmethod + def sample_categorical(p): + """ + Sample from categorical distribution defined by probabilities 'p' + :param p: (n, dim) + :return: one-hot encoded samples (n, dim) + """ + sampled = torch.multinomial(p, 1).squeeze(-1) + return F.one_hot(sampled, num_classes=p.size(1)).float() + + def p_z0(self, batch_mask): + return torch.ones((len(batch_mask), self.dim), device=batch_mask.device) / self.dim + + def sample_z0(self, batch_mask): + """ Prior. """ + z0 = self.sample_categorical(self.p_z0(batch_mask)) + return z0 + + def p_zt(self, z0, z1, t, batch_mask): + Qt_bar = self.get_Qt_bar(t, z1, batch_mask) + return bvm(z0, Qt_bar) + + def sample_zt(self, z0, z1, t, batch_mask): + zt = self.sample_categorical(self.p_zt(z0, z1, t, batch_mask)) + return zt + + def p_zt_given_zs_and_z1(self, zs, z1, s, t, batch_mask): + # 'z1' are one-hot "probabilities" for each class + Qt = self.get_Qt(t, s, z1, batch_mask) + # from pdb import set_trace; set_trace() + q_zs_given_zt = bvm(zs, Qt) + return q_zs_given_zt + + def p_zt_given_zs(self, zs, p_z1_hat, s, t, batch_mask): + """ + Note that x can also represent a categorical distribution to compute + transitions more efficiently at sampling time: + p(z_t|z_s) = \sum_{\hat{z}_1} p(z_t | z_s, \hat{z}_1) * p(\hat{z}_1 | z_s) + = \sum_i z_s (\beta_t I + (1 - \beta_t) 1_vec z1_i^T) * \hat{p}_i + = \beta_t z_s I + (1 - \beta_t) z_s 1_vec \hat{p}^t + """ + return self.p_zt_given_zs_and_z1(zs, p_z1_hat, s, t, batch_mask) + + def sample_zt_given_zs(self, zs, z1_logits, s, t, batch_mask): + p_z1 = z1_logits.softmax(dim=-1) + zt = self.sample_categorical(self.p_zt_given_zs(zs, p_z1, s, t, batch_mask)) + return zt + + def compute_loss(self, pred_logits, zs, z1, batch_mask, s, t, reduce='mean'): + """ Compute loss per sample. """ + assert reduce in {'mean', 'sum', 'none'} + + if self.loss_type == 'CE': + loss = F.cross_entropy(pred_logits, z1, reduction='none') + + else: # VLB + true_p_zs = self.p_zt_given_zs_and_z1(zs, z1, s, t, batch_mask) + pred_p_zs = self.p_zt_given_zs(zs, pred_logits.softmax(dim=-1), s, t, batch_mask) + loss = F.kl_div(pred_p_zs.log(), true_p_zs, reduction='none').sum(dim=-1) + + if reduce == 'mean': + loss = scatter_mean(loss, batch_mask, dim=0) + elif reduce == 'sum': + loss = scatter_add(loss, batch_mask, dim=0) + + return loss + + def get_Qt(self, t, s, z1, batch_mask): + """ Returns one-step transition matrix from step s to step t. """ + + beta_t_given_s = self.schedule.beta(t, t - s) + beta_t_given_s = beta_t_given_s.unsqueeze(-1)[batch_mask] + + # Q_t = beta_t * I + (1 - beta_t) * ones (dot) z1^T + Qt = beta_t_given_s * torch.eye(self.dim, device=t.device).unsqueeze(0) + \ + (1 - beta_t_given_s) * z1.unsqueeze(1) + # (1 - beta_t_given_s) * (torch.ones(self.dim, 1, device=t.device) @ z1) + + # assert (Qt.sum(-1) == 1).all() + + return Qt + + def get_Qt_bar(self, t, z1, batch_mask): + """ Returns transition matrix from step 0 to step t. """ + + beta_bar_t = self.schedule.beta_bar(t) + beta_bar_t = beta_bar_t.unsqueeze(-1)[batch_mask] + + # Q_t_bar = beta_bar * I + (1 - beta_bar) * ones (dot) z1^T + Qt_bar = beta_bar_t * torch.eye(self.dim, device=t.device).unsqueeze(0) + \ + (1 - beta_bar_t) * z1.unsqueeze(1) + # (1 - beta_bar_t) * (torch.ones(self.dim, 1, device=t.device) @ z1) + + # assert (Qt_bar.sum(-1) == 1).all() + + return Qt_bar + + +class MarginalPriorMarkovBridge(UniformPriorMarkovBridge): + def __init__(self, dim, prior_p, loss_type='CE', step_size=None): + self.prior_p = prior_p + print('Marginal Prior MB') + super(MarginalPriorMarkovBridge, self).__init__(dim, loss_type, step_size) + + def p_z0(self, batch_mask): + device = batch_mask.device + p = torch.ones((len(batch_mask), self.dim), device=device) * self.prior_p.view(1, -1).to(device) + return p diff --git a/src/sample_and_evaluate.py b/src/sample_and_evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..a474a06ab4e2f06c3e39d60f63e17e6f5d0c3db5 --- /dev/null +++ b/src/sample_and_evaluate.py @@ -0,0 +1,164 @@ +import argparse +import sys +import yaml +import torch +import numpy as np +import pickle +from argparse import Namespace + +from pathlib import Path + +basedir = Path(__file__).resolve().parent.parent +sys.path.append(str(basedir)) + +from src import utils +from src.utils import dict_to_namespace, namespace_to_dict +from src.analysis.visualization_utils import mols_to_pdbfile, mol_as_pdb +from src.data.data_utils import TensorDict, Residues +from src.data.postprocessing import process_all +from src.model.lightning import DrugFlow +from src.sbdd_metrics.evaluation import compute_all_metrics_drugflow + +from tqdm import tqdm +from pdb import set_trace + + +def combine(base_args, override_args): + assert not isinstance(base_args, dict) + assert not isinstance(override_args, dict) + + arg_dict = base_args.__dict__ + for key, value in override_args.__dict__.items(): + if key not in arg_dict or arg_dict[key] is None: # parameter not provided previously + print(f"Add parameter {key}: {value}") + arg_dict[key] = value + elif isinstance(value, Namespace): + arg_dict[key] = combine(arg_dict[key], value) + else: + print(f"Replace parameter {key}: {arg_dict[key]} -> {value}") + arg_dict[key] = value + return base_args + + +def path_to_str(input_dict): + for key, value in input_dict.items(): + if isinstance(value, dict): + input_dict[key] = path_to_str(value) + else: + input_dict[key] = str(value) if isinstance(value, Path) else value + return input_dict + + +def sample(cfg, model_params, samples_dir, job_id=0, n_jobs=1): + print('Sampling...') + model = DrugFlow.load_from_checkpoint(cfg.checkpoint, map_location=cfg.device, strict=False, + **model_params) + model.setup(stage='fit' if cfg.set == 'train' else cfg.set) + model.eval().to(cfg.device) + + dataloader = getattr(model, f'{cfg.set}_dataloader')() + print(f'Real batch size is {dataloader.batch_size * cfg.n_samples}') + + name2count = {} + for i, data in enumerate(tqdm(dataloader)): + if i % n_jobs != job_id: + print(f'Skipping batch {i}') + continue + + new_data = { + 'ligand': TensorDict(**data['ligand']).to(cfg.device), + 'pocket': Residues(**data['pocket']).to(cfg.device), + } + try: + rdmols, rdpockets, names = model.sample( + data=new_data, + n_samples=cfg.n_samples, + num_nodes=("ground_truth" if cfg.sample_with_ground_truth_size else None) + ) + except Exception as e: + if cfg.set == 'train': + names = data['ligand']['name'] + print(f'Failed to sample for {names}: {e}') + continue + else: + raise e + + for mol, pocket, name in zip(rdmols, rdpockets, names): + name = name.replace('.sdf', '') + idx = name2count.setdefault(name, 0) + output_dir = Path(samples_dir, name) + output_dir.mkdir(parents=True, exist_ok=True) + if cfg.postprocess: + mol = process_all(mol, largest_frag=True, adjust_aromatic_Ns=True, relax_iter=0) + + for prop in mol.GetAtoms()[0].GetPropsAsDict().keys(): + # compute avg uncertainty + mol.SetDoubleProp(prop, np.mean([a.GetDoubleProp(prop) for a in mol.GetAtoms()])) + + # visualise local differences + out_pdb_path = Path(output_dir, f'{idx}_ligand_{prop}.pdb') + mol_as_pdb(mol, out_pdb_path, bfactor=prop) + + out_sdf_path = Path(output_dir, f'{idx}_ligand.sdf') + out_pdb_path = Path(output_dir, f'{idx}_pocket.pdb') + utils.write_sdf_file(out_sdf_path, [mol]) + mols_to_pdbfile([pocket], out_pdb_path) + + name2count[name] += 1 + + +def evaluate(cfg, model_params, samples_dir): + print('Evaluation...') + data, table_detailed, table_aggregated = compute_all_metrics_drugflow( + in_dir=samples_dir, + gnina_path=model_params['train_params'].gnina, + reduce_path=cfg.reduce, + reference_smiles_path=Path(model_params['train_params'].datadir, 'train_smiles.npy'), + n_samples=cfg.n_samples, + exclude_evaluators=[] if cfg.exclude_evaluators is None else cfg.exclude_evaluators, + ) + with open(Path(samples_dir, 'metrics_data.pkl'), 'wb') as f: + pickle.dump(data, f) + table_detailed.to_csv(Path(samples_dir, 'metrics_detailed.csv'), index=False) + table_aggregated.to_csv(Path(samples_dir, 'metrics_aggregated.csv'), index=False) + + +if __name__ == "__main__": + p = argparse.ArgumentParser() + p.add_argument('--config', type=str) + p.add_argument('--job_id', type=int, default=0, help='Job ID') + p.add_argument('--n_jobs', type=int, default=1, help='Number of jobs') + args = p.parse_args() + + with open(args.config, 'r') as f: + cfg = yaml.safe_load(f) + cfg = dict_to_namespace(cfg) + + utils.set_deterministic(seed=cfg.seed) + utils.disable_rdkit_logging() + + model_params = torch.load(cfg.checkpoint, map_location=cfg.device)['hyper_parameters'] + if 'model_args' in cfg: + ckpt_args = dict_to_namespace(model_params) + model_params = combine(ckpt_args, cfg.model_args).__dict__ + + ckpt_path = Path(cfg.checkpoint) + ckpt_name = ckpt_path.parts[-1].split('.')[0] + n_steps = model_params['simulation_params'].n_steps + samples_dir = Path(cfg.sample_outdir, cfg.set, f'{ckpt_name}_T={n_steps}') or \ + Path(ckpt_path.parent.parent, 'samples', cfg.set, f'{ckpt_name}_T={n_steps}') + assert cfg.set in {'val', 'test', 'train'} + samples_dir.mkdir(parents=True, exist_ok=True) + + # save configs + with open(Path(samples_dir, 'model_params.yaml'), 'w') as f: + yaml.dump(path_to_str(namespace_to_dict(model_params)), f) + with open(Path(samples_dir, 'sampling_params.yaml'), 'w') as f: + yaml.dump(path_to_str(namespace_to_dict(cfg)), f) + + if cfg.sample: + sample(cfg, model_params, samples_dir, job_id=args.job_id, n_jobs=args.n_jobs) + + if cfg.evaluate: + assert args.job_id == 0 and args.n_jobs == 1, 'Evaluation is not parallelised on GPU machines' + evaluate(cfg, model_params, samples_dir) \ No newline at end of file diff --git a/src/sbdd_metrics/evaluation.py b/src/sbdd_metrics/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..5b80f835063f4a63fc10d673ec8227c7a48462b6 --- /dev/null +++ b/src/sbdd_metrics/evaluation.py @@ -0,0 +1,239 @@ +import os +import sys +import re + +from pathlib import Path +from typing import Collection, List, Dict, Type + +import numpy as np +import pandas as pd +from tqdm import tqdm + +from .metrics import FullEvaluator, FullCollectionEvaluator + +AUXILIARY_COLUMNS = ['sample', 'sdf_file', 'pdb_file', 'subdir'] +VALIDITY_METRIC_NAME = 'medchem.valid' + + +def get_data_type(key: str, data_types: Dict[str, Type], default=float) -> Type: + found_data_type_key = None + found_data_type_value = None + for data_type_key, data_type_value in data_types.items(): + if re.match(data_type_key, key) is not None: + if found_data_type_key is not None: + raise ValueError(f'Multiple data type keys match [{key}]: {found_data_type_key}, {data_type_key}') + + found_data_type_value = data_type_value + found_data_type_key = data_type_key + + if found_data_type_key is None: + if default is None: + raise KeyError(key) + else: + found_data_type_value = default + + return found_data_type_value + + +def convert_data_to_table(data: List[Dict], data_types: Dict[str, Type]) -> pd.DataFrame: + """ + Converts data from `evaluate_drugflow` to a detailed table + """ + table = [] + for entry in data: + table_entry = {} + for key, value in entry.items(): + if key in AUXILIARY_COLUMNS: + table_entry[key] = value + continue + if get_data_type(key, data_types) != list: + table_entry[key] = value + table.append(table_entry) + + return pd.DataFrame(table) + +def aggregated_metrics(table: pd.DataFrame, data_types: Dict[str, Type], validity_metric_name: str = None): + """ + Args: + table (pd.DataFrame): table with metrics computed for each sample + data_types (Dict[str, Type]): dictionary with data types for each column + validity_metric_name (str): name of the column that has validity metric + + Returns: + agg_table (pd.DataFrame): table with columns ['metric', 'value', 'std'] + """ + aggregated_results = [] + + # If validity column name is provided: + # 1. compute validity on the entire data + # 2. drop all invalid molecules to compute the rest + if validity_metric_name is not None: + aggregated_results.append({ + 'metric': validity_metric_name, + 'value': table[validity_metric_name].fillna(False).astype(float).mean(), + 'std': None, + }) + table = table[table[validity_metric_name]] + + # Compute aggregated metrics + standard deviations where applicable + for column in table.columns: + if column in AUXILIARY_COLUMNS + [validity_metric_name] or get_data_type(column, data_types) == str: + continue + with pd.option_context("future.no_silent_downcasting", True): + if get_data_type(column, data_types) == bool: + values = table[column].fillna(0).values.astype(float).mean() + std = None + else: + values = table[column].dropna().values.astype(float).mean() + std = table[column].dropna().values.astype(float).std() + + aggregated_results.append({ + 'metric': column, + 'value': values, + 'std': std, + }) + + agg_table = pd.DataFrame(aggregated_results) + return agg_table + + +def collection_metrics( + table: pd.DataFrame, + reference_smiles: Collection[str], + validity_metric_name: str = None, + exclude_evaluators: Collection[str] = [], +): + """ + Args: + table (pd.DataFrame): table with metrics computed for each sample + reference_smiles (Collection[str]): list of reference SMILES (e.g. training set) + validity_metric_name (str): name of the column that has validity metric + exclude_evaluators (Collection[str]): Evaluator IDs to exclude + + Returns: + col_table (pd.DataFrame): table with columns ['metric', 'value'] + """ + + # If validity column name is provided drop all invalid molecules + if validity_metric_name is not None: + table = table[table[validity_metric_name]] + + evaluator = FullCollectionEvaluator(reference_smiles, exclude_evaluators=exclude_evaluators) + smiles = table['representation.smiles'].values + if len(smiles) == 0: + print('No valid input molecules') + return pd.DataFrame(columns=['metric', 'value']) + + collection_metrics = evaluator(smiles) + results = [ + {'metric': key, 'value': value} + for key, value in collection_metrics.items() + ] + + col_table = pd.DataFrame(results) + return col_table + + +def evaluate_drugflow_subdir( + in_dir: Path, + evaluator: FullEvaluator, + desc: str = None, + n_samples: int = None, +) -> List[Dict]: + """ + Computes per-molecule metrics for a single directory of samples for one target + """ + results = [] + valid_files = [ + int(fname.split('_')[0]) + for fname in os.listdir(in_dir) + if fname.endswith('_ligand.sdf') and not fname.startswith('.') + ] + if len(valid_files) == 0: + return pd.DataFrame() + + upper_bound = max(valid_files) + 1 + if n_samples is not None: + upper_bound = min(upper_bound, n_samples) + + for i in tqdm(range(upper_bound), desc=desc, file=sys.stdout): + in_mol = Path(in_dir, f'{i}_ligand.sdf') + in_prot = Path(in_dir, f'{i}_pocket.pdb') + res = evaluator(in_mol, in_prot) + + res['sample'] = i + res['sdf_file'] = str(in_mol) + res['pdb_file'] = str(in_prot) + results.append(res) + + return results + + +def evaluate_drugflow( + in_dir: Path, + evaluator: FullEvaluator, + n_samples: int = None, + job_id: int = 0, + n_jobs: int = 1, +) -> List[Dict]: + """ + 1. Computes per-molecule metrics for all single directories of samples + 2. Aggregates these metrics + 3. Computes additional collection metrics (if `reference_smiles_path` is provided) + """ + data = [] + total_number_of_subdirs = len([path for path in in_dir.glob("[!.]*") if os.path.isdir(path)]) + i = 0 + for subdir in in_dir.glob("[!.]*"): + if not os.path.isdir(subdir): + continue + + i += 1 + if (i - 1) % n_jobs != job_id: + continue + + curr_data = evaluate_drugflow_subdir( + in_dir=subdir, + evaluator=evaluator, + desc=f'[{i}/{total_number_of_subdirs}] {str(subdir.name)}', + n_samples=n_samples, + ) + for entry in curr_data: + entry['subdir'] = str(subdir) + data.append(entry) + + return data + + +def compute_all_metrics_drugflow( + in_dir: Path, + gnina_path: Path, + reduce_path: Path = None, + reference_smiles_path: Path = None, + n_samples: int = None, + validity_metric_name: str = VALIDITY_METRIC_NAME, + exclude_evaluators: Collection[str] = [], + job_id: int = 0, + n_jobs: int = 1, +): + evaluator = FullEvaluator(gnina=gnina_path, reduce=reduce_path, exclude_evaluators=exclude_evaluators) + data = evaluate_drugflow(in_dir=in_dir, evaluator=evaluator, n_samples=n_samples, job_id=job_id, n_jobs=n_jobs) + table_detailed = convert_data_to_table(data, evaluator.dtypes) + table_aggregated = aggregated_metrics( + table_detailed, + data_types=evaluator.dtypes, + validity_metric_name=validity_metric_name + ) + + # Add collection metrics (uniqueness, novelty, FCD, etc.) if reference smiles are provided + if reference_smiles_path is not None: + reference_smiles = np.load(reference_smiles_path) + col_metrics = collection_metrics( + table=table_detailed, + reference_smiles=reference_smiles, + validity_metric_name=validity_metric_name, + exclude_evaluators=exclude_evaluators + ) + table_aggregated = pd.concat([table_aggregated, col_metrics]) + + return data, table_detailed, table_aggregated diff --git a/src/sbdd_metrics/fpscores.pkl.gz b/src/sbdd_metrics/fpscores.pkl.gz new file mode 100644 index 0000000000000000000000000000000000000000..aa6f88c9c3fa56161b7df08e74ea6824f3071d08 --- /dev/null +++ b/src/sbdd_metrics/fpscores.pkl.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:10dcef9340c873e7b987924461b0af5365eb8dd96be607203debe8ddf80c1e73 +size 3848394 diff --git a/src/sbdd_metrics/interactions.py b/src/sbdd_metrics/interactions.py new file mode 100644 index 0000000000000000000000000000000000000000..12d792099629a38ba1d459d350e0d6a416ad8277 --- /dev/null +++ b/src/sbdd_metrics/interactions.py @@ -0,0 +1,231 @@ +import prody +import prolif as plf +import pandas as pd +import subprocess + +from io import StringIO +from prolif.fingerprint import Fingerprint +from prolif.plotting.complex3d import Complex3D +from prolif.residue import ResidueId +from prolif.ifp import IFP +from rdkit import Chem +from tqdm import tqdm + + +prody.confProDy(verbosity='none') + + +INTERACTION_LIST = [ + 'Anionic', 'Cationic', # Salt Bridges ~400 kJ/mol + 'HBAcceptor', 'HBDonor', # Hydrogen bonds ~10 kJ/mol + 'XBAcceptor', 'XBDonor', # Halogen bonds ~5-30 kJ/mol + 'CationPi', 'PiCation', # 5-10 kJ/mol + 'PiStacking', # ~2-10 kJ/mol + 'Hydrophobic', # 1-10 kJ/mol +] + +INTERACTION_ALIASES = { + 'Anionic': 'SaltBridge', + 'Cationic': 'SaltBridge', + 'HBAcceptor': 'HBAcceptor', + 'HBDonor': 'HBDonor', + 'XBAcceptor': 'HalogenBond', + 'XBDonor': 'HalogenBond', + 'CationPi': 'CationPi', + 'PiCation': 'PiCation', + 'PiStacking': 'PiStacking', + 'Hydrophobic': 'Hydrophobic', +} + +INTERACTION_COLORS = { + 'SaltBridge': '#eba823', + 'HBDonor': '#3d5dfc', + 'HBAcceptor': '#3d5dfc', + 'HalogenBond': '#53f514', + 'CationPi': '#ff0000', + 'PiCation': '#ff0000', + 'PiStacking': '#e359d8', + 'Hydrophobic': '#c9c5c5', +} + +INTERACTION_IMPORTANCE = ['SaltBridge', 'HydrogenBond', 'HBAcceptor', 'HBDonor', 'CationPi', 'PiCation', 'PiStacking', 'Hydrophobic'] + +REDUCE_EXEC = './reduce' + +def remove_residue_by_atomic_number(structure, resnum, chain_id, icode): + exclude_selection = f'not (chain {chain_id} and resnum {resnum} and icode {icode})' + structure = structure.select(exclude_selection) + return structure + + +def read_protein(protein_path, verbose=False, reduce_exec=REDUCE_EXEC): + structure = prody.parsePDB(protein_path).select('protein') + hydrogens = structure.select('hydrogen') + if hydrogens is None or len(hydrogens) < len(set(structure.getResnums())): + if verbose: + print('Target structure is not protonated. Adding hydrogens...') + + reduce_cmd = f'{str(reduce_exec)} {protein_path}' + reduce_result = subprocess.run(reduce_cmd, shell=True, capture_output=True, text=True) + if reduce_result.returncode != 0: + raise RuntimeError('Error during reduce execution:', reduce_result.stderr) + + pdb_content = reduce_result.stdout + stream = StringIO() + stream.write(pdb_content) + stream.seek(0) + structure = prody.parsePDBStream(stream).select('protein') + + # Select only one (largest) altloc + altlocs = set(structure.getAltlocs()) + try: + best_altloc = max(altlocs, key=lambda a: structure.select(f'altloc "{a}"').numAtoms()) + structure = structure.select(f'altloc "{best_altloc}"') + except TypeError: + # Strange thing that happens only once in the beginning sometimes... + best_altloc = max(altlocs, key=lambda a: structure.select(f'altloc "{a}"').numAtoms()) + structure = structure.select(f'altloc "{best_altloc}"') + + return prepare_protein(structure, to_exclude=[], verbose=verbose) + + +def prepare_protein(input_structure, to_exclude=[], verbose=False): + structure = input_structure.copy() + + # Remove residues with bad atoms + if verbose and len(to_exclude) > 0: + print(f'Removing {len(to_exclude)} residues...') + for resnum, chain_id, icode in to_exclude: + exclude_selection = f'not (chain {chain_id} and resnum {resnum})' + structure = structure.select(exclude_selection) + + # Write new PDB content to the stream + stream = StringIO() + prody.writePDBStream(stream, structure) + stream.seek(0) + + # Sanitize + rdprot = Chem.MolFromPDBBlock(stream.read(), sanitize=False, removeHs=False) + try: + Chem.SanitizeMol(rdprot) + plfprot = plf.Molecule(rdprot) + return plfprot + + except Chem.AtomValenceException as e: + atom_num = int(e.args[0].replace('Explicit valence for atom # ', '').split()[0]) + info = rdprot.GetAtomWithIdx(atom_num).GetPDBResidueInfo() + resnum = info.GetResidueNumber() + chain_id = info.GetChainId() + icode = f'"{info.GetInsertionCode()}"' + + to_exclude_next = to_exclude + [(resnum, chain_id, icode)] + if verbose: + print(f'[{len(to_exclude_next)}] Removing broken residue with atom={atom_num}, resnum={resnum}, chain_id={chain_id}, icode={icode}') + return prepare_protein(input_structure, to_exclude=to_exclude_next) + + +def prepare_ligand(mol): + Chem.SanitizeMol(mol) + mol = Chem.AddHs(mol, addCoords=True) + ligand_plf = plf.Molecule.from_rdkit(mol) + return ligand_plf + + +def sdf_reader(sdf_path, proress_bar=False): + supp = Chem.SDMolSupplier(sdf_path, removeHs=True, sanitize=False) + for mol in tqdm(supp) if progress_bar else supp: + yield prepare_ligand(mol) + + +def profile_detailed( + ligand_plf, protein_plf, interaction_list=INTERACTION_LIST, ligand_name='ligand', protein_name='protein' + ): + + fp = Fingerprint(interactions=interaction_list) + fp.run_from_iterable(lig_iterable=[ligand_plf], prot_mol=protein_plf, progress=False) + + profile = [] + + for ligand_residue in ligand_plf.residues: + for protein_residue in protein_plf.residues: + metadata = fp.metadata(ligand_plf[ligand_residue], protein_plf[protein_residue]) + for int_name, int_metadata in metadata.items(): + for int_instance in int_metadata: + profile.append({ + 'ligand': ligand_name, + 'protein': protein_name, + 'ligand_residue': str(ligand_residue), + 'protein_residue': str(protein_residue), + 'interaction': int_name, + 'alias': INTERACTION_ALIASES[int_name], + 'ligand_atoms': ','.join(map(str, int_instance['indices']['ligand'])), + 'protein_atoms': ','.join(map(str, int_instance['indices']['protein'])), + 'ligand_orig_atoms': ','.join(map(str, int_instance['parent_indices']['ligand'])), + 'protein_orig_atoms': ','.join(map(str, int_instance['parent_indices']['protein'])), + 'distance': int_instance['distance'], + 'plane_angle': int_instance.get('plane_angle', None), + 'normal_to_centroid_angle': int_instance.get('normal_to_centroid_angle', None), + 'intersect_distance': int_instance.get('intersect_distance', None), + 'intersect_radius': int_instance.get('intersect_radius', None), + 'pi_ring': int_instance.get('pi_ring', None), + }) + + return pd.DataFrame(profile) + + +def map_orig_atoms_to_new(atoms, mol): + orig2new = dict() + for atom in mol.GetAtoms(): + orig2new[atom.GetUnsignedProp("mapindex")] = atom.GetIdx() + + atoms = list(map(int, atoms.split(','))) + new_atoms = ','.join(map(str, [orig2new[atom] for atom in atoms])) + return new_atoms + + +def visualize(profile, ligand_plf, protein_plf): + metadata = dict() + + for _, row in profile.iterrows(): + if 'ligand_atoms' not in row: + row['ligand_atoms'] = map_orig_atoms_to_new(row['ligand_orig_atoms'], ligand_plf) + if 'protein_atoms' not in row: + row['protein_atoms'] = map_orig_atoms_to_new(row['protein_orig_atoms'], protein_plf[row['residue']]) + + namenum, chain = row['residue'].split('.') + name = namenum[:3] + num = int(namenum[3:]) + protres = ResidueId(name=name, number=num, chain=chain) + key = (ResidueId(name='UNL', number=1, chain=None), protres) + + metadata.setdefault(key, dict()) + interaction = { + 'indices': { + 'ligand': tuple(map(int, row['ligand_atoms'].split(','))), + 'protein': tuple(map(int, row['protein_atoms'].split(','))), + }, + 'parent_indices': { + 'ligand': tuple(map(int, row['ligand_atoms'].split(','))), + 'protein': tuple(map(int, row['protein_atoms'].split(','))), + }, + 'distance': row['distance'], + } + # if row['plane_angle'] is not None: + # interaction['plane_angle'] = row['plane_angle'] + # if row['normal_to_centroid_angle'] is not None: + # interaction['normal_to_centroid_angle'] = row['normal_to_centroid_angle'] + # if row['intersect_distance'] is not None: + # interaction['intersect_distance'] = row['intersect_distance'] + # if row['intersect_radius'] is not None: + # interaction['intersect_radius'] = row['intersect_radius'] + # if row['pi_ring'] is not None: + # interaction['pi_ring'] = row['pi_ring'] + + metadata[key].setdefault(row['alias'], list()).append(interaction) + + ifp = IFP(metadata) + fp = Fingerprint(interactions=INTERACTION_LIST, vicinity_cutoff=8.0) + fp.ifp = {0: ifp} + Complex3D.COLORS.update(INTERACTION_COLORS) + v = fp.plot_3d(ligand_mol=ligand_plf, protein_mol=protein_plf, frame=0) + return v \ No newline at end of file diff --git a/src/sbdd_metrics/metrics.py b/src/sbdd_metrics/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..9a8f4eeec8c9c5f8c25a7936ca7c3aa11b5a582c --- /dev/null +++ b/src/sbdd_metrics/metrics.py @@ -0,0 +1,929 @@ +import multiprocessing +import subprocess +import tempfile +from abc import abstractmethod +from collections import defaultdict +from pathlib import Path +from typing import Union, Dict, Collection, Set, Optional +import signal +import numpy as np +import pandas as pd +from unittest.mock import patch +from scipy.spatial.distance import jensenshannon +from fcd import get_fcd +from posebusters import PoseBusters +from posebusters.modules.distance_geometry import _get_bond_atom_indices, _get_angle_atom_indices +from rdkit import Chem, RDLogger +from rdkit.Chem import Descriptors, Crippen, Lipinski, QED, KekulizeException, AtomKekulizeException +from rdkit.Chem.rdForceFieldHelpers import UFFGetMoleculeForceField +from scipy.spatial.distance import jensenshannon +from tqdm import tqdm +from useful_rdkit_utils import REOS, RingSystemLookup, get_min_ring_frequency, RingSystemFinder + +from .interactions import INTERACTION_LIST, prepare_ligand, read_protein, profile_detailed +from .sascorer import calculateScore + +def timeout_handler(signum, frame): + raise TimeoutError('Timeout') + +BOND_SYMBOLS = { + Chem.rdchem.BondType.SINGLE: '-', + Chem.rdchem.BondType.DOUBLE: '=', + Chem.rdchem.BondType.TRIPLE: '#', + Chem.rdchem.BondType.AROMATIC: ':', +} + + +def is_nan(value): + return value is None or pd.isna(value) or np.isnan(value) + + +def safe_run(func, timeout, **kwargs): + def _run(f, q, **kwargs): + r = f(**kwargs) + q.put(r) + + queue = multiprocessing.Queue() + process = multiprocessing.Process(target=_run, kwargs={'f': func, 'q': queue, **kwargs}) + process.start() + process.join(timeout) + if process.is_alive(): + print(f"Function {func} didn't finish in {timeout} seconds. Terminating it.") + process.terminate() + process.join() + return None + elif not queue.empty(): + return queue.get() + return None + + +class AbstractEvaluator: + ID = None + def __call__(self, molecule: Union[str, Path, Chem.Mol], protein: Union[str, Path] = None, + timeout=350): + """ + Args: + molecule (Union[str, Path, Chem.Mol]): input molecule + protein (str): target protein + + Returns: + metrics (dict): dictionary of metrics + """ + RDLogger.DisableLog('rdApp.*') + self.check_format(molecule, protein) + + # timeout handler + signal.signal(signal.SIGALRM, timeout_handler) + try: + signal.alarm(timeout) + results = self.evaluate(molecule, protein) + except TimeoutError: + print(f'Error when evaluating [{self.ID}]: Timeout after {timeout} seconds') + signal.alarm(0) + return {} + except Exception as e: + print(f'Error when evaluating [{self.ID}]: {e}') + signal.alarm(0) + return {} + finally: + signal.alarm(0) + return self.add_id(results) + + def add_id(self, results): + if self.ID is not None: + return {f'{self.ID}.{key}': value for key, value in results.items()} + else: + return results + + @abstractmethod + def evaluate(self, molecule: Union[str, Path, Chem.Mol], protein: Union[str, Path]) -> Dict[str, Union[int, float, str]]: + raise NotImplementedError + + @staticmethod + def check_format(molecule, protein): + assert isinstance(molecule, (str, Path, Chem.Mol)), 'Supported molecule types: str, Path, Chem.Mol' + assert protein is None or isinstance(protein, (str, Path)), 'Supported protein types: str' + if isinstance(molecule, (str, Path)): + supp = Chem.SDMolSupplier(str(molecule), sanitize=False) + assert len(supp) == 1, 'Only one molecule per file is supported' + + @staticmethod + def load_molecule(molecule): + if isinstance(molecule, (str, Path)): + return Chem.SDMolSupplier(str(molecule), sanitize=False)[0] + return Chem.Mol(molecule) # create copy to avoid overriding properties of the input molecule + + @staticmethod + def save_molecule(molecule, sdf_path): + if isinstance(molecule, (str, Path)): + return molecule + + with Chem.SDWriter(str(sdf_path)) as w: + try: + w.write(molecule) + except (RuntimeError, ValueError) as e: + if isinstance(e, (KekulizeException, AtomKekulizeException)): + w.SetKekulize(False) + w.write(molecule) + w.SetKekulize(True) + else: + w.write(Chem.Mol()) + print('[AbstractEvaluator] Error when saving the molecule') + + return sdf_path + + @property + def dtypes(self): + return self.add_id(self._dtypes) + + @property + @abstractmethod + def _dtypes(self): + raise NotImplementedError + + +class RepresentationEvaluator(AbstractEvaluator): + ID = 'representation' + + def evaluate(self, molecule, protein=None): + molecule = self.load_molecule(molecule) + try: + smiles = Chem.MolToSmiles(molecule) + except: + smiles = None + + return {'smiles': smiles} + + @property + def _dtypes(self): + return {'smiles': str} + + +class MolPropertyEvaluator(AbstractEvaluator): + ID = 'mol_props' + + def evaluate(self, molecule, protein=None): + molecule = self.load_molecule(molecule) + return {k: v for k, v in molecule.GetPropsAsDict().items() if isinstance(v, float)} + + @property + def _dtypes(self): + return {'*': float} + + +class PoseBustersEvaluator(AbstractEvaluator): + ID = 'posebusters' + def __init__(self, pb_conf: str = 'dock'): + self.posebusters = PoseBusters(config=pb_conf) + + @patch('rdkit.RDLogger.EnableLog', lambda x: None) + @patch('rdkit.RDLogger.DisableLog', lambda x: None) + def evaluate(self, molecule, protein=None): + result = safe_run(self.posebusters.bust, timeout=20, mol_pred=molecule, mol_cond=protein) + if result is None: + return dict() + + with pd.option_context("future.no_silent_downcasting", True): + result = dict(result.fillna(False).iloc[0]) + result['all'] = all([bool(value) if not is_nan(value) else False for value in result.values()]) + return result + + @property + def _dtypes(self): + return {'*': bool} + + +class GeometryEvaluator(AbstractEvaluator): + ID = 'geometry' + + def evaluate(self, molecule, protein=None): + mol = self.load_molecule(molecule) + data = self.get_distances_and_angles(mol) + return data + + @staticmethod + def angle_repr(mol, triplet): + i = mol.GetAtomWithIdx(triplet[0]).GetSymbol() + j = mol.GetAtomWithIdx(triplet[1]).GetSymbol() + k = mol.GetAtomWithIdx(triplet[2]).GetSymbol() + ij = BOND_SYMBOLS[mol.GetBondBetweenAtoms(triplet[0], triplet[1]).GetBondType()] + jk = BOND_SYMBOLS[mol.GetBondBetweenAtoms(triplet[1], triplet[2]).GetBondType()] + + # Unified (sorted) representation + if i < k: + return f'{i}{ij}{j}{jk}{k}' + elif i > j: + return f'{k}{jk}{j}{ij}{i}' + elif ij <= jk: + return f'{i}{ij}{j}{jk}{k}' + else: + return f'{k}{jk}{j}{ij}{i}' + + @staticmethod + def bond_repr(mol, pair): + i = mol.GetAtomWithIdx(pair[0]).GetSymbol() + j = mol.GetAtomWithIdx(pair[1]).GetSymbol() + ij = BOND_SYMBOLS[mol.GetBondBetweenAtoms(pair[0], pair[1]).GetBondType()] + # Unified (sorted) representation + return f'{i}{ij}{j}' if i <= j else f'{j}{ij}{i}' + + @staticmethod + def get_bond_distances(mol, bonds): + i, j = np.array(bonds).T + x = mol.GetConformer().GetPositions() + xi = x[i] + xj = x[j] + bond_distances = np.linalg.norm(xi - xj, axis=1) + return bond_distances + + @staticmethod + def get_angle_values(mol, triplets): + i, j, k = np.array(triplets).T + x = mol.GetConformer().GetPositions() + xi = x[i] + xj = x[j] + xk = x[k] + vji = xi - xj + vjk = xk - xj + angles = np.arccos((vji * vjk).sum(axis=1) / (np.linalg.norm(vji, axis=1) * np.linalg.norm(vjk, axis=1))) + return np.degrees(angles) + + @staticmethod + def get_distances_and_angles(mol): + data = defaultdict(list) + bonds = _get_bond_atom_indices(mol) + distances = GeometryEvaluator.get_bond_distances(mol, bonds) + for b, d in zip(bonds, distances): + data[GeometryEvaluator.bond_repr(mol, b)].append(d) + + triplets = _get_angle_atom_indices(bonds) + angles = GeometryEvaluator.get_angle_values(mol, triplets) + for t, a in zip(triplets, angles): + data[GeometryEvaluator.angle_repr(mol, t)].append(a) + + return data + + @property + def _dtypes(self): + return {'*': list} + + +class EnergyEvaluator(AbstractEvaluator): + ID = 'energy' + + def evaluate(self, molecule, protein=None): + molecule = self.load_molecule(molecule) + try: + energy = self.get_energy(molecule) + except: + energy = None + return {'energy': energy} + + @staticmethod + def get_energy(mol, conf_id=-1): + mol = Chem.AddHs(mol, addCoords=True) + uff = UFFGetMoleculeForceField(mol, confId=conf_id) + e_uff = uff.CalcEnergy() + return e_uff + + @property + def _dtypes(self): + return {'energy': float} + + +class InteractionsEvaluator(AbstractEvaluator): + ID = 'interactions' + + def __init__(self, reduce='./reduce'): + self.reduce = reduce + + @property + def default_profile(self): + return {i: 0 for i in INTERACTION_LIST} + + def evaluate(self, molecule, protein=None): + molecule = self.load_molecule(molecule) + profile = self.default_profile + try: + ligand_plf = prepare_ligand(molecule) + protein_plf = read_protein(str(protein), reduce_exec=self.reduce) + interactions = profile_detailed(ligand_plf, protein_plf) + if not interactions.empty: + profile.update(dict(interactions.interaction.value_counts())) + except Exception: + pass + return profile + + @property + def _dtypes(self): + return {'*': int} + + +class GninaEvalulator(AbstractEvaluator): + ID = 'gnina' + def __init__(self, gnina): + self.gnina = gnina + + def evaluate(self, molecule, protein=None): + with tempfile.TemporaryDirectory() as tmpdir: + molecule = self.save_molecule(molecule, sdf_path=Path(tmpdir, 'molecule.sdf')) + gnina_cmd = f'{self.gnina} -r {str(protein)} -l {str(molecule)} --minimize --seed 42 --no_gpu' + gnina_result = subprocess.run(gnina_cmd, shell=True, capture_output=True, text=True) + n_atoms = self.load_molecule(molecule).GetNumAtoms() + + gnina_scores = self.read_gnina_results(gnina_result) + + # Additionally computing ligand efficiency + gnina_scores['vina_efficiency'] = gnina_scores['vina_score'] / n_atoms if n_atoms > 0 else None + gnina_scores['gnina_efficiency'] = gnina_scores['gnina_score'] / n_atoms if n_atoms > 0 else None + return gnina_scores + + @staticmethod + def read_gnina_results(gnina_result): + res = { + 'vina_score': None, + 'gnina_score': None, + 'minimisation_rmsd': None, + 'cnn_score': None, + } + if gnina_result.returncode != 0: + print(gnina_result.stderr) + return res + + for line in gnina_result.stdout.split('\n'): + if line.startswith('Affinity'): + res['vina_score'] = float(line.split(' ')[1].strip()) + if line.startswith('CNNaffinity'): + res['gnina_score'] = float(line.split(' ')[1].strip()) + if line.startswith('CNNscore'): + res['cnn_score'] = float(line.split(' ')[1].strip()) + if line.startswith('RMSD'): + res['minimisation_rmsd'] = float(line.split(' ')[1].strip()) + + return res + + @property + def _dtypes(self): + return {'*': float} + + +class MedChemEvaluator(AbstractEvaluator): + ID = 'medchem' + def __init__(self, connectivity_threshold=1.0): + self.connectivity_threshold = connectivity_threshold + + def evaluate(self, molecule, protein=None): + molecule = self.load_molecule(molecule) + valid = self.is_valid(molecule) + + if valid: + Chem.SanitizeMol(molecule) + + connected = None if not valid else self.is_connected(molecule) + qed = None if not valid else self.calculate_qed(molecule) + sa = None if not valid else self.calculate_sa(molecule) + logp = None if not valid else self.calculate_logp(molecule) + lipinski = None if not valid else self.calculate_lipinski(molecule) + n_rotatable_bonds = None if not valid else self.calculate_rotatable_bonds(molecule) + size = self.calculate_molecule_size(molecule) + + return { + 'valid': valid, + 'connected': connected, + 'qed': qed, + 'sa': sa, + 'logp': logp, + 'lipinski': lipinski, + 'size': size, + 'n_rotatable_bonds': n_rotatable_bonds, + } + + @staticmethod + def is_valid(rdmol): + if rdmol.GetNumAtoms() < 1: + return False + + _mol = Chem.Mol(rdmol) + try: + Chem.SanitizeMol(_mol) + except ValueError: + return False + + return True + + def is_connected(self, rdmol): + if rdmol.GetNumAtoms() < 1: + return False + + try: + mol_frags = Chem.rdmolops.GetMolFrags(rdmol, asMols=True) + largest_frag = max(mol_frags, default=rdmol, key=lambda m: m.GetNumAtoms()) + return largest_frag.GetNumAtoms() / rdmol.GetNumAtoms() >= self.connectivity_threshold + except: + return False + + @staticmethod + def calculate_qed(rdmol): + try: + return QED.qed(rdmol) + except: + return None + + @staticmethod + def calculate_sa(rdmol): + try: + sa = calculateScore(rdmol) + return sa + except: + return None + + @staticmethod + def calculate_logp(rdmol): + try: + return Crippen.MolLogP(rdmol) + except: + return None + + @staticmethod + def calculate_lipinski(rdmol): + try: + rule_1 = Descriptors.ExactMolWt(rdmol) < 500 + rule_2 = Lipinski.NumHDonors(rdmol) <= 5 + rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10 + rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5) + rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10 + return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]]) + except: + return None + + @staticmethod + def calculate_molecule_size(rdmol): + try: + return rdmol.GetNumAtoms() + except: + return None + + @staticmethod + def calculate_rotatable_bonds(rdmol): + try: + return Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) + except: + return None + + @property + def _dtypes(self): + return { + 'valid': bool, + 'connected': bool, + 'qed': float, + 'sa': float, + 'logp': float, + 'lipinski': int, + 'size': int, + 'n_rotatable_bonds': int, + } + + +class ClashEvaluator(AbstractEvaluator): + ID = 'clashes' + def __init__(self, margin=0.75, ignore={'H'}): + self.margin = margin + self.ignore = ignore + + def evaluate(self, molecule=None, protein=None): + result = { + 'passed_clash_score_ligands': None, + 'passed_clash_score_pockets': None, + 'passed_clash_score_between': None, + } + if molecule is not None: + molecule = self.load_molecule(molecule) + clash_score = self.clash_score(molecule) + result['clash_score_ligands'] = clash_score + result['passed_clash_score_ligands'] = (clash_score == 0) + + if protein is not None: + protein = Chem.MolFromPDBFile(str(protein), sanitize=False) + clash_score = self.clash_score(protein) + result['clash_score_pockets'] = clash_score + result['passed_clash_score_pockets'] = (clash_score == 0) + + if molecule is not None and protein is not None: + clash_score = self.clash_score(molecule, protein) + result['clash_score_between'] = clash_score + result['passed_clash_score_between'] = (clash_score == 0) + + return result + + def clash_score(self, rdmol1, rdmol2=None): + """ + Computes a clash score as the number of atoms that have at least one + clash divided by the number of atoms in the molecule. + + INTERMOLECULAR CLASH SCORE + If rdmol2 is provided, the score is the percentage of atoms in rdmol1 + that have at least one clash with rdmol2. + We define a clash if two atoms are closer than "margin times the sum of + their van der Waals radii". + + INTRAMOLECULAR CLASH SCORE + If rdmol2 is not provided, the score is the percentage of atoms in rdmol1 + that have at least one clash with other atoms in rdmol1. + In this case, a clash is defined by margin times the atoms' smallest + covalent radii (among single, double and triple bond radii). This is done + so that this function is applicable even if no connectivity information is + available. + """ + + intramolecular = rdmol2 is None + if intramolecular: + rdmol2 = rdmol1 + + coord1, radii1 = self.coord_and_radii(rdmol1, intramolecular=intramolecular) + coord2, radii2 = self.coord_and_radii(rdmol2, intramolecular=intramolecular) + + dist = np.sqrt(np.sum((coord1[:, None, :] - coord2[None, :, :]) ** 2, axis=-1)) + if intramolecular: + np.fill_diagonal(dist, np.inf) + + clashes = dist < self.margin * (radii1[:, None] + radii2[None, :]) + clashes = np.any(clashes, axis=1) + return np.mean(clashes) + + def coord_and_radii(self, rdmol, intramolecular): + _periodic_table = Chem.GetPeriodicTable() + _get_radius = _periodic_table.GetRcovalent if intramolecular else _periodic_table.GetRvdw + + coord = rdmol.GetConformer().GetPositions() + radii = np.array([_get_radius(a.GetSymbol()) for a in rdmol.GetAtoms()]) + + mask = np.array([a.GetSymbol() not in self.ignore for a in rdmol.GetAtoms()]) + coord = coord[mask] + radii = radii[mask] + + assert coord.shape[0] == radii.shape[0] + return coord, radii + + @property + def _dtypes(self): + return { + 'clash_score_ligands': float, + 'clash_score_pockets': float, + 'clash_score_between': float, + 'passed_clash_score_ligands': bool, + 'passed_clash_score_pockets': bool, + 'passed_clash_score_between': bool, + } + + +class RingCountEvaluator(AbstractEvaluator): + ID = 'ring_count' + + def evaluate(self, molecule, protein=None): + _mol = self.load_molecule(molecule) + + # compute ring info if not yet available + try: + _mol.UpdatePropertyCache() + except ValueError: + return {} + Chem.GetSymmSSSR(_mol) + + rings = _mol.GetRingInfo().AtomRings() + ring_sizes = [len(r) for r in rings] + + ring_counts = defaultdict(int) + for k in ring_sizes: + ring_counts[f"num_{k}_rings"] += 1 + + return ring_counts + + @property + def _dtypes(self): + return {'*': int} + + +class ChemblRingEvaluator(AbstractEvaluator): + ID = 'chembl_ring_systems' + + def __init__(self): + self.ring_system_lookup = RingSystemLookup.default() # ChEMBL + + def evaluate(self, molecule, protein=None): + + results = { + 'min_ring_smi': None, + 'min_ring_freq_gt0_': None, + 'min_ring_freq_gt10_': None, + 'min_ring_freq_gt100_': None, + } + + molecule = self.load_molecule(molecule) + + try: + Chem.SanitizeMol(molecule) + freq_list = self.ring_system_lookup.process_mol(molecule) + freq_list = self.ring_system_lookup.process_mol(molecule) + except ValueError: + return results + + min_ring, min_freq = get_min_ring_frequency(freq_list) + + return { + 'min_ring_smi': min_ring, + 'min_ring_freq_gt0_': min_freq > 0, + 'min_ring_freq_gt10_': min_freq > 10, + 'min_ring_freq_gt100_': min_freq > 100, + } + + @property + def _dtypes(self): + return { + 'min_ring_smi': str, + 'min_ring_freq_gt0_': bool, + 'min_ring_freq_gt10_': bool, + 'min_ring_freq_gt100_': bool, + } + + +class REOSEvaluator(AbstractEvaluator): + # Based on https://practicalcheminformatics.blogspot.com/2024/05/generative-molecular-design-isnt-as.html + ID = 'reos' + + def __init__(self): + self.reos = REOS() + + def evaluate(self, molecule, protein=None): + + molecule = self.load_molecule(molecule) + try: + Chem.SanitizeMol(molecule) + except ValueError: + return {rule_set: False for rule_set in self.reos.get_available_rule_sets()} + + results = {} + for rule_set in self.reos.get_available_rule_sets(): + self.reos.set_active_rule_sets([rule_set]) + if rule_set == 'PW': + self.reos.drop_rule('furans') + + reos_res = self.reos.process_mol(molecule) + results[rule_set] = reos_res[0] == 'ok' + + results['all'] = all([bool(value) if not is_nan(value) else False for value in results.values()]) + return results + + @property + def _dtypes(self): + return {'*': bool} + + +class FullEvaluator(AbstractEvaluator): + def __init__( + self, + pb_conf: str = 'dock', + gnina: Optional[Union[Path, str]] = None, + reduce: Optional[Union[Path, str]] = None, + connectivity_threshold: float = 1.0, + margin: float = 0.75, + ignore: Set[str] = {'H'}, + exclude_evaluators: Collection[str] = [], + ): + all_evaluators = [ + RepresentationEvaluator(), + MolPropertyEvaluator(), + PoseBustersEvaluator(pb_conf=pb_conf), + MedChemEvaluator(connectivity_threshold=connectivity_threshold), + ClashEvaluator(margin=margin, ignore=ignore), + GeometryEvaluator(), + RingCountEvaluator(), + EnergyEvaluator(), + ChemblRingEvaluator(), + REOSEvaluator() + ] + if gnina is not None: + all_evaluators.append(GninaEvalulator(gnina=gnina)) + else: + print(f'Evaluator [{GninaEvalulator.ID}] is not included') + if reduce is not None: + all_evaluators.append(InteractionsEvaluator(reduce=reduce)) + else: + print(f'Evaluator [{InteractionsEvaluator.ID}] is not included') + + self.evaluators = [] + for e in all_evaluators: + if e.ID in exclude_evaluators: + print(f'Excluded Evaluator [{e.ID}]') + else: + self.evaluators.append(e) + + print('Will use the following evaluators:') + for e in self.evaluators: + print(f'- [{e.ID}]') + + + def evaluate(self, molecule, protein): + results = {} + for evaluator in self.evaluators: + results.update(evaluator(molecule, protein)) + return results + + @property + def _dtypes(self): + all_dtypes = {} + for evaluator in self.evaluators: + all_dtypes.update(evaluator.dtypes) + return all_dtypes + + +######################################################################################## +################################# Collection Metrics ################################### +######################################################################################## + + +class AbstractCollectionEvaluator: + ID = None + def __call__(self, smiles: Collection[str], timeout=300): + """ + Args: + smiles (Collection[smiles]): input list of SMILES + + Returns: + metrics (dict): dictionary of metrics + """ + if self.ID is not None: + print(f'Running CollectionEvaluator [{self.ID}]') + + RDLogger.DisableLog('rdApp.*') + self.check_format(smiles) + # timeout handler + signal.signal(signal.SIGALRM, timeout_handler) + try: + signal.alarm(timeout) + results = self.evaluate(smiles) + except TimeoutError: + print(f'Error when evaluating [{self.ID}]: Timeout after {timeout} seconds') + signal.alarm(0) + return {} + except Exception as e: + print(f'Error when evaluating [{self.ID}]: {e}') + signal.alarm(0) + return {} + finally: + print(f'Finished CollectionEvaluator [{self.ID}]') + signal.alarm(0) + return results + + @staticmethod + def check_format(smiles): + assert len(smiles) > 0, 'List of input SMILES cannot be empty' + assert isinstance(smiles, Collection), 'Only list of SMILES supported' + assert isinstance(smiles[0], str), 'Only list of SMILES supported' + + +class UniquenessEvaluator(AbstractCollectionEvaluator): + ID = 'uniqueness' + def evaluate(self, smiles: Collection[str]): + uniqueness = len(set(smiles)) / len(smiles) + return {'uniqueness': uniqueness} + + +class NoveltyEvaluator(AbstractCollectionEvaluator): + ID = 'novelty' + def __init__(self, reference_smiles: Collection[str]): + self.reference_smiles = set(list(reference_smiles)) + assert len(self.reference_smiles) > 0, 'List of refernce SMILES cannot be empty' + + def evaluate(self, smiles: Collection[str]): + smiles = set(smiles) + novel = [smi for smi in smiles if smi not in self.reference_smiles] + novelty = len(novel) / len(smiles) + return {'novelty': novelty} + +def canonical_smiles(smiles): + for smi in smiles: + try: + mol = Chem.MolFromSmiles(smi) + if mol is not None: + yield Chem.MolToSmiles(mol) + except: + yield None + +class FCDEvaluator(AbstractCollectionEvaluator): + ID = 'fcd' + def __init__(self, reference_smiles: Collection[str]): + self.reference_smiles = list(reference_smiles) + assert len(self.reference_smiles) > 0, 'List of refernce SMILES cannot be empty' + + def evaluate(self, smiles: Collection[str]): + if len(smiles) > len(self.reference_smiles): + print('Number of reference molecules should be greater than number of input molecules') + return {'fcd': None} + + np.random.seed(42) + reference_smiles = np.random.choice(self.reference_smiles, len(smiles), replace=False).tolist() + reference_smiles_canonical = [w for w in canonical_smiles(reference_smiles) if w is not None] + smiles_canonical = [w for w in canonical_smiles(smiles) if w is not None] + fcd = get_fcd(reference_smiles_canonical, smiles_canonical) + return {'fcd': fcd} + + +class RingDistributionEvaluator(AbstractCollectionEvaluator): + ID = 'ring_system_distribution' + + def __init__(self, reference_smiles: Collection[str], jsd_on_k_most_freq: Collection[int] = ()): + self.ring_system_finder = RingSystemFinder() + self.ref_ring_dict = self.compute_ring_dict(reference_smiles) + self.jsd_on_k_most_freq = jsd_on_k_most_freq + + def compute_ring_dict(self, molecules): + + ring_system_dict = defaultdict(int) + + for mol in tqdm(molecules, desc="Computing ring systems"): + + if isinstance(mol, str): + mol = Chem.MolFromSmiles(mol) + + try: + ring_system_list = self.ring_system_finder.find_ring_systems(mol, as_mols=True) + except ValueError: + print(f"WARNING[{type(self).__name__}]: error while computing ring systems; skipping molecule.") + continue + + for ring in ring_system_list: + inchi_key = Chem.MolToInchiKey(ring) + ring_system_dict[inchi_key] += 1 + + return ring_system_dict + + def precision(self, query_ring_dict): + query_ring_systems = set(query_ring_dict.keys()) + ref_ring_systems = set(self.ref_ring_dict.keys()) + intersection = ref_ring_systems & query_ring_systems + return len(intersection) / len(query_ring_systems) if len(query_ring_systems) > 0 else 0 + + def recall(self, query_ring_dict): + query_ring_systems = set(query_ring_dict.keys()) + ref_ring_systems = set(self.ref_ring_dict.keys()) + intersection = ref_ring_systems & query_ring_systems + return len(intersection) / len(ref_ring_systems) if len(ref_ring_systems) > 0 else 0 + + def jsd(self, query_ring_dict, k_most_freq=None): + + if k_most_freq is None: + # example on the union of all ring systems + sample_space = set(self.ref_ring_dict.keys()) | set(query_ring_dict.keys()) + else: + # evaluate only on the k most common rings from the reference set + sorted_rings = [k for k, v in sorted(self.ref_ring_dict.items(), key=lambda item: item[1], reverse=True)] + sample_space = sorted_rings[:k_most_freq] + + p = np.zeros(len(sample_space)) + q = np.zeros(len(sample_space)) + + for i, inchi_key in enumerate(sample_space): + p[i] = self.ref_ring_dict.get(inchi_key, 0) + q[i] = query_ring_dict.get(inchi_key, 0) + + # normalize + p = p / np.sum(p) + q = q / np.sum(q) + + return jensenshannon(p, q) + + def evaluate(self, smiles: Collection[str]): + + query_ring_dict = self.compute_ring_dict(smiles) + + out = { + "precision": self.precision(query_ring_dict), + "recall": self.recall(query_ring_dict), + "jsd": self.jsd(query_ring_dict), + } + + out.update( + {f"jsd_{k}_most_freq": self.jsd(query_ring_dict, k_most_freq=k) for k in self.jsd_on_k_most_freq} + ) + + return out + + +class FullCollectionEvaluator(AbstractCollectionEvaluator): + def __init__(self, reference_smiles: Collection[str], exclude_evaluators: Collection[str] = []): + self.evaluators = [ + UniquenessEvaluator(), + NoveltyEvaluator(reference_smiles=reference_smiles), + FCDEvaluator(reference_smiles=reference_smiles), + RingDistributionEvaluator(reference_smiles, jsd_on_k_most_freq=[10, 100, 1000, 10000]), + ] + for e in self.evaluators: + if e.ID in exclude_evaluators: + print(f'Excluding CollectionEvaluator [{e.ID}]') + self.evaluators.remove(e) + + def evaluate(self, smiles): + results = {} + for evaluator in self.evaluators: + results.update(evaluator(smiles)) + return results diff --git a/src/sbdd_metrics/sascorer.py b/src/sbdd_metrics/sascorer.py new file mode 100644 index 0000000000000000000000000000000000000000..862d191032bb0b366260f9d6e306fb0ddf98ccf3 --- /dev/null +++ b/src/sbdd_metrics/sascorer.py @@ -0,0 +1,173 @@ +# +# calculation of synthetic accessibility score as described in: +# +# Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions +# Peter Ertl and Ansgar Schuffenhauer +# Journal of Cheminformatics 1:8 (2009) +# http://www.jcheminf.com/content/1/1/8 +# +# several small modifications to the original paper are included +# particularly slightly different formula for marocyclic penalty +# and taking into account also molecule symmetry (fingerprint density) +# +# for a set of 10k diverse molecules the agreement between the original method +# as implemented in PipelinePilot and this implementation is r2 = 0.97 +# +# peter ertl & greg landrum, september 2013 +# + + +from rdkit import Chem +from rdkit.Chem import rdMolDescriptors +import pickle + +import math +from collections import defaultdict + +import os.path as op + +_fscores = None + + +def readFragmentScores(name='fpscores'): + import gzip + global _fscores + # generate the full path filename: + if name == "fpscores": + name = op.join(op.dirname(__file__), name) + data = pickle.load(gzip.open('%s.pkl.gz' % name)) + outDict = {} + for i in data: + for j in range(1, len(i)): + outDict[i[j]] = float(i[0]) + _fscores = outDict + + +def numBridgeheadsAndSpiro(mol, ri=None): + nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) + nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) + return nBridgehead, nSpiro + + +def calculateScore(m): + if _fscores is None: + readFragmentScores() + + # fragment score + fp = rdMolDescriptors.GetMorganFingerprint(m, + 2) # <- 2 is the *radius* of the circular fingerprint + fps = fp.GetNonzeroElements() + score1 = 0. + nf = 0 + for bitId, v in fps.items(): + nf += v + sfp = bitId + score1 += _fscores.get(sfp, -4) * v + score1 /= nf + + # features score + nAtoms = m.GetNumAtoms() + nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) + ri = m.GetRingInfo() + nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) + nMacrocycles = 0 + for x in ri.AtomRings(): + if len(x) > 8: + nMacrocycles += 1 + + sizePenalty = nAtoms**1.005 - nAtoms + stereoPenalty = math.log10(nChiralCenters + 1) + spiroPenalty = math.log10(nSpiro + 1) + bridgePenalty = math.log10(nBridgeheads + 1) + macrocyclePenalty = 0. + # --------------------------------------- + # This differs from the paper, which defines: + # macrocyclePenalty = math.log10(nMacrocycles+1) + # This form generates better results when 2 or more macrocycles are present + if nMacrocycles > 0: + macrocyclePenalty = math.log10(2) + + score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty + + # correction for the fingerprint density + # not in the original publication, added in version 1.1 + # to make highly symmetrical molecules easier to synthetise + score3 = 0. + if nAtoms > len(fps): + score3 = math.log(float(nAtoms) / len(fps)) * .5 + + sascore = score1 + score2 + score3 + + # need to transform "raw" value into scale between 1 and 10 + min = -4.0 + max = 2.5 + sascore = 11. - (sascore - min + 1) / (max - min) * 9. + # smooth the 10-end + if sascore > 8.: + sascore = 8. + math.log(sascore + 1. - 9.) + if sascore > 10.: + sascore = 10.0 + elif sascore < 1.: + sascore = 1.0 + + return sascore + + +def processMols(mols): + print('smiles\tName\tsa_score') + for i, m in enumerate(mols): + if m is None: + continue + + s = calculateScore(m) + + smiles = Chem.MolToSmiles(m) + print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) + + +if __name__ == '__main__': + import sys + import time + + t1 = time.time() + readFragmentScores("fpscores") + t2 = time.time() + + suppl = Chem.SmilesMolSupplier(sys.argv[1]) + t3 = time.time() + processMols(suppl) + t4 = time.time() + + print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)), + file=sys.stderr) + +# +# Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. +# * Neither the name of Novartis Institutes for BioMedical Research Inc. +# nor the names of its contributors may be used to endorse or promote +# products derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# \ No newline at end of file diff --git a/src/size_predictor/size_model.py b/src/size_predictor/size_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e13a8bbeff8dc5c54107012ac3ca7645be44238a --- /dev/null +++ b/src/size_predictor/size_model.py @@ -0,0 +1,495 @@ +from typing import Optional +from pathlib import Path +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +import pytorch_lightning as pl +from torch_scatter import scatter_mean + +from src.model.gvp import GVP, GVPModel, LayerNorm, GVPConvLayer +from src.model.gvp_transformer import GVPTransformerModel, GVPTransformerLayer +from src.constants import aa_decoder, residue_bond_encoder +from src.data.dataset import ProcessedLigandPocketDataset +import src.utils as utils + + +class SizeModel(pl.LightningModule): + def __init__( + self, + max_size, + pocket_representation, + train_params, + loss_params, + eval_params, + predictor_params, + ): + super(SizeModel, self).__init__() + self.save_hyperparameters() + + assert pocket_representation == "CA+" + self.pocket_representation = pocket_representation + + self.type = loss_params.type + assert self.type in {'classifier', 'ordinal', 'regression'} + + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + self.data_transform = None + + # Training parameters + self.datadir = train_params.datadir + self.batch_size = train_params.batch_size + self.lr = train_params.lr + self.num_workers = train_params.num_workers + self.clip_grad = train_params.clip_grad + if self.clip_grad: + self.gradnorm_queue = utils.Queue() + # Add large value that will be flushed. + self.gradnorm_queue.add(3000) + + # Feature encoders/decoders + self.aa_decoder = aa_decoder + self.residue_bond_encoder = residue_bond_encoder + + # Set up the neural network + self.edge_cutoff = predictor_params.edge_cutoff + self.add_nma_feat = predictor_params.normal_modes + self.max_size = max_size + self.n_classes = max_size if self.type == 'ordinal' else max_size + 1 + backbone = predictor_params.backbone + model_params = getattr(predictor_params, backbone + '_params') + + self.residue_nf = (len(self.aa_decoder), 0) + if self.add_nma_feat: + self.residue_nf = (self.residue_nf[0], self.residue_nf[1] + 5) + + out_nf = 1 if self.type == "regression" else self.n_classes + + if backbone == 'gvp_transformer': + self.net = SizeGVPTransformer( + node_in_dim=self.residue_nf, + node_h_dim=model_params.node_h_dim, + out_nf=out_nf, + edge_in_nf=len(self.residue_bond_encoder), + edge_h_dim=model_params.edge_h_dim, + num_layers=model_params.n_layers, + dk=model_params.dk, + dv=model_params.dv, + de=model_params.de, + db=model_params.db, + dy=model_params.dy, + attn_heads=model_params.attn_heads, + n_feedforward=model_params.n_feedforward, + drop_rate=model_params.dropout, + reflection_equiv=model_params.reflection_equivariant, + d_max=model_params.d_max, + num_rbf=model_params.num_rbf, + vector_gate=model_params.vector_gate, + attention=model_params.attention, + ) + + elif backbone == 'gvp_gnn': + self.net = SizeGVPModel( + node_in_dim=self.residue_nf, + node_h_dim=model_params.node_h_dim, + out_nf=out_nf, + edge_in_nf=len(self.residue_bond_encoder), + edge_h_dim=model_params.edge_h_dim, + num_layers=model_params.n_layers, + drop_rate=model_params.dropout, + vector_gate=model_params.vector_gate, + reflection_equiv=model_params.reflection_equivariant, + d_max=model_params.d_max, + num_rbf=model_params.num_rbf, + ) + + else: + raise NotImplementedError(f"{backbone} is not available") + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.lr, + amsgrad=True, weight_decay=1e-12) + + def setup(self, stage: Optional[str] = None): + + if stage == 'fit': + self.train_dataset = ProcessedLigandPocketDataset( + Path(self.datadir, 'train.pt'), + ligand_transform=None, catch_errors=True) + # ligand_transform=self.data_transform, catch_errors=True) + self.val_dataset = ProcessedLigandPocketDataset( + Path(self.datadir, 'val.pt'), ligand_transform=None) + elif stage == 'test': + self.test_dataset = ProcessedLigandPocketDataset( + Path(self.datadir, 'test.pt'), ligand_transform=None) + else: + raise NotImplementedError + + def train_dataloader(self): + return DataLoader(self.train_dataset, self.batch_size, shuffle=True, + num_workers=self.num_workers, + # collate_fn=self.train_dataset.collate_fn, + collate_fn=partial(self.train_dataset.collate_fn, ligand_transform=self.data_transform), + pin_memory=True) + + def val_dataloader(self): + return DataLoader(self.val_dataset, self.batch_size, + shuffle=False, num_workers=self.num_workers, + collate_fn=self.val_dataset.collate_fn, + pin_memory=True) + + def test_dataloader(self): + return DataLoader(self.test_dataset, self.batch_size, shuffle=False, + num_workers=self.num_workers, + collate_fn=self.test_dataset.collate_fn, + pin_memory=True) + + def forward(self, pocket): + + # x: CA coordinates + x, h, mask = pocket['x'], pocket['one_hot'], pocket['mask'] + + edges = None + if 'bonds' in pocket: + edges = (pocket['bonds'], pocket['bond_one_hot']) + + v = None + if self.add_nma_feat: + v = pocket['nma_vec'] + + if edges is not None: + # make sure messages are passed both ways + edge_indices = torch.cat( + [edges[0], edges[0].flip(dims=[0])], dim=1) + edge_types = torch.cat([edges[1], edges[1]], dim=0) + + edges, edge_feat = self.get_edges( + mask, x, bond_inds=edge_indices, bond_feat=edge_types) + + assert torch.all(mask[edges[0]] == mask[edges[1]]) + + out = self.net(h, x, edges, v=v, batch_mask=mask, edge_attr=edge_feat) + + if torch.any(torch.isnan(out)): + # print("NaN detected in network output") + # out[torch.isnan(out)] = 0.0 + if self.training: + print("NaN detected in network output") + out[torch.isnan(out)] = 0.0 + else: + raise ValueError("NaN detected in network output") + + return out + + def get_edges(self, batch_mask, coord, bond_inds=None, bond_feat=None, self_edges=False): + + # Adjacency matrix + adj = batch_mask[:, None] == batch_mask[None, :] + + if self.edge_cutoff is not None: + adj = adj & (torch.cdist(coord, coord) <= self.edge_cutoff) + + # Add missing bonds if they got removed + adj[bond_inds[0], bond_inds[1]] = True + + if not self_edges: + adj = adj ^ torch.eye(*adj.size(), out=torch.empty_like(adj)) + + # Feature matrix + nobond_onehot = F.one_hot(torch.tensor( + self.residue_bond_encoder['NOBOND'], device=bond_feat.device), + num_classes=len(self.residue_bond_encoder)).float() + # nobond_emb = self.residue_bond_encoder(nobond_onehot.to(FLOAT_TYPE)) + # feat = nobond_emb.repeat(*adj.shape, 1) + feat = nobond_onehot.repeat(*adj.shape, 1) + feat[bond_inds[0], bond_inds[1]] = bond_feat + + # Return results + edges = torch.stack(torch.where(adj), dim=0) + edge_feat = feat[edges[0], edges[1]] + + return edges, edge_feat + + def compute_loss(self, pred_logits, true_size): + + if self.type == "classifier": + loss = F.cross_entropy(pred_logits, true_size) + + elif self.type == "ordinal": + # each binary variable corresponds to P(x > i), i=0,...,(max_size-1) + binary_labels = true_size.unsqueeze(1) > torch.arange(self.n_classes, device=true_size.device).unsqueeze(0) + loss = F.binary_cross_entropy_with_logits(pred_logits, binary_labels.float()) + + elif self.type == 'regression': + loss = F.mse_loss(pred_logits.squeeze(), true_size.float()) + + else: + raise NotImplementedError() + + return loss + + def max_likelihood(self, pred_logits): + + if self.type == "classifier": + pred = pred_logits.argmax(dim=-1) + + elif self.type == "ordinal": + # convert probabilities from P(x > i), i=0,...,(max_size-1) to + # P(i), i=0,...,max_size + prop_greater = pred_logits.sigmoid() + pred = torch.zeros((pred_logits.size(0), pred_logits.size(1) + 1), + device=pred_logits.device) + pred[:, 0] = 1 - prop_greater[:, 0] + pred[:, 1:-1] = prop_greater[:, :-1] - prop_greater[:, 1:] + pred[:, -1] = prop_greater[:, -1] + pred = pred.argmax(dim=-1) + + elif self.type == 'regression': + pred = torch.clip(torch.round(pred_logits), + min=0, max=self.max_size) + pred = pred.squeeze() + + else: + raise NotImplementedError() + + return pred + + def log_metrics(self, metrics_dict, split, batch_size=None, **kwargs): + for m, value in metrics_dict.items(): + self.log(f'{m}/{split}', value, batch_size=batch_size, **kwargs) + + def compute_metrics(self, pred_logits, target): + + pred = self.max_likelihood(pred_logits) + + accuracy = (pred == target).sum() / len(target) + mse = torch.mean((target - pred).float()**2) + + acc_window3 = (torch.abs(target - pred) <= 1).sum() / len(target) + acc_window5 = (torch.abs(target - pred) <= 2).sum() / len(target) + + return {'accuracy': accuracy, + 'mse': mse, + 'accuracy_window3': acc_window3, + 'accuracy_window5': acc_window5} + + def training_step(self, data, *args): + + ligand, pocket = data['ligand'], data['pocket'] + + try: + pred_logits = self.forward(pocket) + true_size = ligand['size'] + + except RuntimeError as e: + # this is not supported for multi-GPU + if self.trainer.num_devices < 2 and 'out of memory' in str(e): + print('WARNING: ran out of memory, skipping to the next batch') + return None + else: + raise e + loss = self.compute_loss(pred_logits, true_size) + + # Compute metrics + metrics = self.compute_metrics(pred_logits, true_size) + self.log_metrics({'loss': loss, **metrics}, 'train', + batch_size=len(true_size), prog_bar=False) + + return loss + + def validation_step(self, data, *args): + ligand, pocket = data['ligand'], data['pocket'] + + pred_logits = self.forward(pocket) + true_size = ligand['size'] + + loss = self.compute_loss(pred_logits, true_size) + + # Compute metrics + metrics = self.compute_metrics(pred_logits, true_size) + self.log_metrics({'loss': loss, **metrics}, 'val', batch_size=len(true_size)) + + return loss + + def configure_gradient_clipping(self, optimizer, optimizer_idx, + gradient_clip_val, gradient_clip_algorithm): + + if not self.clip_grad: + return + + # Allow gradient norm to be 150% + 2 * stdev of the recent history. + max_grad_norm = 1.5 * self.gradnorm_queue.mean() + \ + 2 * self.gradnorm_queue.std() + + # Get current grad_norm + params = [p for g in optimizer.param_groups for p in g['params']] + grad_norm = utils.get_grad_norm(params) + + # Lightning will handle the gradient clipping + self.clip_gradients(optimizer, gradient_clip_val=max_grad_norm, + gradient_clip_algorithm='norm') + + if float(grad_norm) > max_grad_norm: + self.gradnorm_queue.add(float(max_grad_norm)) + else: + self.gradnorm_queue.add(float(grad_norm)) + + if float(grad_norm) > max_grad_norm: + print(f'Clipped gradient with value {grad_norm:.1f} ' + f'while allowed {max_grad_norm:.1f}') + + +class SizeGVPTransformer(GVPTransformerModel): + """ + GVP-Transformer model + + :param node_in_dim: node dimension in input graph, scalars or tuple (scalars, vectors) + :param node_h_dim: node dimensions to use in GVP-GNN layers, tuple (s, V) + :param out_nf: node dimensions of output feature, tuple (s, V) + :param edge_in_nf: edge dimension in input graph (scalars) + :param edge_h_dim: edge dimensions to embed to before use in GVP-GNN layers, + tuple (s, V) + :param num_layers: number of GVP-GNN layers + :param drop_rate: rate to use in all dropout layers + :param reflection_equiv: bool, use reflection-sensitive feature based on the + cross product if False + :param d_max: + :param num_rbf: + :param vector_gate: use vector gates in all GVPs + :param attention: can be used to turn off the attention mechanism + """ + def __init__(self, node_in_dim, node_h_dim, out_nf, edge_in_nf, + edge_h_dim, num_layers, dk, dv, de, db, dy, + attn_heads, n_feedforward, drop_rate, reflection_equiv=True, + d_max=20.0, num_rbf=16, vector_gate=False, attention=True): + + super(GVPTransformerModel, self).__init__() + + self.reflection_equiv = reflection_equiv + self.d_max = d_max + self.num_rbf = num_rbf + + # node_in_dim = (node_in_dim, 1) + if not isinstance(node_in_dim, tuple): + node_in_dim = (node_in_dim, 0) + + edge_in_dim = (edge_in_nf + 2 * node_in_dim[0] + self.num_rbf, 1) + if not self.reflection_equiv: + edge_in_dim = (edge_in_dim[0], edge_in_dim[1] + 1) + + self.W_v = GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=vector_gate) + self.W_e = GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=vector_gate) + + self.dy = dy + self.layers = nn.ModuleList( + GVPTransformerLayer(node_h_dim, edge_h_dim, dy, dk, dv, de, db, + attn_heads, n_feedforward=n_feedforward, + drop_rate=drop_rate, vector_gate=vector_gate, + activations=(F.relu, None), attention=attention) + for _ in range(num_layers)) + + self.W_y_out = GVP(dy, (out_nf, 0), activations=(None, None), vector_gate=vector_gate) + + def forward(self, h, x, edge_index, v=None, batch_mask=None, edge_attr=None): + + bs = len(batch_mask.unique()) + + # h_v = (h, x.unsqueeze(-2)) + h_v = h if v is None else (h, v) + h_e = self.edge_features(h, x, edge_index, batch_mask, edge_attr) + + h_v = self.W_v(h_v) + h_e = self.W_e(h_e) + h_y = (torch.zeros(bs, self.dy[0], device=h.device), + torch.zeros(bs, self.dy[1], 3, device=h.device)) + + for layer in self.layers: + h_v, h_e, h_y = layer(h_v, edge_index, batch_mask, h_e, h_y) + + return self.W_y_out(h_y) + + +class SizeGVPModel(GVPModel): + """ + GVP-GNN model + inspired by: https://github.com/drorlab/gvp-pytorch/blob/main/gvp/models.py + and: https://github.com/drorlab/gvp-pytorch/blob/82af6b22eaf8311c15733117b0071408d24ed877/gvp/atom3d.py#L115 + + :param node_in_dim: node dimension in input graph, scalars or tuple (scalars, vectors) + :param node_h_dim: node dimensions to use in GVP-GNN layers, tuple (s, V) + :param out_nf: node dimensions of output feature, tuple (s, V) + :param edge_in_nf: edge dimension in input graph (scalars) + :param edge_h_dim: edge dimensions to embed to before use in GVP-GNN layers, + tuple (s, V) + :param num_layers: number of GVP-GNN layers + :param drop_rate: rate to use in all dropout layers + :param vector_gate: use vector gates in all GVPs + :param reflection_equiv: bool, use reflection-sensitive feature based on the + cross product if False + :param d_max: + :param num_rbf: + :param update_edge_attr: bool, update edge attributes at each layer in a + learnable way + """ + def __init__(self, node_in_dim, node_h_dim, out_nf, + edge_in_nf, edge_h_dim, num_layers=3, drop_rate=0.1, + vector_gate=False, reflection_equiv=True, d_max=20.0, + num_rbf=16): + + super(GVPModel, self).__init__() + + self.reflection_equiv = reflection_equiv + self.d_max = d_max + self.num_rbf = num_rbf + + if not isinstance(node_in_dim, tuple): + node_in_dim = (node_in_dim, 0) + + edge_in_dim = (edge_in_nf + 2 * node_in_dim[0] + self.num_rbf, 1) + if not self.reflection_equiv: + edge_in_dim = (edge_in_dim[0], edge_in_dim[1] + 1) + + self.W_v = nn.Sequential( + LayerNorm(node_in_dim, learnable_vector_weight=True), + GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=vector_gate), + ) + self.W_e = nn.Sequential( + LayerNorm(edge_in_dim, learnable_vector_weight=True), + GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=vector_gate), + ) + + self.layers = nn.ModuleList( + GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate, + update_edge_attr=True, activations=(F.relu, None), + vector_gate=vector_gate, ln_vector_weight=True) + for _ in range(num_layers)) + + self.W_y_out = nn.Sequential( + # LayerNorm(node_h_dim, learnable_vector_weight=True), + # GVP(node_h_dim, node_h_dim, vector_gate=vector_gate), + LayerNorm(node_h_dim, learnable_vector_weight=True), + GVP(node_h_dim, (out_nf, 0), activations=(None, None), vector_gate=vector_gate), + ) + + def forward(self, h, x, edge_index, v=None, batch_mask=None, edge_attr=None): + + batch_size = len(torch.unique(batch_mask)) + + h_v = h if v is None else (h, v) + h_e = self.edge_features(h, x, edge_index, batch_mask, edge_attr) + + h_v = self.W_v(h_v) + h_e = self.W_e(h_e) + + for layer in self.layers: + h_v, h_e = layer(h_v, edge_index, edge_attr=h_e) + + # compute graph-level feature + sm = scatter_mean(h_v[0], batch_mask, dim=0, dim_size=batch_size) + vm = scatter_mean(h_v[1], batch_mask, dim=0, dim_size=batch_size) + + return self.W_y_out((sm, vm)) diff --git a/src/size_predictor/train.py b/src/size_predictor/train.py new file mode 100644 index 0000000000000000000000000000000000000000..006530e9f16f87a538f199313b6bd738d46bc284 --- /dev/null +++ b/src/size_predictor/train.py @@ -0,0 +1,161 @@ +import argparse +from argparse import Namespace +from pathlib import Path +import warnings + +import torch +import pytorch_lightning as pl +import yaml + +import sys +basedir = Path(__file__).resolve().parent.parent.parent +sys.path.append(str(basedir)) + +from src.size_predictor.size_model import SizeModel +from src.utils import set_deterministic, disable_rdkit_logging + + +def dict_to_namespace(input_dict): + """ Recursively convert a nested dictionary into a Namespace object """ + if isinstance(input_dict, dict): + output_namespace = Namespace() + output = output_namespace.__dict__ + for key, value in input_dict.items(): + output[key] = dict_to_namespace(value) + return output_namespace + + elif isinstance(input_dict, Namespace): + # recurse as Namespace might contain dictionaries + return dict_to_namespace(input_dict.__dict__) + + else: + return input_dict + + +def merge_args_and_yaml(args, config_dict): + arg_dict = args.__dict__ + for key, value in config_dict.items(): + if key in arg_dict: + warnings.warn(f"Command line argument '{key}' (value: " + f"{arg_dict[key]}) will be overwritten with value " + f"{value} provided in the config file.") + # if isinstance(value, dict): + # arg_dict[key] = Namespace(**value) + # else: + # arg_dict[key] = value + arg_dict[key] = dict_to_namespace(value) + + return args + + +def merge_configs(config, resume_config): + for key, value in resume_config.items(): + if isinstance(value, Namespace): + value = value.__dict__ + + if isinstance(value, dict): + # update dictionaries recursively + value = merge_configs(config[key], value) + + if key in config and config[key] != value: + warnings.warn(f"Config parameter '{key}' (value: " + f"{config[key]}) will be overwritten with value " + f"{value} from the checkpoint.") + config[key] = value + return config + + +# ------------------------------------------------------------------------------ +# Training +# ______________________________________________________________________________ +if __name__ == "__main__": + p = argparse.ArgumentParser() + p.add_argument('--config', type=str, required=True) + p.add_argument('--resume', type=str, default=None) + p.add_argument('--debug', action='store_true') + args = p.parse_args() + + set_deterministic(seed=42) + disable_rdkit_logging() + + with open(args.config, 'r') as f: + config = yaml.safe_load(f) + + assert 'resume' not in config + + # Get main config + ckpt_path = None if args.resume is None else Path(args.resume) + if args.resume is not None: + resume_config = torch.load( + ckpt_path, map_location=torch.device('cpu'))['hyper_parameters'] + config = merge_configs(config, resume_config) + + args = merge_args_and_yaml(args, config) + + if args.debug: + print('DEBUG MODE') + args.run_name = 'debug' + args.wandb_params.mode = 'disabled' + args.train_params.enable_progress_bar = True + args.train_params.num_workers = 0 + # torch.manual_seed(1234) + + out_dir = Path(args.train_params.logdir, args.run_name) + # args.eval_params.outdir = out_dir + pl_module = SizeModel( + max_size=args.max_size, + pocket_representation=args.pocket_representation, + train_params=args.train_params, + loss_params=args.loss_params, + eval_params=None, #args.eval_params, + predictor_params=args.predictor_params, + ) + + logger = pl.loggers.WandbLogger( + save_dir=args.train_params.logdir, + project='FlexFlow', + group=args.wandb_params.group, + name=args.run_name, + id=args.run_name, + resume='must' if args.resume is not None else False, + entity=args.wandb_params.entity, + mode=args.wandb_params.mode, + ) + + checkpoint_callbacks = [ + pl.callbacks.ModelCheckpoint( + dirpath=Path(out_dir, 'checkpoints'), + filename="best-acc={accuracy/val:.2f}-epoch={epoch:02d}", + monitor="accuracy/val", + save_top_k=1, + save_last=True, + mode="max", + # save_on_train_epoch_end=True, + ), + pl.callbacks.ModelCheckpoint( + dirpath=Path(out_dir, 'checkpoints'), + filename="best-mse={mse/val:.2f}-epoch={epoch:02d}", + monitor="loss/train", + save_top_k=1, + save_last=False, + mode="min", + ), + ] + + trainer = pl.Trainer( + max_epochs=args.train_params.n_epochs, + logger=logger, + callbacks=checkpoint_callbacks, + enable_progress_bar=args.train_params.enable_progress_bar, + # check_val_every_n_epoch=args.eval_params.eval_epochs, + num_sanity_val_steps=args.train_params.num_sanity_val_steps, + accumulate_grad_batches=args.train_params.accumulate_grad_batches, + accelerator='gpu' if args.train_params.gpus > 0 else 'cpu', + devices=args.train_params.gpus if args.train_params.gpus > 0 else 'auto', + strategy=('ddp' if args.train_params.gpus > 1 else None) + ) + + trainer.fit(model=pl_module, ckpt_path=ckpt_path) + + # # run test set + # result = trainer.test(ckpt_path='best') diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000000000000000000000000000000000000..9154ffcaf00b5d7db4fad9a74c432ee3ce34175f --- /dev/null +++ b/src/train.py @@ -0,0 +1,188 @@ +import argparse +from argparse import Namespace +from pathlib import Path +import warnings + +import torch +import pytorch_lightning as pl +import yaml + + +import sys +basedir = Path(__file__).resolve().parent.parent +sys.path.append(str(basedir)) + +from src.model.lightning import DrugFlow +from src.model.dpo import DPO +from src.utils import set_deterministic, disable_rdkit_logging, dict_to_namespace, namespace_to_dict + + +def merge_args_and_yaml(args, config_dict): + arg_dict = args.__dict__ + for key, value in config_dict.items(): + if key in arg_dict: + warnings.warn(f"Command line argument '{key}' (value: " + f"{arg_dict[key]}) will be overwritten with value " + f"{value} provided in the config file.") + # if isinstance(value, dict): + # arg_dict[key] = Namespace(**value) + # else: + # arg_dict[key] = value + arg_dict[key] = dict_to_namespace(value) + + return args + + +def merge_configs(config, resume_config): + for key, value in resume_config.items(): + if isinstance(value, Namespace): + value = value.__dict__ + + if isinstance(value, dict): + # update dictionaries recursively + value = merge_configs(config[key], value) + + if key in config and config[key] != value: + print(f'[CONFIG UPDATE] {key}: {value} -> {config[key]}') + return config + + +# ------------------------------------------------------------------------------ +# Training +# ______________________________________________________________________________ +if __name__ == "__main__": + p = argparse.ArgumentParser() + p.add_argument('--config', type=str, required=True) + p.add_argument('--resume', type=str, default=None) + p.add_argument('--backoff', action='store_true') + p.add_argument('--finetune', action='store_true') + p.add_argument('--debug', action='store_true') + p.add_argument('--overfit', action='store_true') + args = p.parse_args() + + set_deterministic(seed=42) + disable_rdkit_logging() + + with open(args.config, 'r') as f: + config = yaml.safe_load(f) + + assert 'resume' not in config + assert not (args.resume is not None and args.backoff) + config['dpo_mode'] = config.get('dpo_mode', None) + assert not (config['dpo_mode'] and 'checkpoint' not in config), 'DPO mode requires a reference checkpoint' + + if args.debug: + config['run_name'] = 'debug' + + out_dir = Path(config['train_params']['logdir'], config['run_name']) + checkpoints_root_dir = Path(out_dir, 'checkpoints') + if args.backoff: + last_checkpoint = Path(checkpoints_root_dir, 'last.ckpt') + print(f'Checking if there is a checkpoint at: {last_checkpoint}') + if last_checkpoint.exists(): + print(f'Found existing checkpoint: {last_checkpoint}') + args.resume = str(last_checkpoint) + else: + print(f'Did not find {last_checkpoint}') + + # Get main config + ckpt_path = None if args.resume is None else Path(args.resume) + if args.resume is not None and not args.finetune: + ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) + print(f'Resuming from epoch {ckpt["epoch"]}') + resume_config = ckpt['hyper_parameters'] + config = merge_configs(config, resume_config) + + args = merge_args_and_yaml(args, config) + + if args.debug: + print('DEBUG MODE') + args.wandb_params.mode = 'disabled' + args.train_params.enable_progress_bar = True + args.train_params.num_workers = 0 + + if args.overfit: + print('OVERFITTING MODE') + + args.eval_params.outdir = out_dir + model_class = DPO if args.dpo_mode else DrugFlow + model_args = { + 'pocket_representation': args.pocket_representation, + 'train_params': args.train_params, + 'loss_params': args.loss_params, + 'eval_params': args.eval_params, + 'predictor_params': args.predictor_params, + 'simulation_params': args.simulation_params, + 'virtual_nodes': args.virtual_nodes, + 'flexible': args.flexible, + 'flexible_bb': args.flexible_bb, + 'debug': args.debug, + 'overfit': args.overfit, + } + if args.dpo_mode: + print('DPO MODE') + model_args.update({ + 'dpo_mode': args.dpo_mode, + 'ref_checkpoint_p': args.checkpoint, + }) + pl_module = model_class(**model_args) + + resume_logging = False + if args.finetune: + resume_logging = 'allow' + elif args.resume is not None: + resume_logging = 'must' + + logger = pl.loggers.WandbLogger( + save_dir=args.train_params.logdir, + project='FlexFlow', + group=args.wandb_params.group, + name=args.run_name, + id=args.run_name, + resume=resume_logging, + entity=args.wandb_params.entity, + mode=args.wandb_params.mode, + ) + + checkpoint_callbacks = [ + pl.callbacks.ModelCheckpoint( + dirpath=checkpoints_root_dir, + save_last=True, + save_on_train_epoch_end=True, + ), + pl.callbacks.ModelCheckpoint( + dirpath=Path(checkpoints_root_dir, 'val_loss'), + filename="epoch_{epoch:04d}_loss_{loss/val:.3f}", + monitor="loss/val", + save_top_k=5, + mode="min", + auto_insert_metric_name=False, + ), + ] + + # For learning rate logging + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step') + + default_strategy = 'auto' if pl.__version__ >= '2.0.0' else None + trainer = pl.Trainer( + max_epochs=args.train_params.n_epochs, + logger=logger, + callbacks=checkpoint_callbacks + [lr_monitor], + enable_progress_bar=args.train_params.enable_progress_bar, + check_val_every_n_epoch=args.eval_params.eval_epochs, + num_sanity_val_steps=args.train_params.num_sanity_val_steps, + accumulate_grad_batches=args.train_params.accumulate_grad_batches, + accelerator='gpu' if args.train_params.gpus > 0 else 'cpu', + devices=args.train_params.gpus if args.train_params.gpus > 0 else 'auto', + strategy=('ddp_find_unused_parameters_true' if args.train_params.gpus > 1 else default_strategy), + use_distributed_sampler=False, + ) + + # add all arguments as dictionaries because WandB does not display + # nested Namespace objects correctly + logger.experiment.config.update({'as_dict': namespace_to_dict(args)}, allow_val_change=True) + + trainer.fit(model=pl_module, ckpt_path=ckpt_path) + + # # run test set + # result = trainer.test(ckpt_path='best') diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5c6b45caada87454de1286cd9279a5d777e70527 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,311 @@ +import warnings +from typing import Union, Iterable +import random +# import argparse +from argparse import Namespace + +import numpy as np +import torch +from rdkit import Chem, RDLogger +from rdkit.Chem import KekulizeException, AtomKekulizeException +import networkx as nx +from networkx.algorithms import isomorphism +from torch_scatter import scatter_add, scatter_mean + + +class Queue(): + def __init__(self, max_len=50): + self.items = [] + self.max_len = max_len + + def __len__(self): + return len(self.items) + + def add(self, item): + self.items.insert(0, item) + if len(self) > self.max_len: + self.items.pop() + + def mean(self): + return np.mean(self.items) + + def std(self): + return np.std(self.items) + + +def reverse_tensor(x): + return x[torch.arange(x.size(0) - 1, -1, -1)] + + +##### + + +def sum_except_batch(x, indices): + if len(x.size()) < 2: + x = x.unsqueeze(-1) + return scatter_add(x.sum(list(range(1, len(x.size())))), indices, dim=0) + + +def remove_mean_batch(x, batch_mask, dim_size=None): + # Compute center of mass per sample + mean = scatter_mean(x, batch_mask, dim=0, dim_size=dim_size) + x = x - mean[batch_mask] + return x, mean + + +def assert_mean_zero(x, batch_mask, thresh=1e-2, eps=1e-10): + largest_value = x.abs().max().item() + error = scatter_add(x, batch_mask, dim=0).abs().max().item() + rel_error = error / (largest_value + eps) + assert rel_error < thresh, f'Mean is not zero, relative_error {rel_error}' + + +def bvm(v, m): + """ + Batched vector-matrix product of the form out = v @ m + :param v: (b, n_in) + :param m: (b, n_in, n_out) + :return: (b, n_out) + """ + # return (v.unsqueeze(1) @ m).squeeze() + return torch.bmm(v.unsqueeze(1), m).squeeze(1) + + +def get_grad_norm( + parameters: Union[torch.Tensor, Iterable[torch.Tensor]], + norm_type: float = 2.0) -> torch.Tensor: + """ + Adapted from: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_ + """ + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + + norm_type = float(norm_type) + + if len(parameters) == 0: + return torch.tensor(0.) + + device = parameters[0].grad.device + + total_norm = torch.norm(torch.stack( + [torch.norm(p.grad.detach(), norm_type).to(device) for p in + parameters]), norm_type) + + return total_norm + + +def write_xyz_file(coords, atom_types, filename): + out = f"{len(coords)}\n\n" + assert len(coords) == len(atom_types) + for i in range(len(coords)): + out += f"{atom_types[i]} {coords[i, 0]:.3f} {coords[i, 1]:.3f} {coords[i, 2]:.3f}\n" + with open(filename, 'w') as f: + f.write(out) + + +def write_sdf_file(sdf_path, molecules, catch_errors=True, connected=False): + with Chem.SDWriter(str(sdf_path)) as w: + for mol in molecules: + try: + if mol is None: + raise ValueError("Mol is None.") + w.write(get_largest_connected_component(mol) if connected else mol) + + except (RuntimeError, ValueError) as e: + if not catch_errors: + raise e + + if isinstance(e, (KekulizeException, AtomKekulizeException)): + w.SetKekulize(False) + w.write(get_largest_connected_component(mol) if connected else mol) + w.SetKekulize(True) + warnings.warn(f"Mol saved without kekulization.") + else: + # write empty mol to preserve the original order + w.write(Chem.Mol()) + warnings.warn(f"Erroneous mol replaced with empty dummy.") + + +def get_largest_connected_component(mol): + try: + frags = Chem.GetMolFrags(mol, asMols=True) + newmol = max(frags, key=lambda m: m.GetNumAtoms()) + except: + newmol = mol + return newmol + + +def write_chain(filename, rdmol_chain): + with open(filename, 'w') as f: + f.write("".join([Chem.MolToXYZBlock(m) for m in rdmol_chain])) + + +def combine_sdfs(sdf_list, out_file): + all_content = [] + for sdf in sdf_list: + with open(sdf, 'r') as f: + all_content.append(f.read()) + combined_str = '$$$$\n'.join(all_content) + with open(out_file, 'w') as f: + f.write(combined_str) + + +def batch_to_list(data, batch_mask, keep_order=True): + if keep_order: # preserve order of elements within each sample + data_list = [data[batch_mask == i] + for i in torch.unique(batch_mask, sorted=True)] + return data_list + + # make sure batch_mask is increasing + idx = torch.argsort(batch_mask) + batch_mask = batch_mask[idx] + data = data[idx] + + chunk_sizes = torch.unique(batch_mask, return_counts=True)[1].tolist() + return torch.split(data, chunk_sizes) + + +def batch_to_list_for_indices(indices, batch_mask, offsets=None): + # (2, n) -> (n, 2) + split = batch_to_list(indices.T, batch_mask) + + # rebase indices at zero & (n, 2) -> (2, n) + if offsets is None: + warnings.warn("Trying to infer index offset from smallest element in " + "batch. This might be wrong.") + split = [x.T - x.min() for x in split] + else: + # Typically 'offsets' would be accumulate(sizes[:-1], initial=0) + assert len(offsets) == len(split) or indices.numel() == 0 + split = [x.T - offset for x, offset in zip(split, offsets)] + + return split + + +def num_nodes_to_batch_mask(n_samples, num_nodes, device): + assert isinstance(num_nodes, int) or len(num_nodes) == n_samples + + if isinstance(num_nodes, torch.Tensor): + num_nodes = num_nodes.to(device) + + sample_inds = torch.arange(n_samples, device=device) + + return torch.repeat_interleave(sample_inds, num_nodes) + + +def rdmol_to_nxgraph(rdmol): + graph = nx.Graph() + for atom in rdmol.GetAtoms(): + # Add the atoms as nodes + graph.add_node(atom.GetIdx(), atom_type=atom.GetAtomicNum()) + + # Add the bonds as edges + for bond in rdmol.GetBonds(): + graph.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) + + return graph + + +def calc_rmsd(mol_a, mol_b): + """ Calculate RMSD of two molecules with unknown atom correspondence. """ + graph_a = rdmol_to_nxgraph(mol_a) + graph_b = rdmol_to_nxgraph(mol_b) + + gm = isomorphism.GraphMatcher( + graph_a, graph_b, + node_match=lambda na, nb: na['atom_type'] == nb['atom_type']) + + isomorphisms = list(gm.isomorphisms_iter()) + if len(isomorphisms) < 1: + return None + + all_rmsds = [] + for mapping in isomorphisms: + atom_types_a = [atom.GetAtomicNum() for atom in mol_a.GetAtoms()] + atom_types_b = [mol_b.GetAtomWithIdx(mapping[i]).GetAtomicNum() + for i in range(mol_b.GetNumAtoms())] + assert atom_types_a == atom_types_b + + conf_a = mol_a.GetConformer() + coords_a = np.array([conf_a.GetAtomPosition(i) + for i in range(mol_a.GetNumAtoms())]) + conf_b = mol_b.GetConformer() + coords_b = np.array([conf_b.GetAtomPosition(mapping[i]) + for i in range(mol_b.GetNumAtoms())]) + + diff = coords_a - coords_b + rmsd = np.sqrt(np.mean(np.sum(diff * diff, axis=1))) + all_rmsds.append(rmsd) + + if len(isomorphisms) > 1: + print("More than one isomorphism found. Returning minimum RMSD.") + + return min(all_rmsds) + + +def set_deterministic(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def disable_rdkit_logging(): + # RDLogger.DisableLog('rdApp.*') + RDLogger.DisableLog('rdApp.info') + RDLogger.DisableLog('rdApp.error') + RDLogger.DisableLog('rdApp.warning') + + +# class Namespace(argparse.Namespace): +# """Simple definition of a Namespace class that supports Namespace- and +# dictionary-like access.""" +# def __getitem__(self, key): +# # return vars(self)[key] +# return self.__dict__[key] +# +# def __setitem__(self, key, value): +# self.__dict__[key] = value +# +# def __getattr__(self, item): +# """Supports other dictionary functionalities, e.g. get(), items(), etc.""" +# # return getattr(vars(self), item) +# return getattr(self.__dict__, item) + + +def dict_to_namespace(input_dict): + """ Recursively convert a nested dictionary into a Namespace object. """ + if isinstance(input_dict, dict): + output_namespace = Namespace() + output = output_namespace.__dict__ + for key, value in input_dict.items(): + output[key] = dict_to_namespace(value) + return output_namespace + + elif isinstance(input_dict, Namespace): + # recurse as Namespace might contain dictionaries + return dict_to_namespace(input_dict.__dict__) + + else: + return input_dict + + +def namespace_to_dict(x): + """ Recursively convert a nested Namespace object into a dictionary. """ + if not (isinstance(x, Namespace) or isinstance(x, dict)): + return x + + if isinstance(x, Namespace): + x = vars(x) + + # recurse + output = {} + for key, value in x.items(): + output[key] = namespace_to_dict(value) + return output