Pytorch学习:torch.argmax(input, dim, keepdim=False)详解

torch.argmax() 返回输入中所有元素的最大值的索引,与torch.max()中返回(values, indices)中的indices类似,它也常被用于深度学习中的分类问题。

在下面程序中,使用torch.argmax()

import torch

a = torch.tensor([[1, 2, 3, 4],
                  [4, 1, 2, 3],
                  [6, 2, 3, 4],
                  [3, 4, 5, 9]])

print(torch.argmax(a))

在这里插入图片描述
最大值在tensor(15)的位置

接下来引入dim属性,dim=0代表消去维数dim=0(行),即求每列最大值的索引。

print(torch.argmax(a, dim=0))

在这里插入图片描述
在这里插入图片描述
dim=1代表消去维数dim=1(列),即求每行最大值的索引。

print(torch.argmax(a, dim=1))

在这里插入图片描述在这里插入图片描述
再接下来引入keepdim属性,默认为False
它表示是否保留要消去的维数,用上面的程序来示范keepdim=True的情况,它保留了要消去的列。

print(torch.argmax(a, dim=1, keepdim=True))

在这里插入图片描述

在深度学习中,我们常用argmax来预测分类的标签,例如:

import torch

outputs = torch.tensor([[0.1, 0.2],
                        [0.3, 0.4]])

preds = outputs.argmax(1)
targets = torch.tensor([0, 1])
print((preds == targets).sum().item())

在这里插入图片描述
在这里插入图片描述

  1. 假设上面outputs是深度学习模型预测的概率值分布
  2. argmax(1)代表预测最大概率所在的标签
  3. 通过预测标签与真实标签相比,如果相等代表预测正确,否则相反,用来表示模型预测的正确率从而评估模型

下面是某深度学习模型在刚开始训练时所预测的标签与真实标签的差异,随着训练的进行,准确率也会不断上升。

在这里插入图片描述
在这里插入图片描述

官方文档torch.argmax(input,dim,keepdim=False)
主要参数:

  • input(Tensor)-输入张量。
  • dim(int)-要减少的维度。如果为 None ,则返回展平输入的argmax。
  • keepdim(bool)-输出张量是否保留了 dim 。如果 dim=None ,则忽略。
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值