|
|
import math |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch_scatter import scatter_add, scatter_mean |
|
|
|
|
|
import utils |
|
|
from equivariant_diffusion.en_diffusion import EnVariationalDiffusion |
|
|
|
|
|
|
|
|
class ConditionalDDPM(EnVariationalDiffusion): |
|
|
""" |
|
|
Conditional Diffusion Module. |
|
|
""" |
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
assert not self.dynamics.update_pocket_coords |
|
|
|
|
|
def kl_prior(self, xh_lig, mask_lig, num_nodes): |
|
|
"""Computes the KL between q(z1 | x) and the prior p(z1) = Normal(0, 1). |
|
|
|
|
|
This is essentially a lot of work for something that is in practice |
|
|
negligible in the loss. However, you compute it so that you see it when |
|
|
you've made a mistake in your noise schedule. |
|
|
""" |
|
|
batch_size = len(num_nodes) |
|
|
|
|
|
|
|
|
ones = torch.ones((batch_size, 1), device=xh_lig.device) |
|
|
gamma_T = self.gamma(ones) |
|
|
alpha_T = self.alpha(gamma_T, xh_lig) |
|
|
|
|
|
|
|
|
mu_T_lig = alpha_T[mask_lig] * xh_lig |
|
|
mu_T_lig_x, mu_T_lig_h = \ |
|
|
mu_T_lig[:, :self.n_dims], mu_T_lig[:, self.n_dims:] |
|
|
|
|
|
|
|
|
sigma_T_x = self.sigma(gamma_T, mu_T_lig_x).squeeze() |
|
|
sigma_T_h = self.sigma(gamma_T, mu_T_lig_h).squeeze() |
|
|
|
|
|
|
|
|
zeros = torch.zeros_like(mu_T_lig_h) |
|
|
ones = torch.ones_like(sigma_T_h) |
|
|
mu_norm2 = self.sum_except_batch((mu_T_lig_h - zeros) ** 2, mask_lig) |
|
|
kl_distance_h = self.gaussian_KL(mu_norm2, sigma_T_h, ones, d=1) |
|
|
|
|
|
|
|
|
zeros = torch.zeros_like(mu_T_lig_x) |
|
|
ones = torch.ones_like(sigma_T_x) |
|
|
mu_norm2 = self.sum_except_batch((mu_T_lig_x - zeros) ** 2, mask_lig) |
|
|
subspace_d = self.subspace_dimensionality(num_nodes) |
|
|
kl_distance_x = self.gaussian_KL(mu_norm2, sigma_T_x, ones, subspace_d) |
|
|
|
|
|
return kl_distance_x + kl_distance_h |
|
|
|
|
|
def log_pxh_given_z0_without_constants(self, ligand, z_0_lig, eps_lig, |
|
|
net_out_lig, gamma_0, epsilon=1e-10): |
|
|
|
|
|
|
|
|
z_h_lig = z_0_lig[:, self.n_dims:] |
|
|
|
|
|
|
|
|
eps_lig_x = eps_lig[:, :self.n_dims] |
|
|
net_lig_x = net_out_lig[:, :self.n_dims] |
|
|
|
|
|
|
|
|
sigma_0 = self.sigma(gamma_0, target_tensor=z_0_lig) |
|
|
sigma_0_cat = sigma_0 * self.norm_values[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
squared_error = (eps_lig_x - net_lig_x) ** 2 |
|
|
if self.vnode_idx is not None: |
|
|
|
|
|
squared_error[ligand['one_hot'][:, self.vnode_idx].bool(), :self.n_dims] = 0 |
|
|
log_p_x_given_z0_without_constants_ligand = -0.5 * ( |
|
|
self.sum_except_batch(squared_error, ligand['mask']) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
ligand_onehot = ligand['one_hot'] * self.norm_values[1] + self.norm_biases[1] |
|
|
|
|
|
estimated_ligand_onehot = z_h_lig * self.norm_values[1] + self.norm_biases[1] |
|
|
|
|
|
|
|
|
centered_ligand_onehot = estimated_ligand_onehot - 1 |
|
|
|
|
|
|
|
|
|
|
|
log_ph_cat_proportional_ligand = torch.log( |
|
|
self.cdf_standard_gaussian((centered_ligand_onehot + 0.5) / sigma_0_cat[ligand['mask']]) |
|
|
- self.cdf_standard_gaussian((centered_ligand_onehot - 0.5) / sigma_0_cat[ligand['mask']]) |
|
|
+ epsilon |
|
|
) |
|
|
|
|
|
|
|
|
log_Z = torch.logsumexp(log_ph_cat_proportional_ligand, dim=1, |
|
|
keepdim=True) |
|
|
log_probabilities_ligand = log_ph_cat_proportional_ligand - log_Z |
|
|
|
|
|
|
|
|
|
|
|
log_ph_given_z0_ligand = self.sum_except_batch( |
|
|
log_probabilities_ligand * ligand_onehot, ligand['mask']) |
|
|
|
|
|
return log_p_x_given_z0_without_constants_ligand, log_ph_given_z0_ligand |
|
|
|
|
|
def sample_p_xh_given_z0(self, z0_lig, xh0_pocket, lig_mask, pocket_mask, |
|
|
batch_size, fix_noise=False): |
|
|
"""Samples x ~ p(x|z0).""" |
|
|
t_zeros = torch.zeros(size=(batch_size, 1), device=z0_lig.device) |
|
|
gamma_0 = self.gamma(t_zeros) |
|
|
|
|
|
sigma_x = self.SNR(-0.5 * gamma_0) |
|
|
net_out_lig, _ = self.dynamics( |
|
|
z0_lig, xh0_pocket, t_zeros, lig_mask, pocket_mask) |
|
|
|
|
|
|
|
|
mu_x_lig = self.compute_x_pred(net_out_lig, z0_lig, gamma_0, lig_mask) |
|
|
xh_lig, xh0_pocket = self.sample_normal_zero_com( |
|
|
mu_x_lig, xh0_pocket, sigma_x, lig_mask, pocket_mask, fix_noise) |
|
|
|
|
|
x_lig, h_lig = self.unnormalize( |
|
|
xh_lig[:, :self.n_dims], z0_lig[:, self.n_dims:]) |
|
|
x_pocket, h_pocket = self.unnormalize( |
|
|
xh0_pocket[:, :self.n_dims], xh0_pocket[:, self.n_dims:]) |
|
|
|
|
|
h_lig = F.one_hot(torch.argmax(h_lig, dim=1), self.atom_nf) |
|
|
|
|
|
|
|
|
return x_lig, h_lig, x_pocket, h_pocket |
|
|
|
|
|
def sample_normal(self, *args): |
|
|
raise NotImplementedError("Has been replaced by sample_normal_zero_com()") |
|
|
|
|
|
def sample_normal_zero_com(self, mu_lig, xh0_pocket, sigma, lig_mask, |
|
|
pocket_mask, fix_noise=False): |
|
|
"""Samples from a Normal distribution.""" |
|
|
if fix_noise: |
|
|
|
|
|
raise NotImplementedError("fix_noise option isn't implemented yet") |
|
|
|
|
|
eps_lig = self.sample_gaussian( |
|
|
size=(len(lig_mask), self.n_dims + self.atom_nf), |
|
|
device=lig_mask.device) |
|
|
|
|
|
out_lig = mu_lig + sigma[lig_mask] * eps_lig |
|
|
|
|
|
|
|
|
xh_pocket = xh0_pocket.detach().clone() |
|
|
out_lig[:, :self.n_dims], xh_pocket[:, :self.n_dims] = \ |
|
|
self.remove_mean_batch(out_lig[:, :self.n_dims], |
|
|
xh0_pocket[:, :self.n_dims], |
|
|
lig_mask, pocket_mask) |
|
|
|
|
|
return out_lig, xh_pocket |
|
|
|
|
|
def noised_representation(self, xh_lig, xh0_pocket, lig_mask, pocket_mask, |
|
|
gamma_t): |
|
|
|
|
|
alpha_t = self.alpha(gamma_t, xh_lig) |
|
|
sigma_t = self.sigma(gamma_t, xh_lig) |
|
|
|
|
|
|
|
|
eps_lig = self.sample_gaussian( |
|
|
size=(len(lig_mask), self.n_dims + self.atom_nf), |
|
|
device=lig_mask.device) |
|
|
|
|
|
|
|
|
z_t_lig = alpha_t[lig_mask] * xh_lig + sigma_t[lig_mask] * eps_lig |
|
|
|
|
|
|
|
|
xh_pocket = xh0_pocket.detach().clone() |
|
|
z_t_lig[:, :self.n_dims], xh_pocket[:, :self.n_dims] = \ |
|
|
self.remove_mean_batch(z_t_lig[:, :self.n_dims], |
|
|
xh_pocket[:, :self.n_dims], |
|
|
lig_mask, pocket_mask) |
|
|
|
|
|
return z_t_lig, xh_pocket, eps_lig |
|
|
|
|
|
def log_pN(self, N_lig, N_pocket): |
|
|
""" |
|
|
Prior on the sample size for computing |
|
|
log p(x,h,N) = log p(x,h|N) + log p(N), where log p(x,h|N) is the |
|
|
model's output |
|
|
Args: |
|
|
N: array of sample sizes |
|
|
Returns: |
|
|
log p(N) |
|
|
""" |
|
|
log_pN = self.size_distribution.log_prob_n1_given_n2(N_lig, N_pocket) |
|
|
return log_pN |
|
|
|
|
|
def delta_log_px(self, num_nodes): |
|
|
return -self.subspace_dimensionality(num_nodes) * \ |
|
|
np.log(self.norm_values[0]) |
|
|
|
|
|
def forward(self, ligand, pocket, return_info=False): |
|
|
""" |
|
|
Computes the loss and NLL terms |
|
|
""" |
|
|
|
|
|
ligand, pocket = self.normalize(ligand, pocket) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
delta_log_px = self.delta_log_px(ligand['size']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lowest_t = 0 if self.training else 1 |
|
|
t_int = torch.randint( |
|
|
lowest_t, self.T + 1, size=(ligand['size'].size(0), 1), |
|
|
device=ligand['x'].device).float() |
|
|
s_int = t_int - 1 |
|
|
|
|
|
|
|
|
t_is_zero = (t_int == 0).float() |
|
|
t_is_not_zero = 1 - t_is_zero |
|
|
|
|
|
|
|
|
|
|
|
s = s_int / self.T |
|
|
t = t_int / self.T |
|
|
|
|
|
|
|
|
gamma_s = self.inflate_batch_array(self.gamma(s), ligand['x']) |
|
|
gamma_t = self.inflate_batch_array(self.gamma(t), ligand['x']) |
|
|
|
|
|
|
|
|
xh0_lig = torch.cat([ligand['x'], ligand['one_hot']], dim=1) |
|
|
xh0_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1) |
|
|
|
|
|
|
|
|
xh0_lig[:, :self.n_dims], xh0_pocket[:, :self.n_dims] = \ |
|
|
self.remove_mean_batch(xh0_lig[:, :self.n_dims], |
|
|
xh0_pocket[:, :self.n_dims], |
|
|
ligand['mask'], pocket['mask']) |
|
|
|
|
|
|
|
|
z_t_lig, xh_pocket, eps_t_lig = \ |
|
|
self.noised_representation(xh0_lig, xh0_pocket, ligand['mask'], |
|
|
pocket['mask'], gamma_t) |
|
|
|
|
|
|
|
|
net_out_lig, _ = self.dynamics( |
|
|
z_t_lig, xh_pocket, t, ligand['mask'], pocket['mask']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
xh_lig_hat = self.xh_given_zt_and_epsilon(z_t_lig, net_out_lig, gamma_t, |
|
|
ligand['mask']) |
|
|
|
|
|
|
|
|
squared_error = (eps_t_lig - net_out_lig) ** 2 |
|
|
if self.vnode_idx is not None: |
|
|
|
|
|
squared_error[ligand['one_hot'][:, self.vnode_idx].bool(), :self.n_dims] = 0 |
|
|
error_t_lig = self.sum_except_batch(squared_error, ligand['mask']) |
|
|
|
|
|
|
|
|
SNR_weight = (1 - self.SNR(gamma_s - gamma_t)).squeeze(1) |
|
|
assert error_t_lig.size() == SNR_weight.size() |
|
|
|
|
|
|
|
|
|
|
|
neg_log_constants = -self.log_constants_p_x_given_z0( |
|
|
n_nodes=ligand['size'], device=error_t_lig.device) |
|
|
|
|
|
|
|
|
|
|
|
kl_prior = self.kl_prior(xh0_lig, ligand['mask'], ligand['size']) |
|
|
|
|
|
if self.training: |
|
|
|
|
|
|
|
|
log_p_x_given_z0_without_constants_ligand, log_ph_given_z0 = \ |
|
|
self.log_pxh_given_z0_without_constants( |
|
|
ligand, z_t_lig, eps_t_lig, net_out_lig, gamma_t) |
|
|
|
|
|
loss_0_x_ligand = -log_p_x_given_z0_without_constants_ligand * \ |
|
|
t_is_zero.squeeze() |
|
|
loss_0_h = -log_ph_given_z0 * t_is_zero.squeeze() |
|
|
|
|
|
|
|
|
error_t_lig = error_t_lig * t_is_not_zero.squeeze() |
|
|
|
|
|
else: |
|
|
|
|
|
t_zeros = torch.zeros_like(s) |
|
|
gamma_0 = self.inflate_batch_array(self.gamma(t_zeros), ligand['x']) |
|
|
|
|
|
|
|
|
z_0_lig, xh_pocket, eps_0_lig = \ |
|
|
self.noised_representation(xh0_lig, xh0_pocket, ligand['mask'], |
|
|
pocket['mask'], gamma_0) |
|
|
|
|
|
net_out_0_lig, _ = self.dynamics( |
|
|
z_0_lig, xh_pocket, t_zeros, ligand['mask'], pocket['mask']) |
|
|
|
|
|
log_p_x_given_z0_without_constants_ligand, log_ph_given_z0 = \ |
|
|
self.log_pxh_given_z0_without_constants( |
|
|
ligand, z_0_lig, eps_0_lig, net_out_0_lig, gamma_0) |
|
|
loss_0_x_ligand = -log_p_x_given_z0_without_constants_ligand |
|
|
loss_0_h = -log_ph_given_z0 |
|
|
|
|
|
|
|
|
log_pN = self.log_pN(ligand['size'], pocket['size']) |
|
|
|
|
|
info = { |
|
|
'eps_hat_lig_x': scatter_mean( |
|
|
net_out_lig[:, :self.n_dims].abs().mean(1), ligand['mask'], |
|
|
dim=0).mean(), |
|
|
'eps_hat_lig_h': scatter_mean( |
|
|
net_out_lig[:, self.n_dims:].abs().mean(1), ligand['mask'], |
|
|
dim=0).mean(), |
|
|
} |
|
|
loss_terms = (delta_log_px, error_t_lig, torch.tensor(0.0), SNR_weight, |
|
|
loss_0_x_ligand, torch.tensor(0.0), loss_0_h, |
|
|
neg_log_constants, kl_prior, log_pN, |
|
|
t_int.squeeze(), xh_lig_hat) |
|
|
return (*loss_terms, info) if return_info else loss_terms |
|
|
|
|
|
def partially_noised_ligand(self, ligand, pocket, noising_steps): |
|
|
""" |
|
|
Partially noises a ligand to be later denoised. |
|
|
""" |
|
|
|
|
|
|
|
|
t_int = torch.ones(size=(ligand['size'].size(0), 1), |
|
|
device=ligand['x'].device).float() * noising_steps |
|
|
|
|
|
|
|
|
t = t_int / self.T |
|
|
|
|
|
|
|
|
gamma_t = self.inflate_batch_array(self.gamma(t), ligand['x']) |
|
|
|
|
|
|
|
|
xh0_lig = torch.cat([ligand['x'], ligand['one_hot']], dim=1) |
|
|
xh0_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1) |
|
|
|
|
|
|
|
|
xh0_lig[:, :self.n_dims], xh0_pocket[:, :self.n_dims] = \ |
|
|
self.remove_mean_batch(xh0_lig[:, :self.n_dims], |
|
|
xh0_pocket[:, :self.n_dims], |
|
|
ligand['mask'], pocket['mask']) |
|
|
|
|
|
|
|
|
z_t_lig, xh_pocket, eps_t_lig = \ |
|
|
self.noised_representation(xh0_lig, xh0_pocket, ligand['mask'], |
|
|
pocket['mask'], gamma_t) |
|
|
|
|
|
return z_t_lig, xh_pocket, eps_t_lig |
|
|
|
|
|
def diversify(self, ligand, pocket, noising_steps): |
|
|
""" |
|
|
Diversifies a set of ligands via noise-denoising |
|
|
""" |
|
|
|
|
|
|
|
|
ligand, pocket = self.normalize(ligand, pocket) |
|
|
|
|
|
z_lig, xh_pocket, _ = self.partially_noised_ligand(ligand, pocket, noising_steps) |
|
|
|
|
|
timesteps = self.T |
|
|
n_samples = len(pocket['size']) |
|
|
device = pocket['x'].device |
|
|
|
|
|
|
|
|
|
|
|
xh0_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1) |
|
|
|
|
|
lig_mask = ligand['mask'] |
|
|
|
|
|
self.assert_mean_zero_with_mask(z_lig[:, :self.n_dims], lig_mask) |
|
|
|
|
|
|
|
|
|
|
|
for s in reversed(range(0, noising_steps)): |
|
|
s_array = torch.full((n_samples, 1), fill_value=s, |
|
|
device=z_lig.device) |
|
|
t_array = s_array + 1 |
|
|
s_array = s_array / timesteps |
|
|
t_array = t_array / timesteps |
|
|
|
|
|
z_lig, xh_pocket = self.sample_p_zs_given_zt( |
|
|
s_array, t_array, z_lig.detach(), xh_pocket.detach(), lig_mask, pocket['mask']) |
|
|
|
|
|
|
|
|
x_lig, h_lig, x_pocket, h_pocket = self.sample_p_xh_given_z0( |
|
|
z_lig, xh_pocket, lig_mask, pocket['mask'], n_samples) |
|
|
|
|
|
self.assert_mean_zero_with_mask(x_lig, lig_mask) |
|
|
|
|
|
|
|
|
out_lig = torch.cat([x_lig, h_lig], dim=1) |
|
|
out_pocket = torch.cat([x_pocket, h_pocket], dim=1) |
|
|
|
|
|
|
|
|
return out_lig, out_pocket, lig_mask, pocket['mask'] |
|
|
|
|
|
|
|
|
def xh_given_zt_and_epsilon(self, z_t, epsilon, gamma_t, batch_mask): |
|
|
""" Equation (7) in the EDM paper """ |
|
|
alpha_t = self.alpha(gamma_t, z_t) |
|
|
sigma_t = self.sigma(gamma_t, z_t) |
|
|
xh = z_t / alpha_t[batch_mask] - epsilon * sigma_t[batch_mask] / \ |
|
|
alpha_t[batch_mask] |
|
|
return xh |
|
|
|
|
|
def sample_p_zt_given_zs(self, zs_lig, xh0_pocket, ligand_mask, pocket_mask, |
|
|
gamma_t, gamma_s, fix_noise=False): |
|
|
sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = \ |
|
|
self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, zs_lig) |
|
|
|
|
|
mu_lig = alpha_t_given_s[ligand_mask] * zs_lig |
|
|
zt_lig, xh0_pocket = self.sample_normal_zero_com( |
|
|
mu_lig, xh0_pocket, sigma_t_given_s, ligand_mask, pocket_mask, |
|
|
fix_noise) |
|
|
|
|
|
return zt_lig, xh0_pocket |
|
|
|
|
|
def sample_p_zs_given_zt(self, s, t, zt_lig, xh0_pocket, ligand_mask, |
|
|
pocket_mask, fix_noise=False): |
|
|
"""Samples from zs ~ p(zs | zt). Only used during sampling.""" |
|
|
gamma_s = self.gamma(s) |
|
|
gamma_t = self.gamma(t) |
|
|
|
|
|
sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = \ |
|
|
self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, zt_lig) |
|
|
|
|
|
sigma_s = self.sigma(gamma_s, target_tensor=zt_lig) |
|
|
sigma_t = self.sigma(gamma_t, target_tensor=zt_lig) |
|
|
|
|
|
|
|
|
eps_t_lig, _ = self.dynamics( |
|
|
zt_lig, xh0_pocket, t, ligand_mask, pocket_mask) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mu_lig = zt_lig / alpha_t_given_s[ligand_mask] - \ |
|
|
(sigma2_t_given_s / alpha_t_given_s / sigma_t)[ligand_mask] * \ |
|
|
eps_t_lig |
|
|
|
|
|
|
|
|
sigma = sigma_t_given_s * sigma_s / sigma_t |
|
|
|
|
|
|
|
|
zs_lig, xh0_pocket = self.sample_normal_zero_com( |
|
|
mu_lig, xh0_pocket, sigma, ligand_mask, pocket_mask, fix_noise) |
|
|
|
|
|
self.assert_mean_zero_with_mask(zt_lig[:, :self.n_dims], ligand_mask) |
|
|
|
|
|
return zs_lig, xh0_pocket |
|
|
|
|
|
def sample_combined_position_feature_noise(self, lig_indices, xh0_pocket, |
|
|
pocket_indices): |
|
|
""" |
|
|
Samples mean-centered normal noise for z_x, and standard normal noise |
|
|
for z_h. |
|
|
""" |
|
|
raise NotImplementedError("Use sample_normal_zero_com() instead.") |
|
|
|
|
|
def sample(self, *args): |
|
|
raise NotImplementedError("Conditional model does not support sampling " |
|
|
"without given pocket.") |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample_given_pocket(self, pocket, num_nodes_lig, return_frames=1, |
|
|
timesteps=None): |
|
|
""" |
|
|
Draw samples from the generative model. Optionally, return intermediate |
|
|
states for visualization purposes. |
|
|
""" |
|
|
timesteps = self.T if timesteps is None else timesteps |
|
|
assert 0 < return_frames <= timesteps |
|
|
assert timesteps % return_frames == 0 |
|
|
|
|
|
n_samples = len(pocket['size']) |
|
|
device = pocket['x'].device |
|
|
|
|
|
_, pocket = self.normalize(pocket=pocket) |
|
|
|
|
|
|
|
|
|
|
|
xh0_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1) |
|
|
|
|
|
lig_mask = utils.num_nodes_to_batch_mask( |
|
|
n_samples, num_nodes_lig, device) |
|
|
|
|
|
|
|
|
mu_lig_x = scatter_mean(pocket['x'], pocket['mask'], dim=0) |
|
|
mu_lig_h = torch.zeros((n_samples, self.atom_nf), device=device) |
|
|
mu_lig = torch.cat((mu_lig_x, mu_lig_h), dim=1)[lig_mask] |
|
|
sigma = torch.ones_like(pocket['size']).unsqueeze(1) |
|
|
|
|
|
z_lig, xh_pocket = self.sample_normal_zero_com( |
|
|
mu_lig, xh0_pocket, sigma, lig_mask, pocket['mask']) |
|
|
|
|
|
self.assert_mean_zero_with_mask(z_lig[:, :self.n_dims], lig_mask) |
|
|
|
|
|
out_lig = torch.zeros((return_frames,) + z_lig.size(), |
|
|
device=z_lig.device) |
|
|
out_pocket = torch.zeros((return_frames,) + xh_pocket.size(), |
|
|
device=device) |
|
|
|
|
|
|
|
|
for s in reversed(range(0, timesteps)): |
|
|
s_array = torch.full((n_samples, 1), fill_value=s, |
|
|
device=z_lig.device) |
|
|
t_array = s_array + 1 |
|
|
s_array = s_array / timesteps |
|
|
t_array = t_array / timesteps |
|
|
|
|
|
z_lig, xh_pocket = self.sample_p_zs_given_zt( |
|
|
s_array, t_array, z_lig, xh_pocket, lig_mask, pocket['mask']) |
|
|
|
|
|
|
|
|
if (s * return_frames) % timesteps == 0: |
|
|
idx = (s * return_frames) // timesteps |
|
|
out_lig[idx], out_pocket[idx] = \ |
|
|
self.unnormalize_z(z_lig, xh_pocket) |
|
|
|
|
|
|
|
|
x_lig, h_lig, x_pocket, h_pocket = self.sample_p_xh_given_z0( |
|
|
z_lig, xh_pocket, lig_mask, pocket['mask'], n_samples) |
|
|
|
|
|
self.assert_mean_zero_with_mask(x_lig, lig_mask) |
|
|
|
|
|
|
|
|
if return_frames == 1: |
|
|
max_cog = scatter_add(x_lig, lig_mask, dim=0).abs().max().item() |
|
|
if max_cog > 5e-2: |
|
|
print(f'Warning CoG drift with error {max_cog:.3f}. Projecting ' |
|
|
f'the positions down.') |
|
|
x_lig, x_pocket = self.remove_mean_batch( |
|
|
x_lig, x_pocket, lig_mask, pocket['mask']) |
|
|
|
|
|
|
|
|
out_lig[0] = torch.cat([x_lig, h_lig], dim=1) |
|
|
out_pocket[0] = torch.cat([x_pocket, h_pocket], dim=1) |
|
|
|
|
|
|
|
|
return out_lig.squeeze(0), out_pocket.squeeze(0), lig_mask, \ |
|
|
pocket['mask'] |
|
|
|
|
|
@torch.no_grad() |
|
|
def inpaint(self, ligand, pocket, lig_fixed, resamplings=1, return_frames=1, |
|
|
timesteps=None, center='ligand'): |
|
|
""" |
|
|
Draw samples from the generative model while fixing parts of the input. |
|
|
Optionally, return intermediate states for visualization purposes. |
|
|
Inspired by Algorithm 1 in: |
|
|
Lugmayr, Andreas, et al. |
|
|
"Repaint: Inpainting using denoising diffusion probabilistic models." |
|
|
Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern |
|
|
Recognition. 2022. |
|
|
""" |
|
|
timesteps = self.T if timesteps is None else timesteps |
|
|
assert 0 < return_frames <= timesteps |
|
|
assert timesteps % return_frames == 0 |
|
|
|
|
|
if len(lig_fixed.size()) == 1: |
|
|
lig_fixed = lig_fixed.unsqueeze(1) |
|
|
|
|
|
n_samples = len(ligand['size']) |
|
|
device = pocket['x'].device |
|
|
|
|
|
|
|
|
ligand, pocket = self.normalize(ligand, pocket) |
|
|
|
|
|
|
|
|
|
|
|
xh0_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1) |
|
|
com_pocket_0 = scatter_mean(pocket['x'], pocket['mask'], dim=0) |
|
|
xh0_ligand = torch.cat([ligand['x'], ligand['one_hot']], dim=1) |
|
|
xh_ligand = xh0_ligand.clone() |
|
|
|
|
|
|
|
|
if center == 'ligand': |
|
|
mean_known = scatter_mean(ligand['x'][lig_fixed.bool().view(-1)], |
|
|
ligand['mask'][lig_fixed.bool().view(-1)], |
|
|
dim=0) |
|
|
elif center == 'pocket': |
|
|
mean_known = scatter_mean(pocket['x'], pocket['mask'], dim=0) |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
f"Centering option {center} not implemented") |
|
|
|
|
|
|
|
|
mu_lig_x = mean_known |
|
|
mu_lig_h = torch.zeros((n_samples, self.atom_nf), device=device) |
|
|
mu_lig = torch.cat((mu_lig_x, mu_lig_h), dim=1)[ligand['mask']] |
|
|
sigma = torch.ones_like(pocket['size']).unsqueeze(1) |
|
|
|
|
|
z_lig, xh_pocket = self.sample_normal_zero_com( |
|
|
mu_lig, xh0_pocket, sigma, ligand['mask'], pocket['mask']) |
|
|
|
|
|
|
|
|
out_lig = torch.zeros((return_frames,) + z_lig.size(), |
|
|
device=z_lig.device) |
|
|
out_pocket = torch.zeros((return_frames,) + xh_pocket.size(), |
|
|
device=device) |
|
|
|
|
|
|
|
|
for s in reversed(range(0, timesteps)): |
|
|
|
|
|
|
|
|
for u in range(resamplings): |
|
|
|
|
|
|
|
|
s_array = torch.full((n_samples, 1), fill_value=s, |
|
|
device=device) |
|
|
t_array = s_array + 1 |
|
|
s_array = s_array / timesteps |
|
|
t_array = t_array / timesteps |
|
|
|
|
|
gamma_t = self.gamma(t_array) |
|
|
gamma_s = self.gamma(s_array) |
|
|
|
|
|
|
|
|
z_lig_unknown, xh_pocket = self.sample_p_zs_given_zt( |
|
|
s_array, t_array, z_lig, xh_pocket, ligand['mask'], |
|
|
pocket['mask']) |
|
|
|
|
|
|
|
|
com_pocket = scatter_mean(xh_pocket[:, :self.n_dims], |
|
|
pocket['mask'], dim=0) |
|
|
xh_ligand[:, :self.n_dims] = \ |
|
|
ligand['x'] + (com_pocket - com_pocket_0)[ligand['mask']] |
|
|
z_lig_known, xh_pocket, _ = self.noised_representation( |
|
|
xh_ligand, xh_pocket, ligand['mask'], pocket['mask'], |
|
|
gamma_s) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
com_noised = scatter_mean( |
|
|
z_lig_known[lig_fixed.bool().view(-1)][:, :self.n_dims], |
|
|
ligand['mask'][lig_fixed.bool().view(-1)], dim=0) |
|
|
com_denoised = scatter_mean( |
|
|
z_lig_unknown[lig_fixed.bool().view(-1)][:, :self.n_dims], |
|
|
ligand['mask'][lig_fixed.bool().view(-1)], dim=0) |
|
|
dx = com_denoised - com_noised |
|
|
z_lig_known[:, :self.n_dims] = z_lig_known[:, :self.n_dims] + dx[ligand['mask']] |
|
|
xh_pocket[:, :self.n_dims] = xh_pocket[:, :self.n_dims] + dx[pocket['mask']] |
|
|
|
|
|
|
|
|
z_lig = z_lig_known * lig_fixed + z_lig_unknown * ( |
|
|
1 - lig_fixed) |
|
|
|
|
|
if u < resamplings - 1: |
|
|
|
|
|
z_lig, xh_pocket = self.sample_p_zt_given_zs( |
|
|
z_lig, xh_pocket, ligand['mask'], pocket['mask'], |
|
|
gamma_t, gamma_s) |
|
|
|
|
|
|
|
|
if u == resamplings - 1: |
|
|
if (s * return_frames) % timesteps == 0: |
|
|
idx = (s * return_frames) // timesteps |
|
|
|
|
|
out_lig[idx], out_pocket[idx] = \ |
|
|
self.unnormalize_z(z_lig, xh_pocket) |
|
|
|
|
|
|
|
|
x_lig, h_lig, x_pocket, h_pocket = self.sample_p_xh_given_z0( |
|
|
z_lig, xh_pocket, ligand['mask'], pocket['mask'], n_samples) |
|
|
|
|
|
|
|
|
out_lig[0] = torch.cat([x_lig, h_lig], dim=1) |
|
|
out_pocket[0] = torch.cat([x_pocket, h_pocket], dim=1) |
|
|
|
|
|
|
|
|
return out_lig.squeeze(0), out_pocket.squeeze(0), ligand['mask'], \ |
|
|
pocket['mask'] |
|
|
|
|
|
@classmethod |
|
|
def remove_mean_batch(cls, x_lig, x_pocket, lig_indices, pocket_indices): |
|
|
|
|
|
|
|
|
mean = scatter_mean(x_lig, lig_indices, dim=0) |
|
|
|
|
|
x_lig = x_lig - mean[lig_indices] |
|
|
x_pocket = x_pocket - mean[pocket_indices] |
|
|
return x_lig, x_pocket |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SimpleConditionalDDPM(ConditionalDDPM): |
|
|
""" |
|
|
Simpler conditional diffusion module without subspace-trick. |
|
|
- rotational equivariance is guaranteed by construction |
|
|
- translationally equivariant likelihood is achieved by first mapping |
|
|
samples to a space where the context is COM-free and evaluating the |
|
|
likelihood there |
|
|
- molecule generation is equivariant because we can first sample in the |
|
|
space where the context is COM-free and translate the whole system back to |
|
|
the original position of the context later |
|
|
""" |
|
|
def subspace_dimensionality(self, input_size): |
|
|
""" Override because we don't use the linear subspace anymore. """ |
|
|
return input_size * self.n_dims |
|
|
|
|
|
@classmethod |
|
|
def remove_mean_batch(cls, x_lig, x_pocket, lig_indices, pocket_indices): |
|
|
""" Hacky way of removing the centering steps without changing too much |
|
|
code. """ |
|
|
return x_lig, x_pocket |
|
|
|
|
|
@staticmethod |
|
|
def assert_mean_zero_with_mask(x, node_mask, eps=1e-10): |
|
|
return |
|
|
|
|
|
def forward(self, ligand, pocket, return_info=False): |
|
|
|
|
|
|
|
|
pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0) |
|
|
ligand['x'] = ligand['x'] - pocket_com[ligand['mask']] |
|
|
pocket['x'] = pocket['x'] - pocket_com[pocket['mask']] |
|
|
|
|
|
return super(SimpleConditionalDDPM, self).forward( |
|
|
ligand, pocket, return_info) |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample_given_pocket(self, pocket, num_nodes_lig, return_frames=1, |
|
|
timesteps=None): |
|
|
|
|
|
|
|
|
pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0) |
|
|
pocket['x'] = pocket['x'] - pocket_com[pocket['mask']] |
|
|
|
|
|
return super(SimpleConditionalDDPM, self).sample_given_pocket( |
|
|
pocket, num_nodes_lig, return_frames, timesteps) |
|
|
|