PyTorch中的topk函数详解

听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。

用法

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

  • input:一个tensor数据
  • k:指明是得到前k个数据以及其index
  • dim: 指定在哪个维度上排序, 默认是最后一个维度
  • largest:如果为True,按照大到小排序; 如果为False,按照小到大排序
  • sorted:返回的结果按照顺序返回
  • out:可缺省,不要

topk最常用的场合就是求一个样本被网络认为前k个最可能属于的类别。我们就用这个场景为例,说明函数的使用方法。
假设一个tensor F ∈ R N × D F \in R^{N \times D} FRN×D,N是样本数目,一般等于batch size, D是类别数目。我们想知道每个样本的最可能属于的那个类别,其实可以用torch.max得到。如果要使用topk,则k应该设置为1。

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print(indices)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=1, keepdim=True)

print(indices_max == indices)
# pred
tensor([[-0.1480, -0.9819, -0.3364,  0.7912, -0.3263],
        [-0.8013, -0.9083,  0.7973,  0.1458, -0.9156],
        [-0.2334, -0.0142, -0.5493,  0.0673,  0.8185],
        [-0.4075, -0.1097,  0.8193, -0.2352, -0.9273]])
# indices, shape为 【4,1】,
tensor([[3],   #【0,0】代表 第一个样本最可能属于第一类别
        [2],   # 【1, 0】代表第二个样本最可能属于第二类别
        [4],
        [2]])
# indices_max等于indices
tensor([[True],
        [True],
        [True],
        [True]])

现在在尝试一下k=2

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(2, dim=1, largest=True, sorted=True)  # k=2
print(indices)
# pred
tensor([[-0.2203, -0.7538,  1.8789,  0.4451, -0.2526],
        [-0.0413,  0.6366,  1.1155,  0.3484,  0.0395],
        [ 0.0365,  0.5158,  1.1067, -0.9276, -0.2124],
        [ 0.6232,  0.9912, -0.8562,  0.0148,  1.6413]])
# indices
tensor([[2, 3],
        [2, 1],
        [2, 1],
        [4, 1]])

可以发现indices的shape变成了【4, k】,k=2。
其中indices[0] = [2,3]。其意义是说明第一个样本的前两个最大概率对应的类别分别是第3类和第4类。

大家可以自行print一下values。可以发现values的shape和indices的shape是一样的。indices描述了在values中对应的值在pred中的位置。

  • 51
    点赞
  • 111
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
PyTorchtopk函数是用于返回输入张量指定维度上的前k个最大值及其对应的索引。它的函数签名为torch.topk(input, k, dim=None, largest=True, sorted=True, out=None),返回一个元组,包含最大的k个值组成的张量和它们在输入张量的索引组成的长整型张量。其,input是输入张量,k是要返回的最大值的个数,dim是指定的维度,largest决定是否返回最大值(默认为True),sorted决定是否返回排序的结果(默认为True),out是输出的张量。 例如,如果我们有一个输入张量input为[5, 9, 3, 2, 7],我们想要找出其最大的3个值及其索引,我们可以使用torch.topk(input, 3)。这将返回一个包含[9, 7, 5]的张量和一个包含[1, 4, 0]的长整型张量,分别表示最大的3个值和它们在输入张量的索引。 在具体的代码,maxk = max(topk)用于获取topk列表的最大值,而output.topk(maxk, 1, True, True)则是对output进行topk操作,返回最大值和对应的索引。这种用法可以帮助我们在代码获取最大的k个值及其索引。 总结来说,PyTorchtopk函数可以帮助我们在指定维度上找出输入张量的最大值及其对应的索引。这在许多机器学习和深度学习任务非常有用。如果想要了解更多关于topk函数的用法,可以参考PyTorch官方文文档或者一篇介绍topk函数用法的文章。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* [pytorch topk函数](https://blog.csdn.net/u012505617/article/details/103711019)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *3* [PyTorchtopk函数的用法详解](https://download.csdn.net/download/weixin_38628150/12856649)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值