DrugFlow / src /model /markov_bridge.py
mority's picture
Upload 53 files
6e7d4ba verified
from functools import reduce
import torch
import torch.nn.functional as F
from torch_scatter import scatter_mean, scatter_add
from src.utils import bvm
class LinearSchedule:
"""
We use the scheduling parameter \beta to linearly remove noise, i.e.
\bar{\beta}_t = 1 - h (h: step size) with
\bar{Q}_t = \bar{\beta}_t I + (1 - \bar{\beta}_t) 1_vec z1^T
From this, it follows that for each step transition matrix, we have
\beta_t = \bar{\beta}_t / \bar{\beta}_{t-h} = \frac{1-t}{1-t+h}
"""
def __init__(self):
super().__init__()
def beta_bar(self, t):
return 1 - t
def beta(self, t, step_size):
return (1 - t) / (1 - t + step_size)
class UniformPriorMarkovBridge:
"""
Markov bridge model in which z0 is drawn from a uniform prior.
Transitions are defined as:
Q_t = \beta_t I + (1 - \beta_t) 1_vec z1^T
where z1 is a one-hot representation of the final state.
We follow the notation from [1] and multiply transition matrices from the
right to one-hot state vectors.
We use the scheduling parameter \beta to linearly remove noise, i.e.
\bar{\beta}_t = 1 - h (h: step size) with
\bar{Q}_t = \bar{\beta}_t I + (1 - \bar{\beta}_t) 1_vec z1^T
From this, it follows that for each step transition matrix, we have
\beta_t = \bar{\beta}_t / \bar{\beta}_{t-h} = \frac{1-t}{1-t+h}
[1] Austin, Jacob, et al.
"Structured denoising diffusion models in discrete state-spaces."
Advances in Neural Information Processing Systems 34 (2021): 17981-17993.
"""
def __init__(self, dim, loss_type='CE', step_size=None):
assert loss_type in ['VLB', 'CE']
self.dim = dim
self.step_size = step_size # required for VLB
self.schedule = LinearSchedule()
self.loss_type = loss_type
super(UniformPriorMarkovBridge, self).__init__()
@staticmethod
def sample_categorical(p):
"""
Sample from categorical distribution defined by probabilities 'p'
:param p: (n, dim)
:return: one-hot encoded samples (n, dim)
"""
sampled = torch.multinomial(p, 1).squeeze(-1)
return F.one_hot(sampled, num_classes=p.size(1)).float()
def p_z0(self, batch_mask):
return torch.ones((len(batch_mask), self.dim), device=batch_mask.device) / self.dim
def sample_z0(self, batch_mask):
""" Prior. """
z0 = self.sample_categorical(self.p_z0(batch_mask))
return z0
def p_zt(self, z0, z1, t, batch_mask):
Qt_bar = self.get_Qt_bar(t, z1, batch_mask)
return bvm(z0, Qt_bar)
def sample_zt(self, z0, z1, t, batch_mask):
zt = self.sample_categorical(self.p_zt(z0, z1, t, batch_mask))
return zt
def p_zt_given_zs_and_z1(self, zs, z1, s, t, batch_mask):
# 'z1' are one-hot "probabilities" for each class
Qt = self.get_Qt(t, s, z1, batch_mask)
# from pdb import set_trace; set_trace()
q_zs_given_zt = bvm(zs, Qt)
return q_zs_given_zt
def p_zt_given_zs(self, zs, p_z1_hat, s, t, batch_mask):
"""
Note that x can also represent a categorical distribution to compute
transitions more efficiently at sampling time:
p(z_t|z_s) = \sum_{\hat{z}_1} p(z_t | z_s, \hat{z}_1) * p(\hat{z}_1 | z_s)
= \sum_i z_s (\beta_t I + (1 - \beta_t) 1_vec z1_i^T) * \hat{p}_i
= \beta_t z_s I + (1 - \beta_t) z_s 1_vec \hat{p}^t
"""
return self.p_zt_given_zs_and_z1(zs, p_z1_hat, s, t, batch_mask)
def sample_zt_given_zs(self, zs, z1_logits, s, t, batch_mask):
p_z1 = z1_logits.softmax(dim=-1)
zt = self.sample_categorical(self.p_zt_given_zs(zs, p_z1, s, t, batch_mask))
return zt
def compute_loss(self, pred_logits, zs, z1, batch_mask, s, t, reduce='mean'):
""" Compute loss per sample. """
assert reduce in {'mean', 'sum', 'none'}
if self.loss_type == 'CE':
loss = F.cross_entropy(pred_logits, z1, reduction='none')
else: # VLB
true_p_zs = self.p_zt_given_zs_and_z1(zs, z1, s, t, batch_mask)
pred_p_zs = self.p_zt_given_zs(zs, pred_logits.softmax(dim=-1), s, t, batch_mask)
loss = F.kl_div(pred_p_zs.log(), true_p_zs, reduction='none').sum(dim=-1)
if reduce == 'mean':
loss = scatter_mean(loss, batch_mask, dim=0)
elif reduce == 'sum':
loss = scatter_add(loss, batch_mask, dim=0)
return loss
def get_Qt(self, t, s, z1, batch_mask):
""" Returns one-step transition matrix from step s to step t. """
beta_t_given_s = self.schedule.beta(t, t - s)
beta_t_given_s = beta_t_given_s.unsqueeze(-1)[batch_mask]
# Q_t = beta_t * I + (1 - beta_t) * ones (dot) z1^T
Qt = beta_t_given_s * torch.eye(self.dim, device=t.device).unsqueeze(0) + \
(1 - beta_t_given_s) * z1.unsqueeze(1)
# (1 - beta_t_given_s) * (torch.ones(self.dim, 1, device=t.device) @ z1)
# assert (Qt.sum(-1) == 1).all()
return Qt
def get_Qt_bar(self, t, z1, batch_mask):
""" Returns transition matrix from step 0 to step t. """
beta_bar_t = self.schedule.beta_bar(t)
beta_bar_t = beta_bar_t.unsqueeze(-1)[batch_mask]
# Q_t_bar = beta_bar * I + (1 - beta_bar) * ones (dot) z1^T
Qt_bar = beta_bar_t * torch.eye(self.dim, device=t.device).unsqueeze(0) + \
(1 - beta_bar_t) * z1.unsqueeze(1)
# (1 - beta_bar_t) * (torch.ones(self.dim, 1, device=t.device) @ z1)
# assert (Qt_bar.sum(-1) == 1).all()
return Qt_bar
class MarginalPriorMarkovBridge(UniformPriorMarkovBridge):
def __init__(self, dim, prior_p, loss_type='CE', step_size=None):
self.prior_p = prior_p
print('Marginal Prior MB')
super(MarginalPriorMarkovBridge, self).__init__(dim, loss_type, step_size)
def p_z0(self, batch_mask):
device = batch_mask.device
p = torch.ones((len(batch_mask), self.dim), device=device) * self.prior_p.view(1, -1).to(device)
return p