| import torch | |
| from liegroups.torch import utils | |
| def test_isclose(): | |
| tol = 1e-6 | |
| mat = torch.Tensor([0, 1, tol, 10 * tol, 0.1 * tol]) | |
| ans = torch.ByteTensor([1, 0, 0, 0, 1]) | |
| assert (utils.isclose(mat, 0., tol=tol) == ans).all() | |
| def test_allclose(): | |
| tol = 1e-6 | |
| mat_good = torch.Tensor([0.1 * tol, 0.01 * tol, 0, 0, 0]) | |
| mat_bad = torch.Tensor([0, 1, tol, 10 * tol, 0.1 * tol]) | |
| assert utils.allclose(mat_good, 0., tol=tol) | |
| assert not utils.allclose(mat_bad, 0., tol=tol) | |
| def test_outer(): | |
| vec1 = torch.Tensor([1, 2, 3]) | |
| vec2 = torch.Tensor([0, 1, 2]) | |
| assert (utils.outer(vec1, vec2) == torch.mm( | |
| vec1.unsqueeze(dim=1), vec2.unsqueeze(dim=0))).all() | |
| vecs1 = torch.Tensor([[1, 2, 3], [4, 5, 6]]) | |
| vecs2 = torch.Tensor([[0, 1, 2], [3, 4, 5]]) | |
| assert (utils.outer(vecs1, vecs2) == torch.bmm( | |
| vecs1.unsqueeze(dim=2), vecs2.unsqueeze(dim=1))).all() | |
| def test_trace(): | |
| mat = torch.arange(1, 10).view(3, 3) | |
| assert utils.trace(mat)[0] == torch.trace(mat) | |
| mats = torch.cat([torch.arange(1, 10).view(1, 3, 3), | |
| torch.arange(11, 20).view(1, 3, 3)], dim=0) | |
| traces = utils.trace(mats) | |
| assert len(traces) == 2 and \ | |
| traces[0] == torch.trace(mats[0]) and \ | |
| traces[1] == torch.trace(mats[1]) | |