重点说一下 argmax输出的是什么东西
test = np.array([[1, 2, 3],
[2, 3, 4],
[5, 4, 3],
[8, 7, 2]]) # 构建一个4X3的矩阵
out = np.argmax(test, axis=1) # axis=1:按行查找最大元素 axis=0:按列查找最大元素
print(out)
输出是:
[2 2 0 0] # 按行查找出的最大元素的索引号
输出值第一个元素为什么是2 ---> 按行索引
--->第一行[1, 2, 3] 最大值是3
--->3的索引值是2(数组,从0开始)
--->第二行[2, 3, 4] 最大值是4
--->4的索引值是2
--->第三行[5, 4, 3] 最大值是5
--->5的索引值是0
--->第四行[8, 7, 2] 最大值是8
--->8的索引值是0
----->所以输出值是[2 2 0 0]
按列索引同理
那么用处在哪呢?就我用到的地方就是使用MNIST数据集的时候,因为这个数据集用一个矩阵表示了一张图片表示数字几。
[[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]]
也就像这样一个矩阵。第一行在索引5处有个1,说明这行数据表示对应的图片是手写的数字5....以此类推
使用argmax按行取出最大值的索引值就可以得到一个数字的矩阵了。
actuals = np.argmax(test, axis=1)
[5 5 2 3 7 3] # 使用argmax对上面的数据集的输出