网上大多数对max的解释只停留在二维数据,在三维及以上就没有详述,我将对二维数据和三维数据进行详细解释,让你不再有疑虑
参考文章
torch.max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor)
在分类问题中,通常使用max()函数对softmax函数的输出值进行操作,求出预测值索引
参数
- input:softmax函数输出的一个tensor
- dim:是max函数索引的维度 0 0 0或 1 1 1, 0 0 0指每列的最大值, 1 1 1指每行的最大值
输出
- 函数会返回两个tensor,第一个tensor是每行的最大值,softmax的输出中最大的是1,索引第一个tensor是全1的tensor;第二个tensor是每行最大值的索引
二维数据详细讲述
>>>import torch
>>>a = torch.tensor([[1,5,62,54], [2,6,2,6], [2,65,2,6]])
>>>print(a)
tensor([[ 1, 5, 62, 54],
[ 2, 6, 2, 6],
[ 2, 65, 2, 6]])
dim = 0
torch.max(a,0)
torch.return_types.max(
values=tensor([ 2, 65, 62, 54]),
indices=tensor([1, 2, 0, 0]))
这个计算过程是:
- a[dim][0],dim会从0遍历到2,也就是[1,2,2],得到第一个最大值2,index为1
- a[dim][1],对[5,6,65],最大值为65,index为2
- …
- 最终得到上图结果
dim = 1
torch.max(a, 1)
torch.return_types.max(
values=tensor([62, 6, 65]),
indices=tensor([2, 1, 1]))
这个计算过程是:
- a[0][dim],dim会从0遍历到3,也就是[1,5,62,54],得到第一个最大值62,index为2
- a[1][dim],对[2,6,2,6],最大值为6,index为1
- …
- 最终得到上图结果
三维数据详述
a = [1,2,13,4,5,6,27,8,9,0,11,12]
a = np.array(a).reshape(3,2,2)
a = torch.Tensor(a)
print(a)
tensor([[[ 1., 2.],
[13., 4.]],
[[ 5., 6.],
[27., 8.]],
[[ 9., 0.],
[11., 12.]]])
dim = 0
torch.max(a,dim=0)
torch.return_types.max(
values=tensor([[ 9., 6.],
[27., 12.]]),
indices=tensor([[2, 1],
[1, 2]]))
计算过程:
- a[dim][0][0],dim会从0遍历到2,其他维数值不变,也就是[1,5,9],得到第一个最大值9,index为2
- a[dim][0][1],dim会从0遍历到2,其他维数值不变,对[2,6,0]遍历,最大值为6,index为1
- a[dim][1][0],dim会从0遍历到2,其他维数值不变,对[13,27,11]遍历,最大值为27,index为1
- a[dim][1][1],dim会从0遍历到2,其他维数值不变,对[13,27,11]遍历,最大值为27,index为1
- 最终得到上面结果
dim = 1
torch.max(a,dim=1)
torch.return_types.max(
values=tensor([[13., 4.],
[27., 8.],
[11., 12.]]),
indices=tensor([[1, 1],
[1, 1],
[1, 1]]))
计算过程:
- a[0][dim][0],dim会从0遍历到1,其他维数值不变,也就是[1,13],得到第一个最大值13,index为1
- a[0][dim][1],dim会从0遍历到1,其他维数值不变,对[2,4]遍历,最大值为4,index为1
- a[1][dim][0],dim会从0遍历到1,其他维数值不变,对[5,27]遍历,最大值为27,index为1
- a[1][dim][1],dim会从0遍历到1,其他维数值不变,对[6,8]遍历,最大值为8,index为1
- …
- 最终得到上面结果
dim = 2
torch.max(a,dim=2)
torch.return_types.max(
values=tensor([[ 2., 13.],
[ 6., 27.],
[ 9., 12.]]),
indices=tensor([[1, 0],
[1, 0],
[0, 1]]))
计算过程:
- a[0][0][dim],dim会从0遍历到1,其他维数值不变,也就是[1,2],得到第一个最大值2,index为1
- a[0][1][dim],dim会从0遍历到1,其他维数值不变,对[13,4]遍历,最大值为13,index为0
- a[1][0][dim],dim会从0遍历到1,其他维数值不变,对[5,6]遍历,最大值为6,index为1
- a[1][1][dim],dim会从0遍历到1,其他维数值不变,对[27,8]遍历,最大值为27,index为1
- …
- 最终得到上面结果