pytorch中的torch.argmax函数

转载自:pytorch中的torch.argmax函数 - 知乎

x = torch.randn(3, 5)
print(x)
print(torch.argmax(x))
print(torch.argmax(x, dim=0))
print(torch.argmax(x, dim=-2))
print(torch.argmax(x, dim=1))
print(torch.argmax(x, dim=-1))



output:
tensor([[-1.0214,  0.7577, -0.0481, -1.0252,  0.9443],
        [ 0.5071, -1.6073, -0.6960, -0.6066,  1.6297],
        [-0.2776, -1.3551,  0.0036, -0.9210, -0.6517]])
tensor(9)
tensor([1, 0, 2, 1, 1])
tensor([1, 0, 2, 1, 1])
tensor([4, 4, 2])
tensor([4, 4, 2])

 结论:dim的取值为[-2, 1]之间,只能取整,有四个数,0和-2对应,得到的是每一列的最大值,1和-1对应,得到的是每一行的最大值。如果参数中不写dim,则得到的是张量中最大的值对应的索引(从0开始)。

注意:

(1)就是dim等于几,就是表明删除那一维。比如x = torch.randn(3, 5),三行五列的矩阵,就是当dim=1时,就是说明结果是3个数字了,dim=0时,就是结果为5个数字

(2)此外不指定dim维度的时候,是直接按照顺序对所有元素进行遍历。返回的索引是基于所有维度铺平时的索引,比如上面:

print(torch.argmax(x))

结果就是9,这个就是遍历下面的矩阵:

tensor([[-1.0214,  0.7577, -0.0481, -1.0252,  0.9443],
        [ 0.5071, -1.6073, -0.6960, -0.6066,  1.6297],
        [-0.2776, -1.3551,  0.0036, -0.9210, -0.6517]])

1.6297这个元素就是最大的,其索引等于平铺后的15个元素中的第8个,即索引为9

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值