最远点采样(FPS)可以设置采样的点数,比起网格采样等方法更加实用,被应用在点云处理方法中(例如PointNet++)。缺点是每次要计算两个集合中所有点的相对距离,计算量较大。但实际上,采取矩阵运算的实现方式,计算速率也能接受。
下面是FPS的Python实现,主要包括计算索引和映射点两步骤,度量使用欧式距离。
def FarthestPointSampling_ForBatch(xyz, npoint):
B, N, C = xyz.shape
centroids = np.zeros((B, npoint))
distance = np.ones((B, N)) * 1e10
batch_indices = np.arange(B)
barycenter = np.sum((xyz), 1)
barycenter = barycenter/xyz.shape[1]
barycenter = barycenter.reshape(B, 1, 3)
dist = np.sum((xyz - barycenter) ** 2, -1)
farthest = np.argmax(dist,1)
for i in range(npoint):
print("-------------------------------------------------------")
print("The %d farthest pts %s " % (i, farthest))
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].reshape(B, 1, 3)
dist = np.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = np.argmax(distance, -1)
return centroids.astype('int')
def index_points(points, idx):
batch_indices = np.arange(points.shape[0])
batch_indices = batch_indices.repeat(idx.shape[1]).reshape((-1,idx.shape[1])) #(batch_size,npoints)
print(batch_indices.shape)
new_points = points[batch_indices,idx, :]
return new_points