【Pytorch】topk函数

topk 是 PyTorch 中的一个函数,用于从张量中选取最大(或最小)的 k 个元素及其对应的索引。其定义如下:

values, indices = torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)

参数说明

  • input (Tensor): 输入张量。
  • k (int): 要选取的最大(或最小)元素的数量。
  • dim (int, 可选): 指定沿着哪个维度进行操作。默认为 None,此时沿着最后一个维度进行操作。
  • largest (bool, 可选): 如果为True,则选取最大的 k 个元素;如果为 False,则选取最小的 k 个元素。默认为 True。
  • sorted (bool, 可选): 如果为 True,则返回的值是排序过的(即最大的值排在前面)。如果为 False,则返回的值是按照它们在原张量中的顺序排列。默认为 True。
  • out (tuple, 可选): 可以指定一个元组来存储输出结果。元组应该包含两个张量,分别用于存储值和索引。默认为None。

代码片段赏析:

    def get_embedding_indices(self, points):
        r"""Compute the indices of pair-wise distance embedding and triplet-wise angular embedding.

        Args:
            points: torch.Tensor (B, N, 3), input point cloud

        Returns:
            d_indices: torch.FloatTensor (B, N, N), distance embedding indices
            a_indices: torch.FloatTensor (B, N, N, k), angular embedding indices
        """
        batch_size, num_point, _ = points.shape

        dist_map = torch.sqrt(pairwise_distance(points, points))  # (B, N, N)
        d_indices = dist_map / self.sigma_d

        k = self.angle_k
        ##! largest=False的含义是选择K个距离最小的点, dim=2代表从dist_map的第二个维度来选择
        knn_indices = dist_map.topk(k=k + 1, dim=2, largest=False)[1][:, :, 1:]  # 这里将 dist_map.topk会返回values和indices, 用索引1来取出indices后, 再从所有的knn中去掉自身, 所以取[:, :, 1:] 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

steptoward

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值