map-anything / mapanything /utils /wai /intersection_check.py
aknapitsch user
simpler inference and refactoring
37de32d
import torch
from einops import rearrange, repeat
from tqdm import tqdm
def create_frustum_from_intrinsics(
intrinsics: torch.Tensor,
near: torch.Tensor | float,
far: torch.Tensor | float,
) -> torch.Tensor:
r"""
Create a frustum from camera intrinsics.
Args:
intrinsics (torch.Tensor): Bx3x3 Intrinsics of cameras.
near (torch.Tensor or float): [B] Near plane distance.
far (torch.Tensor or float): [B] Far plane distance.
Returns:
frustum (torch.Tensor): Bx8x3 batch of frustum points following the order:
5 ---------- 4
|\ /|
6 \ / 7
\ 1 ---- 0 /
\| |/
2 ---- 3
"""
fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1]
cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2]
# Calculate the offsets at the near plane
near_x = near * (cx / fx)
near_y = near * (cy / fy)
far_x = far * (cx / fx)
far_y = far * (cy / fy)
# Define frustum vertices in camera space
near_plane = torch.stack(
[
torch.stack([near_x, near_y, near * torch.ones_like(near_x)], dim=-1),
torch.stack([-near_x, near_y, near * torch.ones_like(near_x)], dim=-1),
torch.stack([-near_x, -near_y, near * torch.ones_like(near_x)], dim=-1),
torch.stack([near_x, -near_y, near * torch.ones_like(near_x)], dim=-1),
],
dim=1,
)
far_plane = torch.stack(
[
torch.stack([far_x, far_y, far * torch.ones_like(far_x)], dim=-1),
torch.stack([-far_x, far_y, far * torch.ones_like(far_x)], dim=-1),
torch.stack([-far_x, -far_y, far * torch.ones_like(far_x)], dim=-1),
torch.stack([far_x, -far_y, far * torch.ones_like(far_x)], dim=-1),
],
dim=1,
)
return torch.cat([near_plane, far_plane], dim=1)
def _frustum_to_triangles(frustum: torch.Tensor) -> torch.Tensor:
"""
Convert frustum to triangles.
Args:
frustums (torch.Tensor): Bx8 batch of frustum points.
Returns:
frustum_triangles (torch.Tensor): Bx3x3 batch of frustum triangles.
"""
triangle_inds = torch.tensor(
[
[0, 1, 2],
[0, 2, 3],
[0, 3, 7],
[0, 7, 4],
[1, 2, 6],
[1, 6, 5],
[1, 4, 5],
[1, 0, 4],
[2, 6, 7],
[2, 3, 7],
[6, 7, 4],
[6, 5, 4],
]
)
frustum_triangles = frustum[:, triangle_inds]
return frustum_triangles
def segment_triangle_intersection_check(
start_points: torch.Tensor,
end_points: torch.Tensor,
triangles: torch.Tensor,
) -> torch.Tensor:
"""
Check if segments (lines with starting and end point) intersect triangles in 3D using the
Moller-Trumbore algorithm.
Args:
start_points (torch.Tensor): Bx3 Starting points of the segment.
end_points (torch.Tensor): Bx3 End points of the segment.
triangles (torch.Tensor): Bx3x3 Vertices of the triangles.
Returns:
intersects (torch.Tensor): B Boolean tensor indicating if each ray intersects its
corresponding triangle.
"""
vertex0 = triangles[:, 0, :]
vertex1 = triangles[:, 1, :]
vertex2 = triangles[:, 2, :]
edge1 = vertex1 - vertex0
edge2 = vertex2 - vertex0
ray_vectors = end_points - start_points
max_lengths = torch.norm(ray_vectors, dim=1)
ray_vectors = ray_vectors / max_lengths[:, None]
h = torch.cross(ray_vectors, edge2, dim=1)
a = (edge1 * h).sum(dim=1)
epsilon = 1e-6
mask = torch.abs(a) > epsilon
f = torch.zeros_like(a)
f[mask] = 1.0 / a[mask]
s = start_points - vertex0
u = f * (s * h).sum(dim=1)
q = torch.cross(s, edge1, dim=1)
v = f * (ray_vectors * q).sum(dim=1)
t = f * (edge2 * q).sum(dim=1)
# Check conditions
intersects = (
(u >= 0)
& (u <= 1)
& (v >= 0)
& (u + v <= 1)
& (t >= epsilon)
& (t <= max_lengths)
)
return intersects
def triangle_intersection_check(
triangles1: torch.Tensor,
triangles2: torch.Tensor,
) -> torch.Tensor:
"""
Check if two triangles intersect.
Args:
triangles1 (torch.Tensor): Bx3x3 Vertices of the first batch of triangles.
triangles2 (torch.Tensor): Bx3x3 Vertices of the first batch of triangles.
Returns:
triangle_intersection (torch.Tensor): B Boolean tensor indicating if triangles intersect.
"""
n = triangles1.shape[1]
start_points1 = rearrange(triangles1, "B N C -> (B N) C")
end_points1 = rearrange(
triangles1[:, torch.arange(1, n + 1) % n], "B N C -> (B N) C"
)
start_points2 = rearrange(triangles2, "B N C -> (B N) C")
end_points2 = rearrange(
triangles2[:, torch.arange(1, n + 1) % n], "B N C -> (B N) C"
)
intersection_1_2 = segment_triangle_intersection_check(
start_points1, end_points1, repeat(triangles2, "B N C -> (B N2) N C", N2=3)
)
intersection_2_1 = segment_triangle_intersection_check(
start_points2, end_points2, repeat(triangles1, "B N C -> (B N2) N C", N2=3)
)
triangle_intersection = torch.any(
rearrange(intersection_1_2, "(B N N2) -> B (N N2)", B=triangles1.shape[0], N=n),
dim=1,
) | torch.any(
rearrange(intersection_2_1, "(B N N2) -> B (N N2)", B=triangles1.shape[0], N=n),
dim=1,
)
return triangle_intersection
def frustum_intersection_check(
frustums: torch.Tensor,
check_inside: bool = True,
chunk_size: int = 500,
device: str | None = None,
) -> torch.Tensor:
"""
Check if any pair of the frustums intersect with each other.
Args:
frustums (torch.Tensor): Bx8 batch of frustum points.
check_inside (bool): If True, also checks if one frustum is inside another.
Defaults to True.
chunk_size (Optional[int]): Number of chunks to split the computation into.
Defaults to 500.
device (Optional[str]): Device to store exhuastive frustum intersection matrix on.
Defaults to None.
Returns:
frustum_intersection (torch.Tensor): BxB tensor of Booleans indicating if any pair
of frustums intersect with each other.
"""
B = frustums.shape[0]
if device is None:
device = frustums.device
frustum_triangles = _frustum_to_triangles(frustums)
T = frustum_triangles.shape[1]
# Perform frustum in frustum check if required
if check_inside:
frustum_intersection = frustums_in_frustum_check(
frustums=frustums, chunk_size=chunk_size, device=device
)
else:
frustum_intersection = torch.zeros((B, B), dtype=torch.bool, device=device)
# Check triangle intersections in chunks
for i in tqdm(range(0, B, chunk_size), desc="Checking triangle intersections"):
i_end = min(i + chunk_size, B)
chunk_i_size = i_end - i
for j in range(0, B, chunk_size):
j_end = min(j + chunk_size, B)
chunk_j_size = j_end - j
# Process all triangle pairs between the two chunks in a vectorized way
triangles_i = frustum_triangles[i:i_end] # [chunk_i, T, 3, 3]
triangles_j = frustum_triangles[j:j_end] # [chunk_j, T, 3, 3]
# Reshape to process all triangle pairs at once
tri_i = triangles_i.reshape(chunk_i_size * T, 3, 3)
tri_j = triangles_j.reshape(chunk_j_size * T, 3, 3)
# Expand for all pairs - explicitly specify dimensions instead of using ...
tri_i_exp = repeat(tri_i, "bt i j -> (bt bj_t) i j", bj_t=chunk_j_size * T)
tri_j_exp = repeat(tri_j, "bt i j -> (bi_t bt) i j", bi_t=chunk_i_size * T)
# Check intersection
batch_intersect = triangle_intersection_check(tri_i_exp, tri_j_exp)
# Reshape and check if any triangle pair intersects
batch_intersect = batch_intersect.reshape(chunk_i_size, T, chunk_j_size, T)
batch_intersect = batch_intersect.any(dim=(1, 3))
# Update result
frustum_intersection[i:i_end, j:j_end] |= batch_intersect.to(device)
return frustum_intersection
def ray_triangle_intersection_check(
ray_origins: torch.Tensor,
ray_vectors: torch.Tensor,
triangles: torch.Tensor,
max_lengths: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Check if rays intersect triangles in 3D using the Moller-Trumbore algorithm, considering the
finite length of rays.
Args:
ray_origins (torch.Tensor): Bx3 Origins of the rays.
ray_vectors (torch.Tensor): Bx3 Direction vectors of the rays.
triangles (torch.Tensor): Bx3x3 Vertices of the triangles.
max_lengths Optional[torch.Tensor]: B Maximum lengths of the rays.
Returns:
intersects (torch.Tensor): B Boolean tensor indicating if each ray intersects its
corresponding triangle.
"""
vertex0 = triangles[:, 0, :]
vertex1 = triangles[:, 1, :]
vertex2 = triangles[:, 2, :]
edge1 = vertex1 - vertex0
edge2 = vertex2 - vertex0
h = torch.cross(ray_vectors, edge2, dim=1)
a = (edge1 * h).sum(dim=1)
epsilon = 1e-6
mask = torch.abs(a) > epsilon
f = torch.zeros_like(a)
f[mask] = 1.0 / a[mask]
s = ray_origins - vertex0
u = f * (s * h).sum(dim=1)
q = torch.cross(s, edge1, dim=1)
v = f * (ray_vectors * q).sum(dim=1)
t = f * (edge2 * q).sum(dim=1)
# Check conditions
intersects = (u >= 0) & (u <= 1) & (v >= 0) & (u + v <= 1) & (t >= epsilon)
if max_lengths is not None:
intersects &= t <= max_lengths
return intersects
#### Checks for frustums
def _frustum_to_planes(frustums: torch.Tensor) -> torch.Tensor:
r"""
Converts frustum parameters to plane representation.
Args:
frustums (torch.Tensor): Bx8 batch of frustum points following the order:
5 ---------- 4
|\ /|
6 \ / 7
\ 1 ---- 0 /
\| |/
2 ---- 3
Returns:
planes (torch.Tensor): Bx6x4 where 6 represents the six frustum planes and
4 represents plane parameters [a, b, c, d].
"""
planes = []
for inds in [[0, 1, 3], [1, 6, 2], [0, 3, 7], [2, 6, 3], [0, 5, 1], [6, 5, 4]]:
normal = torch.cross(
frustums[:, inds[1]] - frustums[:, inds[0]],
frustums[:, inds[2]] - frustums[:, inds[0]],
dim=1,
)
normal = normal / torch.norm(normal, dim=1, keepdim=True)
d = -torch.sum(normal * frustums[:, inds[0]], dim=1, keepdim=True)
planes.append(torch.cat([normal, d], -1))
return torch.stack(planes, 1)
def points_in_frustum_check(
frustums: torch.Tensor,
points: torch.Tensor,
chunk_size: int | None = None,
device: str | None = None,
):
"""
Check if points are inside frustums.
Args:
frustums (torch.Tensor): Bx8 batch of frustum points.
points (torch.Tensor): BxNx3 batch of points.
chunk_size (Optional[int]): Number of chunks to split the computation into. Defaults to None.
device (Optional[str]): Device to perfrom computation on. Defaults to None.
Returns:
inside (torch.Tensor): BxN batch of Booleans indicating if points are inside frustums.
"""
if device is None:
device = frustums.device
if chunk_size is not None:
# Split computation into chunks to avoid OOM errors for large batch sizes
point_plane_direction = []
for chunk_idx in range(0, frustums.shape[0], chunk_size):
chunk_frustum_planes = _frustum_to_planes(
frustums[chunk_idx : chunk_idx + chunk_size]
)
# Bx8x4 tensor of plane parameters [a, b, c, d]
chunk_points = points[chunk_idx : chunk_idx + chunk_size]
chunk_point_plane_direction = torch.einsum(
"bij,bnj->bni", (chunk_frustum_planes[:, :, :-1], chunk_points)
) + repeat(
chunk_frustum_planes[:, :, -1], "B P -> B N P", N=chunk_points.shape[1]
) # BxMxN tensor
point_plane_direction.append(chunk_point_plane_direction.to(device))
point_plane_direction = torch.cat(point_plane_direction)
else:
# Convert frustums to planes
frustum_planes = _frustum_to_planes(
frustums
) # Bx8x4 tensor of plane parameters [a, b, c, d]
# Compute dot product between each point and each plane
point_plane_direction = torch.einsum(
"bij,bnj->bni", (frustum_planes[:, :, :-1], points)
) + repeat(frustum_planes[:, :, -1], "B P -> B N P", N=points.shape[1]).to(
device
) # BxMxN tensor
inside = (point_plane_direction >= 0).all(-1)
return inside
def frustums_in_frustum_check(
frustums: torch.Tensor,
chunk_size: int,
device: str | None = None,
use_double_chunking: bool = True,
):
"""
Check if frustums are contained within other frustums.
Args:
frustums (torch.Tensor): Bx8 batch of frustum points.
chunk_size (Optional[int]): Number of chunks to split the computation into.
Defaults to None.
device (Optional[str]): Device to store exhuastive frustum containment matrix on.
Defaults to None.
use_double_chunking (bool): If True, use double chunking to avoid OOM errors.
Defaults to True.
Returns:
frustum_contained (torch.Tensor): BxB batch of Booleans indiciating if frustums are inside
other frustums.
"""
B = frustums.shape[0]
if device is None:
device = frustums.device
if use_double_chunking:
frustum_contained = torch.zeros((B, B), dtype=torch.bool, device=device)
# Check if frustums are containing each other by processing in chunks
for i in tqdm(range(0, B, chunk_size), desc="Checking frustum containment"):
i_end = min(i + chunk_size, B)
chunk_i_size = i_end - i
for j in range(0, B, chunk_size):
j_end = min(j + chunk_size, B)
chunk_j_size = j_end - j
# Process a chunk of frustums against another chunk
frustums_i = frustums[i:i_end]
frustums_j_vertices = frustums[
j:j_end, :1
] # Just need one vertex to check containment
# Perform points in frustum check
contained = rearrange(
points_in_frustum_check(
repeat(frustums_i, "B ... -> (B B2) ...", B2=chunk_j_size),
repeat(
frustums_j_vertices, "B ... -> (B2 B) ...", B2=chunk_i_size
),
)[:, 0],
"(B B2) -> B B2",
B=chunk_i_size,
).to(device)
# Map results back to the full matrix
frustum_contained[i:i_end, j:j_end] |= contained
frustum_contained[j:j_end, i:i_end] |= contained.transpose(
0, 1
) # Symmetric relation
else:
# Perform points in frustum check with a single chunked loop
frustum_contained = rearrange(
points_in_frustum_check(
repeat(frustums, "B ... -> (B B2) ...", B2=B),
repeat(frustums[:, :1], "B ... -> (B2 B) ...", B2=B),
chunk_size=chunk_size,
)[:, 0],
"(B B2) -> B B2",
B=B,
).to(device)
frustum_contained = frustum_contained | frustum_contained.T
return frustum_contained