File size: 9,798 Bytes
6e7d4ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
"""
Natural Extension Reference Frame (NERF)
Inspiration for parallel reconstruction:
https://github.com/EleutherAI/mp_nerf and references therein
For atom names, see also:
https://www.ccpn.ac.uk/manual/v3/NEFAtomNames.html
References:
- https://onlinelibrary.wiley.com/doi/10.1002/jcc.20237 (NERF)
- https://onlinelibrary.wiley.com/doi/10.1002/jcc.26768 (for code)
"""
import warnings
import torch
import numpy as np
from src.data.misc import protein_letters_3to1
from src.constants import aa_atom_index, aa_atom_mask, aa_nerf_indices, aa_chi_indices, aa_chi_anchor_atom
# https://github.com/EleutherAI/mp_nerf/blob/master/mp_nerf/utils.py
def get_dihedral(c1, c2, c3, c4):
""" Returns the dihedral angle in radians.
Will use atan2 formula from:
https://en.wikipedia.org/wiki/Dihedral_angle#In_polymer_physics
Inputs:
* c1: (batch, 3) or (3,)
* c2: (batch, 3) or (3,)
* c3: (batch, 3) or (3,)
* c4: (batch, 3) or (3,)
"""
u1 = c2 - c1
u2 = c3 - c2
u3 = c4 - c3
return torch.atan2( ( (torch.norm(u2, dim=-1, keepdim=True) * u1) * torch.cross(u2,u3, dim=-1) ).sum(dim=-1) ,
( torch.cross(u1,u2, dim=-1) * torch.cross(u2, u3, dim=-1) ).sum(dim=-1) )
# https://github.com/EleutherAI/mp_nerf/blob/master/mp_nerf/utils.py
def get_angle(c1, c2, c3):
""" Returns the angle in radians.
Inputs:
* c1: (batch, 3) or (3,)
* c2: (batch, 3) or (3,)
* c3: (batch, 3) or (3,)
"""
u1 = c2 - c1
u2 = c3 - c2
# dont use acos since norms involved.
# better use atan2 formula: atan2(cross, dot) from here:
# https://johnblackburne.blogspot.com/2012/05/angle-between-two-3d-vectors.html
# add a minus since we want the angle in reversed order - sidechainnet issues
return torch.atan2( torch.norm(torch.cross(u1,u2, dim=-1), dim=-1),
-(u1*u2).sum(dim=-1) )
def get_nerf_params(biopython_residue):
aa = protein_letters_3to1[biopython_residue.get_resname()]
# Basic mask and index tensors
atom_mask = torch.tensor(aa_atom_mask[aa], dtype=bool)
nerf_indices = torch.tensor(aa_nerf_indices[aa], dtype=int)
chi_indices = torch.tensor(aa_chi_indices[aa], dtype=int)
fixed_coord = torch.zeros((5, 3))
residue_coords = torch.zeros((14, 3)) # only required to compute internal coordinates during pre-processing
atom_found = torch.zeros_like(atom_mask)
for atom in biopython_residue.get_atoms():
try:
idx = aa_atom_index[aa][atom.get_name()]
atom_found[idx] = True
except KeyError:
warnings.warn(f"{atom.get_name()} not found")
continue
residue_coords[idx, :] = torch.from_numpy(atom.get_coord())
if atom.get_name() in ['N', 'CA', 'C', 'O', 'CB']:
fixed_coord[idx, :] = torch.from_numpy(atom.get_coord())
# Determine chi angles
chi = torch.zeros(6) # the last chi angle is a dummy and should always be zero
for chi_idx, anchor in aa_chi_anchor_atom[aa].items():
idx_a = nerf_indices[anchor, 2]
idx_b = nerf_indices[anchor, 1]
idx_c = nerf_indices[anchor, 0]
coords_a = residue_coords[idx_a, :]
coords_b = residue_coords[idx_b, :]
coords_c = residue_coords[idx_c, :]
coords_d = residue_coords[anchor, :]
chi[chi_idx] = get_dihedral(coords_a, coords_b, coords_c, coords_d)
# Compute remaining internal coordinates
# (parallel version)
idx_a = nerf_indices[:, 2]
idx_b = nerf_indices[:, 1]
idx_c = nerf_indices[:, 0]
# update atom mask
# remove atoms for which one or several parameters are missing/incorrect
_atom_mask = atom_mask & atom_found & atom_found[idx_a] & atom_found[idx_b] & atom_found[idx_c]
if not torch.all(_atom_mask == atom_mask):
warnings.warn("Some atoms are missing for NERF reconstruction")
atom_mask = _atom_mask
coords_a = residue_coords[idx_a]
coords_b = residue_coords[idx_b]
coords_c = residue_coords[idx_c]
coords_d = residue_coords
length = torch.norm(coords_d - coords_c, dim=-1)
theta = get_angle(coords_b, coords_c, coords_d)
ddihedral = get_dihedral(coords_a, coords_b, coords_c, coords_d)
# subtract chi angles from dihedrals
ddihedral = ddihedral - chi[chi_indices]
# # (serial version)
# length = torch.zeros(14)
# theta = torch.zeros(14)
# ddihedral = torch.zeros(14)
# for i in range(5, 14):
# if not atom_mask[i]: # atom doesn't exist
# continue
# idx_a = nerf_indices[i, 2]
# idx_b = nerf_indices[i, 1]
# idx_c = nerf_indices[i, 0]
# coords_a = residue_coords[idx_a]
# coords_b = residue_coords[idx_b]
# coords_c = residue_coords[idx_c]
# coords_d = residue_coords[i]
# length[i] = torch.norm(coords_d - coords_c, dim=-1)
# theta[i] = get_angle(coords_b, coords_c, coords_d)
# ddihedral[i] = get_dihedral(coords_a, coords_b, coords_c, coords_d)
# # subtract chi angles from dihedrals
# ddihedral[i] = ddihedral[i] - chi[chi_indices[i]]
return {
'fixed_coord': fixed_coord,
'atom_mask': atom_mask,
'nerf_indices': nerf_indices,
'length': length,
'theta': theta,
'chi': chi,
'ddihedral': ddihedral,
'chi_indices': chi_indices,
}
# https://github.com/EleutherAI/mp_nerf/blob/master/mp_nerf/massive_pnerf.py#L38C1-L65C67
def mp_nerf_torch(a, b, c, l, theta, chi):
""" Custom Natural extension of Reference Frame.
Inputs:
* a: (batch, 3) or (3,). point(s) of the plane, not connected to d
* b: (batch, 3) or (3,). point(s) of the plane, not connected to d
* c: (batch, 3) or (3,). point(s) of the plane, connected to d
* theta: (batch,) or (float). angle(s) between b-c-d
* chi: (batch,) or float. dihedral angle(s) between the a-b-c and b-c-d planes
Outputs: d (batch, 3) or (float). the next point in the sequence, linked to c
"""
# safety check
if not ( (-np.pi <= theta) * (theta <= np.pi) ).all().item():
raise ValueError(f"theta(s) must be in radians and in [-pi, pi]. theta(s) = {theta}")
# calc vecs
ba = b-a
cb = c-b
# calc rotation matrix. based on plane normals and normalized
n_plane = torch.cross(ba, cb, dim=-1)
n_plane_ = torch.cross(n_plane, cb, dim=-1)
rotate = torch.stack([cb, n_plane_, n_plane], dim=-1)
rotate /= torch.norm(rotate, dim=-2, keepdim=True)
# calc proto point, rotate. add (-1 for sidechainnet convention)
# https://github.com/jonathanking/sidechainnet/issues/14
d = torch.stack([-torch.cos(theta),
torch.sin(theta) * torch.cos(chi),
torch.sin(theta) * torch.sin(chi)], dim=-1).unsqueeze(-1)
# extend base point, set length
return c + l.unsqueeze(-1) * torch.matmul(rotate, d).squeeze()
# inspired by: https://github.com/EleutherAI/mp_nerf/blob/master/mp_nerf/proteins.py#L323C5-L344C65
def ic_to_coords(fixed_coord, atom_mask, nerf_indices, length, theta, chi, ddihedral, chi_indices):
"""
Run NERF in parallel for all residues.
:param fixed_coord: (L, 5, 3) coordinates of (N, CA, C, O, CB) atoms, they don't depend on chi angles
:param atom_mask: (L, 14) indicates whether atom exists in this residue
:param nerf_indices: (L, 14, 3) indices of the three previous atoms ({c, b, a} for the NERF algorithm)
:param length: (L, 14) bond length between this and previous atom
:param theta: (L, 14) angle between this and previous two atoms
:param chi: (L, 6) values of the 5 rotatable bonds, plus zero in last column
:param ddihedral: (L, 14) angle offset to which chi is added
:param chi_indices: (L, 14) indexes into the chi array
:returns: (L, 14, 3) tensor with all coordinates, non-existing atoms are assigned CA coords
"""
if not torch.all(chi[:, 5] == 0):
chi[:, 5] = 0.0
warnings.warn("Last column of 'chi' tensor should be zero. Overriding values.")
assert torch.all(chi[:, 5] == 0)
L, device = fixed_coord.size(0), fixed_coord.device
coords = torch.zeros((L, 14, 3), device=device)
coords[:, :5, :] = fixed_coord
for i in range(5, 14):
level_mask = atom_mask[:, i]
# level_mask = torch.ones(len(atom_mask), dtype=bool)
length_i = length[level_mask, i]
theta_i = theta[level_mask, i]
# dihedral_i = dihedral[level_mask, i]
dihedral_i = chi[level_mask, chi_indices[level_mask, i]] + ddihedral[level_mask, i]
idx_a = nerf_indices[level_mask, i, 2]
idx_b = nerf_indices[level_mask, i, 1]
idx_c = nerf_indices[level_mask, i, 0]
coords[level_mask, i] = mp_nerf_torch(coords[level_mask, idx_a],
coords[level_mask, idx_b],
coords[level_mask, idx_c],
length_i,
theta_i,
dihedral_i)
if coords.isnan().any():
warnings.warn("Side chain reconstruction error. Removing affected atoms...")
# mask out affected atoms
m, n, _ = torch.where(coords.isnan())
atom_mask[m, n] = False
coords[m, n, :] = 0.0
# replace non-existing atom coords with CA coords (TODO: don't hard-code CA index)
coords = atom_mask.unsqueeze(-1) * coords + \
(~atom_mask.unsqueeze(2)) * coords[:, 1, :].unsqueeze(1)
return coords
|