argmax函数:torch.argmax(input, dim=None, keepdim=False) 返回指定维度最大值的序号,dim给定的定义是:the demention to reduce,就是把dim这个维度,变成这个维度的最大值的index。
1)dim表示不同维度。特别的在dim=0表示二维矩阵中的列,dim=1在二维矩阵中的行。广泛的来说,我们不管一个矩阵是几维的,比如一个矩阵维度如下:(d0,d1,…,dn−1) ,那么dim=0就表示对应到d0 也就是第一个维度,dim=1表示对应到也就是第二个维度,以此类推。
2)知道dim的值是什么意思还不行,还要知道函数中这个dim给出来会发生什么。
例子一:二维数组
import torch
x = torch.randn(2, 4)
print(x)
'''
tensor([[ 1.2864, -0.5955, 1.5042, 0.5398],
[-1.2048, 0.5106, -2.0288, 1.4782]])
'''
# y0表示矩阵dim=0维度上(每一列)张量最大值的索引
y0 = torch.argmax(x, dim=0)
print(y0)
'''
tensor([0, 1, 0, 1])
'''
# y1表示矩阵dim=1维度上(每一行)张量最大值的索引
y1 = torch.argmax(x, dim=1)
print(y1)
'''
tensor([2, 3])
'''
例子二:三维数组
x = torch.randn(2, 4, 5)
print(x)
'''
tensor([[[-1.2204, -0.6428, -0.2278, 0.5589, 1.1589],
[ 0.4235, 1.9663, 0.5055, -1.3472, 1.3523],
[ 1.4220, 0.7886, -1.0821, 0.6268, -0.9465],
[-0.3950, 1.3275, 0.3369, 1.0224, -0.9944]],
[[ 0.6024, -0.2604, -0.8631, 0.8113, -0.3140],
[ 0.3487, -0.1941, -0.3955, -0.1719, -1.3734],
[ 0.2467, -0.4268, -1.3428, 0.7346, 1.0932],
[-0.5799, 0.0976, -1.9403, -0.2643, 0.7657]]])
'''
# dim=0,将第一个维度消除,也就是将两个[4*5]矩阵只保留一个,因此要在上下两个[4*5]的矩阵分别在对应位置上比较
y0 = torch.argmax(x, dim=0)
print(y0)
'''
tensor([[1, 1, 0, 1, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 1, 1],
[0, 0, 0, 0, 1]])
'''
# dim=1,将第二个维度消除,也就是将四个[2*5]矩阵只保留一个
y1 = torch.argmax(x, dim=1)
print(y1)
'''
tensor([[2, 1, 1, 3, 1],
[0, 3, 1, 0, 2]])
'''
y2 = torch.argmax(x, dim=2)
print(y2)
'''
tensor([[4, 1, 0, 1],
[3, 0, 4, 4]])
'''