|
|
import math |
|
|
import torch |
|
|
|
|
|
|
|
|
def _batch_trace(m): |
|
|
return torch.einsum('...ii', m) |
|
|
|
|
|
|
|
|
def regularize(point, eps=1e-6): |
|
|
""" |
|
|
Norm of the rotation vector should be between 0 and pi. |
|
|
Inverts the direction of the rotation axis if the value is between pi and 2 pi. |
|
|
Args: |
|
|
point, (n, 3) |
|
|
Returns: |
|
|
regularized point, (n, 3) |
|
|
""" |
|
|
theta = torch.linalg.norm(point, axis=-1) |
|
|
|
|
|
|
|
|
theta_wrapped = theta % (2 * math.pi) |
|
|
inv_mask = theta_wrapped > math.pi |
|
|
|
|
|
|
|
|
theta_wrapped[inv_mask] = -1 * (2 * math.pi - theta_wrapped[inv_mask]) |
|
|
|
|
|
|
|
|
theta = torch.clamp(theta, min=eps) |
|
|
point = point * (theta_wrapped / theta).unsqueeze(-1) |
|
|
assert not point.isnan().any() |
|
|
return point |
|
|
|
|
|
|
|
|
def random_uniform(n_samples, device=None): |
|
|
""" |
|
|
Follow geomstats implementation: |
|
|
https://geomstats.github.io/_modules/geomstats/geometry/special_orthogonal.html |
|
|
|
|
|
Args: |
|
|
n_samples: int |
|
|
Returns: |
|
|
rotation vectors, (n, 3) |
|
|
""" |
|
|
random_point = (torch.rand(n_samples, 3, device=device) * 2 - 1) * math.pi |
|
|
random_point = regularize(random_point) |
|
|
|
|
|
return random_point |
|
|
|
|
|
|
|
|
def hat(rot_vec): |
|
|
""" |
|
|
Maps R^3 vector to a skew-symmetric matrix r (i.e. r \in R^{3x3} and r^T = -r). |
|
|
Since we have the identity rv = rot_vec x v for all v \in R^3, this is |
|
|
identical to a cross-product-matrix representation of rot_vec. |
|
|
rot_vec x v = hat(rot_vec)^T v |
|
|
See also: |
|
|
https://en.wikipedia.org/wiki/Cross_product#Conversion_to_matrix_multiplication |
|
|
https://en.wikipedia.org/wiki/Hat_notation#Cross_product |
|
|
Args: |
|
|
rot_vec: (n, 3) |
|
|
Returns: |
|
|
skew-symmetric matrices (n, 3, 3) |
|
|
""" |
|
|
basis = torch.tensor([ |
|
|
[[0., 0., 0.], [0., 0., -1.], [0., 1., 0.]], |
|
|
[[0., 0., 1.], [0., 0., 0.], [-1., 0., 0.]], |
|
|
[[0., -1., 0.], [1., 0., 0.], [0., 0., 0.]] |
|
|
], device=rot_vec.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return torch.einsum('...i,ijk->...jk', rot_vec, basis) |
|
|
|
|
|
|
|
|
def inv_hat(skew_mat): |
|
|
""" |
|
|
Inverse of hat operation |
|
|
Args: |
|
|
skew_mat: skew-symmetric matrices (n, 3, 3) |
|
|
Returns: |
|
|
rotation vectors, (n, 3) |
|
|
""" |
|
|
|
|
|
assert torch.allclose(-skew_mat, skew_mat.transpose(-2, -1), atol=1e-4), \ |
|
|
f"Input not skew-symmetric (err={(-skew_mat - skew_mat.transpose(-2, -1)).abs().max():.4g})" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vec = torch.stack([ |
|
|
skew_mat[:, 2, 1], |
|
|
skew_mat[:, 0, 2], |
|
|
skew_mat[:, 1, 0] |
|
|
], dim=1) |
|
|
|
|
|
return vec |
|
|
|
|
|
|
|
|
def matrix_from_rotation_vector(axis_angle, eps=1e-6): |
|
|
""" |
|
|
Args: |
|
|
axis_angle: (n, 3) |
|
|
Returns: |
|
|
rotation matrices, (n, 3, 3) |
|
|
""" |
|
|
|
|
|
axis_angle = regularize(axis_angle) |
|
|
angle = axis_angle.norm(dim=-1) |
|
|
_norm = torch.clamp(angle, min=eps).unsqueeze(-1) |
|
|
skew_mat = hat(axis_angle / _norm) |
|
|
|
|
|
|
|
|
_id = torch.eye(3, device=axis_angle.device).unsqueeze(0) |
|
|
rot_mat = _id + \ |
|
|
torch.sin(angle)[:, None, None] * skew_mat + \ |
|
|
(1 - torch.cos(angle))[:, None, None] * torch.bmm(skew_mat, skew_mat) |
|
|
|
|
|
return rot_mat |
|
|
|
|
|
|
|
|
class safe_acos(torch.autograd.Function): |
|
|
""" |
|
|
Implementation of arccos that avoids NaN in backward pass. |
|
|
https://github.com/pytorch/pytorch/issues/8069#issuecomment-2041223872 |
|
|
""" |
|
|
EPS = 1e-4 |
|
|
@classmethod |
|
|
def d_acos_dx(cls, x): |
|
|
x = torch.clamp(x, min=-1. + cls.EPS, max=1. - cls.EPS) |
|
|
return -1.0 / (1 - x**2).sqrt() |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, input): |
|
|
ctx.save_for_backward(input) |
|
|
return input.acos() |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
input, = ctx.saved_tensors |
|
|
return grad_output * safe_acos.d_acos_dx(input) |
|
|
|
|
|
|
|
|
def rotation_vector_from_matrix(rot_mat, approx=1e-4): |
|
|
""" |
|
|
Args: |
|
|
rot_mat: (n, 3, 3) |
|
|
approx: float, minimum angle below which an approximation will be used |
|
|
for numerical stability |
|
|
Returns: |
|
|
rotation vector, (n, 3) |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
skew_mat = rot_mat - rot_mat.transpose(-2, -1) |
|
|
|
|
|
|
|
|
cos_angle = 0.5 * (_batch_trace(rot_mat) - 1) |
|
|
|
|
|
assert torch.all(cos_angle.abs() <= 1 + 1e-6) |
|
|
cos_angle = torch.clamp(cos_angle, min=-1., max=1.) |
|
|
|
|
|
abs_angle = safe_acos.apply(cos_angle) |
|
|
|
|
|
|
|
|
close_to_0 = abs_angle < approx |
|
|
_fac = torch.empty_like(abs_angle) |
|
|
_fac[close_to_0] = 0.5 |
|
|
_fac[~close_to_0] = 0.5 * abs_angle[~close_to_0] / torch.sin(abs_angle[~close_to_0]) |
|
|
|
|
|
axis_angle = inv_hat(_fac[:, None, None] * skew_mat) |
|
|
return regularize(axis_angle) |
|
|
|
|
|
|
|
|
def get_jacobian(point, left=True, inverse=False, eps=1e-4): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
angle_squared = point.square().sum(-1) |
|
|
angle = angle_squared.sqrt() |
|
|
skew_mat = hat(point) |
|
|
|
|
|
assert torch.all(angle <= math.pi) |
|
|
close_to_0 = angle < eps |
|
|
close_to_pi = (math.pi - angle) < eps |
|
|
|
|
|
angle = angle[:, None, None] |
|
|
angle_squared = angle_squared[:, None, None] |
|
|
|
|
|
if inverse: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_term1 = torch.empty_like(angle) |
|
|
_term1[close_to_0] = 0.5 |
|
|
_term1[~close_to_0] = (1 - torch.cos(angle)) / angle_squared |
|
|
|
|
|
_term2 = torch.empty_like(angle) |
|
|
_term2[close_to_0] = 1 / 6 |
|
|
_term2[~close_to_0] = (angle - torch.sin(angle)) / angle ** 3 |
|
|
|
|
|
jacobian = torch.eye(3, device=point.device).unsqueeze(0) + \ |
|
|
_term1 * skew_mat + _term2 * (skew_mat @ skew_mat) |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
_term1 = torch.empty_like(angle) |
|
|
_term1[close_to_0] = 1 / 12 |
|
|
_term1[close_to_pi] = 1 / math.pi**2 |
|
|
default = ~close_to_0 & ~close_to_pi |
|
|
_term1[default] = 1 / angle_squared[default] - \ |
|
|
(1 + torch.cos(angle[default])) / (2 * angle[default] * torch.sin(angle[default])) |
|
|
|
|
|
jacobian = torch.eye(3, device=point.device).unsqueeze(0) - \ |
|
|
0.5 * skew_mat + _term1 * (skew_mat @ skew_mat) |
|
|
|
|
|
|
|
|
if left: |
|
|
jacobian = jacobian.transpose(-2, -1) |
|
|
|
|
|
return jacobian |
|
|
|
|
|
|
|
|
def compose_rotations(rot_vec_1, rot_vec_2): |
|
|
rot_mat_1 = matrix_from_rotation_vector(rot_vec_1) |
|
|
rot_mat_2 = matrix_from_rotation_vector(rot_vec_2) |
|
|
rot_mat_out = torch.bmm(rot_mat_1, rot_mat_2) |
|
|
return rotation_vector_from_matrix(rot_mat_out) |
|
|
|
|
|
|
|
|
def exp(tangent): |
|
|
""" |
|
|
Exponential map at identity. |
|
|
Args: |
|
|
tangent: vector on the tangent space, (n, 3) |
|
|
Returns: |
|
|
rotation vector on the manifold, (n, 3) |
|
|
""" |
|
|
|
|
|
exp_from_identity = regularize(tangent) |
|
|
return exp_from_identity |
|
|
|
|
|
|
|
|
def exp_not_from_identity(tangent_vec, base_point): |
|
|
""" |
|
|
Exponential map at base point. |
|
|
Args: |
|
|
tangent_vec: vector on the tangent plane, (n, 3) |
|
|
base_point: base point on the manifold, (n, 3) |
|
|
Returns: |
|
|
new point on the manifold, (n, 3) |
|
|
""" |
|
|
|
|
|
tangent_vec = regularize(tangent_vec) |
|
|
base_point = regularize(base_point) |
|
|
|
|
|
|
|
|
|
|
|
jacobian = get_jacobian(base_point, left=True, inverse=True) |
|
|
tangent_vec_at_id = torch.einsum("...ij,...j->...i", jacobian, tangent_vec) |
|
|
|
|
|
|
|
|
exp_from_identity = exp(tangent_vec_at_id) |
|
|
|
|
|
|
|
|
return compose_rotations(base_point, exp_from_identity) |
|
|
|
|
|
|
|
|
def log(rot_vec, as_skew=False): |
|
|
""" |
|
|
Logarithm map from tangent space at the identity. |
|
|
Args: |
|
|
rot_vec: point on the manifold, (n, 3) |
|
|
Returns: |
|
|
vector on the tangent space, (n, 3) |
|
|
""" |
|
|
|
|
|
|
|
|
log_from_id = rot_vec |
|
|
if as_skew: |
|
|
log_from_id = hat(log_from_id) |
|
|
return log_from_id |
|
|
|
|
|
|
|
|
def log_not_from_identity(point, base_point): |
|
|
""" |
|
|
Logarithm map of point from base point. |
|
|
Args: |
|
|
point: point on the manifold, (n, 3) |
|
|
base_point: base point on the manifold, (n, 3) |
|
|
Returns: |
|
|
vector on the tangent plane, (n, 3) |
|
|
""" |
|
|
point = regularize(point) |
|
|
base_point = regularize(base_point) |
|
|
|
|
|
inv_base_point = -1 * base_point |
|
|
|
|
|
point_near_id = compose_rotations(inv_base_point, point) |
|
|
|
|
|
|
|
|
log_from_id = log(point_near_id) |
|
|
|
|
|
jacobian = get_jacobian(base_point, inverse=False) |
|
|
tangent_vec_at_id = torch.einsum("...ij,...j->...i", jacobian, log_from_id) |
|
|
|
|
|
return tangent_vec_at_id |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import os |
|
|
os.environ['GEOMSTATS_BACKEND'] = "pytorch" |
|
|
import scipy.optimize |
|
|
default_dtype = torch.get_default_dtype() |
|
|
from geomstats.geometry.special_orthogonal import SpecialOrthogonal |
|
|
torch.set_default_dtype(default_dtype) |
|
|
|
|
|
so3_vector = SpecialOrthogonal(n=3, point_type="vector") |
|
|
|
|
|
|
|
|
if torch.__version__ >= '2.0.0': |
|
|
GEOMSTATS_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
def geomstats_tensor_type(func): |
|
|
def inner(*args, **kwargs): |
|
|
with torch.device(GEOMSTATS_DEVICE): |
|
|
out = func(*args, **kwargs) |
|
|
return out |
|
|
|
|
|
return inner |
|
|
else: |
|
|
GEOMSTATS_TENSOR_TYPE = 'torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatTensor' |
|
|
|
|
|
|
|
|
def geomstats_tensor_type(func): |
|
|
def inner(*args, **kwargs): |
|
|
|
|
|
torch.set_default_tensor_type(GEOMSTATS_TENSOR_TYPE) |
|
|
out = func(*args, **kwargs) |
|
|
|
|
|
torch.set_default_tensor_type('torch.FloatTensor') |
|
|
return out |
|
|
|
|
|
return inner |
|
|
|
|
|
@geomstats_tensor_type |
|
|
def gs_matrix_from_rotation_vector(*args, **kwargs): |
|
|
return so3_vector.matrix_from_rotation_vector(*args, **kwargs) |
|
|
|
|
|
@geomstats_tensor_type |
|
|
def gs_rotation_vector_from_matrix(*args, **kwargs): |
|
|
return so3_vector.rotation_vector_from_matrix(*args, **kwargs) |
|
|
|
|
|
@geomstats_tensor_type |
|
|
def gs_exp_not_from_identity(*args, **kwargs): |
|
|
return so3_vector.exp_not_from_identity(*args, **kwargs) |
|
|
|
|
|
@geomstats_tensor_type |
|
|
def gs_log_not_from_identity(*args, **kwargs): |
|
|
|
|
|
return so3_vector.log_not_from_identity(*args, **kwargs) |
|
|
|
|
|
@geomstats_tensor_type |
|
|
def compose(*args, **kwargs): |
|
|
return so3_vector.compose(*args, **kwargs) |
|
|
|
|
|
@geomstats_tensor_type |
|
|
def inverse(*args, **kwargs): |
|
|
return so3_vector.inverse(*args, **kwargs) |
|
|
|
|
|
@geomstats_tensor_type |
|
|
def gs_random_uniform(*args, **kwargs): |
|
|
return so3_vector.random_uniform(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n = 16 |
|
|
device = 'cuda' if torch.cuda.is_available() else None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vec = (torch.rand(n, 3) * 4 - 2) * math.pi |
|
|
axis_angle = regularize(vec) |
|
|
assert torch.all(torch.cross(vec, axis_angle).norm(dim=-1) < 1e-5), "not all vectors collinear" |
|
|
assert torch.all(axis_angle.norm(dim=-1) < math.pi) & torch.all(axis_angle.norm(dim=-1) >= 0), "norm not between 0 and pi" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rot_vec = random_uniform(16, device=device) |
|
|
assert torch.allclose(matrix_from_rotation_vector(rot_vec), |
|
|
gs_matrix_from_rotation_vector(rot_vec), atol=1e-06) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rot_vec = random_uniform(16, device=device) |
|
|
rot_mat = matrix_from_rotation_vector(rot_vec) |
|
|
assert torch.allclose(rotation_vector_from_matrix(rot_mat), |
|
|
gs_rotation_vector_from_matrix(rot_mat), atol=1e-05) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tangent_vec = random_uniform(16, device=device) |
|
|
base_pt = random_uniform(16, device=device) |
|
|
my_val = exp_not_from_identity(tangent_vec, base_pt) |
|
|
gs_val = gs_exp_not_from_identity(tangent_vec, base_pt) |
|
|
assert torch.allclose(my_val, gs_val, atol=1e-03), (my_val - gs_val).abs().max() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pt = random_uniform(16, device=device) |
|
|
base_pt = random_uniform(16, device=device) |
|
|
my_val = log_not_from_identity(pt, base_pt) |
|
|
gs_val = gs_log_not_from_identity(pt, base_pt) |
|
|
assert torch.allclose(my_val, gs_val, atol=1e-03), (my_val - gs_val).abs().max() |
|
|
|
|
|
|
|
|
print("All tests successful!") |
|
|
|