|
|
from abc import ABC |
|
|
from abc import abstractmethod |
|
|
import math |
|
|
import torch |
|
|
from torch_scatter import scatter_mean, scatter_add |
|
|
|
|
|
import src.data.so3_utils as so3 |
|
|
|
|
|
|
|
|
class ICFM(ABC): |
|
|
""" |
|
|
Abstract base class for all Independent-coupling CFM classes. |
|
|
Defines a common interface. |
|
|
Notation: |
|
|
- zt is the intermediate representation at time step t \in [0, 1] |
|
|
- zs is the noised representation at time step s < t |
|
|
|
|
|
# TODO: add interpolation schedule (not necessrily linear) |
|
|
""" |
|
|
def __init__(self, sigma): |
|
|
self.sigma = sigma |
|
|
|
|
|
@abstractmethod |
|
|
def sample_zt(self, z0, z1, t, *args, **kwargs): |
|
|
""" TODO. """ |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def sample_zt_given_zs(self, *args, **kwargs): |
|
|
""" Perform update, typically using an explicit Euler step. """ |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def sample_z0(self, *args, **kwargs): |
|
|
""" Prior. """ |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def compute_loss(self, pred, z0, z1, *args, **kwargs): |
|
|
""" Compute loss per sample. """ |
|
|
pass |
|
|
|
|
|
|
|
|
class CoordICFM(ICFM): |
|
|
def __init__(self, sigma): |
|
|
self.dim = 3 |
|
|
self.scale = 2.7 |
|
|
super().__init__(sigma) |
|
|
|
|
|
def sample_zt(self, z0, z1, t, batch_mask): |
|
|
zt = t[batch_mask] * z1 + (1 - t)[batch_mask] * z0 |
|
|
|
|
|
return zt |
|
|
|
|
|
def sample_zt_given_zs(self, zs, pred, s, t, batch_mask): |
|
|
""" Perform an explicit Euler step. """ |
|
|
step_size = t - s |
|
|
zt = zs + step_size[batch_mask] * self.scale * pred |
|
|
return zt |
|
|
|
|
|
def sample_z0(self, com, batch_mask): |
|
|
""" Prior. """ |
|
|
z0 = torch.randn((len(batch_mask), self.dim), device=batch_mask.device) |
|
|
|
|
|
|
|
|
z0 = z0 + com[batch_mask] |
|
|
|
|
|
return z0 |
|
|
|
|
|
def reduce_loss(self, loss, batch_mask, reduce): |
|
|
assert reduce in {'mean', 'sum', 'none'} |
|
|
|
|
|
if reduce == 'mean': |
|
|
loss = scatter_mean(loss / self.dim, batch_mask, dim=0) |
|
|
elif reduce == 'sum': |
|
|
loss = scatter_add(loss, batch_mask, dim=0) |
|
|
|
|
|
return loss |
|
|
|
|
|
def compute_loss(self, pred, z0, z1, t, batch_mask, reduce='mean'): |
|
|
""" Compute loss per sample. """ |
|
|
|
|
|
loss = torch.sum((pred - (z1 - z0) / self.scale) ** 2, dim=-1) |
|
|
|
|
|
return self.reduce_loss(loss, batch_mask, reduce) |
|
|
|
|
|
def get_z1_given_zt_and_pred(self, zt, pred, z0, t, batch_mask): |
|
|
""" Make a best guess on the final state z1 given the current state and |
|
|
the network prediction. """ |
|
|
|
|
|
z1 = zt + (1 - t)[batch_mask] * pred |
|
|
return z1 |
|
|
|
|
|
|
|
|
class TorusICFM(ICFM): |
|
|
""" |
|
|
Following: |
|
|
Chen, Ricky TQ, and Yaron Lipman. |
|
|
"Riemannian flow matching on general geometries." |
|
|
arXiv preprint arXiv:2302.03660 (2023). |
|
|
""" |
|
|
def __init__(self, sigma, dim, scheduler_args=None): |
|
|
super().__init__(sigma) |
|
|
self.dim = dim |
|
|
|
|
|
|
|
|
scheduler_args = scheduler_args or {} |
|
|
scheduler_args["type"] = scheduler_args.get("type", "linear") |
|
|
scheduler_args["learn_scaled"] = scheduler_args.get("learn_scaled", False) |
|
|
|
|
|
|
|
|
if scheduler_args["type"] == "linear": |
|
|
|
|
|
self.flow_scaling = lambda t: t |
|
|
|
|
|
|
|
|
self.velocity_scaling = lambda t: torch.ones_like(t) |
|
|
|
|
|
|
|
|
elif scheduler_args["type"] == "exponential": |
|
|
|
|
|
self.c = scheduler_args["c"] |
|
|
assert self.c > 0 |
|
|
|
|
|
|
|
|
self.flow_scaling = lambda t: 1 - torch.exp(-self.c * t) |
|
|
|
|
|
|
|
|
self.velocity_scaling = lambda t: self.c * torch.exp(-self.c * t) |
|
|
|
|
|
|
|
|
elif scheduler_args["type"] == "polynomial": |
|
|
self.k = scheduler_args["k"] |
|
|
assert self.k > 0 |
|
|
|
|
|
|
|
|
self.flow_scaling = lambda t: 1 - (1 - t)**self.k |
|
|
|
|
|
|
|
|
self.velocity_scaling = lambda t: self.k * (1 - t)**(self.k - 1) |
|
|
|
|
|
else: |
|
|
raise NotImplementedError(f"Scheduler {scheduler_args['type']} not implemented.") |
|
|
|
|
|
kappa_interval = self.flow_scaling(torch.tensor([0.0, 1.0])) |
|
|
if kappa_interval[0] != 0.0 or kappa_interval[1] != 1.0: |
|
|
print(f"Scheduler should satisfy kappa(0)=1 and kappa(1)=0. Found " |
|
|
f"interval {kappa_interval.tolist()} instead.") |
|
|
|
|
|
|
|
|
|
|
|
self.learn_scaled = scheduler_args["learn_scaled"] |
|
|
|
|
|
@staticmethod |
|
|
def wrap(angle): |
|
|
""" Maps angles to range [-\pi, \pi). """ |
|
|
return ((angle + math.pi) % (2 * math.pi)) - math.pi |
|
|
|
|
|
def exponential_map(self, x, u): |
|
|
""" |
|
|
:param x: point on the manifold |
|
|
:param u: point on the tangent space |
|
|
""" |
|
|
return self.wrap(x + u) |
|
|
|
|
|
@staticmethod |
|
|
def logarithm_map(x, y): |
|
|
""" |
|
|
:param x, y: points on the manifold |
|
|
""" |
|
|
return torch.atan2(torch.sin(y - x), torch.cos(y - x)) |
|
|
|
|
|
def sample_zt(self, z0, z1, t, batch_mask): |
|
|
""" expressed in terms of exponential and logarithm maps """ |
|
|
|
|
|
|
|
|
|
|
|
zt_tangent = self.flow_scaling(t)[batch_mask] * self.logarithm_map(z0, z1) |
|
|
|
|
|
|
|
|
return self.exponential_map(z0, zt_tangent) |
|
|
|
|
|
def get_z1_given_zt_and_pred(self, zt, pred, z0, t, batch_mask): |
|
|
""" Make a best guess on the final state z1 given the current state and |
|
|
the network prediction. """ |
|
|
|
|
|
|
|
|
if self.learn_scaled: |
|
|
pred = pred / torch.clamp(self.velocity_scaling(t), min=1e-3)[batch_mask] |
|
|
|
|
|
z1_tangent = (1 - t)[batch_mask] * pred |
|
|
|
|
|
|
|
|
return self.exponential_map(zt, z1_tangent) |
|
|
|
|
|
def sample_zt_given_zs(self, zs, pred, s, t, batch_mask): |
|
|
""" Perform update, typically using an explicit Euler step. """ |
|
|
|
|
|
step_size = t - s |
|
|
zt_tangent = step_size[batch_mask] * pred |
|
|
|
|
|
if not self.learn_scaled: |
|
|
zt_tangent = self.velocity_scaling(t)[batch_mask] * zt_tangent |
|
|
|
|
|
|
|
|
return self.exponential_map(zs, zt_tangent) |
|
|
|
|
|
def sample_z0(self, batch_mask): |
|
|
""" Prior. """ |
|
|
|
|
|
|
|
|
z0 = torch.rand((len(batch_mask), self.dim), device=batch_mask.device) |
|
|
|
|
|
return 2 * math.pi * z0 - math.pi |
|
|
|
|
|
def compute_loss(self, pred, z0, z1, zt, t, batch_mask, reduce='mean'): |
|
|
""" Compute loss per sample. """ |
|
|
assert reduce in {'mean', 'sum', 'none'} |
|
|
mask = ~torch.isnan(z1) |
|
|
z1 = torch.nan_to_num(z1, nan=0.0) |
|
|
|
|
|
zt_dot = self.logarithm_map(z0, z1) |
|
|
if self.learn_scaled: |
|
|
|
|
|
zt_dot = self.velocity_scaling(t)[batch_mask] * zt_dot |
|
|
loss = mask * (pred - zt_dot) ** 2 |
|
|
loss = torch.sum(loss, dim=-1) |
|
|
|
|
|
if reduce == 'mean': |
|
|
denom = mask.sum(dim=-1) + 1e-6 |
|
|
loss = scatter_mean(loss / denom, batch_mask, dim=0) |
|
|
elif reduce == 'sum': |
|
|
loss = scatter_add(loss, batch_mask, dim=0) |
|
|
return loss |
|
|
|
|
|
|
|
|
class SO3ICFM(ICFM): |
|
|
""" |
|
|
All rotations are assumed to be in axis-angle format. |
|
|
Mostly following descriptions from the FoldFlow paper: |
|
|
https://openreview.net/forum?id=kJFIH23hXb |
|
|
|
|
|
See also: |
|
|
https://geomstats.github.io/_modules/geomstats/geometry/special_orthogonal.html#SpecialOrthogonal |
|
|
https://geomstats.github.io/_modules/geomstats/geometry/lie_group.html#LieGroup |
|
|
""" |
|
|
def __init__(self, sigma): |
|
|
super().__init__(sigma) |
|
|
|
|
|
def exponential_map(self, base, tangent): |
|
|
""" |
|
|
Args: |
|
|
base: base point (rotation vector) on the manifold |
|
|
tangent: point in tangent space at identity |
|
|
Returns: |
|
|
rotation vector on the manifold |
|
|
""" |
|
|
|
|
|
return so3.compose_rotations(base, so3.exp(tangent)) |
|
|
|
|
|
def logarithm_map(self, base, r): |
|
|
""" |
|
|
Args: |
|
|
base: base point (rotation vector) on the manifold |
|
|
r: rotation vector on the manifold |
|
|
Return: |
|
|
point in tangent space at identity |
|
|
""" |
|
|
|
|
|
return so3.log(so3.compose_rotations(-base, r)) |
|
|
|
|
|
def sample_zt(self, z0, z1, t, batch_mask): |
|
|
""" |
|
|
Expressed in terms of exponential and logarithm maps. |
|
|
Corresponds to SLERP interpolation: R(t) = R1 exp( t * log(R1^T R2) ) |
|
|
(see https://lucaballan.altervista.org/pdfs/IK.pdf, slide 16) |
|
|
""" |
|
|
|
|
|
|
|
|
zt_tangent = t[batch_mask] * self.logarithm_map(z0, z1) |
|
|
|
|
|
|
|
|
return self.exponential_map(z0, zt_tangent) |
|
|
|
|
|
def get_z1_given_zt_and_pred(self, zt, pred, z0, t, batch_mask): |
|
|
""" Make a best guess on the final state z1 given the current state and |
|
|
the network prediction. """ |
|
|
|
|
|
|
|
|
z1_tangent = (1 - t)[batch_mask] * pred |
|
|
|
|
|
|
|
|
return self.exponential_map(zt, z1_tangent) |
|
|
|
|
|
def sample_zt_given_zs(self, zs, pred, s, t, batch_mask): |
|
|
""" Perform update, typically using an explicit Euler step. """ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
step_size = t - s |
|
|
zt_tangent = step_size[batch_mask] * pred |
|
|
|
|
|
|
|
|
return self.exponential_map(zs, zt_tangent) |
|
|
|
|
|
def sample_z0(self, batch_mask): |
|
|
""" Prior. """ |
|
|
return so3.random_uniform(n_samples=len(batch_mask), device=batch_mask.device) |
|
|
|
|
|
@staticmethod |
|
|
def d_R_squared_SO3(rot_vec_1, rot_vec_2): |
|
|
""" |
|
|
Squared Riemannian metric on SO(3). |
|
|
Defined as d(R1, R2) = sqrt(0.5) ||log(R1^T R2)||_F |
|
|
where R1, R2 are rotation matrices. |
|
|
|
|
|
The following is equivalent if the difference between the rotations is |
|
|
expressed as a rotation vector \omega_diff: |
|
|
d(r1, r2) = ||\omega_diff||_2 |
|
|
----- |
|
|
With the definition of the Frobenius matrix norm ||A||_F^2 = trace(A^H A): |
|
|
d^2(R1, R2) = 1/2 ||log(R1^T R2)||_F^2 |
|
|
= 1/2 || hat(R_d) ||_F^2 |
|
|
= 1/2 tr( hat(R_d)^T hat(R_d) ) |
|
|
= 1/2 * 2 * ||\omega||_2^2 |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
diff_rot = so3.compose_rotations(-rot_vec_1, rot_vec_2) |
|
|
return diff_rot.square().sum(dim=-1) |
|
|
|
|
|
def compute_loss(self, pred, z0, z1, zt, t, batch_mask, reduce='mean', eps=5e-2): |
|
|
""" Compute loss per sample. """ |
|
|
assert reduce in {'mean', 'sum', 'none'} |
|
|
|
|
|
zt_dot = self.logarithm_map(zt, z1) / torch.clamp(1 - t, min=eps)[batch_mask] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = torch.sum((pred - zt_dot)**2, dim=-1) |
|
|
|
|
|
|
|
|
if reduce == 'mean': |
|
|
loss = scatter_mean(loss, batch_mask, dim=0) |
|
|
elif reduce == 'sum': |
|
|
loss = scatter_add(loss, batch_mask, dim=0) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CoordICFMPredictFinal(CoordICFM): |
|
|
def __init__(self, sigma): |
|
|
self.dim = 3 |
|
|
super().__init__(sigma) |
|
|
|
|
|
def sample_zt_given_zs(self, zs, z1_minus_zs_pred, s, t, batch_mask): |
|
|
""" Perform an explicit Euler step. """ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
step_size = (t - s) / (1.0 - s) |
|
|
assert torch.all(step_size <= 1.0) |
|
|
|
|
|
zt = zs + step_size[batch_mask] * z1_minus_zs_pred |
|
|
return zt |
|
|
|
|
|
def compute_loss(self, z1_minus_zt_pred, z0, z1, t, batch_mask, reduce='mean'): |
|
|
""" Compute loss per sample. """ |
|
|
assert reduce in {'mean', 'sum', 'none'} |
|
|
t = torch.clamp(t, max=0.9) |
|
|
zt = self.sample_zt(z0, z1, t, batch_mask) |
|
|
loss = torch.sum((z1_minus_zt_pred + zt - z1) ** 2, dim=-1) / torch.square(1 - t)[batch_mask].squeeze() |
|
|
|
|
|
if reduce == 'mean': |
|
|
loss = scatter_mean(loss / self.dim, batch_mask, dim=0) |
|
|
elif reduce == 'sum': |
|
|
loss = scatter_add(loss, batch_mask, dim=0) |
|
|
|
|
|
return loss |
|
|
|
|
|
def get_z1_given_zt_and_pred(self, zt, z1_minus_zt_pred, z0, t, batch_mask): |
|
|
return z1_minus_zt_pred + zt |
|
|
|
|
|
|
|
|
class TorusICFMPredictFinal(TorusICFM): |
|
|
""" |
|
|
Following: |
|
|
Chen, Ricky TQ, and Yaron Lipman. |
|
|
"Riemannian flow matching on general geometries." |
|
|
arXiv preprint arXiv:2302.03660 (2023). |
|
|
""" |
|
|
def __init__(self, sigma, dim): |
|
|
super().__init__(sigma, dim) |
|
|
|
|
|
def get_z1_given_zt_and_pred(self, zt, z1_tangent_pred, z0, t, batch_mask): |
|
|
""" Make a best guess on the final state z1 given the current state and |
|
|
the network prediction. """ |
|
|
|
|
|
|
|
|
return self.exponential_map(zt, z1_tangent_pred) |
|
|
|
|
|
def sample_zt_given_zs(self, zs, z1_tangent_pred, s, t, batch_mask): |
|
|
""" Perform update, typically using an explicit Euler step. """ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
step_size = (t - s) / (1.0 - s) |
|
|
assert torch.all(step_size <= 1.0) |
|
|
|
|
|
zt_tangent = step_size[batch_mask] * z1_tangent_pred |
|
|
|
|
|
|
|
|
return self.exponential_map(zs, zt_tangent) |
|
|
|
|
|
def compute_loss(self, z1_tangent_pred, z0, z1, t, batch_mask, reduce='mean'): |
|
|
""" Compute loss per sample. """ |
|
|
assert reduce in {'mean', 'sum', 'none'} |
|
|
zt = self.sample_zt(z0, z1, t, batch_mask) |
|
|
t = torch.clamp(t, max=0.9) |
|
|
|
|
|
mask = ~torch.isnan(z1) |
|
|
z1 = torch.nan_to_num(z1, nan=0.0) |
|
|
loss = mask * (z1_tangent_pred - self.logarithm_map(zt, z1)) ** 2 |
|
|
loss = torch.sum(loss, dim=-1) / torch.square(1 - t)[batch_mask].squeeze() |
|
|
|
|
|
if reduce == 'mean': |
|
|
denom = mask.sum(dim=-1) + 1e-6 |
|
|
loss = scatter_mean(loss / denom, batch_mask, dim=0) |
|
|
elif reduce == 'sum': |
|
|
loss = scatter_add(loss, batch_mask, dim=0) |
|
|
|
|
|
return loss |
|
|
|