Pytorch 中的 torch.topk()函数

torch.topk函数用于在张量中提取指定维度上元素的前k个最大或最小值。该函数接受参数如输入张量(input),要选取的元素数量(k),排序的维度(dim)以及是否返回最大值(largest)和是否保持排序(sorted)。返回结果是一个包含值和对应索引的命名元组。示例展示了如何在张量的列上找到最小的两个元素。
摘要由CSDN通过智能技术生成

1 作用

对一个 tensor 中的元素取它的前 K 个元素 (从大到小排列)

2 参数介绍

import torch

torch.topk(input, k, dim = None, largest = True, sorted = True, *, out = None)
  • input(Tensor) : 输入的张量

  • k(int) : 前 k 个大小中的 k

  • dim(int, optional) : 需要进行排序的维度, dim = 0 表示按照列来排序, dim = 1 表示按照行来排序, 默认情况下, dim = 1

  • largest(bool, optional) : 控制是否返回最大值或最小值

  • sorted(bool, optional) : 控制是否对元素进行排序后再返回

  • out(tuple,可选):(Tensor,LongTensor)的输出元组,可以可选地指定用作输出缓冲区

3 注意事项

  • 返回给定维度上给定输入张量的k个最大元素。

  • 如果未给出dim,则选择输入的最后一个维度。

  • 如果maximum为False,则返回k个最小元素。

  • 返回一个(值,索引)的命名元组,其中包含给定维度dim中输入张量每行的最大k个元素的值和索引。

  • 如果为True,则布尔选项将确保返回的k个元素本身已排序

4 示例

import torch

a = torch.randn(3, 3)
print("a : ", a)
b = torch.topk(a, 2, dim = 1, largest = False)
print("b :", b)

>>> a :  tensor([[-0.6474, -0.0939,  1.3639],
        [-0.0297,  0.6471, -0.2255],
        [-1.2431, -0.3386,  1.9692]])
>>> b : torch.return_types.topk(
>>> values=tensor([[-0.6474, -0.0939],
        [-0.2255, -0.0297],
        [-1.2431, -0.3386]]),
>>> indices=tensor([[0, 1],
        [2, 0],
        [0, 1]]))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值