PointNet++ 代码阅读四 query_ball_point函数
(本文是通过阅读https://blog.csdn.net/weixin_39373480/article/details/88934146之后带来的感悟,基本内容来自以上链接,侵删)
query_ball_point函数用于寻找球形领域中的点。输入中radius为球形领域的半径,nsample为每个领域中要采样的点,new_xyz为S个球形领域的中心(由最远点采样在前面得出),xyz为所有的点云;输出为每个样本的每个球形领域的nsample个采样点集的索引[B,S,nsample]
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, C]
new_xyz: query points, [B, S, C]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
#将输入点的坐标写入GPU中
device = xyz.device
#读出输入点以及输出点的形状大小,注意只是形状大小而不包含内容,这里的B,N, C,S都是一个数
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
#创建输出的张量,注意这里的group_idx的维度是[B, S, N],其中B即该点属于哪个样本,S为该点属于哪个圆形区域
#N为该点的坐标的索引,其中N的值为1-N
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
#sqrdists: [B, S, N] 记录中心点与所有点之间的欧几里德距离
#这里的square_distance为PointNet++定义的函数,注意这里每一个样本点就会产生一个圆形区域,
#所以sqrdists与group_idx的各个维度表示的意义一致
sqrdists = square_distance(new_xyz, xyz)
#找到所有距离大于radius^2的,其group_idx直接置为N;其余的保留原来的值
#注意,由于是批量操作,所以这里只有哪个圆形区域都不属于的点才会被赋值为N
group_idx[sqrdists > radius ** 2] = N
#做升序排列,前面大于radius^2的都是N,会是最大值,所以会直接在剩下的点中取出前nsample个点
#注意这句语句结束后的group_idx的大小与一开始发生变化,他现在只包含圆形区域内的点
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
#考虑到有可能前nsample个点中也有被赋值为N的点(即球形区域内不足nsample个点),这种点需要舍弃,直接用第一个点来代替即可
#group_first: [B, S, k], 实际就是把group_idx中的第一个点的值复制为了[B, S, K]的维度,便利于后面的替换
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
#找到group_idx中值等于N的点
#这里的mask返回的是前面判断语句为true的索引值
mask = group_idx == N
#将这些点的值替换为第一个点的值
#即用原始点替换圆形区域内的为N的点
group_idx[mask] = group_first[mask]
return group_idx