【代码阅读】PointNet++中ball query的CUDA实现

本文是PointNet++ CUDA代码阅读系列的第三部分,聚焦于Ball Query的CUDA实现。介绍了如何在CUDA中找到点云中以指定中心点为中心、半径内的点的下标,涉及pointnet2.ball_query_wrapper的python定义和对应的cpp及cuda代码详解。
摘要由CSDN通过智能技术生成

文章目录

本文为PointNet++ CUDA代码阅读系列的第三部分,其他详见:
(一)PointNet++代码梳理
(二)PointNet++中的FPS的CUDA实现
(三)PointNet++中ball query的CUDA实现
(四)PointNet++中的Three_nn的CUDA实现


CUDA代码要在pytorch中使用,必须设置好CUDA代码与python的接口,并用python编写pytorch中的模块,这两部分详见PointNet++中的FPS的CUDA实现。本文直接看ball query的实现。

给定一个点云xyz,然后给定中心点new_xyz,给定半径和邻域内点的数量,Ball Query可以找出以new_xyz为中心的领域内包含的xyz中的点的下标。

直接看代码,仍然是用PointRCNN中的PointNet++的代码。先看在python中定义的函数,在pointnet2_utils.py中:

class BallQuery(Function):

    @staticmethod
    def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
        """
        :param ctx:
        :param radius: float, radius of the balls
        :param nsample: int, maximum number of features in the balls
        :param xyz: (B, N, 3) xyz coordinates of the features
        :param new_xyz: (B, npoint, 3) centers of the ball query
        :return:
            idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
        """
        assert new_xyz.is_contiguous()
        assert xyz.is_contiguous()

        B, N, _ = xyz.size()
        npoint = new_xyz.size(1)
        idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()

        pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
`scipy.spatial.cKDTree.query_ball_point`的`count`参数是一个布尔值,用于指是否只返回满足条件的点的数量,而不返回点索引。 当`count_only为`True`时,函数将返回满足条件的点的数量,而不返回点的索引。这在你只关心满足条件的点的数量而不关心具体点的情况下很有用。 当`count_only`为`False`时,函数将返回满足条件的点的索引列表。这对于需要知道具体满足条件的点的索引的情况很有用。 下面是一个示例: ```python import numpy as np from scipy.spatial import cKDTree # 创建一个包含10个二维点的数组 points = np.random.rand(10, 2) # 创建一个cKDTree对象 tree = cKDTree(points) # 查询距离原点(0, 0)距离小于0.5的点的数量 count = tree.query_ball_point([0, 0], 0.5, count_only=True) print("满足条件的点的数量:", count) # 查询距离原点(0, 0)距离小于0.5的点的索引列表 indices = tree.query_ball_point([0, 0], 0.5, count_only=False) print("满足条件的点的索引列表:", indices) ``` 输出结果示例: ``` 满足条件的点的数量: 3 满足条件的点的索引列表: [0, 2, 6] ``` 在上面的示例,我们创建了一个包含10个二维点的数组,并使用`cKDTree`构建了一个树。然后,我们使用`query_ball_point`方法查询距离原点(0, 0)距离小于0.5的点。通过设置`count_only`参数为`True`,我们只获取满足条件的点的数量;通过设置为`False`,我们获取满足条件的点的索引列表。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值