Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from torch.autograd import Function | |
| from .backend import _backend | |
| __all__ = ['gather', 'furthest_point_sample', 'logits_mask'] | |
| class Gather(Function): | |
| def forward(ctx, features, indices): | |
| """ | |
| Gather | |
| :param ctx: | |
| :param features: features of points, FloatTensor[B, C, N] | |
| :param indices: centers' indices in points, IntTensor[b, m] | |
| :return: | |
| centers_coords: coordinates of sampled centers, FloatTensor[B, C, M] | |
| """ | |
| features = features.contiguous() | |
| indices = indices.int().contiguous() | |
| ctx.save_for_backward(indices) | |
| ctx.num_points = features.size(-1) | |
| return _backend.gather_features_forward(features, indices) | |
| def backward(ctx, grad_output): | |
| indices, = ctx.saved_tensors | |
| grad_features = _backend.gather_features_backward(grad_output.contiguous(), indices, ctx.num_points) | |
| return grad_features, None | |
| gather = Gather.apply | |
| def furthest_point_sample(coords, num_samples): | |
| """ | |
| Uses iterative furthest point sampling to select a set of npoint features that have the largest | |
| minimum distance to the sampled point set | |
| :param coords: coordinates of points, FloatTensor[B, 3, N] | |
| :param num_samples: int, M | |
| :return: | |
| centers_coords: coordinates of sampled centers, FloatTensor[B, 3, M] | |
| """ | |
| coords = coords.contiguous() | |
| indices = _backend.furthest_point_sampling(coords, num_samples) | |
| return gather(coords, indices) | |
| def logits_mask(coords, logits, num_points_per_object): | |
| """ | |
| Use logits to sample points | |
| :param coords: coords of points, FloatTensor[B, 3, N] | |
| :param logits: binary classification logits, FloatTensor[B, 2, N] | |
| :param num_points_per_object: M, #points per object after masking, int | |
| :return: | |
| selected_coords: FloatTensor[B, 3, M] | |
| masked_coords_mean: mean coords of selected points, FloatTensor[B, 3] | |
| mask: mask to select points, BoolTensor[B, N] | |
| """ | |
| batch_size, _, num_points = coords.shape | |
| mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N] | |
| num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1] | |
| masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N] | |
| masked_coords_mean = torch.sum(masked_coords, dim=-1) / torch.max(num_candidates, | |
| torch.ones_like(num_candidates)).float() # [B, C] | |
| selected_indices = torch.zeros((batch_size, num_points_per_object), device=coords.device, dtype=torch.int32) | |
| for i in range(batch_size): | |
| current_mask = mask[i] # [N] | |
| current_candidates = current_mask.nonzero().view(-1) | |
| current_num_candidates = current_candidates.numel() | |
| if current_num_candidates >= num_points_per_object: | |
| choices = np.random.choice(current_num_candidates, num_points_per_object, replace=False) | |
| selected_indices[i] = current_candidates[choices] | |
| elif current_num_candidates > 0: | |
| choices = np.concatenate([ | |
| np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates), | |
| np.random.choice(current_num_candidates, num_points_per_object % current_num_candidates, replace=False) | |
| ]) | |
| np.random.shuffle(choices) | |
| selected_indices[i] = current_candidates[choices] | |
| selected_coords = gather(masked_coords - masked_coords_mean.view(batch_size, -1, 1), selected_indices) | |
| return selected_coords, masked_coords_mean, mask | |