topk 是 PyTorch 中的一个函数,用于从张量中选取最大(或最小)的 k 个元素及其对应的索引。其定义如下:
values, indices = torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
参数说明
- input (Tensor): 输入张量。
- k (int): 要选取的最大(或最小)元素的数量。
- dim (int, 可选): 指定沿着哪个维度进行操作。默认为 None,此时沿着最后一个维度进行操作。
- largest (bool, 可选): 如果为True,则选取最大的 k 个元素;如果为 False,则选取最小的 k 个元素。默认为 True。
- sorted (bool, 可选): 如果为 True,则返回的值是排序过的(即最大的值排在前面)。如果为 False,则返回的值是按照它们在原张量中的顺序排列。默认为 True。
- out (tuple, 可选): 可以指定一个元组来存储输出结果。元组应该包含两个张量,分别用于存储值和索引。默认为None。
代码片段赏析:
def get_embedding_indices(self, points):
r"""Compute the indices of pair-wise distance embedding and triplet-wise angular embedding.
Args:
points: torch.Tensor (B, N, 3), input point cloud
Returns:
d_indices: torch.FloatTensor (B, N, N), distance embedding indices
a_indices: torch.FloatTensor (B, N, N, k), angular embedding indices
"""
batch_size, num_point, _ = points.shape
dist_map = torch.sqrt(pairwise_distance(points, points)) # (B, N, N)
d_indices = dist_map / self.sigma_d
k = self.angle_k
##! largest=False的含义是选择K个距离最小的点, dim=2代表从dist_map的第二个维度来选择
knn_indices = dist_map.topk(k=k + 1, dim=2, largest=False)[1][:, :, 1:] # 这里将 dist_map.topk会返回values和indices, 用索引1来取出indices后, 再从所有的knn中去掉自身, 所以取[:, :, 1:]