DrugFlow / src /model /dynamics.py
mority's picture
Upload 53 files
6e7d4ba verified
from collections.abc import Iterable
from abc import abstractmethod
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from src.constants import INT_TYPE
from src.model.gvp import GVPModel, GVP, LayerNorm
from src.model.gvp_transformer import GVPTransformerModel
from src.constants import FLOAT_TYPE
from pdb import set_trace
def binomial_coefficient(n, k):
# source: https://discuss.pytorch.org/t/n-choose-k-function/121974
return ((n + 1).lgamma() - (k + 1).lgamma() - ((n - k) + 1).lgamma()).exp()
def cycle_counts(adj):
assert (adj.diag() == 0).all()
assert (adj == adj.T).all()
A = adj.float()
d = A.sum(dim=-1)
# Compute powers
A2 = A @ A
A3 = A2 @ A
A4 = A3 @ A
A5 = A4 @ A
x3 = A3.diag() / 2
x4 = (A4.diag() - d * (d - 1) - A @ d) / 2
""" New (different from DiGress)
case where correction is relevant:
2 o
|
1,3 o--o 4
| /
0,5 o
"""
# Triangle count matrix (indicates for each node i how many triangles it shares with node j)
T = adj * A2
x5 = (A5.diag() - 2 * T @ d - 4 * d * x3 - 2 * A @ x3 + 10 * x3) / 2
# # TODO
# A6 = A5 @ A
#
# # 4-cycle count matrix (indicates in how many shared 4-cycles i and j are 2 hops apart)
# Q2 = binomial_coefficient(n=A2 - d.diag(), k=torch.tensor(2))
#
# # 4-cycle count matrix (indicates in how many shared 4-cycles i and j are 1 (and 3) hop(s) apart)
# Q1 = A * (A3 - (d.view(-1, 1) + d.view(1, -1)) + 1) # "+1" because link between i and j is subtracted twice
#
# x6 = ...
# return torch.stack([x3, x4, x5, x6], dim=-1)
return torch.stack([x3, x4, x5], dim=-1)
# TODO: also consider directional aggregation as in:
# Beaini, Dominique, et al. "Directional graph networks."
# International Conference on Machine Learning. PMLR, 2021.
def eigenfeatures(A, batch_mask, k=5):
# TODO, see:
# - https://github.com/cvignac/DiGress/blob/main/src/diffusion/extra_features.py
# - https://arxiv.org/pdf/2209.14734.pdf (Appendix B.2)
# split adjacency matrix
batch = []
for i in torch.unique(batch_mask, sorted=True): # TODO: optimize (try to avoid loop)
batch_inds = torch.where(batch_mask == i)[0]
batch.append(A[torch.meshgrid(batch_inds, batch_inds, indexing='ij')])
eigenfeats = [get_nontrivial_eigenvectors(adj)[:, :k] for adj in batch]
# if there are less than k non-trivial eigenvectors
eigenfeats = [torch.cat([
x, torch.zeros(x.size(0), max(k - x.size(1), 0), device=x.device)], dim=-1)
for x in eigenfeats]
return torch.cat(eigenfeats, dim=0)
def get_nontrivial_eigenvectors(A, normalize_l=True, thresh=1e-5,
norm_eps=1e-12):
"""
Compute eigenvectors of the graph Laplacian corresponding to non-zero
eigenvalues.
"""
assert (A == A.T).all(), "undirected graph"
# Compute laplacian
d = A.sum(-1)
D = d.diag()
L = D - A
if normalize_l:
D_inv_sqrt = (1 / (d.sqrt() + norm_eps)).diag()
L = D_inv_sqrt @ L @ D_inv_sqrt
# Eigendecomposition
# eigenvalues are sorted in ascending order
# eigvecs matrix contains eigenvectors as its columns
eigvals, eigvecs = torch.linalg.eigh(L)
# index of first non-trivial eigenvector
try:
idx = torch.nonzero(eigvals > thresh)[0].item()
except IndexError:
# recover if no non-trivial eigenvectors are found
idx = eigvecs.size(1)
return eigvecs[:, idx:]
class DynamicsBase(nn.Module):
"""
Implements self-conditioning logic and basic functions
"""
def __init__(
self,
predict_angles=False,
predict_frames=False,
add_cycle_counts=False,
add_spectral_feat=False,
self_conditioning=False,
augment_residue_sc=False,
augment_ligand_sc=False
):
super().__init__()
if not hasattr(self, 'predict_angles'):
self.predict_angles = predict_angles
if not hasattr(self, 'predict_frames'):
self.predict_frames = predict_frames
if not hasattr(self, 'add_cycle_counts'):
self.add_cycle_counts = add_cycle_counts
if not hasattr(self, 'add_spectral_feat'):
self.add_spectral_feat = add_spectral_feat
if not hasattr(self, 'self_conditioning'):
self.self_conditioning = self_conditioning
if not hasattr(self, 'augment_residue_sc'):
self.augment_residue_sc = augment_residue_sc
if not hasattr(self, 'augment_ligand_sc'):
self.augment_ligand_sc = augment_ligand_sc
if self.self_conditioning:
self.prev_ligand = None
self.prev_residues = None
@abstractmethod
def _forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None,
h_atoms_sc=None, e_atoms_sc=None, h_residues_sc=None):
"""
Implement forward pass.
Returns:
- vel
- h_final_atoms
- edge_final_atoms
- residue_angles
- residue_trans
- residue_rot
"""
pass
def make_sc_input(self, pred_ligand, pred_residues, sc_transform):
if self.predict_confidence:
h_atoms_sc = (torch.cat([pred_ligand['logits_h'], pred_ligand['uncertainty_vel'].unsqueeze(1)], dim=-1),
pred_ligand['vel'].unsqueeze(1))
else:
h_atoms_sc = (pred_ligand['logits_h'], pred_ligand['vel'].unsqueeze(1))
e_atoms_sc = pred_ligand['logits_e']
if self.predict_frames:
h_residues_sc = (torch.cat([pred_residues['chi'], pred_residues['rot']], dim=-1),
pred_residues['trans'].unsqueeze(1))
elif self.predict_angles:
h_residues_sc = pred_residues['chi']
else:
h_residues_sc = None
if self.augment_residue_sc and h_residues_sc is not None:
if self.predict_frames:
h_residues_sc = (h_residues_sc[0], torch.cat(
[h_residues_sc[1], sc_transform['residues'](pred_residues['chi'], pred_residues['trans'].squeeze(1), pred_residues['rot'])], dim=1))
else:
h_residues_sc = (h_residues_sc, sc_transform['residues'](pred_residues['chi']))
if self.augment_ligand_sc:
h_atoms_sc = (h_atoms_sc[0], torch.cat(
[h_atoms_sc[1], sc_transform['atoms'](pred_ligand['vel'].unsqueeze(1))], dim=1))
return h_atoms_sc, e_atoms_sc, h_residues_sc
def forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None, sc_transform=None):
"""
Implements self-conditioning as in https://arxiv.org/abs/2208.04202
"""
h_atoms_sc, e_atoms_sc = None, None
h_residues_sc = None
if self.self_conditioning:
# Sampling: use previous prediction in all but the first time step
if not self.training and t.min() > 0.0:
assert t.min() == t.max(), "currently only supports sampling at same time steps"
assert self.prev_ligand is not None
assert self.prev_residues is not None or not self.predict_frames
else:
# Create zero tensors
zeros_ligand = {'logits_h': torch.zeros_like(h_atoms),
'vel': torch.zeros_like(x_atoms),
'logits_e': torch.zeros_like(bonds_ligand[1])}
if self.predict_confidence:
zeros_ligand['uncertainty_vel'] = torch.zeros(
len(x_atoms), dtype=x_atoms.dtype, device=x_atoms.device)
zeros_residues = {}
if self.predict_angles:
zeros_residues['chi'] = torch.zeros((pocket['one_hot'].size(0), 5), device=pocket['one_hot'].device)
if self.predict_frames:
zeros_residues['trans'] = torch.zeros((pocket['one_hot'].size(0), 3), device=pocket['one_hot'].device)
zeros_residues['rot'] = torch.zeros((pocket['one_hot'].size(0), 3), device=pocket['one_hot'].device)
# Training: use 50% zeros and 50% predictions with detached gradients
if self.training and random.random() > 0.5:
with torch.no_grad():
h_atoms_sc, e_atoms_sc, h_residues_sc = self.make_sc_input(
zeros_ligand, zeros_residues, sc_transform)
self.prev_ligand, self.prev_residues = self._forward(
x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand,
h_atoms_sc, e_atoms_sc, h_residues_sc)
# use zeros for first sampling step and 50% of training
else:
self.prev_ligand = zeros_ligand
self.prev_residues = zeros_residues
h_atoms_sc, e_atoms_sc, h_residues_sc = self.make_sc_input(
self.prev_ligand, self.prev_residues, sc_transform)
pred_ligand, pred_residues = self._forward(
x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand,
h_atoms_sc, e_atoms_sc, h_residues_sc
)
if self.self_conditioning and not self.training:
self.prev_ligand = pred_ligand.copy()
self.prev_residues = pred_residues.copy()
return pred_ligand, pred_residues
def compute_extra_features(self, batch_mask, edge_indices, edge_types):
feat = torch.zeros(len(batch_mask), 0, device=batch_mask.device)
if not (self.add_cycle_counts or self.add_spectral_feat):
return feat
adj = batch_mask[:, None] == batch_mask[None, :]
E = torch.zeros_like(adj, dtype=INT_TYPE)
E[edge_indices[0], edge_indices[1]] = edge_types
A = (E > 0).float()
if self.add_cycle_counts:
cycle_features = cycle_counts(A)
cycle_features[cycle_features > 10] = 10 # avoid large values
feat = torch.cat([feat, cycle_features], dim=-1)
if self.add_spectral_feat:
feat = torch.cat([feat, eigenfeatures(A, batch_mask)], dim=-1)
return feat
class Dynamics(DynamicsBase):
def __init__(self, atom_nf, residue_nf, joint_nf, bond_dict, pocket_bond_dict,
edge_nf, hidden_nf, act_fn=torch.nn.SiLU(), condition_time=True,
model='egnn', model_params=None,
edge_cutoff_ligand=None, edge_cutoff_pocket=None,
edge_cutoff_interaction=None,
predict_angles=False, predict_frames=False,
add_cycle_counts=False, add_spectral_feat=False,
add_nma_feat=False, self_conditioning=False,
augment_residue_sc=False, augment_ligand_sc=False,
add_chi_as_feature=False, angle_act_fn=False):
super().__init__()
self.model = model
self.edge_cutoff_l = edge_cutoff_ligand
self.edge_cutoff_p = edge_cutoff_pocket
self.edge_cutoff_i = edge_cutoff_interaction
self.hidden_nf = hidden_nf
self.predict_angles = predict_angles
self.predict_frames = predict_frames
self.bond_dict = bond_dict
self.pocket_bond_dict = pocket_bond_dict
self.bond_nf = len(bond_dict)
self.pocket_bond_nf = len(pocket_bond_dict)
self.edge_nf = edge_nf
self.add_cycle_counts = add_cycle_counts
self.add_spectral_feat = add_spectral_feat
self.add_nma_feat = add_nma_feat
self.self_conditioning = self_conditioning
self.augment_residue_sc = augment_residue_sc
self.augment_ligand_sc = augment_ligand_sc
self.add_chi_as_feature = add_chi_as_feature
self.predict_confidence = False
if self.self_conditioning:
self.prev_vel = None
self.prev_h = None
self.prev_e = None
self.prev_a = None
self.prev_ca = None
self.prev_rot = None
lig_nf = atom_nf
if self.add_cycle_counts:
lig_nf = lig_nf + 3
if self.add_spectral_feat:
lig_nf = lig_nf + 5
if not isinstance(joint_nf, Iterable):
# joint_nf contains only scalars
joint_nf = (joint_nf, 0)
if isinstance(residue_nf, Iterable):
_atom_in_nf = (lig_nf, 0)
_residue_atom_dim = residue_nf[1]
if self.add_nma_feat:
residue_nf = (residue_nf[0], residue_nf[1] + 5)
if self.self_conditioning:
_atom_in_nf = (_atom_in_nf[0] + atom_nf, 1)
if self.augment_ligand_sc:
_atom_in_nf = (_atom_in_nf[0], _atom_in_nf[1] + 1)
if self.predict_angles:
residue_nf = (residue_nf[0] + 5, residue_nf[1])
if self.predict_frames:
residue_nf = (residue_nf[0], residue_nf[1] + 2)
if self.augment_residue_sc:
assert self.predict_angles
residue_nf = (residue_nf[0], residue_nf[1] + _residue_atom_dim)
if self.add_chi_as_feature:
residue_nf = (residue_nf[0] + 5, residue_nf[1])
self.atom_encoder = nn.Sequential(
GVP(_atom_in_nf, joint_nf, activations=(act_fn, torch.sigmoid)),
LayerNorm(joint_nf, learnable_vector_weight=True),
GVP(joint_nf, joint_nf, activations=(None, None)),
)
self.residue_encoder = nn.Sequential(
GVP(residue_nf, joint_nf, activations=(act_fn, torch.sigmoid)),
LayerNorm(joint_nf, learnable_vector_weight=True),
GVP(joint_nf, joint_nf, activations=(None, None)),
)
else:
# No vector-valued input features
assert joint_nf[1] == 0
# self-conditioning not yet supported
assert not self.self_conditioning
# Normal mode features are vectors
assert not self.add_nma_feat
if self.add_chi_as_feature:
residue_nf += 5
self.atom_encoder = nn.Sequential(
nn.Linear(lig_nf, 2 * atom_nf),
act_fn,
nn.Linear(2 * atom_nf, joint_nf[0])
)
self.residue_encoder = nn.Sequential(
nn.Linear(residue_nf, 2 * residue_nf),
act_fn,
nn.Linear(2 * residue_nf, joint_nf[0])
)
self.atom_decoder = nn.Sequential(
nn.Linear(joint_nf[0], 2 * atom_nf),
act_fn,
nn.Linear(2 * atom_nf, atom_nf)
)
self.edge_decoder = nn.Sequential(
nn.Linear(hidden_nf, hidden_nf),
act_fn,
nn.Linear(hidden_nf, self.bond_nf)
)
_atom_bond_nf = 2 * self.bond_nf if self.self_conditioning else self.bond_nf
self.ligand_bond_encoder = nn.Sequential(
nn.Linear(_atom_bond_nf, hidden_nf),
act_fn,
nn.Linear(hidden_nf, self.edge_nf)
)
self.pocket_bond_encoder = nn.Sequential(
nn.Linear(self.pocket_bond_nf, hidden_nf),
act_fn,
nn.Linear(hidden_nf, self.edge_nf)
)
out_nf = (joint_nf[0], 1)
res_out_nf = (0, 0)
if self.predict_angles:
res_out_nf = (res_out_nf[0] + 5, res_out_nf[1])
if self.predict_frames:
res_out_nf = (res_out_nf[0], res_out_nf[1] + 2)
self.residue_decoder = nn.Sequential(
GVP(out_nf, out_nf, activations=(act_fn, torch.sigmoid)),
LayerNorm(out_nf, learnable_vector_weight=True),
GVP(out_nf, res_out_nf, activations=(None, None)),
) if res_out_nf != (0, 0) else None
if angle_act_fn is None:
self.angle_act_fn = None
elif angle_act_fn == 'tanh':
self.angle_act_fn = lambda x: np.pi * F.tanh(x)
else:
raise NotImplementedError(f"Angle activation {angle_act_fn} not available")
# self.ligand_nobond_emb = nn.Parameter(torch.zeros(self.edge_nf))
# self.pocket_nobond_emb = nn.Parameter(torch.zeros(self.edge_nf))
self.cross_emb = nn.Parameter(torch.zeros(self.edge_nf),
requires_grad=True)
if condition_time:
dynamics_node_nf = (joint_nf[0] + 1, joint_nf[1])
else:
print('Warning: dynamics model is NOT conditioned on time.')
dynamics_node_nf = (joint_nf[0], joint_nf[1])
if model == 'egnn':
raise NotImplementedError
# self.net = EGNN(
# in_node_nf=dynamics_node_nf[0], in_edge_nf=self.edge_nf,
# hidden_nf=hidden_nf, out_node_nf=joint_nf[0],
# device=model_params.device, act_fn=act_fn,
# n_layers=model_params.n_layers,
# attention=model_params.attention,
# tanh=model_params.tanh,
# norm_constant=model_params.norm_constant,
# inv_sublayers=model_params.inv_sublayers,
# sin_embedding=model_params.sin_embedding,
# normalization_factor=model_params.normalization_factor,
# aggregation_method=model_params.aggregation_method,
# reflection_equiv=model_params.reflection_equivariant,
# update_edge_attr=True
# )
# self.node_nf = dynamics_node_nf[0]
elif model == 'gvp':
self.net = GVPModel(
node_in_dim=dynamics_node_nf, node_h_dim=model_params.node_h_dim,
node_out_nf=joint_nf[0], edge_in_nf=self.edge_nf,
edge_h_dim=model_params.edge_h_dim, edge_out_nf=hidden_nf,
num_layers=model_params.n_layers,
drop_rate=model_params.dropout,
vector_gate=model_params.vector_gate,
reflection_equiv=model_params.reflection_equivariant,
d_max=model_params.d_max,
num_rbf=model_params.num_rbf,
update_edge_attr=True
)
elif model == 'gvp_transformer':
self.net = GVPTransformerModel(
node_in_dim=dynamics_node_nf,
node_h_dim=model_params.node_h_dim,
node_out_nf=joint_nf[0],
edge_in_nf=self.edge_nf,
edge_h_dim=model_params.edge_h_dim,
edge_out_nf=hidden_nf,
num_layers=model_params.n_layers,
dk=model_params.dk,
dv=model_params.dv,
de=model_params.de,
db=model_params.db,
dy=model_params.dy,
attn_heads=model_params.attn_heads,
n_feedforward=model_params.n_feedforward,
drop_rate=model_params.dropout,
reflection_equiv=model_params.reflection_equivariant,
d_max=model_params.d_max,
num_rbf=model_params.num_rbf,
vector_gate=model_params.vector_gate,
attention=model_params.attention,
)
elif model == 'gnn':
raise NotImplementedError
# n_dims = 3
# self.net = GNN(
# in_node_nf=dynamics_node_nf + n_dims, in_edge_nf=self.edge_emb_dim,
# hidden_nf=hidden_nf, out_node_nf=n_dims + dynamics_node_nf,
# device=model_params.device, act_fn=act_fn, n_layers=model_params.n_layers,
# attention=model_params.attention, normalization_factor=model_params.normalization_factor,
# aggregation_method=model_params.aggregation_method)
else:
raise NotImplementedError(f"{model} is not available")
# self.device = device
# self.n_dims = n_dims
self.condition_time = condition_time
def _forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None,
h_atoms_sc=None, e_atoms_sc=None, h_residues_sc=None):
"""
:param x_atoms:
:param h_atoms:
:param mask_atoms:
:param pocket: must contain keys: 'x', 'one_hot', 'mask', 'bonds' and 'bond_one_hot'
:param t:
:param bonds_ligand: tuple - bond indices (2, n_bonds) & bond types (n_bonds, bond_nf)
:param h_atoms_sc: additional node feature for self-conditioning, (s, V)
:param e_atoms_sc: additional edge feature for self-conditioning, only scalar
:param h_residues_sc: additional node feature for self-conditioning, tensor or tuple
:return:
"""
x_residues, h_residues, mask_residues = pocket['x'], pocket['one_hot'], pocket['mask']
if 'bonds' in pocket:
bonds_pocket = (pocket['bonds'], pocket['bond_one_hot'])
else:
bonds_pocket = None
if self.add_chi_as_feature:
h_residues = torch.cat([h_residues, pocket['chi'][:, :5]], dim=-1)
if 'v' in pocket:
v_residues = pocket['v']
if self.add_nma_feat:
v_residues = torch.cat([v_residues, pocket['nma_vec']], dim=1)
h_residues = (h_residues, v_residues)
if h_residues_sc is not None:
# if self.augment_residue_sc:
if isinstance(h_residues_sc, tuple):
h_residues = (torch.cat([h_residues[0], h_residues_sc[0]], dim=-1),
torch.cat([h_residues[1], h_residues_sc[1]], dim=1))
else:
h_residues = (torch.cat([h_residues[0], h_residues_sc], dim=-1),
h_residues[1])
# get graph edges and edge attributes
if bonds_ligand is not None:
# NOTE: 'bond' denotes one-directional edges and 'edge' means bi-directional
ligand_bond_indices = bonds_ligand[0]
# make sure messages are passed both ways
ligand_edge_indices = torch.cat(
[bonds_ligand[0], bonds_ligand[0].flip(dims=[0])], dim=1)
ligand_edge_types = torch.cat([bonds_ligand[1], bonds_ligand[1]], dim=0)
# edges_ligand = (ligand_edge_indices, ligand_edge_types)
# add auxiliary features to ligand nodes
extra_features = self.compute_extra_features(
mask_atoms, ligand_edge_indices, ligand_edge_types.argmax(-1))
h_atoms = torch.cat([h_atoms, extra_features], dim=-1)
if bonds_pocket is not None:
# make sure messages are passed both ways
pocket_edge_indices = torch.cat(
[bonds_pocket[0], bonds_pocket[0].flip(dims=[0])], dim=1)
pocket_edge_types = torch.cat([bonds_pocket[1], bonds_pocket[1]], dim=0)
# edges_pocket = (pocket_edge_indices, pocket_edge_types)
if h_atoms_sc is not None:
h_atoms = (torch.cat([h_atoms, h_atoms_sc[0]], dim=-1),
h_atoms_sc[1])
if e_atoms_sc is not None:
e_atoms_sc = torch.cat([e_atoms_sc, e_atoms_sc], dim=0)
ligand_edge_types = torch.cat([ligand_edge_types, e_atoms_sc], dim=-1)
# embed atom features and residue features in a shared space
h_atoms = self.atom_encoder(h_atoms)
e_ligand = self.ligand_bond_encoder(ligand_edge_types)
if len(h_residues) > 0:
h_residues = self.residue_encoder(h_residues)
e_pocket = self.pocket_bond_encoder(pocket_edge_types)
else:
e_pocket = pocket_edge_types
h_residues = (h_residues, h_residues)
pocket_edge_indices = torch.tensor([[], []], dtype=torch.long, device=h_residues[0].device)
pocket_edge_types = torch.tensor([[], []], dtype=torch.long, device=h_residues[0].device)
if isinstance(h_atoms, tuple):
h_atoms, v_atoms = h_atoms
h_residues, v_residues = h_residues
v = torch.cat((v_atoms, v_residues), dim=0)
else:
v = None
edges, edge_feat = self.get_edges(
mask_atoms, mask_residues, x_atoms, x_residues,
bond_inds_ligand=ligand_edge_indices, bond_inds_pocket=pocket_edge_indices,
bond_feat_ligand=e_ligand, bond_feat_pocket=e_pocket)
# combine the two node types
x = torch.cat((x_atoms, x_residues), dim=0)
h = torch.cat((h_atoms, h_residues), dim=0)
mask = torch.cat([mask_atoms, mask_residues])
if self.condition_time:
if np.prod(t.size()) == 1:
# t is the same for all elements in batch.
h_time = torch.empty_like(h[:, 0:1]).fill_(t.item())
else:
# t is different over the batch dimension.
h_time = t[mask]
h = torch.cat([h, h_time], dim=1)
assert torch.all(mask[edges[0]] == mask[edges[1]])
if self.model == 'egnn':
# Don't update pocket coordinates
update_coords_mask = torch.cat((torch.ones_like(mask_atoms),
torch.zeros_like(mask_residues))).unsqueeze(1)
h_final, vel, edge_final = self.net(
h, x, edges, batch_mask=mask, edge_attr=edge_feat,
update_coords_mask=update_coords_mask)
# vel = (x_final - x)
elif self.model == 'gvp' or self.model == 'gvp_transformer':
h_final, vel, edge_final = self.net(
h, x, edges, v=v, batch_mask=mask, edge_attr=edge_feat)
elif self.model == 'gnn':
xh = torch.cat([x, h], dim=1)
output = self.net(xh, edges, node_mask=None, edge_attr=edge_feat)
vel = output[:, :3]
h_final = output[:, 3:]
else:
raise NotImplementedError(f"Wrong model ({self.model})")
# if self.condition_time:
# # Slice off last dimension which represented time.
# h_final = h_final[:, :-1]
# decode atom and residue features
h_final_atoms = self.atom_decoder(h_final[:len(mask_atoms)])
if torch.any(torch.isnan(vel)) or torch.any(torch.isnan(h_final_atoms)):
if self.training:
vel[torch.isnan(vel)] = 0.0
h_final_atoms[torch.isnan(h_final_atoms)] = 0.0
else:
raise ValueError("NaN detected in network output")
# predict edge type
ligand_edge_mask = (edges[0] < len(mask_atoms)) & (edges[1] < len(mask_atoms))
edge_final = edge_final[ligand_edge_mask]
edges = edges[:, ligand_edge_mask]
# Symmetrize
edge_logits = torch.zeros(
(len(mask_atoms), len(mask_atoms), self.hidden_nf),
device=mask_atoms.device)
edge_logits[edges[0], edges[1]] = edge_final
edge_logits = (edge_logits + edge_logits.transpose(0, 1)) * 0.5
# edge_logits = edge_logits[lig_edge_indices[0], lig_edge_indices[1]]
# return upper triangular elements only (matching the input)
edge_logits = edge_logits[ligand_bond_indices[0], ligand_bond_indices[1]]
# assert (edge_logits == 0).sum() == 0
edge_final_atoms = self.edge_decoder(edge_logits)
# Predict torsion angles
residue_angles = None
residue_trans, residue_rot = None, None
if self.residue_decoder is not None:
h_residues = h_final[len(mask_atoms):]
vec_residues = vel[len(mask_atoms):].unsqueeze(1)
residue_angles = self.residue_decoder((h_residues, vec_residues))
if self.predict_frames:
residue_angles, residue_frames = residue_angles
residue_trans = residue_frames[:, 0, :].squeeze(1)
residue_rot = residue_frames[:, 1, :].squeeze(1)
if self.angle_act_fn is not None:
residue_angles = self.angle_act_fn(residue_angles)
# return vel[:len(mask_atoms)], h_final_atoms, edge_final_atoms, residue_angles, residue_trans, residue_rot
pred_ligand = {'vel': vel[:len(mask_atoms)], 'logits_h': h_final_atoms, 'logits_e': edge_final_atoms}
pred_residues = {'chi': residue_angles, 'trans': residue_trans, 'rot': residue_rot}
return pred_ligand, pred_residues
def get_edges(self, batch_mask_ligand, batch_mask_pocket, x_ligand,
x_pocket, bond_inds_ligand=None, bond_inds_pocket=None,
bond_feat_ligand=None, bond_feat_pocket=None, self_edges=False):
# Adjacency matrix
adj_ligand = batch_mask_ligand[:, None] == batch_mask_ligand[None, :]
adj_pocket = batch_mask_pocket[:, None] == batch_mask_pocket[None, :]
adj_cross = batch_mask_ligand[:, None] == batch_mask_pocket[None, :]
if self.edge_cutoff_l is not None:
adj_ligand = adj_ligand & (torch.cdist(x_ligand, x_ligand) <= self.edge_cutoff_l)
# Add missing bonds if they got removed
adj_ligand[bond_inds_ligand[0], bond_inds_ligand[1]] = True
if self.edge_cutoff_p is not None and len(x_pocket) > 0:
adj_pocket = adj_pocket & (torch.cdist(x_pocket, x_pocket) <= self.edge_cutoff_p)
# Add missing bonds if they got removed
adj_pocket[bond_inds_pocket[0], bond_inds_pocket[1]] = True
if self.edge_cutoff_i is not None and len(x_pocket) > 0:
adj_cross = adj_cross & (torch.cdist(x_ligand, x_pocket) <= self.edge_cutoff_i)
adj = torch.cat((torch.cat((adj_ligand, adj_cross), dim=1),
torch.cat((adj_cross.T, adj_pocket), dim=1)), dim=0)
if not self_edges:
adj = adj ^ torch.eye(*adj.size(), out=torch.empty_like(adj))
# # ensure that edge definition is consistent if bonds are provided (for loss computation)
# if bond_inds_ligand is not None:
# # remove ligand edges
# adj[:adj_ligand.size(0), :adj_ligand.size(1)] = False
# edges = torch.stack(torch.where(adj), dim=0)
# # add ligand edges back with original definition
# edges = torch.cat([bond_inds_ligand, edges], dim=-1)
# else:
# edges = torch.stack(torch.where(adj), dim=0)
# Feature matrix
ligand_nobond_onehot = F.one_hot(torch.tensor(
self.bond_dict['NOBOND'], device=bond_feat_ligand.device),
num_classes=self.ligand_bond_encoder[0].in_features)
ligand_nobond_emb = self.ligand_bond_encoder(
ligand_nobond_onehot.to(FLOAT_TYPE))
feat_ligand = ligand_nobond_emb.repeat(*adj_ligand.shape, 1)
feat_ligand[bond_inds_ligand[0], bond_inds_ligand[1]] = bond_feat_ligand
if len(adj_pocket) > 0:
pocket_nobond_onehot = F.one_hot(torch.tensor(
self.pocket_bond_dict['NOBOND'], device=bond_feat_pocket.device),
num_classes=self.pocket_bond_nf)
pocket_nobond_emb = self.pocket_bond_encoder(
pocket_nobond_onehot.to(FLOAT_TYPE))
feat_pocket = pocket_nobond_emb.repeat(*adj_pocket.shape, 1)
feat_pocket[bond_inds_pocket[0], bond_inds_pocket[1]] = bond_feat_pocket
feat_cross = self.cross_emb.repeat(*adj_cross.shape, 1)
feats = torch.cat((torch.cat((feat_ligand, feat_cross), dim=1),
torch.cat((feat_cross.transpose(0, 1), feat_pocket), dim=1)), dim=0)
else:
feats = feat_ligand
# Return results
edges = torch.stack(torch.where(adj), dim=0)
edge_feat = feats[edges[0], edges[1]]
return edges, edge_feat