文章目录
本文为PointNet++ CUDA代码阅读系列的第三部分,其他详见:
(一)PointNet++代码梳理
(二)PointNet++中的FPS的CUDA实现
(三)PointNet++中ball query的CUDA实现
(四)PointNet++中的Three_nn的CUDA实现
CUDA代码要在pytorch中使用,必须设置好CUDA代码与python的接口,并用python编写pytorch中的模块,这两部分详见PointNet++中的FPS的CUDA实现。本文直接看ball query的实现。
给定一个点云xyz,然后给定中心点new_xyz,给定半径和邻域内点的数量,Ball Query可以找出以new_xyz为中心的领域内包含的xyz中的点的下标。
直接看代码,仍然是用PointRCNN中的PointNet++的代码。先看在python中定义的函数,在pointnet2_utils.py中:
class BallQuery(Function):
@staticmethod
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param radius: float, radius of the balls
:param nsample: int, maximum number of features in the balls
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: (B, npoint, 3) centers of the ball query
:return:
idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
"""
assert new_xyz.is_contiguous()
assert xyz.is_contiguous()
B, N, _ = xyz.size()
npoint = new_xyz.size(1)
idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)