Spaces:
Running
Running
| // Copyright (c) OpenMMLab. All rights reserved | |
| using namespace parrots; | |
| void chamfer_distance_forward_cuda_parrots(CudaContext& ctx, | |
| const SSElement& attr, | |
| const OperatorBase::in_list_t& ins, | |
| OperatorBase::out_list_t& outs) { | |
| auto xyz1 = buildATensor(ctx, ins[0]); | |
| auto xyz2 = buildATensor(ctx, ins[1]); | |
| auto dist1 = buildATensor(ctx, outs[0]); | |
| auto dist2 = buildATensor(ctx, outs[1]); | |
| auto idx1 = buildATensor(ctx, outs[2]); | |
| auto idx2 = buildATensor(ctx, outs[3]); | |
| chamfer_distance_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); | |
| } | |
| void chamfer_distance_backward_cuda_parrots(CudaContext& ctx, | |
| const SSElement& attr, | |
| const OperatorBase::in_list_t& ins, | |
| OperatorBase::out_list_t& outs) { | |
| auto xyz1 = buildATensor(ctx, ins[0]); | |
| auto xyz2 = buildATensor(ctx, ins[1]); | |
| auto idx1 = buildATensor(ctx, ins[2]); | |
| auto idx2 = buildATensor(ctx, ins[3]); | |
| auto graddist1 = buildATensor(ctx, ins[4]); | |
| auto graddist2 = buildATensor(ctx, ins[5]); | |
| auto gradxyz1 = buildATensor(ctx, outs[0]); | |
| auto gradxyz2 = buildATensor(ctx, outs[1]); | |
| chamfer_distance_backward(xyz1, xyz2, idx1, idx2, graddist1, graddist2, | |
| gradxyz1, gradxyz2); | |
| } | |
| PARROTS_EXTENSION_REGISTER(chamfer_distance_forward) | |
| .input(2) | |
| .output(4) | |
| .apply(chamfer_distance_forward_cuda_parrots) | |
| .done(); | |
| PARROTS_EXTENSION_REGISTER(chamfer_distance_backward) | |
| .input(6) | |
| .output(2) | |
| .apply(chamfer_distance_backward_cuda_parrots) | |
| .done(); | |