pytorch中topk()用法的测试与个人理解

参数介绍:

直接官网的介绍topk()
在这里插入图片描述

  • input:就是输入的tensor,也就是要取topk的张量
  • k:就是取前k个最大的值。
  • dim:就是在哪一维来取这k个值。
  • lagest:默认是true表示取前k大的值,false则表示取前k小的值
  • sorted:是否按照顺序输出,默认是true。
  • out : 可选输出张量 (Tensor, LongTensor)

直接上代码:


首先研究一下dim和k这两个最重要的参数:

import torch
seed = 0
torch.manual_seed(seed)
a = torch.randint(1,10,(3,4,4))

print(a)
tensor([[[9, 1, 3, 7],
         [8, 7, 8, 2],
         [2, 1, 9, 3],
         [7, 4, 2, 3]],

        [[1, 1, 6, 4],
         [9, 3, 9, 3],
         [9, 6, 8, 9],
         [7, 1, 2, 1]],

        [[9, 7, 2, 7],
         [2, 9, 9, 8],
         [3, 4, 8, 8],
         [2, 6, 5, 8]]])
values , indices = a.topk(2,dim=0)
print(values.shape)
print(values)
print(indices.shape)
print(indices)
torch.Size([2, 4, 4])
tensor([[[9, 7, 6, 7],
         [9, 9, 9, 8],
         [9, 6, 9, 9],
         [7, 6, 5, 8]],

        [[9, 1, 3, 7],
         [8, 7, 9, 3],
         [3, 4, 8, 8],
         [7, 4, 2, 3]]])
torch.Size([2, 4, 4])
tensor([[[0, 2, 1, 0],
         [1, 2, 1, 2],
         [1, 1, 0, 1],
         [0, 2, 2, 2]],

        [[2, 0, 0, 2],
         [0, 0, 2, 1],
         [2, 2, 1, 2],
         [1, 0, 0, 0]]])

**这里说一下我的理解 **

首先dim=0这个参数表示在某一维取topk,在我的代码中就是取前2个。首先看输出的values和indices的张量形状:(2,4,4)这里可以结合下面两个实验总结出规律,dim取几,输出结果的形状就是:其他维度不变,对应维度变成k。

现在dim=0最后的输出就是要变成(2,4,4)也就是之前第一维中保留两个最大的。

看values的值,讲一下第一行元素[9,7,6,7]是如何得来的:
因为dim=0所以要从第0维来看,将数据分成3份,分别是:

1. [[9, 1, 3, 7],
    [8, 7, 8, 2],
    [2, 1, 9, 3],
    [7, 4, 2, 3]]
2. [[1, 1, 6, 4],
    [9, 3, 9, 3],
    [9, 6, 8, 9],
    [7, 1, 2, 1]]
3. [[9, 7, 2, 7],
    [2, 9, 9, 8],
    [3, 4, 8, 8],
    [2, 6, 5, 8]]

要以这三个tensor为单位进行topk的筛选,首先比较每一个tensor的第一行,因为参数k为2,所以就要找到这3组元素中的最大值和次大值,作为最后的输出。因此最大值就是[9,7,6,7],次大值为:[9,1,3,7]这样就完成了筛选。索引值也就是当前位置处的元素,是来自于这三个元素中的哪一个。我认为把这个看懂后面就可以迎刃而解,大家可以仔细理解一下不太懂的话也没关系,看完后面两个可能这个就懂了。

values1 , indices1 = a.topk(2,dim=1)
print(values1)
print(indices1)
torch.Size([3, 2, 4])
tensor([[[9, 7, 9, 7],
         [8, 4, 8, 3]],

        [[9, 6, 9, 9],
         [9, 3, 8, 4]],

        [[9, 9, 9, 8],
         [3, 7, 8, 8]]])
torch.Size([3, 2, 4])
tensor([[[0, 1, 2, 0],
         [1, 3, 1, 2]],

        [[1, 2, 1, 2],
         [2, 1, 2, 0]],

        [[0, 1, 1, 2],
         [2, 0, 2, 3]]])

这个例子是dim=1时,类比于dim=0的情况。这里是对第一维进行筛选操作。需要注意的是这里第0维的三个元素是分开操作的。这里我提供一种我自己的理解思路大家借鉴。首先还是按照第0维将tensor分为3块

1. [[9, 1, 3, 7],
    [8, 7, 8, 2],
    [2, 1, 9, 3],
    [7, 4, 2, 3]]
2. [[1, 1, 6, 4],
    [9, 3, 9, 3],
    [9, 6, 8, 9],
    [7, 1, 2, 1]]
3. [[9, 7, 2, 7],
    [2, 9, 9, 8],
    [3, 4, 8, 8],
    [2, 6, 5, 8]]

