torch.argmax()与torch.max()的使用方法及区别

1.torch.max()
分类神经网络的输出是所有类别对应的概率值,要返回标签的话就需要用到将概率值与标签对应。torch.max()返回tensor数据最大值和索引,输出的值有两个参数,第一个参数是最大值,第二个参数是最大值的索引(也就是分类label),主要用于神经网络输出与label的匹配。
代码如下:

import torch
output=torch.tensor([[-0.7403, -0.6481],
        [-0.6869, -0.6994],
        [-0.6569, -0.7307],
        [-0.6623, -0.7250],
        [-0.7715, -0.6205],
        [-0.6643, -0.7229],
        [-0.5958, -0.8010],
        [-0.5925, -0.8051],
        [-0.7938, -0.6017],
        [-0.6400, -0.7493]])
output
out:
tensor([[-0.7403, -0.6481],
        [-0.6869, -0.6994],
        [-0.6569, -0.7307],
        [-0.6623, -0.7250],
        [-0.7715, -0.6205],
        [-0.6643, -0.7229],
        [-0.5958, -0.8010],
        [-0.5925, -0.8051],
        [-0.7938, -0.6017],
        [-0.6400, -0.7493]])
data=torch.max(output,dim=1)
data

out:
torch.return_types.max(
values=tensor([-0.6481, -0.6869, -0.6569, -0.6623, -0.6205, -0.6643, -0.5958, -0.5925,
        -0.6017, -0.6400]),
indices=tensor([1, 0, 0, 0, 1, 0, 0, 0, 1, 0]))

可以看到,返回的值里面有两个参数,一个是最大值values,另一个是最大值的索引indices,实际应用中取输出标签可以这样写:

values, predictions = torch.max(outputs.data, 1)
 
values:
tensor([-0.6481, -0.6869, -0.6569, -0.6623, -0.6205, -0.6643, -0.5958, -0.5925,-0.6017, -0.6400])

predictions:
tensor([1, 0, 0, 0, 1, 0, 0, 0, 1, 0])

2.torch.argmax()
torch.argmax()的作用与前面类似,我们只想要神经网络最终的标签,它输出的概率值并不关心,那么就可以直接用torch.argmax()返回tensor数据最大值的索引,代码示例如下:

predictions =torch.argmax(output,dim=1)
predictions

out:
tensor([1, 0, 0, 0, 1, 0, 0, 0, 1, 0])

3.补充另外几点
准确率计算:

correct_counts = predictions.eq(labels.data.view_as(predictions))
acc = torch.mean(correct_counts.type(torch.FloatTensor))
  • 15
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值