在MNIST数据集分类任务中,test类中有这样一个语句:
_, predicted = torch.max(outputs, dim=1)
先解释torch.max():
import torch
# 创建一个3x4的张量
x = torch.tensor([[[[-0.2979, -0.8524]],
[[ 0.5737, 0.3848]],
[[ 0.0707, -1.0534]],
[[ 1.1824, -1.5657]]],
[[[ 0.4247, -0.1413]],
[[-0.6301, -0.4760]],
[[ 0.3165, -0.0945]],
[[ 0.7529, 1.4418]]],
[[[-0.7834, 1.1841]],
[[ 0.0056, -0.0068]],
[[ 1.2475, 0.1757]],
[[-0.0916, -0.6443]]]])
print(torch.max(x))
print(x.shape)
print(torch.max(x,dim=0))
'''输出'''
tensor(1.4418)
torch.Size([3, 4, 1, 2])
torch.return_types.max(
values=tensor([[[0.4247, 1.1841]],
[[0.5737, 0.3848]],
[[1.2475, 0.1757]],
[[1.1824, 1.4418]]]),
indices=tensor([[[1, 2]],
[[0, 0]],
[[2, 2]],
[[0, 1]]]))
torch.max(x):返回x中所有元素的最大值,一维的tensor
torch.max(x,dim=0) 沿第0维寻找其他每个维度最大值,返回[4,1,2]尺寸的Tensor,这个Tensor是每个维度最大值的组合。
value, idx = torch.max(outputs, dim=1)
这种赋值方式,值赋给value,索引赋给idx,由于我们不关心预测概率最大的值,只关心其索引,使用
_, predicted = torch.max(outputs, dim=1)
这里 predicted
将仅接收最大值的索引,而 _
是一个占位符,用于忽略不需要使用的最大概率值。