动机
在点云识别中,有时候需要采样指定点附近的点形成一个区域。本文记录了PointNet++的BallQuery的方式,指定中心点和半径的一个球体,在球体内进行采样。
为了加快速度,在GPU上进行。
实现
- 安装PointNet++
参考 https://github.com/sshaoshuai/Pointnet2.PyTorch
git clone https://github.com/sshaoshuai/Pointnet2.PyTorch.git
cd pointnet2
python setup.py install
cd ../
- 调用
pointnet2_utils.ball_query
函数接口介绍
class BallQuery(Function):
@staticmethod
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param radius: float, radius of the balls
:param nsample: int, maximum number of features in the balls
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: (B, npoint, 3) centers of the ball query
:return:
idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
"""
assert new_xyz.is_contiguous()
assert xyz.is_contiguous()
B, N, _ = xyz.size()
npoint = new_xyz.size(1)
idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
return idx
@staticmethod
def backward(ctx, a=None):
return None, None, None, None
ball_query = BallQuery.apply