torch.max()函数解释

在MNIST数据集分类任务中,test类中有这样一个语句:

_, predicted = torch.max(outputs, dim=1)

先解释torch.max():

import torch

# 创建一个3x4的张量
x = torch.tensor([[[[-0.2979, -0.8524]],

         [[ 0.5737,  0.3848]],

         [[ 0.0707, -1.0534]],

         [[ 1.1824, -1.5657]]],


        [[[ 0.4247, -0.1413]],

         [[-0.6301, -0.4760]],

         [[ 0.3165, -0.0945]],

         [[ 0.7529,  1.4418]]],


        [[[-0.7834,  1.1841]],

         [[ 0.0056, -0.0068]],

         [[ 1.2475,  0.1757]],

         [[-0.0916, -0.6443]]]])

print(torch.max(x))
print(x.shape)
print(torch.max(x,dim=0))
'''输出'''
tensor(1.4418)
torch.Size([3, 4, 1, 2])
torch.return_types.max(
values=tensor([[[0.4247, 1.1841]],

        [[0.5737, 0.3848]],

        [[1.2475, 0.1757]],

        [[1.1824, 1.4418]]]),
indices=tensor([[[1, 2]],

        [[0, 0]],

        [[2, 2]],

        [[0, 1]]]))

torch.max(x):返回x中所有元素的最大值,一维的tensor

torch.max(x,dim=0)  沿第0维寻找其他每个维度最大值,返回[4,1,2]尺寸的Tensor,这个Tensor是每个维度最大值的组合。

value, idx = torch.max(outputs, dim=1)

这种赋值方式,值赋给value,索引赋给idx,由于我们不关心预测概率最大的值,只关心其索引,使用

_, predicted = torch.max(outputs, dim=1)

这里 predicted 将仅接收最大值的索引,而 _ 是一个占位符,用于忽略不需要使用的最大概率值。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值