官网介绍:
输入是两个点云P1,P2(也可以是同一个,比如都是P1),
举个例子:
import torch
from pytorch3d.ops import ball_query
# 示例点云数据
N = 2 # 批次大小
P1 = 5 # 源点云中的点数
P2 = 3 # 查询点云中的点数
p1 = torch.rand(N, P1, 3) # 形状为 (N, P1, 3) 的随机源点云
p2 = torch.rand(N, P2, 3) # 形状为 (N, P2, 3) 的随机查询点云
# 球查询参数
radius = 0.5
K = 2
# 执行球查询
#要做的是以p1的点为中心,在p2中找邻居点,最大邻居可以找k个,不够K个的res.idx用-1填充
#res.dist用0填充
#返回res.idx是(N,P1,K)的shape
res = ball_query(p1, p2, radius=radius, K=K)
print("源点云 p1:", p1)
print("查询点云 p2:", p2)
print("邻居点索引 idx:", res.idx)
现在p1是(N,P1, 3)的3维点云,p2是(N, P2, 3)的3维点云。
现在以p1的点为中心,在p2里面找每个p1点的邻居,
邻居的上限限制在K,
找到的邻居不一定是距离最近的,但是是指定半径radius内的前K个。
如果找到的邻居数不够K个,那么res.dist里面会用0填充,res.idx会用-1填充。
返回res.idx的shape为(N,P1, K), 和p1的shape是一样的,也就是每个p1中的点找到的K个邻居。
引申:
看下面的代码
cutout_point_cloud = original_pcd[cutout_mask > 0]
for i in range(grow_iter):
num_points_in_seed = seed_pcd.shape[0]
res = pytorch3d.ops.ball_query(
cutout_point_cloud.unsqueeze(0),
seed_pcd.unsqueeze(0),
K=1,
radius=thresh,
return_nn=False
).idx
mask = (res != -1).sum(-1) != 0 #对最后一维求和,找到的邻居数量
mask = mask.squeeze()
seed_pcd = cutout_point_cloud[mask, :]
这段代码的初始pcd是cutout_point_cloud,要查询的pcd是seed_pcd(分割出的目标点云),
它要做的是以cutout_point_cloud的每个点为中心,在seed_pcd找K个邻居,
如果找到了邻居, 说明它不是孤立点,在分割目标附近,那么把这个点也加到分割目标里面,
迭代grow_iter次。
可以看作一个区域点的生长算法。