|
|
from functools import reduce |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch_scatter import scatter_mean, scatter_add |
|
|
|
|
|
from src.utils import bvm |
|
|
|
|
|
|
|
|
class LinearSchedule: |
|
|
""" |
|
|
We use the scheduling parameter \beta to linearly remove noise, i.e. |
|
|
\bar{\beta}_t = 1 - h (h: step size) with |
|
|
\bar{Q}_t = \bar{\beta}_t I + (1 - \bar{\beta}_t) 1_vec z1^T |
|
|
|
|
|
From this, it follows that for each step transition matrix, we have |
|
|
\beta_t = \bar{\beta}_t / \bar{\beta}_{t-h} = \frac{1-t}{1-t+h} |
|
|
""" |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
def beta_bar(self, t): |
|
|
return 1 - t |
|
|
|
|
|
def beta(self, t, step_size): |
|
|
return (1 - t) / (1 - t + step_size) |
|
|
|
|
|
|
|
|
class UniformPriorMarkovBridge: |
|
|
""" |
|
|
Markov bridge model in which z0 is drawn from a uniform prior. |
|
|
Transitions are defined as: |
|
|
Q_t = \beta_t I + (1 - \beta_t) 1_vec z1^T |
|
|
where z1 is a one-hot representation of the final state. |
|
|
We follow the notation from [1] and multiply transition matrices from the |
|
|
right to one-hot state vectors. |
|
|
|
|
|
We use the scheduling parameter \beta to linearly remove noise, i.e. |
|
|
\bar{\beta}_t = 1 - h (h: step size) with |
|
|
\bar{Q}_t = \bar{\beta}_t I + (1 - \bar{\beta}_t) 1_vec z1^T |
|
|
|
|
|
From this, it follows that for each step transition matrix, we have |
|
|
\beta_t = \bar{\beta}_t / \bar{\beta}_{t-h} = \frac{1-t}{1-t+h} |
|
|
|
|
|
[1] Austin, Jacob, et al. |
|
|
"Structured denoising diffusion models in discrete state-spaces." |
|
|
Advances in Neural Information Processing Systems 34 (2021): 17981-17993. |
|
|
""" |
|
|
def __init__(self, dim, loss_type='CE', step_size=None): |
|
|
assert loss_type in ['VLB', 'CE'] |
|
|
self.dim = dim |
|
|
self.step_size = step_size |
|
|
self.schedule = LinearSchedule() |
|
|
self.loss_type = loss_type |
|
|
super(UniformPriorMarkovBridge, self).__init__() |
|
|
|
|
|
@staticmethod |
|
|
def sample_categorical(p): |
|
|
""" |
|
|
Sample from categorical distribution defined by probabilities 'p' |
|
|
:param p: (n, dim) |
|
|
:return: one-hot encoded samples (n, dim) |
|
|
""" |
|
|
sampled = torch.multinomial(p, 1).squeeze(-1) |
|
|
return F.one_hot(sampled, num_classes=p.size(1)).float() |
|
|
|
|
|
def p_z0(self, batch_mask): |
|
|
return torch.ones((len(batch_mask), self.dim), device=batch_mask.device) / self.dim |
|
|
|
|
|
def sample_z0(self, batch_mask): |
|
|
""" Prior. """ |
|
|
z0 = self.sample_categorical(self.p_z0(batch_mask)) |
|
|
return z0 |
|
|
|
|
|
def p_zt(self, z0, z1, t, batch_mask): |
|
|
Qt_bar = self.get_Qt_bar(t, z1, batch_mask) |
|
|
return bvm(z0, Qt_bar) |
|
|
|
|
|
def sample_zt(self, z0, z1, t, batch_mask): |
|
|
zt = self.sample_categorical(self.p_zt(z0, z1, t, batch_mask)) |
|
|
return zt |
|
|
|
|
|
def p_zt_given_zs_and_z1(self, zs, z1, s, t, batch_mask): |
|
|
|
|
|
Qt = self.get_Qt(t, s, z1, batch_mask) |
|
|
|
|
|
q_zs_given_zt = bvm(zs, Qt) |
|
|
return q_zs_given_zt |
|
|
|
|
|
def p_zt_given_zs(self, zs, p_z1_hat, s, t, batch_mask): |
|
|
""" |
|
|
Note that x can also represent a categorical distribution to compute |
|
|
transitions more efficiently at sampling time: |
|
|
p(z_t|z_s) = \sum_{\hat{z}_1} p(z_t | z_s, \hat{z}_1) * p(\hat{z}_1 | z_s) |
|
|
= \sum_i z_s (\beta_t I + (1 - \beta_t) 1_vec z1_i^T) * \hat{p}_i |
|
|
= \beta_t z_s I + (1 - \beta_t) z_s 1_vec \hat{p}^t |
|
|
""" |
|
|
return self.p_zt_given_zs_and_z1(zs, p_z1_hat, s, t, batch_mask) |
|
|
|
|
|
def sample_zt_given_zs(self, zs, z1_logits, s, t, batch_mask): |
|
|
p_z1 = z1_logits.softmax(dim=-1) |
|
|
zt = self.sample_categorical(self.p_zt_given_zs(zs, p_z1, s, t, batch_mask)) |
|
|
return zt |
|
|
|
|
|
def compute_loss(self, pred_logits, zs, z1, batch_mask, s, t, reduce='mean'): |
|
|
""" Compute loss per sample. """ |
|
|
assert reduce in {'mean', 'sum', 'none'} |
|
|
|
|
|
if self.loss_type == 'CE': |
|
|
loss = F.cross_entropy(pred_logits, z1, reduction='none') |
|
|
|
|
|
else: |
|
|
true_p_zs = self.p_zt_given_zs_and_z1(zs, z1, s, t, batch_mask) |
|
|
pred_p_zs = self.p_zt_given_zs(zs, pred_logits.softmax(dim=-1), s, t, batch_mask) |
|
|
loss = F.kl_div(pred_p_zs.log(), true_p_zs, reduction='none').sum(dim=-1) |
|
|
|
|
|
if reduce == 'mean': |
|
|
loss = scatter_mean(loss, batch_mask, dim=0) |
|
|
elif reduce == 'sum': |
|
|
loss = scatter_add(loss, batch_mask, dim=0) |
|
|
|
|
|
return loss |
|
|
|
|
|
def get_Qt(self, t, s, z1, batch_mask): |
|
|
""" Returns one-step transition matrix from step s to step t. """ |
|
|
|
|
|
beta_t_given_s = self.schedule.beta(t, t - s) |
|
|
beta_t_given_s = beta_t_given_s.unsqueeze(-1)[batch_mask] |
|
|
|
|
|
|
|
|
Qt = beta_t_given_s * torch.eye(self.dim, device=t.device).unsqueeze(0) + \ |
|
|
(1 - beta_t_given_s) * z1.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return Qt |
|
|
|
|
|
def get_Qt_bar(self, t, z1, batch_mask): |
|
|
""" Returns transition matrix from step 0 to step t. """ |
|
|
|
|
|
beta_bar_t = self.schedule.beta_bar(t) |
|
|
beta_bar_t = beta_bar_t.unsqueeze(-1)[batch_mask] |
|
|
|
|
|
|
|
|
Qt_bar = beta_bar_t * torch.eye(self.dim, device=t.device).unsqueeze(0) + \ |
|
|
(1 - beta_bar_t) * z1.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return Qt_bar |
|
|
|
|
|
|
|
|
class MarginalPriorMarkovBridge(UniformPriorMarkovBridge): |
|
|
def __init__(self, dim, prior_p, loss_type='CE', step_size=None): |
|
|
self.prior_p = prior_p |
|
|
print('Marginal Prior MB') |
|
|
super(MarginalPriorMarkovBridge, self).__init__(dim, loss_type, step_size) |
|
|
|
|
|
def p_z0(self, batch_mask): |
|
|
device = batch_mask.device |
|
|
p = torch.ones((len(batch_mask), self.dim), device=device) * self.prior_p.view(1, -1).to(device) |
|
|
return p |
|
|
|