Spaces:
Sleeping
Sleeping
File size: 74,554 Bytes
0efdb24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 |
"""
Utilities for geometry operations.
References: DUSt3R, MoGe
"""
from numbers import Number
from typing import Tuple, Union
import einops as ein
import numpy as np
import torch
import torch.nn.functional as F
from mapanything.utils.misc import invalid_to_zeros
from mapanything.utils.warnings import no_warnings
def depthmap_to_camera_frame(depthmap, intrinsics):
"""
Convert depth image to a pointcloud in camera frame.
Args:
- depthmap: HxW or BxHxW torch tensor
- intrinsics: 3x3 or Bx3x3 torch tensor
Returns:
pointmap in camera frame (HxWx3 or BxHxWx3 tensor), and a mask specifying valid pixels.
"""
# Add batch dimension if not present
if depthmap.dim() == 2:
depthmap = depthmap.unsqueeze(0)
intrinsics = intrinsics.unsqueeze(0)
squeeze_batch_dim = True
else:
squeeze_batch_dim = False
batch_size, height, width = depthmap.shape
device = depthmap.device
# Compute 3D point in camera frame associated with each pixel
x_grid, y_grid = torch.meshgrid(
torch.arange(width, device=device).float(),
torch.arange(height, device=device).float(),
indexing="xy",
)
x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1)
y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1)
fx = intrinsics[:, 0, 0].view(-1, 1, 1)
fy = intrinsics[:, 1, 1].view(-1, 1, 1)
cx = intrinsics[:, 0, 2].view(-1, 1, 1)
cy = intrinsics[:, 1, 2].view(-1, 1, 1)
depth_z = depthmap
xx = (x_grid - cx) * depth_z / fx
yy = (y_grid - cy) * depth_z / fy
pts3d_cam = torch.stack((xx, yy, depth_z), dim=-1)
# Compute mask of valid non-zero depth pixels
valid_mask = depthmap > 0.0
# Remove batch dimension if it was added
if squeeze_batch_dim:
pts3d_cam = pts3d_cam.squeeze(0)
valid_mask = valid_mask.squeeze(0)
return pts3d_cam, valid_mask
def depthmap_to_world_frame(depthmap, intrinsics, camera_pose=None):
"""
Convert depth image to a pointcloud in world frame.
Args:
- depthmap: HxW or BxHxW torch tensor
- intrinsics: 3x3 or Bx3x3 torch tensor
- camera_pose: 4x4 or Bx4x4 torch tensor
Returns:
pointmap in world frame (HxWx3 or BxHxWx3 tensor), and a mask specifying valid pixels.
"""
pts3d_cam, valid_mask = depthmap_to_camera_frame(depthmap, intrinsics)
if camera_pose is not None:
# Add batch dimension if not present
if camera_pose.dim() == 2:
camera_pose = camera_pose.unsqueeze(0)
pts3d_cam = pts3d_cam.unsqueeze(0)
squeeze_batch_dim = True
else:
squeeze_batch_dim = False
# Convert points from camera frame to world frame
pts3d_cam_homo = torch.cat(
[pts3d_cam, torch.ones_like(pts3d_cam[..., :1])], dim=-1
)
pts3d_world = ein.einsum(
camera_pose, pts3d_cam_homo, "b i k, b h w k -> b h w i"
)
pts3d_world = pts3d_world[..., :3]
# Remove batch dimension if it was added
if squeeze_batch_dim:
pts3d_world = pts3d_world.squeeze(0)
else:
pts3d_world = pts3d_cam
return pts3d_world, valid_mask
def transform_pts3d(pts3d, transformation):
"""
Transform 3D points using a 4x4 transformation matrix.
Args:
- pts3d: HxWx3 or BxHxWx3 torch tensor
- transformation: 4x4 or Bx4x4 torch tensor
Returns:
transformed points (HxWx3 or BxHxWx3 tensor)
"""
# Add batch dimension if not present
if pts3d.dim() == 3:
pts3d = pts3d.unsqueeze(0)
transformation = transformation.unsqueeze(0)
squeeze_batch_dim = True
else:
squeeze_batch_dim = False
# Convert points to homogeneous coordinates
pts3d_homo = torch.cat([pts3d, torch.ones_like(pts3d[..., :1])], dim=-1)
# Transform points
transformed_pts3d = ein.einsum(
transformation, pts3d_homo, "b i k, b h w k -> b h w i"
)
transformed_pts3d = transformed_pts3d[..., :3]
# Remove batch dimension if it was added
if squeeze_batch_dim:
transformed_pts3d = transformed_pts3d.squeeze(0)
return transformed_pts3d
def project_pts3d_to_image(pts3d, intrinsics, return_z_dim):
"""
Project 3D points to image plane (assumes pinhole camera model with no distortion).
Args:
- pts3d: HxWx3 or BxHxWx3 torch tensor
- intrinsics: 3x3 or Bx3x3 torch tensor
- return_z_dim: bool, whether to return the third dimension of the projected points
Returns:
projected points (HxWx2)
"""
if pts3d.dim() == 3:
pts3d = pts3d.unsqueeze(0)
intrinsics = intrinsics.unsqueeze(0)
squeeze_batch_dim = True
else:
squeeze_batch_dim = False
# Project points to image plane
projected_pts2d = ein.einsum(intrinsics, pts3d, "b i k, b h w k -> b h w i")
projected_pts2d[..., :2] /= projected_pts2d[..., 2].unsqueeze(-1).clamp(min=1e-6)
# Remove the z dimension if not required
if not return_z_dim:
projected_pts2d = projected_pts2d[..., :2]
# Remove batch dimension if it was added
if squeeze_batch_dim:
projected_pts2d = projected_pts2d.squeeze(0)
return projected_pts2d
def get_rays_in_camera_frame(intrinsics, height, width, normalize_to_unit_sphere):
"""
Convert camera intrinsics to a raymap (ray origins + directions) in camera frame.
Note: Currently only supports pinhole camera model.
Args:
- intrinsics: 3x3 or Bx3x3 torch tensor
- height: int
- width: int
- normalize_to_unit_sphere: bool
Returns:
- ray_origins: (HxWx3 or BxHxWx3) tensor
- ray_directions: (HxWx3 or BxHxWx3) tensor
"""
# Add batch dimension if not present
if intrinsics.dim() == 2:
intrinsics = intrinsics.unsqueeze(0)
squeeze_batch_dim = True
else:
squeeze_batch_dim = False
batch_size = intrinsics.shape[0]
device = intrinsics.device
# Compute rays in camera frame associated with each pixel
x_grid, y_grid = torch.meshgrid(
torch.arange(width, device=device).float(),
torch.arange(height, device=device).float(),
indexing="xy",
)
x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1)
y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1)
fx = intrinsics[:, 0, 0].view(-1, 1, 1)
fy = intrinsics[:, 1, 1].view(-1, 1, 1)
cx = intrinsics[:, 0, 2].view(-1, 1, 1)
cy = intrinsics[:, 1, 2].view(-1, 1, 1)
ray_origins = torch.zeros((batch_size, height, width, 3), device=device)
xx = (x_grid - cx) / fx
yy = (y_grid - cy) / fy
ray_directions = torch.stack((xx, yy, torch.ones_like(xx)), dim=-1)
# Normalize ray directions to unit sphere if required (else rays will lie on unit plane)
if normalize_to_unit_sphere:
ray_directions = ray_directions / torch.norm(
ray_directions, dim=-1, keepdim=True
)
# Remove batch dimension if it was added
if squeeze_batch_dim:
ray_origins = ray_origins.squeeze(0)
ray_directions = ray_directions.squeeze(0)
return ray_origins, ray_directions
def get_rays_in_world_frame(
intrinsics, height, width, normalize_to_unit_sphere, camera_pose=None
):
"""
Convert camera intrinsics & camera_pose (if provided) to a raymap (ray origins + directions) in camera or world frame (if camera_pose is provided).
Note: Currently only supports pinhole camera model.
Args:
- intrinsics: 3x3 or Bx3x3 torch tensor
- height: int
- width: int
- normalize_to_unit_sphere: bool
- camera_pose: 4x4 or Bx4x4 torch tensor
Returns:
- ray_origins: (HxWx3 or BxHxWx3) tensor
- ray_directions: (HxWx3 or BxHxWx3) tensor
"""
# Get rays in camera frame
ray_origins, ray_directions = get_rays_in_camera_frame(
intrinsics, height, width, normalize_to_unit_sphere
)
if camera_pose is not None:
# Add batch dimension if not present
if camera_pose.dim() == 2:
camera_pose = camera_pose.unsqueeze(0)
ray_origins = ray_origins.unsqueeze(0)
ray_directions = ray_directions.unsqueeze(0)
squeeze_batch_dim = True
else:
squeeze_batch_dim = False
# Convert rays from camera frame to world frame
ray_origins_homo = torch.cat(
[ray_origins, torch.ones_like(ray_origins[..., :1])], dim=-1
)
ray_directions_homo = torch.cat(
[ray_directions, torch.zeros_like(ray_directions[..., :1])], dim=-1
)
ray_origins_world = ein.einsum(
camera_pose, ray_origins_homo, "b i k, b h w k -> b h w i"
)
ray_directions_world = ein.einsum(
camera_pose, ray_directions_homo, "b i k, b h w k -> b h w i"
)
ray_origins_world = ray_origins_world[..., :3]
ray_directions_world = ray_directions_world[..., :3]
# Remove batch dimension if it was added
if squeeze_batch_dim:
ray_origins_world = ray_origins_world.squeeze(0)
ray_directions_world = ray_directions_world.squeeze(0)
else:
ray_origins_world = ray_origins
ray_directions_world = ray_directions
return ray_origins_world, ray_directions_world
def recover_pinhole_intrinsics_from_ray_directions(
ray_directions, use_geometric_calculation=False
):
"""
Recover pinhole camera intrinsics from ray directions, supporting both batched and non-batched inputs.
Args:
ray_directions: Tensor of shape [H, W, 3] or [B, H, W, 3] containing unit normalized ray directions
Returns:
Dictionary containing camera intrinsics (fx, fy, cx, cy) as tensors
"""
# Add batch dimension if not present
if ray_directions.dim() == 3: # [H, W, 3]
ray_directions = ray_directions.unsqueeze(0) # [1, H, W, 3]
squeeze_batch_dim = True
else:
squeeze_batch_dim = False
batch_size, height, width, _ = ray_directions.shape
device = ray_directions.device
# Create pixel coordinate grid
x_grid, y_grid = torch.meshgrid(
torch.arange(width, device=device).float(),
torch.arange(height, device=device).float(),
indexing="xy",
)
# Expand grid for all batches
x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1) # [B, H, W]
y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1) # [B, H, W]
# Determine if high resolution or not
is_high_res = height * width > 1000000
if is_high_res or use_geometric_calculation:
# For high-resolution cases, use direct geometric calculation
# Define key points
center_h, center_w = height // 2, width // 2
quarter_w, three_quarter_w = width // 4, 3 * width // 4
quarter_h, three_quarter_h = height // 4, 3 * height // 4
# Get rays at key points
center_rays = ray_directions[:, center_h, center_w, :].clone() # [B, 3]
left_rays = ray_directions[:, center_h, quarter_w, :].clone() # [B, 3]
right_rays = ray_directions[:, center_h, three_quarter_w, :].clone() # [B, 3]
top_rays = ray_directions[:, quarter_h, center_w, :].clone() # [B, 3]
bottom_rays = ray_directions[:, three_quarter_h, center_w, :].clone() # [B, 3]
# Normalize rays to have dz = 1
center_rays = center_rays / center_rays[:, 2].unsqueeze(1) # [B, 3]
left_rays = left_rays / left_rays[:, 2].unsqueeze(1) # [B, 3]
right_rays = right_rays / right_rays[:, 2].unsqueeze(1) # [B, 3]
top_rays = top_rays / top_rays[:, 2].unsqueeze(1) # [B, 3]
bottom_rays = bottom_rays / bottom_rays[:, 2].unsqueeze(1) # [B, 3]
# Calculate fx directly (vectorized across batch)
fx_left = (quarter_w - center_w) / (left_rays[:, 0] - center_rays[:, 0])
fx_right = (three_quarter_w - center_w) / (right_rays[:, 0] - center_rays[:, 0])
fx = (fx_left + fx_right) / 2 # Average for robustness
# Calculate cx
cx = center_w - fx * center_rays[:, 0]
# Calculate fy and cy
fy_top = (quarter_h - center_h) / (top_rays[:, 1] - center_rays[:, 1])
fy_bottom = (three_quarter_h - center_h) / (
bottom_rays[:, 1] - center_rays[:, 1]
)
fy = (fy_top + fy_bottom) / 2
cy = center_h - fy * center_rays[:, 1]
else:
# For standard resolution, use regression with sampling for efficiency
# Sample a grid of points (but more dense than for high-res)
step_h = max(1, height // 50)
step_w = max(1, width // 50)
h_indices = torch.arange(0, height, step_h, device=device)
w_indices = torch.arange(0, width, step_w, device=device)
# Extract subset of coordinates
x_sampled = x_grid[:, h_indices[:, None], w_indices[None, :]] # [B, H', W']
y_sampled = y_grid[:, h_indices[:, None], w_indices[None, :]] # [B, H', W']
rays_sampled = ray_directions[
:, h_indices[:, None], w_indices[None, :], :
] # [B, H', W', 3]
# Reshape for linear regression
x_flat = x_sampled.reshape(batch_size, -1) # [B, N]
y_flat = y_sampled.reshape(batch_size, -1) # [B, N]
# Extract ray direction components
dx = rays_sampled[..., 0].reshape(batch_size, -1) # [B, N]
dy = rays_sampled[..., 1].reshape(batch_size, -1) # [B, N]
dz = rays_sampled[..., 2].reshape(batch_size, -1) # [B, N]
# Compute ratios for linear regression
ratio_x = dx / dz # [B, N]
ratio_y = dy / dz # [B, N]
# Since torch.linalg.lstsq doesn't support batched input, we'll use a different approach
# For x-direction: x = cx + fx * (dx/dz)
# We can solve this using normal equations: A^T A x = A^T b
# Create design matrices
ones = torch.ones_like(x_flat) # [B, N]
A_x = torch.stack([ones, ratio_x], dim=2) # [B, N, 2]
b_x = x_flat.unsqueeze(2) # [B, N, 1]
# Compute A^T A and A^T b for each batch
ATA_x = torch.bmm(A_x.transpose(1, 2), A_x) # [B, 2, 2]
ATb_x = torch.bmm(A_x.transpose(1, 2), b_x) # [B, 2, 1]
# Solve the system for each batch
solution_x = torch.linalg.solve(ATA_x, ATb_x).squeeze(2) # [B, 2]
cx, fx = solution_x[:, 0], solution_x[:, 1]
# Repeat for y-direction
A_y = torch.stack([ones, ratio_y], dim=2) # [B, N, 2]
b_y = y_flat.unsqueeze(2) # [B, N, 1]
ATA_y = torch.bmm(A_y.transpose(1, 2), A_y) # [B, 2, 2]
ATb_y = torch.bmm(A_y.transpose(1, 2), b_y) # [B, 2, 1]
solution_y = torch.linalg.solve(ATA_y, ATb_y).squeeze(2) # [B, 2]
cy, fy = solution_y[:, 0], solution_y[:, 1]
# Create intrinsics matrices
batch_size = fx.shape[0]
intrinsics = torch.zeros(batch_size, 3, 3, device=ray_directions.device)
# Fill in the intrinsics matrices
intrinsics[:, 0, 0] = fx # focal length x
intrinsics[:, 1, 1] = fy # focal length y
intrinsics[:, 0, 2] = cx # principal point x
intrinsics[:, 1, 2] = cy # principal point y
intrinsics[:, 2, 2] = 1.0 # bottom-right element is always 1
# Remove batch dimension if it was added
if squeeze_batch_dim:
intrinsics = intrinsics.squeeze(0)
return intrinsics
def transform_rays(ray_origins, ray_directions, transformation):
"""
Transform 6D rays (ray origins and ray directions) using a 4x4 transformation matrix.
Args:
- ray_origins: HxWx3 or BxHxWx3 torch tensor
- ray_directions: HxWx3 or BxHxWx3 torch tensor
- transformation: 4x4 or Bx4x4 torch tensor
- normalize_to_unit_sphere: bool, whether to normalize the transformed ray directions to unit length
Returns:
transformed ray_origins (HxWx3 or BxHxWx3 tensor) and ray_directions (HxWx3 or BxHxWx3 tensor)
"""
# Add batch dimension if not present
if ray_origins.dim() == 3:
ray_origins = ray_origins.unsqueeze(0)
ray_directions = ray_directions.unsqueeze(0)
transformation = transformation.unsqueeze(0)
squeeze_batch_dim = True
else:
squeeze_batch_dim = False
# Transform ray origins and directions
ray_origins_homo = torch.cat(
[ray_origins, torch.ones_like(ray_origins[..., :1])], dim=-1
)
ray_directions_homo = torch.cat(
[ray_directions, torch.zeros_like(ray_directions[..., :1])], dim=-1
)
transformed_ray_origins = ein.einsum(
transformation, ray_origins_homo, "b i k, b h w k -> b h w i"
)
transformed_ray_directions = ein.einsum(
transformation, ray_directions_homo, "b i k, b h w k -> b h w i"
)
transformed_ray_origins = transformed_ray_origins[..., :3]
transformed_ray_directions = transformed_ray_directions[..., :3]
# Remove batch dimension if it was added
if squeeze_batch_dim:
transformed_ray_origins = transformed_ray_origins.squeeze(0)
transformed_ray_directions = transformed_ray_directions.squeeze(0)
return transformed_ray_origins, transformed_ray_directions
def convert_z_depth_to_depth_along_ray(z_depth, intrinsics):
"""
Convert z-depth image to depth along camera rays.
Args:
- z_depth: HxW or BxHxW torch tensor
- intrinsics: 3x3 or Bx3x3 torch tensor
Returns:
- depth_along_ray: HxW or BxHxW torch tensor
"""
# Add batch dimension if not present
if z_depth.dim() == 2:
z_depth = z_depth.unsqueeze(0)
intrinsics = intrinsics.unsqueeze(0)
squeeze_batch_dim = True
else:
squeeze_batch_dim = False
# Get rays in camera frame
batch_size, height, width = z_depth.shape
_, ray_directions = get_rays_in_camera_frame(
intrinsics, height, width, normalize_to_unit_sphere=False
)
# Compute depth along ray
pts3d_cam = z_depth[..., None] * ray_directions
depth_along_ray = torch.norm(pts3d_cam, dim=-1)
# Remove batch dimension if it was added
if squeeze_batch_dim:
depth_along_ray = depth_along_ray.squeeze(0)
return depth_along_ray
def convert_raymap_z_depth_quats_to_pointmap(ray_origins, ray_directions, depth, quats):
"""
Convert raymap (ray origins + directions on unit plane), z-depth and
unit quaternions (representing rotation) to a pointmap in world frame.
Args:
- ray_origins: (HxWx3 or BxHxWx3) torch tensor
- ray_directions: (HxWx3 or BxHxWx3) torch tensor
- depth: (HxWx1 or BxHxWx1) torch tensor
- quats: (HxWx4 or BxHxWx4) torch tensor (unit quaternions and notation is (x, y, z, w))
Returns:
- pointmap: (HxWx3 or BxHxWx3) torch tensor
"""
# Add batch dimension if not present
if ray_origins.dim() == 3:
ray_origins = ray_origins.unsqueeze(0)
ray_directions = ray_directions.unsqueeze(0)
depth = depth.unsqueeze(0)
quats = quats.unsqueeze(0)
squeeze_batch_dim = True
else:
squeeze_batch_dim = False
batch_size, height, width, _ = depth.shape
device = depth.device
# Normalize the quaternions to ensure they are unit quaternions
quats = quats / torch.norm(quats, dim=-1, keepdim=True)
# Convert quaternions to pixel-wise rotation matrices
qx, qy, qz, qw = quats[..., 0], quats[..., 1], quats[..., 2], quats[..., 3]
rot_mat = (
torch.stack(
[
qw**2 + qx**2 - qy**2 - qz**2,
2 * (qx * qy - qw * qz),
2 * (qw * qy + qx * qz),
2 * (qw * qz + qx * qy),
qw**2 - qx**2 + qy**2 - qz**2,
2 * (qy * qz - qw * qx),
2 * (qx * qz - qw * qy),
2 * (qw * qx + qy * qz),
qw**2 - qx**2 - qy**2 + qz**2,
],
dim=-1,
)
.reshape(batch_size, height, width, 3, 3)
.to(device)
)
# Compute 3D points in local camera frame
pts3d_local = depth * ray_directions
# Rotate the local points using the quaternions
rotated_pts3d_local = ein.einsum(
rot_mat, pts3d_local, "b h w i k, b h w k -> b h w i"
)
# Compute 3D point in world frame associated with each pixel
pts3d = ray_origins + rotated_pts3d_local
# Remove batch dimension if it was added
if squeeze_batch_dim:
pts3d = pts3d.squeeze(0)
return pts3d
def quaternion_to_rotation_matrix(quat):
"""
Convert a quaternion into a 3x3 rotation matrix.
Args:
- quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
Returns:
- rot_matrix: 3x3 or Bx3x3 torch tensor
"""
if quat.dim() == 1:
quat = quat.unsqueeze(0)
squeeze_batch_dim = True
else:
squeeze_batch_dim = False
# Ensure the quaternion is normalized
quat = quat / quat.norm(dim=1, keepdim=True)
x, y, z, w = quat.unbind(dim=1)
# Compute the rotation matrix elements
xx = x * x
yy = y * y
zz = z * z
xy = x * y
xz = x * z
yz = y * z
wx = w * x
wy = w * y
wz = w * z
# Construct the rotation matrix
rot_matrix = torch.stack(
[
1 - 2 * (yy + zz),
2 * (xy - wz),
2 * (xz + wy),
2 * (xy + wz),
1 - 2 * (xx + zz),
2 * (yz - wx),
2 * (xz - wy),
2 * (yz + wx),
1 - 2 * (xx + yy),
],
dim=1,
).view(-1, 3, 3)
# Squeeze batch dimension if it was unsqueezed
if squeeze_batch_dim:
rot_matrix = rot_matrix.squeeze(0)
return rot_matrix
def rotation_matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part last, as tensor of shape (..., 4).
Quaternion Order: XYZW or say ijkr, scalar-last
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack(
[
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
# the candidate won't be picked.
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator)
out = quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
].reshape(batch_dim + (4,))
# Convert from rijk to ijkr
out = out[..., [1, 2, 3, 0]]
out = standardize_quaternion(out)
return out
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
if torch.is_grad_enabled():
ret[positive_mask] = torch.sqrt(x[positive_mask])
else:
ret = torch.where(positive_mask, torch.sqrt(x), ret)
return ret
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part last,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
def quaternion_inverse(quat):
"""
Compute the inverse of a quaternion.
Args:
- quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
Returns:
- inv_quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
"""
# Unsqueeze batch dimension if not present
if quat.dim() == 1:
quat = quat.unsqueeze(0)
squeeze_batch_dim = True
else:
squeeze_batch_dim = False
# Compute the inverse
quat_conj = quat.clone()
quat_conj[:, :3] = -quat_conj[:, :3]
quat_norm = torch.sum(quat * quat, dim=1, keepdim=True)
inv_quat = quat_conj / quat_norm
# Squeeze batch dimension if it was unsqueezed
if squeeze_batch_dim:
inv_quat = inv_quat.squeeze(0)
return inv_quat
def quaternion_multiply(q1, q2):
"""
Multiply two quaternions.
Args:
- q1: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
- q2: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
Returns:
- qm: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
"""
# Unsqueeze batch dimension if not present
if q1.dim() == 1:
q1 = q1.unsqueeze(0)
q2 = q2.unsqueeze(0)
squeeze_batch_dim = True
else:
squeeze_batch_dim = False
# Unbind the quaternions
x1, y1, z1, w1 = q1.unbind(dim=1)
x2, y2, z2, w2 = q2.unbind(dim=1)
# Compute the product
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
# Stack the components
qm = torch.stack([x, y, z, w], dim=1)
# Squeeze batch dimension if it was unsqueezed
if squeeze_batch_dim:
qm = qm.squeeze(0)
return qm
def transform_pose_using_quats_and_trans_2_to_1(quats1, trans1, quats2, trans2):
"""
Transform quats and translation of pose2 from absolute frame (pose2 to world) to relative frame (pose2 to pose1).
Args:
- quats1: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
- trans1: 3 or Bx3 torch tensor
- quats2: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
- trans2: 3 or Bx3 torch tensor
Returns:
- quats: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
- trans: 3 or Bx3 torch tensor
"""
# Unsqueeze batch dimension if not present
if quats1.dim() == 1:
quats1 = quats1.unsqueeze(0)
trans1 = trans1.unsqueeze(0)
quats2 = quats2.unsqueeze(0)
trans2 = trans2.unsqueeze(0)
squeeze_batch_dim = True
else:
squeeze_batch_dim = False
# Compute the inverse of view1's pose
inv_quats1 = quaternion_inverse(quats1)
R1_inv = quaternion_to_rotation_matrix(inv_quats1)
t1_inv = -1 * ein.einsum(R1_inv, trans1, "b i j, b j -> b i")
# Transform view2's pose to view1's frame
quats = quaternion_multiply(inv_quats1, quats2)
trans = ein.einsum(R1_inv, trans2, "b i j, b j -> b i") + t1_inv
# Squeeze batch dimension if it was unsqueezed
if squeeze_batch_dim:
quats = quats.squeeze(0)
trans = trans.squeeze(0)
return quats, trans
def convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
ray_directions, depth_along_ray, pose_trans, pose_quats
):
"""
Convert ray directions, depth along ray, pose translation, and
unit quaternions (representing pose rotation) to a pointmap in world frame.
Args:
- ray_directions: (HxWx3 or BxHxWx3) torch tensor
- depth_along_ray: (HxWx1 or BxHxWx1) torch tensor
- pose_trans: (3 or Bx3) torch tensor
- pose_quats: (4 or Bx4) torch tensor (unit quaternions and notation is (x, y, z, w))
Returns:
- pointmap: (HxWx3 or BxHxWx3) torch tensor
"""
# Add batch dimension if not present
if ray_directions.dim() == 3:
ray_directions = ray_directions.unsqueeze(0)
depth_along_ray = depth_along_ray.unsqueeze(0)
pose_trans = pose_trans.unsqueeze(0)
pose_quats = pose_quats.unsqueeze(0)
squeeze_batch_dim = True
else:
squeeze_batch_dim = False
batch_size, height, width, _ = depth_along_ray.shape
device = depth_along_ray.device
# Normalize the quaternions to ensure they are unit quaternions
pose_quats = pose_quats / torch.norm(pose_quats, dim=-1, keepdim=True)
# Convert quaternions to rotation matrices (B x 3 x 3)
rot_mat = quaternion_to_rotation_matrix(pose_quats)
# Get pose matrix (B x 4 x 4)
pose_mat = torch.eye(4, device=device).unsqueeze(0).repeat(batch_size, 1, 1)
pose_mat[:, :3, :3] = rot_mat
pose_mat[:, :3, 3] = pose_trans
# Compute 3D points in local camera frame
pts3d_local = depth_along_ray * ray_directions
# Compute 3D points in world frame
pts3d_homo = torch.cat([pts3d_local, torch.ones_like(pts3d_local[..., :1])], dim=-1)
pts3d_world = ein.einsum(pose_mat, pts3d_homo, "b i k, b h w k -> b h w i")
pts3d_world = pts3d_world[..., :3]
# Remove batch dimension if it was added
if squeeze_batch_dim:
pts3d_world = pts3d_world.squeeze(0)
return pts3d_world
def xy_grid(
W,
H,
device=None,
origin=(0, 0),
unsqueeze=None,
cat_dim=-1,
homogeneous=False,
**arange_kw,
):
"""
Generate a coordinate grid of shape (H,W,2) or (H,W,3) if homogeneous=True.
Args:
W (int): Width of the grid
H (int): Height of the grid
device (torch.device, optional): Device to place the grid on. If None, uses numpy arrays
origin (tuple, optional): Origin coordinates (x,y) for the grid. Default is (0,0)
unsqueeze (int, optional): Dimension to unsqueeze in the output tensors
cat_dim (int, optional): Dimension to concatenate the x,y coordinates. If None, returns tuple
homogeneous (bool, optional): If True, adds a third dimension of ones to make homogeneous coordinates
**arange_kw: Additional keyword arguments passed to np.arange or torch.arange
Returns:
numpy.ndarray or torch.Tensor: Coordinate grid where:
- output[j,i,0] = i + origin[0] (x-coordinate)
- output[j,i,1] = j + origin[1] (y-coordinate)
- output[j,i,2] = 1 (if homogeneous=True)
"""
if device is None:
# numpy
arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
else:
# torch
def arange(*a, **kw):
return torch.arange(*a, device=device, **kw)
meshgrid, stack = torch.meshgrid, torch.stack
def ones(*a):
return torch.ones(*a, device=device)
tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)]
grid = meshgrid(tw, th, indexing="xy")
if homogeneous:
grid = grid + (ones((H, W)),)
if unsqueeze is not None:
grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
if cat_dim is not None:
grid = stack(grid, cat_dim)
return grid
def geotrf(Trf, pts, ncol=None, norm=False):
"""
Apply a geometric transformation to a set of 3-D points.
Args:
Trf: 3x3 or 4x4 projection matrix (typically a Homography) or batch of matrices
with shape (B, 3, 3) or (B, 4, 4)
pts: numpy/torch/tuple of coordinates with shape (..., 2) or (..., 3)
ncol: int, number of columns of the result (2 or 3)
norm: float, if not 0, the result is projected on the z=norm plane
(homogeneous normalization)
Returns:
Array or tensor of projected points with the same type as input and shape (..., ncol)
"""
assert Trf.ndim >= 2
if isinstance(Trf, np.ndarray):
pts = np.asarray(pts)
elif isinstance(Trf, torch.Tensor):
pts = torch.as_tensor(pts, dtype=Trf.dtype)
# Adapt shape if necessary
output_reshape = pts.shape[:-1]
ncol = ncol or pts.shape[-1]
# Optimized code
if (
isinstance(Trf, torch.Tensor)
and isinstance(pts, torch.Tensor)
and Trf.ndim == 3
and pts.ndim == 4
):
d = pts.shape[3]
if Trf.shape[-1] == d:
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
elif Trf.shape[-1] == d + 1:
pts = (
torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
+ Trf[:, None, None, :d, d]
)
else:
raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
else:
if Trf.ndim >= 3:
n = Trf.ndim - 2
assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
if pts.ndim > Trf.ndim:
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
elif pts.ndim == 2:
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
pts = pts[:, None, :]
if pts.shape[-1] + 1 == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
elif pts.shape[-1] == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf
else:
pts = Trf @ pts.T
if pts.ndim >= 2:
pts = pts.swapaxes(-1, -2)
if norm:
pts = pts / pts[..., -1:] # DONT DO /=, it will lead to a bug
if norm != 1:
pts *= norm
res = pts[..., :ncol].reshape(*output_reshape, ncol)
return res
def inv(mat):
"""
Invert a torch or numpy matrix
"""
if isinstance(mat, torch.Tensor):
return torch.linalg.inv(mat)
if isinstance(mat, np.ndarray):
return np.linalg.inv(mat)
raise ValueError(f"bad matrix type = {type(mat)}")
def closed_form_pose_inverse(
pose_matrices, rotation_matrices=None, translation_vectors=None
):
"""
Compute the inverse of each 4x4 (or 3x4) SE3 pose matrices in a batch.
If `rotation_matrices` and `translation_vectors` are provided, they must correspond to the rotation and translation
components of `pose_matrices`. Otherwise, they will be extracted from `pose_matrices`.
Args:
pose_matrices: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
rotation_matrices (optional): Nx3x3 array or tensor of rotation matrices.
translation_vectors (optional): Nx3x1 array or tensor of translation vectors.
Returns:
Inverted SE3 matrices with the same type and device as input `pose_matrices`.
Shapes:
pose_matrices: (N, 4, 4)
rotation_matrices: (N, 3, 3)
translation_vectors: (N, 3, 1)
"""
# Check if pose_matrices is a numpy array or a torch tensor
is_numpy = isinstance(pose_matrices, np.ndarray)
# Validate shapes
if pose_matrices.shape[-2:] != (4, 4) and pose_matrices.shape[-2:] != (3, 4):
raise ValueError(
f"pose_matrices must be of shape (N,4,4), got {pose_matrices.shape}."
)
# Extract rotation_matrices and translation_vectors if not provided
if rotation_matrices is None:
rotation_matrices = pose_matrices[:, :3, :3]
if translation_vectors is None:
translation_vectors = pose_matrices[:, :3, 3:]
# Compute the inverse of input SE3 matrices
if is_numpy:
rotation_transposed = np.transpose(rotation_matrices, (0, 2, 1))
new_translation = -np.matmul(rotation_transposed, translation_vectors)
inverted_matrix = np.tile(np.eye(4), (len(rotation_matrices), 1, 1))
else:
rotation_transposed = rotation_matrices.transpose(1, 2)
new_translation = -torch.bmm(rotation_transposed, translation_vectors)
inverted_matrix = torch.eye(4, 4)[None].repeat(len(rotation_matrices), 1, 1)
inverted_matrix = inverted_matrix.to(rotation_matrices.dtype).to(
rotation_matrices.device
)
inverted_matrix[:, :3, :3] = rotation_transposed
inverted_matrix[:, :3, 3:] = new_translation
return inverted_matrix
def relative_pose_transformation(trans_01, trans_02):
r"""
Function that computes the relative homogenous transformation from a
reference transformation :math:`T_1^{0} = \begin{bmatrix} R_1 & t_1 \\
\mathbf{0} & 1 \end{bmatrix}` to destination :math:`T_2^{0} =
\begin{bmatrix} R_2 & t_2 \\ \mathbf{0} & 1 \end{bmatrix}`.
The relative transformation is computed as follows:
.. math::
T_1^{2} = (T_0^{1})^{-1} \cdot T_0^{2}
Arguments:
trans_01 (torch.Tensor): reference transformation tensor of shape
:math:`(N, 4, 4)` or :math:`(4, 4)`.
trans_02 (torch.Tensor): destination transformation tensor of shape
:math:`(N, 4, 4)` or :math:`(4, 4)`.
Shape:
- Output: :math:`(N, 4, 4)` or :math:`(4, 4)`.
Returns:
torch.Tensor: the relative transformation between the transformations.
Example::
>>> trans_01 = torch.eye(4) # 4x4
>>> trans_02 = torch.eye(4) # 4x4
>>> trans_12 = relative_transformation(trans_01, trans_02) # 4x4
"""
if not torch.is_tensor(trans_01):
raise TypeError(
"Input trans_01 type is not a torch.Tensor. Got {}".format(type(trans_01))
)
if not torch.is_tensor(trans_02):
raise TypeError(
"Input trans_02 type is not a torch.Tensor. Got {}".format(type(trans_02))
)
if trans_01.dim() not in (2, 3) and trans_01.shape[-2:] == (4, 4):
raise ValueError(
"Input must be a of the shape Nx4x4 or 4x4. Got {}".format(trans_01.shape)
)
if trans_02.dim() not in (2, 3) and trans_02.shape[-2:] == (4, 4):
raise ValueError(
"Input must be a of the shape Nx4x4 or 4x4. Got {}".format(trans_02.shape)
)
if not trans_01.dim() == trans_02.dim():
raise ValueError(
"Input number of dims must match. Got {} and {}".format(
trans_01.dim(), trans_02.dim()
)
)
# Convert to Nx4x4 if inputs are 4x4
squeeze_batch_dim = False
if trans_01.dim() == 2:
trans_01 = trans_01.unsqueeze(0)
trans_02 = trans_02.unsqueeze(0)
squeeze_batch_dim = True
# Compute inverse of trans_01 using closed form
trans_10 = closed_form_pose_inverse(trans_01)
# Compose transformations using matrix multiplication
trans_12 = torch.matmul(trans_10, trans_02)
# Remove batch dimension if it was added
if squeeze_batch_dim:
trans_12 = trans_12.squeeze(0)
return trans_12
def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
"""
Args:
- depthmap (BxHxW array):
- pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
Returns:
pointmap of absolute coordinates (BxHxWx3 array)
"""
if len(depth.shape) == 4:
B, H, W, n = depth.shape
else:
B, H, W = depth.shape
n = None
if len(pseudo_focal.shape) == 3: # [B,H,W]
pseudo_focalx = pseudo_focaly = pseudo_focal
elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W]
pseudo_focalx = pseudo_focal[:, 0]
if pseudo_focal.shape[1] == 2:
pseudo_focaly = pseudo_focal[:, 1]
else:
pseudo_focaly = pseudo_focalx
else:
raise NotImplementedError("Error, unknown input focal shape format.")
assert pseudo_focalx.shape == depth.shape[:3]
assert pseudo_focaly.shape == depth.shape[:3]
grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]
# set principal point
if pp is None:
grid_x = grid_x - (W - 1) / 2
grid_y = grid_y - (H - 1) / 2
else:
grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]
if n is None:
pts3d = torch.empty((B, H, W, 3), device=depth.device)
pts3d[..., 0] = depth * grid_x / pseudo_focalx
pts3d[..., 1] = depth * grid_y / pseudo_focaly
pts3d[..., 2] = depth
else:
pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
pts3d[..., 2, :] = depth
return pts3d
def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
"""
Args:
- depthmap (HxW array):
- camera_intrinsics: a 3x3 matrix
Returns:
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
"""
camera_intrinsics = np.float32(camera_intrinsics)
H, W = depthmap.shape
# Compute 3D ray associated with each pixel
# Strong assumption: there are no skew terms
assert camera_intrinsics[0, 1] == 0.0
assert camera_intrinsics[1, 0] == 0.0
if pseudo_focal is None:
fu = camera_intrinsics[0, 0]
fv = camera_intrinsics[1, 1]
else:
assert pseudo_focal.shape == (H, W)
fu = fv = pseudo_focal
cu = camera_intrinsics[0, 2]
cv = camera_intrinsics[1, 2]
u, v = np.meshgrid(np.arange(W), np.arange(H))
z_cam = depthmap
x_cam = (u - cu) * z_cam / fu
y_cam = (v - cv) * z_cam / fv
X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
# Mask for valid coordinates
valid_mask = depthmap > 0.0
return X_cam, valid_mask
def depthmap_to_absolute_camera_coordinates(
depthmap, camera_intrinsics, camera_pose, **kw
):
"""
Args:
- depthmap (HxW array):
- camera_intrinsics: a 3x3 matrix
- camera_pose: a 4x3 or 4x4 cam2world matrix
Returns:
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
"""
X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
X_world = X_cam # default
if camera_pose is not None:
# R_cam2world = np.float32(camera_params["R_cam2world"])
# t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
R_cam2world = camera_pose[:3, :3]
t_cam2world = camera_pose[:3, 3]
# Express in absolute coordinates (invalid depth values)
X_world = (
np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
)
return X_world, valid_mask
def get_absolute_pointmaps_and_rays_info(
depthmap, camera_intrinsics, camera_pose, **kw
):
"""
Args:
- depthmap (HxW array):
- camera_intrinsics: a 3x3 matrix
- camera_pose: a 4x3 or 4x4 cam2world matrix
Returns:
pointmap of absolute coordinates (HxWx3 array),
a mask specifying valid pixels,
ray origins of absolute coordinates (HxWx3 array),
ray directions of absolute coordinates (HxWx3 array),
depth along ray (HxWx1 array),
ray directions of camera/local coordinates (HxWx3 array),
pointmap of camera/local coordinates (HxWx3 array).
"""
camera_intrinsics = np.float32(camera_intrinsics)
H, W = depthmap.shape
# Compute 3D ray associated with each pixel
# Strong assumption: pinhole & there are no skew terms
assert camera_intrinsics[0, 1] == 0.0
assert camera_intrinsics[1, 0] == 0.0
fu = camera_intrinsics[0, 0]
fv = camera_intrinsics[1, 1]
cu = camera_intrinsics[0, 2]
cv = camera_intrinsics[1, 2]
# Get the rays on the unit plane
u, v = np.meshgrid(np.arange(W), np.arange(H))
x_cam = (u - cu) / fu
y_cam = (v - cv) / fv
z_cam = np.ones_like(x_cam)
ray_dirs_cam_on_unit_plane = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(
np.float32
)
# Compute the 3d points in the local camera coordinate system
pts_cam = depthmap[..., None] * ray_dirs_cam_on_unit_plane
# Get the depth along the ray and compute the ray directions on the unit sphere
depth_along_ray = np.linalg.norm(pts_cam, axis=-1, keepdims=True)
ray_directions_cam = ray_dirs_cam_on_unit_plane / np.linalg.norm(
ray_dirs_cam_on_unit_plane, axis=-1, keepdims=True
)
# Mask for valid coordinates
valid_mask = depthmap > 0.0
# Get the ray origins in absolute coordinates and the ray directions in absolute coordinates
ray_origins_world = np.zeros_like(ray_directions_cam)
ray_directions_world = ray_directions_cam
pts_world = pts_cam
if camera_pose is not None:
R_cam2world = camera_pose[:3, :3]
t_cam2world = camera_pose[:3, 3]
# Express in absolute coordinates
ray_origins_world = ray_origins_world + t_cam2world[None, None, :]
ray_directions_world = np.einsum(
"ik, vuk -> vui", R_cam2world, ray_directions_cam
)
pts_world = ray_origins_world + ray_directions_world * depth_along_ray
return (
pts_world,
valid_mask,
ray_origins_world,
ray_directions_world,
depth_along_ray,
ray_directions_cam,
pts_cam,
)
def adjust_camera_params_for_rotation(camera_params, original_size, k):
"""
Adjust camera parameters for rotation.
Args:
camera_params: Camera parameters [fx, fy, cx, cy, ...]
original_size: Original image size as (width, height)
k: Number of 90-degree rotations counter-clockwise (k=3 means 90 degrees clockwise)
Returns:
Adjusted camera parameters
"""
fx, fy, cx, cy = camera_params[:4]
width, height = original_size
if k % 4 == 1: # 90 degrees counter-clockwise
new_fx, new_fy = fy, fx
new_cx, new_cy = height - cy, cx
elif k % 4 == 2: # 180 degrees
new_fx, new_fy = fx, fy
new_cx, new_cy = width - cx, height - cy
elif k % 4 == 3: # 90 degrees clockwise (270 counter-clockwise)
new_fx, new_fy = fy, fx
new_cx, new_cy = cy, width - cx
else: # No rotation
return camera_params
adjusted_params = [new_fx, new_fy, new_cx, new_cy]
if len(camera_params) > 4:
adjusted_params.extend(camera_params[4:])
return adjusted_params
def adjust_pose_for_rotation(pose, k):
"""
Adjust camera pose for rotation.
Args:
pose: 4x4 camera pose matrix (camera-to-world, OpenCV convention - X right, Y down, Z forward)
k: Number of 90-degree rotations counter-clockwise (k=3 means 90 degrees clockwise)
Returns:
Adjusted 4x4 camera pose matrix
"""
# Create rotation matrices for different rotations
if k % 4 == 1: # 90 degrees counter-clockwise
rot_transform = np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]])
elif k % 4 == 2: # 180 degrees
rot_transform = np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]])
elif k % 4 == 3: # 90 degrees clockwise (270 counter-clockwise)
rot_transform = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
else: # No rotation
return pose
# Apply the transformation to the pose
adjusted_pose = pose
adjusted_pose[:3, :3] = adjusted_pose[:3, :3] @ rot_transform.T
return adjusted_pose
def crop_to_aspect_ratio(image, depth, camera_params, target_ratio=1.5):
"""
Crop image and depth to the largest possible target aspect ratio while
keeping the left side if aspect ratio is wider and the bottom of image if the aspect ratio is taller.
Args:
image: PIL image
depth: Depth map as numpy array
camera_params: Camera parameters [fx, fy, cx, cy, ...]
target_ratio: Target width/height ratio
Returns:
Cropped image, cropped depth, adjusted camera parameters
"""
width, height = image.size
fx, fy, cx, cy = camera_params[:4]
current_ratio = width / height
if abs(current_ratio - target_ratio) < 1e-6:
# Already at target ratio
return image, depth, camera_params
if current_ratio > target_ratio:
# Image is wider than target ratio, crop width
new_width = int(height * target_ratio)
left = 0
right = new_width
# Crop image
cropped_image = image.crop((left, 0, right, height))
# Crop depth
if len(depth.shape) == 3:
cropped_depth = depth[:, left:right, :]
else:
cropped_depth = depth[:, left:right]
# Adjust camera parameters
new_cx = cx - left
adjusted_params = [fx, fy, new_cx, cy] + list(camera_params[4:])
else:
# Image is taller than target ratio, crop height
new_height = int(width / target_ratio)
top = max(0, height - new_height)
bottom = height
# Crop image
cropped_image = image.crop((0, top, width, bottom))
# Crop depth
if len(depth.shape) == 3:
cropped_depth = depth[top:bottom, :, :]
else:
cropped_depth = depth[top:bottom, :]
# Adjust camera parameters
new_cy = cy - top
adjusted_params = [fx, fy, cx, new_cy] + list(camera_params[4:])
return cropped_image, cropped_depth, adjusted_params
def colmap_to_opencv_intrinsics(K):
"""
Modify camera intrinsics to follow a different convention.
Coordinates of the center of the top-left pixels are by default:
- (0.5, 0.5) in Colmap
- (0,0) in OpenCV
"""
K = K.copy()
K[0, 2] -= 0.5
K[1, 2] -= 0.5
return K
def opencv_to_colmap_intrinsics(K):
"""
Modify camera intrinsics to follow a different convention.
Coordinates of the center of the top-left pixels are by default:
- (0.5, 0.5) in Colmap
- (0,0) in OpenCV
"""
K = K.copy()
K[0, 2] += 0.5
K[1, 2] += 0.5
return K
def normalize_depth_using_non_zero_pixels(depth, return_norm_factor=False):
"""
Normalize the depth by the average depth of non-zero depth pixels.
Args:
depth (torch.Tensor): Depth tensor of size [B, H, W, 1].
Returns:
normalized_depth (torch.Tensor): Normalized depth tensor.
norm_factor (torch.Tensor): Norm factor tensor of size B.
"""
assert depth.ndim == 4 and depth.shape[3] == 1
# Calculate the sum and count of non-zero depth pixels for each batch
valid_depth_mask = depth > 0
valid_sum = torch.sum(depth * valid_depth_mask, dim=(1, 2, 3))
valid_count = torch.sum(valid_depth_mask, dim=(1, 2, 3))
# Calculate the norm factor
norm_factor = valid_sum / (valid_count + 1e-8)
while norm_factor.ndim < depth.ndim:
norm_factor.unsqueeze_(-1)
# Normalize the depth by the norm factor
norm_factor = norm_factor.clip(min=1e-8)
normalized_depth = depth / norm_factor
# Create the output tuple
output = (
(normalized_depth, norm_factor.squeeze(-1).squeeze(-1).squeeze(-1))
if return_norm_factor
else normalized_depth
)
return output
def normalize_pose_translations(pose_translations, return_norm_factor=False):
"""
Normalize the pose translations by the average norm of the non-zero pose translations.
Args:
pose_translations (torch.Tensor): Pose translations tensor of size [B, V, 3]. B is the batch size, V is the number of views.
Returns:
normalized_pose_translations (torch.Tensor): Normalized pose translations tensor of size [B, V, 3].
norm_factor (torch.Tensor): Norm factor tensor of size B.
"""
assert pose_translations.ndim == 3 and pose_translations.shape[2] == 3
# Compute distance of all pose translations to origin
pose_translations_dis = pose_translations.norm(dim=-1) # [B, V]
non_zero_pose_translations_dis = pose_translations_dis > 0 # [B, V]
# Calculate the average norm of the translations across all views (considering only views with non-zero translations)
sum_of_all_views_pose_translations = pose_translations_dis.sum(dim=1) # [B]
count_of_all_views_with_non_zero_pose_translations = (
non_zero_pose_translations_dis.sum(dim=1)
) # [B]
norm_factor = sum_of_all_views_pose_translations / (
count_of_all_views_with_non_zero_pose_translations + 1e-8
) # [B]
# Normalize the pose translations by the norm factor
norm_factor = norm_factor.clip(min=1e-8)
normalized_pose_translations = pose_translations / norm_factor.unsqueeze(
-1
).unsqueeze(-1)
# Create the output tuple
output = (
(normalized_pose_translations, norm_factor)
if return_norm_factor
else normalized_pose_translations
)
return output
def normalize_multiple_pointclouds(
pts_list, valid_masks=None, norm_mode="avg_dis", ret_factor=False
):
"""
Normalize multiple point clouds using a joint normalization strategy.
Args:
pts_list: List of point clouds, each with shape (..., H, W, 3) or (B, H, W, 3)
valid_masks: Optional list of masks indicating valid points in each point cloud
norm_mode: String in format "{norm}_{dis}" where:
- norm: Normalization strategy (currently only "avg" is supported)
- dis: Distance transformation ("dis" for raw distance, "log1p" for log(1+distance),
"warp-log1p" to warp points using log distance)
ret_factor: If True, return the normalization factor as the last element in the result list
Returns:
List of normalized point clouds with the same shapes as inputs.
If ret_factor is True, the last element is the normalization factor.
"""
assert all(pts.ndim >= 3 and pts.shape[-1] == 3 for pts in pts_list)
if valid_masks is not None:
assert len(pts_list) == len(valid_masks)
norm_mode, dis_mode = norm_mode.split("_")
# Gather all points together (joint normalization)
nan_pts_list = [
invalid_to_zeros(pts, valid_masks[i], ndim=3)
if valid_masks
else invalid_to_zeros(pts, None, ndim=3)
for i, pts in enumerate(pts_list)
]
all_pts = torch.cat([nan_pts for nan_pts, _ in nan_pts_list], dim=1)
nnz_list = [nnz for _, nnz in nan_pts_list]
# Compute distance to origin
all_dis = all_pts.norm(dim=-1)
if dis_mode == "dis":
pass # do nothing
elif dis_mode == "log1p":
all_dis = torch.log1p(all_dis)
elif dis_mode == "warp-log1p":
# Warp input points before normalizing them
log_dis = torch.log1p(all_dis)
warp_factor = log_dis / all_dis.clip(min=1e-8)
for i, pts in enumerate(pts_list):
H, W = pts.shape[1:-1]
pts_list[i] = pts * warp_factor[:, i * (H * W) : (i + 1) * (H * W)].view(
-1, H, W, 1
)
all_dis = log_dis
else:
raise ValueError(f"bad {dis_mode=}")
# Compute normalization factor
norm_factor = all_dis.sum(dim=1) / (sum(nnz_list) + 1e-8)
norm_factor = norm_factor.clip(min=1e-8)
while norm_factor.ndim < pts_list[0].ndim:
norm_factor.unsqueeze_(-1)
# Normalize points
res = [pts / norm_factor for pts in pts_list]
if ret_factor:
res.append(norm_factor)
return res
def apply_log_to_norm(input_data):
"""
Normalize the input data and apply a logarithmic transformation based on the normalization factor.
Args:
input_data (torch.Tensor): The input tensor to be normalized and transformed.
Returns:
torch.Tensor: The transformed tensor after normalization and logarithmic scaling.
"""
org_d = input_data.norm(dim=-1, keepdim=True)
input_data = input_data / org_d.clip(min=1e-8)
input_data = input_data * torch.log1p(org_d)
return input_data
def angle_diff_vec3(v1, v2, eps=1e-12):
"""
Compute angle difference between 3D vectors.
Args:
v1: torch.Tensor of shape (..., 3)
v2: torch.Tensor of shape (..., 3)
eps: Small epsilon value for numerical stability
Returns:
torch.Tensor: Angle differences in radians
"""
cross_norm = torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps
dot_prod = (v1 * v2).sum(dim=-1)
return torch.atan2(cross_norm, dot_prod)
def angle_diff_vec3_numpy(v1: np.ndarray, v2: np.ndarray, eps: float = 1e-12):
"""
Compute angle difference between 3D vectors using NumPy.
Args:
v1 (np.ndarray): First vector of shape (..., 3)
v2 (np.ndarray): Second vector of shape (..., 3)
eps (float, optional): Small epsilon value for numerical stability. Defaults to 1e-12.
Returns:
np.ndarray: Angle differences in radians
"""
return np.arctan2(
np.linalg.norm(np.cross(v1, v2, axis=-1), axis=-1) + eps, (v1 * v2).sum(axis=-1)
)
@no_warnings(category=RuntimeWarning)
def points_to_normals(
point: np.ndarray, mask: np.ndarray = None, edge_threshold: float = None
) -> np.ndarray:
"""
Calculate normal map from point map. Value range is [-1, 1].
Args:
point (np.ndarray): shape (height, width, 3), point map
mask (optional, np.ndarray): shape (height, width), dtype=bool. Mask of valid depth pixels. Defaults to None.
edge_threshold (optional, float): threshold for the angle (in degrees) between the normal and the view direction. Defaults to None.
Returns:
normal (np.ndarray): shape (height, width, 3), normal map.
"""
height, width = point.shape[-3:-1]
has_mask = mask is not None
if mask is None:
mask = np.ones_like(point[..., 0], dtype=bool)
mask_pad = np.zeros((height + 2, width + 2), dtype=bool)
mask_pad[1:-1, 1:-1] = mask
mask = mask_pad
pts = np.zeros((height + 2, width + 2, 3), dtype=point.dtype)
pts[1:-1, 1:-1, :] = point
up = pts[:-2, 1:-1, :] - pts[1:-1, 1:-1, :]
left = pts[1:-1, :-2, :] - pts[1:-1, 1:-1, :]
down = pts[2:, 1:-1, :] - pts[1:-1, 1:-1, :]
right = pts[1:-1, 2:, :] - pts[1:-1, 1:-1, :]
normal = np.stack(
[
np.cross(up, left, axis=-1),
np.cross(left, down, axis=-1),
np.cross(down, right, axis=-1),
np.cross(right, up, axis=-1),
]
)
normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12)
valid = (
np.stack(
[
mask[:-2, 1:-1] & mask[1:-1, :-2],
mask[1:-1, :-2] & mask[2:, 1:-1],
mask[2:, 1:-1] & mask[1:-1, 2:],
mask[1:-1, 2:] & mask[:-2, 1:-1],
]
)
& mask[None, 1:-1, 1:-1]
)
if edge_threshold is not None:
view_angle = angle_diff_vec3_numpy(pts[None, 1:-1, 1:-1, :], normal)
view_angle = np.minimum(view_angle, np.pi - view_angle)
valid = valid & (view_angle < np.deg2rad(edge_threshold))
normal = (normal * valid[..., None]).sum(axis=0)
normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12)
if has_mask:
normal_mask = valid.any(axis=0)
normal = np.where(normal_mask[..., None], normal, 0)
return normal, normal_mask
else:
return normal
def sliding_window_1d(x: np.ndarray, window_size: int, stride: int, axis: int = -1):
"""
Create a sliding window view of the input array along a specified axis.
This function creates a memory-efficient view of the input array with sliding windows
of the specified size and stride. The window dimension is appended to the end of the
output array's shape. This is useful for operations like convolution, pooling, or
any analysis that requires examining local neighborhoods in the data.
Args:
x (np.ndarray): Input array with shape (..., axis_size, ...)
window_size (int): Size of the sliding window
stride (int): Stride of the sliding window (step size between consecutive windows)
axis (int, optional): Axis to perform sliding window over. Defaults to -1 (last axis)
Returns:
np.ndarray: View of the input array with shape (..., n_windows, ..., window_size),
where n_windows = (axis_size - window_size + 1) // stride
Raises:
AssertionError: If window_size is larger than the size of the specified axis
Example:
>>> x = np.array([1, 2, 3, 4, 5, 6])
>>> sliding_window_1d(x, window_size=3, stride=2)
array([[1, 2, 3],
[3, 4, 5]])
"""
assert x.shape[axis] >= window_size, (
f"kernel_size ({window_size}) is larger than axis_size ({x.shape[axis]})"
)
axis = axis % x.ndim
shape = (
*x.shape[:axis],
(x.shape[axis] - window_size + 1) // stride,
*x.shape[axis + 1 :],
window_size,
)
strides = (
*x.strides[:axis],
stride * x.strides[axis],
*x.strides[axis + 1 :],
x.strides[axis],
)
x_sliding = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
return x_sliding
def sliding_window_nd(
x: np.ndarray,
window_size: Tuple[int, ...],
stride: Tuple[int, ...],
axis: Tuple[int, ...],
) -> np.ndarray:
"""
Create sliding windows along multiple dimensions of the input array.
This function applies sliding_window_1d sequentially along multiple axes to create
N-dimensional sliding windows. This is useful for operations that need to examine
local neighborhoods in multiple dimensions simultaneously.
Args:
x (np.ndarray): Input array
window_size (Tuple[int, ...]): Size of the sliding window for each axis
stride (Tuple[int, ...]): Stride of the sliding window for each axis
axis (Tuple[int, ...]): Axes to perform sliding window over
Returns:
np.ndarray: Array with sliding windows along the specified dimensions.
The window dimensions are appended to the end of the shape.
Note:
The length of window_size, stride, and axis tuples must be equal.
Example:
>>> x = np.random.rand(10, 10)
>>> windows = sliding_window_nd(x, window_size=(3, 3), stride=(2, 2), axis=(-2, -1))
>>> # Creates 3x3 sliding windows with stride 2 in both dimensions
"""
axis = [axis[i] % x.ndim for i in range(len(axis))]
for i in range(len(axis)):
x = sliding_window_1d(x, window_size[i], stride[i], axis[i])
return x
def sliding_window_2d(
x: np.ndarray,
window_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
axis: Tuple[int, int] = (-2, -1),
) -> np.ndarray:
"""
Create 2D sliding windows over the input array.
Convenience function for creating 2D sliding windows, commonly used for image
processing operations like convolution, pooling, or patch extraction.
Args:
x (np.ndarray): Input array
window_size (Union[int, Tuple[int, int]]): Size of the 2D sliding window.
If int, same size is used for both dimensions.
stride (Union[int, Tuple[int, int]]): Stride of the 2D sliding window.
If int, same stride is used for both dimensions.
axis (Tuple[int, int], optional): Two axes to perform sliding window over.
Defaults to (-2, -1) (last two dimensions).
Returns:
np.ndarray: Array with 2D sliding windows. The window dimensions (height, width)
are appended to the end of the shape.
Example:
>>> image = np.random.rand(100, 100)
>>> patches = sliding_window_2d(image, window_size=8, stride=4)
>>> # Creates 8x8 patches with stride 4 from the image
"""
if isinstance(window_size, int):
window_size = (window_size, window_size)
if isinstance(stride, int):
stride = (stride, stride)
return sliding_window_nd(x, window_size, stride, axis)
def max_pool_1d(
x: np.ndarray, kernel_size: int, stride: int, padding: int = 0, axis: int = -1
):
"""
Perform 1D max pooling on the input array.
Max pooling reduces the dimensionality of the input by taking the maximum value
within each sliding window. This is commonly used in neural networks and signal
processing for downsampling and feature extraction.
Args:
x (np.ndarray): Input array
kernel_size (int): Size of the pooling kernel
stride (int): Stride of the pooling operation
padding (int, optional): Amount of padding to add on both sides. Defaults to 0.
axis (int, optional): Axis to perform max pooling over. Defaults to -1.
Returns:
np.ndarray: Max pooled array with reduced size along the specified axis
Note:
- For floating point arrays, padding is done with np.nan values
- For integer arrays, padding is done with the minimum value of the dtype
- np.nanmax is used to handle NaN values in the computation
Example:
>>> x = np.array([1, 3, 2, 4, 5, 1, 2])
>>> max_pool_1d(x, kernel_size=3, stride=2)
array([3, 5, 2])
"""
axis = axis % x.ndim
if padding > 0:
fill_value = np.nan if x.dtype.kind == "f" else np.iinfo(x.dtype).min
padding_arr = np.full(
(*x.shape[:axis], padding, *x.shape[axis + 1 :]),
fill_value=fill_value,
dtype=x.dtype,
)
x = np.concatenate([padding_arr, x, padding_arr], axis=axis)
a_sliding = sliding_window_1d(x, kernel_size, stride, axis)
max_pool = np.nanmax(a_sliding, axis=-1)
return max_pool
def max_pool_nd(
x: np.ndarray,
kernel_size: Tuple[int, ...],
stride: Tuple[int, ...],
padding: Tuple[int, ...],
axis: Tuple[int, ...],
) -> np.ndarray:
"""
Perform N-dimensional max pooling on the input array.
This function applies max_pool_1d sequentially along multiple axes to perform
multi-dimensional max pooling. This is useful for downsampling multi-dimensional
data while preserving the most important features.
Args:
x (np.ndarray): Input array
kernel_size (Tuple[int, ...]): Size of the pooling kernel for each axis
stride (Tuple[int, ...]): Stride of the pooling operation for each axis
padding (Tuple[int, ...]): Amount of padding for each axis
axis (Tuple[int, ...]): Axes to perform max pooling over
Returns:
np.ndarray: Max pooled array with reduced size along the specified axes
Note:
The length of kernel_size, stride, padding, and axis tuples must be equal.
Max pooling is applied sequentially along each axis in the order specified.
Example:
>>> x = np.random.rand(10, 10, 10)
>>> pooled = max_pool_nd(x, kernel_size=(2, 2, 2), stride=(2, 2, 2),
... padding=(0, 0, 0), axis=(-3, -2, -1))
>>> # Reduces each dimension by half with 2x2x2 max pooling
"""
for i in range(len(axis)):
x = max_pool_1d(x, kernel_size[i], stride[i], padding[i], axis[i])
return x
def max_pool_2d(
x: np.ndarray,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
padding: Union[int, Tuple[int, int]],
axis: Tuple[int, int] = (-2, -1),
):
"""
Perform 2D max pooling on the input array.
Convenience function for 2D max pooling, commonly used in computer vision
and image processing for downsampling images while preserving important features.
Args:
x (np.ndarray): Input array
kernel_size (Union[int, Tuple[int, int]]): Size of the 2D pooling kernel.
If int, same size is used for both dimensions.
stride (Union[int, Tuple[int, int]]): Stride of the 2D pooling operation.
If int, same stride is used for both dimensions.
padding (Union[int, Tuple[int, int]]): Amount of padding for both dimensions.
If int, same padding is used for both dimensions.
axis (Tuple[int, int], optional): Two axes to perform max pooling over.
Defaults to (-2, -1) (last two dimensions).
Returns:
np.ndarray: 2D max pooled array with reduced size along the specified axes
Example:
>>> image = np.random.rand(64, 64)
>>> pooled = max_pool_2d(image, kernel_size=2, stride=2, padding=0)
>>> # Reduces image size from 64x64 to 32x32 with 2x2 max pooling
"""
if isinstance(kernel_size, Number):
kernel_size = (kernel_size, kernel_size)
if isinstance(stride, Number):
stride = (stride, stride)
if isinstance(padding, Number):
padding = (padding, padding)
axis = tuple(axis)
return max_pool_nd(x, kernel_size, stride, padding, axis)
@no_warnings(category=RuntimeWarning)
def depth_edge(
depth: np.ndarray,
atol: float = None,
rtol: float = None,
kernel_size: int = 3,
mask: np.ndarray = None,
) -> np.ndarray:
"""
Compute the edge mask from depth map. The edge is defined as the pixels whose neighbors have large difference in depth.
Args:
depth (np.ndarray): shape (..., height, width), linear depth map
atol (float): absolute tolerance
rtol (float): relative tolerance
Returns:
edge (np.ndarray): shape (..., height, width) of dtype torch.bool
"""
if mask is None:
diff = max_pool_2d(
depth, kernel_size, stride=1, padding=kernel_size // 2
) + max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)
else:
diff = max_pool_2d(
np.where(mask, depth, -np.inf),
kernel_size,
stride=1,
padding=kernel_size // 2,
) + max_pool_2d(
np.where(mask, -depth, -np.inf),
kernel_size,
stride=1,
padding=kernel_size // 2,
)
edge = np.zeros_like(depth, dtype=bool)
if atol is not None:
edge |= diff > atol
if rtol is not None:
edge |= diff / depth > rtol
return edge
def depth_aliasing(
depth: np.ndarray,
atol: float = None,
rtol: float = None,
kernel_size: int = 3,
mask: np.ndarray = None,
) -> np.ndarray:
"""
Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors.
Args:
depth (np.ndarray): shape (..., height, width), linear depth map
atol (float): absolute tolerance
rtol (float): relative tolerance
Returns:
edge (np.ndarray): shape (..., height, width) of dtype torch.bool
"""
if mask is None:
diff_max = (
max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth
)
diff_min = (
max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth
)
else:
diff_max = (
max_pool_2d(
np.where(mask, depth, -np.inf),
kernel_size,
stride=1,
padding=kernel_size // 2,
)
- depth
)
diff_min = (
max_pool_2d(
np.where(mask, -depth, -np.inf),
kernel_size,
stride=1,
padding=kernel_size // 2,
)
+ depth
)
diff = np.minimum(diff_max, diff_min)
edge = np.zeros_like(depth, dtype=bool)
if atol is not None:
edge |= diff > atol
if rtol is not None:
edge |= diff / depth > rtol
return edge
@no_warnings(category=RuntimeWarning)
def normals_edge(
normals: np.ndarray, tol: float, kernel_size: int = 3, mask: np.ndarray = None
) -> np.ndarray:
"""
Compute the edge mask from normal map.
Args:
normal (np.ndarray): shape (..., height, width, 3), normal map
tol (float): tolerance in degrees
Returns:
edge (np.ndarray): shape (..., height, width) of dtype torch.bool
"""
assert normals.ndim >= 3 and normals.shape[-1] == 3, (
"normal should be of shape (..., height, width, 3)"
)
normals = normals / (np.linalg.norm(normals, axis=-1, keepdims=True) + 1e-12)
padding = kernel_size // 2
normals_window = sliding_window_2d(
np.pad(
normals,
(
*([(0, 0)] * (normals.ndim - 3)),
(padding, padding),
(padding, padding),
(0, 0),
),
mode="edge",
),
window_size=kernel_size,
stride=1,
axis=(-3, -2),
)
if mask is None:
angle_diff = np.arccos(
(normals[..., None, None] * normals_window).sum(axis=-3)
).max(axis=(-2, -1))
else:
mask_window = sliding_window_2d(
np.pad(
mask,
(*([(0, 0)] * (mask.ndim - 3)), (padding, padding), (padding, padding)),
mode="edge",
),
window_size=kernel_size,
stride=1,
axis=(-3, -2),
)
angle_diff = np.where(
mask_window,
np.arccos((normals[..., None, None] * normals_window).sum(axis=-3)),
0,
).max(axis=(-2, -1))
angle_diff = max_pool_2d(
angle_diff, kernel_size, stride=1, padding=kernel_size // 2
)
edge = angle_diff > np.deg2rad(tol)
return edge
|