pointnet的详细代码https://github.com/yanx27/Pointnet_Pointnet2_pytorch
今天主要讲讲代码里mask的主要作用
1、 i 为第几次球心
每次的球心 cen(i)
每个距离为d(i,j) i为第几次球心,j为所有点的第j个点
每次cen(i)与所有点求距离得到的dist一维数组 dist(i)
2、distance矩阵第一次更新会全部刷新为dist(1),因为是记录更短的距离;第二次更新,所有点与cen(2)求距离后,dist(2)与dist(1)逐点比较,保留d(i,j)更短的。
(这一步可以理解为,若某点 x(j) 与 cen(1) 的 d(1,j) 比 d(2,j) 更短,就保留 d(1,j) ;若 d(2,j ) 比d(1,j) 更短,就保留 d(2,j)
以此类推,distance每次刷新,都是为了保留更短距离,在更短的距离里提取最大值点作为下一次的球心cen(i)
3、注意的是,每次求距离,cen(i)都会有x(j)重合,此时重合d(i,j)就等于0,直接刷新distance在该点的位置(可以看作是该点已标注为球心)
4、实际上就是每个点选取距离他最近的球心cen(i),最终输出的distance矩阵是每个点到最近球心cen(i)的距离
def farthest_point_sample(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
# 每个bacth的随机球心
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
# npoint个球心索引
centroids[:, i] = farthest
# 每个球心位置
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
# 勾股定理(毕达哥拉斯定理、三高定理)我称它为牛逼三高哥斯拉定理
dist = torch.sum((xyz - centroid) ** 2, -1)
"掩码的作用是采样距离更短的点"
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
return centroids
torch.manual_seed(3407)
xyz = torch.randn(1, 512, 3)
npoint = 128
centriods = farthest_point_sample(xyz, npoint)
# 打印输出所有秋心索引
print(centriods)
tensor([[444, 294, 243, 99, 423, 488, 250, 494, 333, 15, 286, 306, 129, 67,
132, 332, 9, 112, 460, 271, 421, 79, 220, 121, 230, 12, 222, 464,
316, 262, 120, 285, 298, 367, 233, 255, 317, 147, 497, 185, 409, 190,
412, 303, 202, 217, 273, 443, 325, 42, 414, 462, 184, 216, 283, 390,
4, 291, 245, 371, 503, 471, 145, 52, 223, 211, 51, 164, 295, 68,
314, 287, 483, 224, 365, 388, 457, 165, 486, 476, 274, 92, 463, 205,
10, 482, 16, 320, 411, 109, 77, 410, 41, 351, 123, 134, 207, 70,
24, 455, 467, 375, 127, 369, 334, 304, 73, 509, 310, 116, 176, 117,
153, 253, 3, 282, 401, 432, 236, 296, 427, 126, 93, 359, 248, 382,
194, 292]])
这样做的目的是为了使采样数据更加均匀,防止只在初始点附近和第二次点附近疯狂采样
这是我的一组对照实验
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
"将初始比较距离改为0"
distance = torch.zeros(B, N).to(device)
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
dist = torch.sum((xyz - centroid) ** 2, -1)
"改为采样更远的距离"
mask = dist > distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
return centroids
同样的输入,同样的随机种子,比前一种方法做对比
tensor([[444, 294, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15]])