File size: 1,008 Bytes
2fd6166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from torch.autograd import Function

from .backend import _backend

__all__ = ['grouping']


class Grouping(Function):
    @staticmethod
    def forward(ctx, features, indices):
        """
        :param ctx:
        :param features: features of points, FloatTensor[B, C, N]
        :param indices: neighbor indices of centers, IntTensor[B, M, U], M is #centers, U is #neighbors
        :return:
            grouped_features: grouped features, FloatTensor[B, C, M, U]
        """
        features = features.contiguous()
        indices = indices.contiguous()
        ctx.save_for_backward(indices)
        ctx.num_points = features.size(-1)
        # print(features.dtype, features.shape)
        return _backend.grouping_forward(features, indices)

    @staticmethod
    def backward(ctx, grad_output):
        indices, = ctx.saved_tensors
        grad_features = _backend.grouping_backward(grad_output.contiguous(), indices, ctx.num_points)
        return grad_features, None


grouping = Grouping.apply