|
|
from collections.abc import Iterable |
|
|
from abc import abstractmethod |
|
|
import random |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
|
|
|
from src.constants import INT_TYPE |
|
|
from src.model.gvp import GVPModel, GVP, LayerNorm |
|
|
from src.model.gvp_transformer import GVPTransformerModel |
|
|
from src.constants import FLOAT_TYPE |
|
|
|
|
|
from pdb import set_trace |
|
|
|
|
|
|
|
|
def binomial_coefficient(n, k): |
|
|
|
|
|
return ((n + 1).lgamma() - (k + 1).lgamma() - ((n - k) + 1).lgamma()).exp() |
|
|
|
|
|
|
|
|
def cycle_counts(adj): |
|
|
assert (adj.diag() == 0).all() |
|
|
assert (adj == adj.T).all() |
|
|
|
|
|
A = adj.float() |
|
|
d = A.sum(dim=-1) |
|
|
|
|
|
|
|
|
A2 = A @ A |
|
|
A3 = A2 @ A |
|
|
A4 = A3 @ A |
|
|
A5 = A4 @ A |
|
|
|
|
|
x3 = A3.diag() / 2 |
|
|
x4 = (A4.diag() - d * (d - 1) - A @ d) / 2 |
|
|
|
|
|
""" New (different from DiGress) |
|
|
case where correction is relevant: |
|
|
2 o |
|
|
| |
|
|
1,3 o--o 4 |
|
|
| / |
|
|
0,5 o |
|
|
""" |
|
|
|
|
|
T = adj * A2 |
|
|
x5 = (A5.diag() - 2 * T @ d - 4 * d * x3 - 2 * A @ x3 + 10 * x3) / 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return torch.stack([x3, x4, x5], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def eigenfeatures(A, batch_mask, k=5): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch = [] |
|
|
for i in torch.unique(batch_mask, sorted=True): |
|
|
batch_inds = torch.where(batch_mask == i)[0] |
|
|
batch.append(A[torch.meshgrid(batch_inds, batch_inds, indexing='ij')]) |
|
|
|
|
|
eigenfeats = [get_nontrivial_eigenvectors(adj)[:, :k] for adj in batch] |
|
|
|
|
|
eigenfeats = [torch.cat([ |
|
|
x, torch.zeros(x.size(0), max(k - x.size(1), 0), device=x.device)], dim=-1) |
|
|
for x in eigenfeats] |
|
|
return torch.cat(eigenfeats, dim=0) |
|
|
|
|
|
|
|
|
def get_nontrivial_eigenvectors(A, normalize_l=True, thresh=1e-5, |
|
|
norm_eps=1e-12): |
|
|
""" |
|
|
Compute eigenvectors of the graph Laplacian corresponding to non-zero |
|
|
eigenvalues. |
|
|
""" |
|
|
assert (A == A.T).all(), "undirected graph" |
|
|
|
|
|
|
|
|
d = A.sum(-1) |
|
|
D = d.diag() |
|
|
L = D - A |
|
|
|
|
|
if normalize_l: |
|
|
D_inv_sqrt = (1 / (d.sqrt() + norm_eps)).diag() |
|
|
L = D_inv_sqrt @ L @ D_inv_sqrt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eigvals, eigvecs = torch.linalg.eigh(L) |
|
|
|
|
|
|
|
|
try: |
|
|
idx = torch.nonzero(eigvals > thresh)[0].item() |
|
|
except IndexError: |
|
|
|
|
|
idx = eigvecs.size(1) |
|
|
|
|
|
return eigvecs[:, idx:] |
|
|
|
|
|
|
|
|
class DynamicsBase(nn.Module): |
|
|
""" |
|
|
Implements self-conditioning logic and basic functions |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
predict_angles=False, |
|
|
predict_frames=False, |
|
|
add_cycle_counts=False, |
|
|
add_spectral_feat=False, |
|
|
self_conditioning=False, |
|
|
augment_residue_sc=False, |
|
|
augment_ligand_sc=False |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if not hasattr(self, 'predict_angles'): |
|
|
self.predict_angles = predict_angles |
|
|
|
|
|
if not hasattr(self, 'predict_frames'): |
|
|
self.predict_frames = predict_frames |
|
|
|
|
|
if not hasattr(self, 'add_cycle_counts'): |
|
|
self.add_cycle_counts = add_cycle_counts |
|
|
|
|
|
if not hasattr(self, 'add_spectral_feat'): |
|
|
self.add_spectral_feat = add_spectral_feat |
|
|
|
|
|
if not hasattr(self, 'self_conditioning'): |
|
|
self.self_conditioning = self_conditioning |
|
|
|
|
|
if not hasattr(self, 'augment_residue_sc'): |
|
|
self.augment_residue_sc = augment_residue_sc |
|
|
|
|
|
if not hasattr(self, 'augment_ligand_sc'): |
|
|
self.augment_ligand_sc = augment_ligand_sc |
|
|
|
|
|
if self.self_conditioning: |
|
|
self.prev_ligand = None |
|
|
self.prev_residues = None |
|
|
|
|
|
@abstractmethod |
|
|
def _forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None, |
|
|
h_atoms_sc=None, e_atoms_sc=None, h_residues_sc=None): |
|
|
""" |
|
|
Implement forward pass. |
|
|
Returns: |
|
|
- vel |
|
|
- h_final_atoms |
|
|
- edge_final_atoms |
|
|
- residue_angles |
|
|
- residue_trans |
|
|
- residue_rot |
|
|
""" |
|
|
pass |
|
|
|
|
|
def make_sc_input(self, pred_ligand, pred_residues, sc_transform): |
|
|
|
|
|
if self.predict_confidence: |
|
|
h_atoms_sc = (torch.cat([pred_ligand['logits_h'], pred_ligand['uncertainty_vel'].unsqueeze(1)], dim=-1), |
|
|
pred_ligand['vel'].unsqueeze(1)) |
|
|
else: |
|
|
h_atoms_sc = (pred_ligand['logits_h'], pred_ligand['vel'].unsqueeze(1)) |
|
|
e_atoms_sc = pred_ligand['logits_e'] |
|
|
|
|
|
if self.predict_frames: |
|
|
h_residues_sc = (torch.cat([pred_residues['chi'], pred_residues['rot']], dim=-1), |
|
|
pred_residues['trans'].unsqueeze(1)) |
|
|
elif self.predict_angles: |
|
|
h_residues_sc = pred_residues['chi'] |
|
|
else: |
|
|
h_residues_sc = None |
|
|
|
|
|
if self.augment_residue_sc and h_residues_sc is not None: |
|
|
if self.predict_frames: |
|
|
h_residues_sc = (h_residues_sc[0], torch.cat( |
|
|
[h_residues_sc[1], sc_transform['residues'](pred_residues['chi'], pred_residues['trans'].squeeze(1), pred_residues['rot'])], dim=1)) |
|
|
|
|
|
else: |
|
|
h_residues_sc = (h_residues_sc, sc_transform['residues'](pred_residues['chi'])) |
|
|
|
|
|
if self.augment_ligand_sc: |
|
|
h_atoms_sc = (h_atoms_sc[0], torch.cat( |
|
|
[h_atoms_sc[1], sc_transform['atoms'](pred_ligand['vel'].unsqueeze(1))], dim=1)) |
|
|
|
|
|
return h_atoms_sc, e_atoms_sc, h_residues_sc |
|
|
|
|
|
def forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None, sc_transform=None): |
|
|
""" |
|
|
Implements self-conditioning as in https://arxiv.org/abs/2208.04202 |
|
|
""" |
|
|
|
|
|
h_atoms_sc, e_atoms_sc = None, None |
|
|
h_residues_sc = None |
|
|
|
|
|
if self.self_conditioning: |
|
|
|
|
|
|
|
|
if not self.training and t.min() > 0.0: |
|
|
assert t.min() == t.max(), "currently only supports sampling at same time steps" |
|
|
assert self.prev_ligand is not None |
|
|
assert self.prev_residues is not None or not self.predict_frames |
|
|
|
|
|
else: |
|
|
|
|
|
zeros_ligand = {'logits_h': torch.zeros_like(h_atoms), |
|
|
'vel': torch.zeros_like(x_atoms), |
|
|
'logits_e': torch.zeros_like(bonds_ligand[1])} |
|
|
if self.predict_confidence: |
|
|
zeros_ligand['uncertainty_vel'] = torch.zeros( |
|
|
len(x_atoms), dtype=x_atoms.dtype, device=x_atoms.device) |
|
|
|
|
|
zeros_residues = {} |
|
|
if self.predict_angles: |
|
|
zeros_residues['chi'] = torch.zeros((pocket['one_hot'].size(0), 5), device=pocket['one_hot'].device) |
|
|
if self.predict_frames: |
|
|
zeros_residues['trans'] = torch.zeros((pocket['one_hot'].size(0), 3), device=pocket['one_hot'].device) |
|
|
zeros_residues['rot'] = torch.zeros((pocket['one_hot'].size(0), 3), device=pocket['one_hot'].device) |
|
|
|
|
|
|
|
|
if self.training and random.random() > 0.5: |
|
|
with torch.no_grad(): |
|
|
h_atoms_sc, e_atoms_sc, h_residues_sc = self.make_sc_input( |
|
|
zeros_ligand, zeros_residues, sc_transform) |
|
|
|
|
|
self.prev_ligand, self.prev_residues = self._forward( |
|
|
x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand, |
|
|
h_atoms_sc, e_atoms_sc, h_residues_sc) |
|
|
|
|
|
|
|
|
else: |
|
|
self.prev_ligand = zeros_ligand |
|
|
self.prev_residues = zeros_residues |
|
|
|
|
|
h_atoms_sc, e_atoms_sc, h_residues_sc = self.make_sc_input( |
|
|
self.prev_ligand, self.prev_residues, sc_transform) |
|
|
|
|
|
pred_ligand, pred_residues = self._forward( |
|
|
x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand, |
|
|
h_atoms_sc, e_atoms_sc, h_residues_sc |
|
|
) |
|
|
|
|
|
if self.self_conditioning and not self.training: |
|
|
self.prev_ligand = pred_ligand.copy() |
|
|
self.prev_residues = pred_residues.copy() |
|
|
|
|
|
return pred_ligand, pred_residues |
|
|
|
|
|
def compute_extra_features(self, batch_mask, edge_indices, edge_types): |
|
|
|
|
|
feat = torch.zeros(len(batch_mask), 0, device=batch_mask.device) |
|
|
|
|
|
if not (self.add_cycle_counts or self.add_spectral_feat): |
|
|
return feat |
|
|
|
|
|
adj = batch_mask[:, None] == batch_mask[None, :] |
|
|
|
|
|
E = torch.zeros_like(adj, dtype=INT_TYPE) |
|
|
E[edge_indices[0], edge_indices[1]] = edge_types |
|
|
|
|
|
A = (E > 0).float() |
|
|
|
|
|
if self.add_cycle_counts: |
|
|
cycle_features = cycle_counts(A) |
|
|
cycle_features[cycle_features > 10] = 10 |
|
|
|
|
|
feat = torch.cat([feat, cycle_features], dim=-1) |
|
|
|
|
|
if self.add_spectral_feat: |
|
|
feat = torch.cat([feat, eigenfeatures(A, batch_mask)], dim=-1) |
|
|
|
|
|
return feat |
|
|
|
|
|
|
|
|
class Dynamics(DynamicsBase): |
|
|
def __init__(self, atom_nf, residue_nf, joint_nf, bond_dict, pocket_bond_dict, |
|
|
edge_nf, hidden_nf, act_fn=torch.nn.SiLU(), condition_time=True, |
|
|
model='egnn', model_params=None, |
|
|
edge_cutoff_ligand=None, edge_cutoff_pocket=None, |
|
|
edge_cutoff_interaction=None, |
|
|
predict_angles=False, predict_frames=False, |
|
|
add_cycle_counts=False, add_spectral_feat=False, |
|
|
add_nma_feat=False, self_conditioning=False, |
|
|
augment_residue_sc=False, augment_ligand_sc=False, |
|
|
add_chi_as_feature=False, angle_act_fn=False): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
self.edge_cutoff_l = edge_cutoff_ligand |
|
|
self.edge_cutoff_p = edge_cutoff_pocket |
|
|
self.edge_cutoff_i = edge_cutoff_interaction |
|
|
self.hidden_nf = hidden_nf |
|
|
self.predict_angles = predict_angles |
|
|
self.predict_frames = predict_frames |
|
|
self.bond_dict = bond_dict |
|
|
self.pocket_bond_dict = pocket_bond_dict |
|
|
self.bond_nf = len(bond_dict) |
|
|
self.pocket_bond_nf = len(pocket_bond_dict) |
|
|
self.edge_nf = edge_nf |
|
|
self.add_cycle_counts = add_cycle_counts |
|
|
self.add_spectral_feat = add_spectral_feat |
|
|
self.add_nma_feat = add_nma_feat |
|
|
self.self_conditioning = self_conditioning |
|
|
self.augment_residue_sc = augment_residue_sc |
|
|
self.augment_ligand_sc = augment_ligand_sc |
|
|
self.add_chi_as_feature = add_chi_as_feature |
|
|
self.predict_confidence = False |
|
|
|
|
|
if self.self_conditioning: |
|
|
self.prev_vel = None |
|
|
self.prev_h = None |
|
|
self.prev_e = None |
|
|
self.prev_a = None |
|
|
self.prev_ca = None |
|
|
self.prev_rot = None |
|
|
|
|
|
lig_nf = atom_nf |
|
|
if self.add_cycle_counts: |
|
|
lig_nf = lig_nf + 3 |
|
|
if self.add_spectral_feat: |
|
|
lig_nf = lig_nf + 5 |
|
|
|
|
|
|
|
|
if not isinstance(joint_nf, Iterable): |
|
|
|
|
|
joint_nf = (joint_nf, 0) |
|
|
|
|
|
|
|
|
if isinstance(residue_nf, Iterable): |
|
|
_atom_in_nf = (lig_nf, 0) |
|
|
_residue_atom_dim = residue_nf[1] |
|
|
|
|
|
if self.add_nma_feat: |
|
|
residue_nf = (residue_nf[0], residue_nf[1] + 5) |
|
|
|
|
|
if self.self_conditioning: |
|
|
_atom_in_nf = (_atom_in_nf[0] + atom_nf, 1) |
|
|
|
|
|
if self.augment_ligand_sc: |
|
|
_atom_in_nf = (_atom_in_nf[0], _atom_in_nf[1] + 1) |
|
|
|
|
|
if self.predict_angles: |
|
|
residue_nf = (residue_nf[0] + 5, residue_nf[1]) |
|
|
|
|
|
if self.predict_frames: |
|
|
residue_nf = (residue_nf[0], residue_nf[1] + 2) |
|
|
|
|
|
if self.augment_residue_sc: |
|
|
assert self.predict_angles |
|
|
residue_nf = (residue_nf[0], residue_nf[1] + _residue_atom_dim) |
|
|
|
|
|
if self.add_chi_as_feature: |
|
|
residue_nf = (residue_nf[0] + 5, residue_nf[1]) |
|
|
|
|
|
self.atom_encoder = nn.Sequential( |
|
|
GVP(_atom_in_nf, joint_nf, activations=(act_fn, torch.sigmoid)), |
|
|
LayerNorm(joint_nf, learnable_vector_weight=True), |
|
|
GVP(joint_nf, joint_nf, activations=(None, None)), |
|
|
) |
|
|
|
|
|
self.residue_encoder = nn.Sequential( |
|
|
GVP(residue_nf, joint_nf, activations=(act_fn, torch.sigmoid)), |
|
|
LayerNorm(joint_nf, learnable_vector_weight=True), |
|
|
GVP(joint_nf, joint_nf, activations=(None, None)), |
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
assert joint_nf[1] == 0 |
|
|
|
|
|
|
|
|
assert not self.self_conditioning |
|
|
|
|
|
|
|
|
assert not self.add_nma_feat |
|
|
|
|
|
if self.add_chi_as_feature: |
|
|
residue_nf += 5 |
|
|
|
|
|
self.atom_encoder = nn.Sequential( |
|
|
nn.Linear(lig_nf, 2 * atom_nf), |
|
|
act_fn, |
|
|
nn.Linear(2 * atom_nf, joint_nf[0]) |
|
|
) |
|
|
|
|
|
self.residue_encoder = nn.Sequential( |
|
|
nn.Linear(residue_nf, 2 * residue_nf), |
|
|
act_fn, |
|
|
nn.Linear(2 * residue_nf, joint_nf[0]) |
|
|
) |
|
|
|
|
|
self.atom_decoder = nn.Sequential( |
|
|
nn.Linear(joint_nf[0], 2 * atom_nf), |
|
|
act_fn, |
|
|
nn.Linear(2 * atom_nf, atom_nf) |
|
|
) |
|
|
|
|
|
self.edge_decoder = nn.Sequential( |
|
|
nn.Linear(hidden_nf, hidden_nf), |
|
|
act_fn, |
|
|
nn.Linear(hidden_nf, self.bond_nf) |
|
|
) |
|
|
|
|
|
_atom_bond_nf = 2 * self.bond_nf if self.self_conditioning else self.bond_nf |
|
|
self.ligand_bond_encoder = nn.Sequential( |
|
|
nn.Linear(_atom_bond_nf, hidden_nf), |
|
|
act_fn, |
|
|
nn.Linear(hidden_nf, self.edge_nf) |
|
|
) |
|
|
|
|
|
self.pocket_bond_encoder = nn.Sequential( |
|
|
nn.Linear(self.pocket_bond_nf, hidden_nf), |
|
|
act_fn, |
|
|
nn.Linear(hidden_nf, self.edge_nf) |
|
|
) |
|
|
|
|
|
out_nf = (joint_nf[0], 1) |
|
|
res_out_nf = (0, 0) |
|
|
if self.predict_angles: |
|
|
res_out_nf = (res_out_nf[0] + 5, res_out_nf[1]) |
|
|
if self.predict_frames: |
|
|
res_out_nf = (res_out_nf[0], res_out_nf[1] + 2) |
|
|
self.residue_decoder = nn.Sequential( |
|
|
GVP(out_nf, out_nf, activations=(act_fn, torch.sigmoid)), |
|
|
LayerNorm(out_nf, learnable_vector_weight=True), |
|
|
GVP(out_nf, res_out_nf, activations=(None, None)), |
|
|
) if res_out_nf != (0, 0) else None |
|
|
|
|
|
if angle_act_fn is None: |
|
|
self.angle_act_fn = None |
|
|
elif angle_act_fn == 'tanh': |
|
|
self.angle_act_fn = lambda x: np.pi * F.tanh(x) |
|
|
else: |
|
|
raise NotImplementedError(f"Angle activation {angle_act_fn} not available") |
|
|
|
|
|
|
|
|
|
|
|
self.cross_emb = nn.Parameter(torch.zeros(self.edge_nf), |
|
|
requires_grad=True) |
|
|
|
|
|
if condition_time: |
|
|
dynamics_node_nf = (joint_nf[0] + 1, joint_nf[1]) |
|
|
else: |
|
|
print('Warning: dynamics model is NOT conditioned on time.') |
|
|
dynamics_node_nf = (joint_nf[0], joint_nf[1]) |
|
|
|
|
|
if model == 'egnn': |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif model == 'gvp': |
|
|
self.net = GVPModel( |
|
|
node_in_dim=dynamics_node_nf, node_h_dim=model_params.node_h_dim, |
|
|
node_out_nf=joint_nf[0], edge_in_nf=self.edge_nf, |
|
|
edge_h_dim=model_params.edge_h_dim, edge_out_nf=hidden_nf, |
|
|
num_layers=model_params.n_layers, |
|
|
drop_rate=model_params.dropout, |
|
|
vector_gate=model_params.vector_gate, |
|
|
reflection_equiv=model_params.reflection_equivariant, |
|
|
d_max=model_params.d_max, |
|
|
num_rbf=model_params.num_rbf, |
|
|
update_edge_attr=True |
|
|
) |
|
|
|
|
|
elif model == 'gvp_transformer': |
|
|
self.net = GVPTransformerModel( |
|
|
node_in_dim=dynamics_node_nf, |
|
|
node_h_dim=model_params.node_h_dim, |
|
|
node_out_nf=joint_nf[0], |
|
|
edge_in_nf=self.edge_nf, |
|
|
edge_h_dim=model_params.edge_h_dim, |
|
|
edge_out_nf=hidden_nf, |
|
|
num_layers=model_params.n_layers, |
|
|
dk=model_params.dk, |
|
|
dv=model_params.dv, |
|
|
de=model_params.de, |
|
|
db=model_params.db, |
|
|
dy=model_params.dy, |
|
|
attn_heads=model_params.attn_heads, |
|
|
n_feedforward=model_params.n_feedforward, |
|
|
drop_rate=model_params.dropout, |
|
|
reflection_equiv=model_params.reflection_equivariant, |
|
|
d_max=model_params.d_max, |
|
|
num_rbf=model_params.num_rbf, |
|
|
vector_gate=model_params.vector_gate, |
|
|
attention=model_params.attention, |
|
|
) |
|
|
|
|
|
elif model == 'gnn': |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
raise NotImplementedError(f"{model} is not available") |
|
|
|
|
|
|
|
|
|
|
|
self.condition_time = condition_time |
|
|
|
|
|
def _forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None, |
|
|
h_atoms_sc=None, e_atoms_sc=None, h_residues_sc=None): |
|
|
""" |
|
|
:param x_atoms: |
|
|
:param h_atoms: |
|
|
:param mask_atoms: |
|
|
:param pocket: must contain keys: 'x', 'one_hot', 'mask', 'bonds' and 'bond_one_hot' |
|
|
:param t: |
|
|
:param bonds_ligand: tuple - bond indices (2, n_bonds) & bond types (n_bonds, bond_nf) |
|
|
:param h_atoms_sc: additional node feature for self-conditioning, (s, V) |
|
|
:param e_atoms_sc: additional edge feature for self-conditioning, only scalar |
|
|
:param h_residues_sc: additional node feature for self-conditioning, tensor or tuple |
|
|
:return: |
|
|
""" |
|
|
x_residues, h_residues, mask_residues = pocket['x'], pocket['one_hot'], pocket['mask'] |
|
|
if 'bonds' in pocket: |
|
|
bonds_pocket = (pocket['bonds'], pocket['bond_one_hot']) |
|
|
else: |
|
|
bonds_pocket = None |
|
|
|
|
|
if self.add_chi_as_feature: |
|
|
h_residues = torch.cat([h_residues, pocket['chi'][:, :5]], dim=-1) |
|
|
|
|
|
if 'v' in pocket: |
|
|
v_residues = pocket['v'] |
|
|
if self.add_nma_feat: |
|
|
v_residues = torch.cat([v_residues, pocket['nma_vec']], dim=1) |
|
|
h_residues = (h_residues, v_residues) |
|
|
|
|
|
if h_residues_sc is not None: |
|
|
|
|
|
if isinstance(h_residues_sc, tuple): |
|
|
h_residues = (torch.cat([h_residues[0], h_residues_sc[0]], dim=-1), |
|
|
torch.cat([h_residues[1], h_residues_sc[1]], dim=1)) |
|
|
else: |
|
|
h_residues = (torch.cat([h_residues[0], h_residues_sc], dim=-1), |
|
|
h_residues[1]) |
|
|
|
|
|
|
|
|
if bonds_ligand is not None: |
|
|
|
|
|
ligand_bond_indices = bonds_ligand[0] |
|
|
|
|
|
|
|
|
ligand_edge_indices = torch.cat( |
|
|
[bonds_ligand[0], bonds_ligand[0].flip(dims=[0])], dim=1) |
|
|
ligand_edge_types = torch.cat([bonds_ligand[1], bonds_ligand[1]], dim=0) |
|
|
|
|
|
|
|
|
|
|
|
extra_features = self.compute_extra_features( |
|
|
mask_atoms, ligand_edge_indices, ligand_edge_types.argmax(-1)) |
|
|
h_atoms = torch.cat([h_atoms, extra_features], dim=-1) |
|
|
|
|
|
if bonds_pocket is not None: |
|
|
|
|
|
pocket_edge_indices = torch.cat( |
|
|
[bonds_pocket[0], bonds_pocket[0].flip(dims=[0])], dim=1) |
|
|
pocket_edge_types = torch.cat([bonds_pocket[1], bonds_pocket[1]], dim=0) |
|
|
|
|
|
|
|
|
if h_atoms_sc is not None: |
|
|
h_atoms = (torch.cat([h_atoms, h_atoms_sc[0]], dim=-1), |
|
|
h_atoms_sc[1]) |
|
|
|
|
|
if e_atoms_sc is not None: |
|
|
e_atoms_sc = torch.cat([e_atoms_sc, e_atoms_sc], dim=0) |
|
|
ligand_edge_types = torch.cat([ligand_edge_types, e_atoms_sc], dim=-1) |
|
|
|
|
|
|
|
|
h_atoms = self.atom_encoder(h_atoms) |
|
|
e_ligand = self.ligand_bond_encoder(ligand_edge_types) |
|
|
|
|
|
if len(h_residues) > 0: |
|
|
h_residues = self.residue_encoder(h_residues) |
|
|
e_pocket = self.pocket_bond_encoder(pocket_edge_types) |
|
|
else: |
|
|
e_pocket = pocket_edge_types |
|
|
h_residues = (h_residues, h_residues) |
|
|
pocket_edge_indices = torch.tensor([[], []], dtype=torch.long, device=h_residues[0].device) |
|
|
pocket_edge_types = torch.tensor([[], []], dtype=torch.long, device=h_residues[0].device) |
|
|
|
|
|
if isinstance(h_atoms, tuple): |
|
|
h_atoms, v_atoms = h_atoms |
|
|
h_residues, v_residues = h_residues |
|
|
v = torch.cat((v_atoms, v_residues), dim=0) |
|
|
else: |
|
|
v = None |
|
|
|
|
|
edges, edge_feat = self.get_edges( |
|
|
mask_atoms, mask_residues, x_atoms, x_residues, |
|
|
bond_inds_ligand=ligand_edge_indices, bond_inds_pocket=pocket_edge_indices, |
|
|
bond_feat_ligand=e_ligand, bond_feat_pocket=e_pocket) |
|
|
|
|
|
|
|
|
x = torch.cat((x_atoms, x_residues), dim=0) |
|
|
h = torch.cat((h_atoms, h_residues), dim=0) |
|
|
mask = torch.cat([mask_atoms, mask_residues]) |
|
|
|
|
|
if self.condition_time: |
|
|
if np.prod(t.size()) == 1: |
|
|
|
|
|
h_time = torch.empty_like(h[:, 0:1]).fill_(t.item()) |
|
|
else: |
|
|
|
|
|
h_time = t[mask] |
|
|
h = torch.cat([h, h_time], dim=1) |
|
|
|
|
|
assert torch.all(mask[edges[0]] == mask[edges[1]]) |
|
|
|
|
|
if self.model == 'egnn': |
|
|
|
|
|
update_coords_mask = torch.cat((torch.ones_like(mask_atoms), |
|
|
torch.zeros_like(mask_residues))).unsqueeze(1) |
|
|
h_final, vel, edge_final = self.net( |
|
|
h, x, edges, batch_mask=mask, edge_attr=edge_feat, |
|
|
update_coords_mask=update_coords_mask) |
|
|
|
|
|
|
|
|
elif self.model == 'gvp' or self.model == 'gvp_transformer': |
|
|
h_final, vel, edge_final = self.net( |
|
|
h, x, edges, v=v, batch_mask=mask, edge_attr=edge_feat) |
|
|
|
|
|
elif self.model == 'gnn': |
|
|
xh = torch.cat([x, h], dim=1) |
|
|
output = self.net(xh, edges, node_mask=None, edge_attr=edge_feat) |
|
|
vel = output[:, :3] |
|
|
h_final = output[:, 3:] |
|
|
|
|
|
else: |
|
|
raise NotImplementedError(f"Wrong model ({self.model})") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
h_final_atoms = self.atom_decoder(h_final[:len(mask_atoms)]) |
|
|
|
|
|
if torch.any(torch.isnan(vel)) or torch.any(torch.isnan(h_final_atoms)): |
|
|
if self.training: |
|
|
vel[torch.isnan(vel)] = 0.0 |
|
|
h_final_atoms[torch.isnan(h_final_atoms)] = 0.0 |
|
|
else: |
|
|
raise ValueError("NaN detected in network output") |
|
|
|
|
|
|
|
|
ligand_edge_mask = (edges[0] < len(mask_atoms)) & (edges[1] < len(mask_atoms)) |
|
|
edge_final = edge_final[ligand_edge_mask] |
|
|
edges = edges[:, ligand_edge_mask] |
|
|
|
|
|
|
|
|
edge_logits = torch.zeros( |
|
|
(len(mask_atoms), len(mask_atoms), self.hidden_nf), |
|
|
device=mask_atoms.device) |
|
|
edge_logits[edges[0], edges[1]] = edge_final |
|
|
edge_logits = (edge_logits + edge_logits.transpose(0, 1)) * 0.5 |
|
|
|
|
|
|
|
|
|
|
|
edge_logits = edge_logits[ligand_bond_indices[0], ligand_bond_indices[1]] |
|
|
|
|
|
|
|
|
edge_final_atoms = self.edge_decoder(edge_logits) |
|
|
|
|
|
|
|
|
residue_angles = None |
|
|
residue_trans, residue_rot = None, None |
|
|
if self.residue_decoder is not None: |
|
|
h_residues = h_final[len(mask_atoms):] |
|
|
vec_residues = vel[len(mask_atoms):].unsqueeze(1) |
|
|
residue_angles = self.residue_decoder((h_residues, vec_residues)) |
|
|
if self.predict_frames: |
|
|
residue_angles, residue_frames = residue_angles |
|
|
residue_trans = residue_frames[:, 0, :].squeeze(1) |
|
|
residue_rot = residue_frames[:, 1, :].squeeze(1) |
|
|
if self.angle_act_fn is not None: |
|
|
residue_angles = self.angle_act_fn(residue_angles) |
|
|
|
|
|
|
|
|
pred_ligand = {'vel': vel[:len(mask_atoms)], 'logits_h': h_final_atoms, 'logits_e': edge_final_atoms} |
|
|
pred_residues = {'chi': residue_angles, 'trans': residue_trans, 'rot': residue_rot} |
|
|
return pred_ligand, pred_residues |
|
|
|
|
|
def get_edges(self, batch_mask_ligand, batch_mask_pocket, x_ligand, |
|
|
x_pocket, bond_inds_ligand=None, bond_inds_pocket=None, |
|
|
bond_feat_ligand=None, bond_feat_pocket=None, self_edges=False): |
|
|
|
|
|
|
|
|
adj_ligand = batch_mask_ligand[:, None] == batch_mask_ligand[None, :] |
|
|
adj_pocket = batch_mask_pocket[:, None] == batch_mask_pocket[None, :] |
|
|
adj_cross = batch_mask_ligand[:, None] == batch_mask_pocket[None, :] |
|
|
|
|
|
if self.edge_cutoff_l is not None: |
|
|
adj_ligand = adj_ligand & (torch.cdist(x_ligand, x_ligand) <= self.edge_cutoff_l) |
|
|
|
|
|
|
|
|
adj_ligand[bond_inds_ligand[0], bond_inds_ligand[1]] = True |
|
|
|
|
|
if self.edge_cutoff_p is not None and len(x_pocket) > 0: |
|
|
adj_pocket = adj_pocket & (torch.cdist(x_pocket, x_pocket) <= self.edge_cutoff_p) |
|
|
|
|
|
|
|
|
adj_pocket[bond_inds_pocket[0], bond_inds_pocket[1]] = True |
|
|
|
|
|
if self.edge_cutoff_i is not None and len(x_pocket) > 0: |
|
|
adj_cross = adj_cross & (torch.cdist(x_ligand, x_pocket) <= self.edge_cutoff_i) |
|
|
|
|
|
adj = torch.cat((torch.cat((adj_ligand, adj_cross), dim=1), |
|
|
torch.cat((adj_cross.T, adj_pocket), dim=1)), dim=0) |
|
|
|
|
|
if not self_edges: |
|
|
adj = adj ^ torch.eye(*adj.size(), out=torch.empty_like(adj)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ligand_nobond_onehot = F.one_hot(torch.tensor( |
|
|
self.bond_dict['NOBOND'], device=bond_feat_ligand.device), |
|
|
num_classes=self.ligand_bond_encoder[0].in_features) |
|
|
ligand_nobond_emb = self.ligand_bond_encoder( |
|
|
ligand_nobond_onehot.to(FLOAT_TYPE)) |
|
|
feat_ligand = ligand_nobond_emb.repeat(*adj_ligand.shape, 1) |
|
|
feat_ligand[bond_inds_ligand[0], bond_inds_ligand[1]] = bond_feat_ligand |
|
|
|
|
|
if len(adj_pocket) > 0: |
|
|
pocket_nobond_onehot = F.one_hot(torch.tensor( |
|
|
self.pocket_bond_dict['NOBOND'], device=bond_feat_pocket.device), |
|
|
num_classes=self.pocket_bond_nf) |
|
|
pocket_nobond_emb = self.pocket_bond_encoder( |
|
|
pocket_nobond_onehot.to(FLOAT_TYPE)) |
|
|
feat_pocket = pocket_nobond_emb.repeat(*adj_pocket.shape, 1) |
|
|
feat_pocket[bond_inds_pocket[0], bond_inds_pocket[1]] = bond_feat_pocket |
|
|
|
|
|
feat_cross = self.cross_emb.repeat(*adj_cross.shape, 1) |
|
|
|
|
|
feats = torch.cat((torch.cat((feat_ligand, feat_cross), dim=1), |
|
|
torch.cat((feat_cross.transpose(0, 1), feat_pocket), dim=1)), dim=0) |
|
|
else: |
|
|
feats = feat_ligand |
|
|
|
|
|
|
|
|
edges = torch.stack(torch.where(adj), dim=0) |
|
|
edge_feat = feats[edges[0], edges[1]] |
|
|
|
|
|
return edges, edge_feat |
|
|
|