代码
返回采样点索引
def farthest_point_sample1(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
dist = torch.sum((xyz - centroid) ** 2, -1)
distance = torch.min(distance, dist)
farthest = torch.max(distance, -1)[1]
return centroids
返回采样点坐标
def farthest_point_sample(xyz, npoint):
"""
:param xyz: pointcloud data, [N,3]
:param npoint: number of samples
:return:
result: sampled pointcloud points
"""
device = xyz.device
N, C = xyz.shape
centroids = torch.zeros(npoint).to(device)
distance = torch.ones(N).to(device) * 1e10
farthest = torch.randint(0, N, ()).to(device)
for i in range(npoint):
centroids[i] = farthest
centroid = xyz[farthest, :]
centroid = centroid.to(device)
dist = torch.sum((xyz - centroid) ** 2, -1)
dist = dist.to(device)
mask = dist < distance
mask = mask.to(device)
dist = dist.float()
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
result = xyz[centroids.long()]
return result