这里每一块中的第0维就是总体tensor的第一维,从第0维来看就是4个14的向量,因此就是对这4向量取最大值和次大值。也就是在这个44的张量中选出对应位置的最大值和次大值。例如第一块中筛选出的结果就是[9,7,9,7]和[8,4,8,3]其他同理,索引值表示当前位置处的值是来自哪一个向量。

values2 , indices2 = a.topk(2,dim=2)
print(values2.shape)
print(values2)
print(indices2.shape)
print(indices2)
torch.Size([3, 4, 2])
tensor([[[9, 7],
         [8, 8],
         [9, 3],
         [7, 4]],

        [[6, 4],
         [9, 9],
         [9, 9],
         [7, 2]],

        [[9, 7],
         [9, 9],
         [8, 8],
         [8, 6]]])
torch.Size([3, 4, 2])
tensor([[[0, 3],
         [0, 2],
         [2, 3],
         [0, 1]],

        [[2, 3],
         [0, 2],
         [3, 0],
         [0, 2]],

        [[0, 1],
         [1, 2],
         [2, 3],
         [3, 1]]])

类比前两种情况的思考方式,这里的操作就是对整个张量最内层做的操作,也就是整体张量形状(3,4,4)中的4这个4就是最内层每一个一维向量中的4个元素,取对应的最大值和次大值,应该也容易理解。大家可以对比着三种情况的输入输出加以理解。
另外,k参数默认是最后一维

然后研究一下lagest参数:
直接用最后一维

values2 , indices2 = a.topk(2,dim=2,largest=False)
print(values2.shape)
print(values2)
print(indices2.shape)
print(indices2)
torch.Size([3, 4, 2])
tensor([[[1, 3],
         [2, 7],
         [1, 2],
         [2, 3]],

        [[1, 1],
         [3, 3],
         [6, 8],
         [1, 1]],

        [[2, 7],
         [2, 8],
         [3, 4],
         [2, 5]]])
torch.Size([3, 4, 2])
tensor([[[1, 2],
         [3, 1],
         [1, 0],
         [2, 3]],

        [[1, 0],
         [1, 3],
         [1, 2],
         [1, 3]],

        [[2, 3],
         [0, 3],
         [0, 1],
         [0, 2]]])

很明显,只是取最小和次小


下面是sorted参数
依然用dim=2进行测试

values2 , indices2 = a.topk(2,dim=2,sorted=False)
print(values2.shape)
print(values2)
print(indices2.shape)
print(indices2)
torch.Size([3, 4, 2])
tensor([[[9, 7],
         [8, 8],
         [9, 3],
         [7, 4]],

        [[6, 4],
         [9, 9],
         [9, 9],
         [7, 2]],

        [[9, 7],
         [9, 9],
         [8, 8],
         [8, 6]]])
torch.Size([3, 4, 2])
tensor([[[0, 3],
         [0, 2],
         [2, 3],
         [0, 1]],

        [[2, 3],
         [0, 2],
         [3, 0],
         [0, 2]],

        [[0, 1],
         [1, 2],
         [2, 3],
         [3, 1]]])

这里和不sorted=True好像并没有区别,不知道要怎么理解,网上也没找到类似的解释,希望有知道的大佬可以多多指教!!

如有错误请多指正

  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorchtopk函数是用于返回输入张量指定维度上的前k个最大值及其对应的索引。它的函数签名为torch.topk(input, k, dim=None, largest=True, sorted=True, out=None),返回一个元组,包含最大的k个值组成的张量和它们在输入张量的索引组成的长整型张量。其,input是输入张量,k是要返回的最大值的个数,dim是指定的维度,largest决定是否返回最大值(默认为True),sorted决定是否返回排序的结果(默认为True),out是输出的张量。 例如,如果我们有一个输入张量input为[5, 9, 3, 2, 7],我们想要找出其最大的3个值及其索引,我们可以使用torch.topk(input, 3)。这将返回一个包含[9, 7, 5]的张量和一个包含[1, 4, 0]的长整型张量,分别表示最大的3个值和它们在输入张量的索引。 在具体的代码,maxk = max(topk)用于获取topk列表的最大值,而output.topk(maxk, 1, True, True)则是对output进行topk操作,返回最大值和对应的索引。这种用法可以帮助我们在代码获取最大的k个值及其索引。 总结来说,PyTorchtopk函数可以帮助我们在指定维度上找出输入张量的最大值及其对应的索引。这在许多机器学习和深度学习任务非常有用。如果想要了解更多关于topk函数的用法,可以参考PyTorch官方文文档或者一篇介绍topk函数用法的文章。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* [pytorch topk函数](https://blog.csdn.net/u012505617/article/details/103711019)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *3* [PyTorchtopk函数的用法详解](https://download.csdn.net/download/weixin_38628150/12856649)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值