FPS(最远点采样算法)原理与代码详解
介绍
最远点采样(Farthest Point Sampling)是一种非常常用的采样算法,相较于随机采样算法,FPS能够保证对样本采样的点集更均匀,缺点是计算量太大,耗时较多。
原理
最远点采样本质就是寻找与其他已采样点最远的点作为下一个采样点,尽量保证采样均匀,尽量能够保证采样之后的数据依旧能够看出物体的大致轮廓。
最远采样算的主要分为以下几个步骤:
-
对于一堆需要采样的点集(1000个点,按序编号1-1000),首先随机选取一个点(1)作为采样点。(好了,我们现在有了第一个采样点(1)以及未采样点集(2-1000))
-
依次计算每个点到采样点之间的距离,一般取欧式距离,选取距离最大的点作为下一个采样点。(这里假设点(38)距离采样点(1)距离最远,所以点(38)成为第二个采样点,现在采样点集{1,38})
-
依次计算每个点(未采样点)与采样点集(1,38)的距离,保留最小距离(ps:点3与采样点集(点1,点38)的距离分别为x1,x2(x1>x2),保留x1),全部计算完成后,得到(x1,y1……),取其中最大值的点为第三个采样点。
-
按照第三步,不断的保留点与采样点集的最小值,直到采样足够的点数。
# 最远点采样规则是离已采样集中任何点的距离的最远,比如数字1它离点集{0, 8}的距离就是它与0的距离,千万不要理解为离已采样集合所有点距离之和。
# 故选取点集时所谓的最远距离就是点与采样点集中最小值列表里面的最大值
def farthest_point_sample(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
# 创建大小为(B, npoint)的零张量,用于存储采样点的索引
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
# 创建大小为(B, N)的张量,并初始化为一个较大的值,用于存储每个点到采样点的距离
distance = torch.ones(B, N).to(device) * 1e10
# 从点云中随机选择一个点作为初始采样点
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):#循环采样npoint个点
# 将当前最远点的索引添加到centroids中
centroids[:, i] = farthest
# 获取上一个采样点的坐标,形状为(B, 1, 3)
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
# 计算每个点与上一个采样点的距离平方,-1表示沿着最后一个维度(坐标维度)求欧式距离
dist = torch.sum((xyz - centroid) ** 2, -1)
# 建立一个mask,如果dist中的元素小于distance矩阵中保存的距离值,则更新distance中的对应值
# 随着迭代的继续,distance矩阵中的值会慢慢变小,
# 其相当于记录着某个Batch中每个点距离所有已出现的采样点的最小距离
# 每次迭代都需要计算未选择的点与已采样点集中所有点的最小距离的最大值,但采样点是每次迭代中产生的,
# 且点的位置不会随意变动,所以每次迭代保留未选择点与已采样点集中的最小距离即可,如果是更小的距离就需要更新,如果不是就不需要,这样每次迭代时就计算了每个点与上一个采样点的距离,每次迭代只需要保留最小值即可,节省了内存开销。
mask = dist < distance
distance[mask] = dist[mask]
# 选择距离最远的点作为下一个采样点
farthest = torch.max(distance, -1)[1]
return centroids```