提示:如果本文对您有帮助,请点赞支持!
前言
在写AI算法的Demo时,偶然间出现了一个bug,发现是我不小心将tf.argmax()写成了np.argmax(),正好闲来无事,辨析下两个API的使用
一、np.argmax()的使用
np.argmax()是Python的第三方库numpy中的一个常见API,经常用来获取数组中的最大值所在的索引。所以使用该API要先导入该库:
import numpy as np
该API的完全定义如下:
def argmax(a, axis=None, out=None)# 第1个参数是输入的np数组;第2个参数是所获取的轴,取值为整数0,1等;第3个参数是输出的np数组,一般用不到
# 定义一个一维数组
y1 = np.array([1, 2, 3, 7, 8, 9])
print("result: {}".format(np.argmax(y1,axis=None))) #result: 5
print("result: {}".format(np.argmax(y1, axis=0))) # result: 5
print("result: {}".format(np.argmax(y1, axis=1))) # numpy.AxisError: axis 1 is out of bounds for array of dimension 1
接下来定义一个二维数组来进行测试:
# 定义一个二维数组
y2 = np.array([[1, 9, 3], [7, 8, 9]])
print("result: {}".format(np.argmax(y2,axis=None))) # result: 1
print("result: {}".format(np.argmax(y2,axis=0))) # result: [1 0 1]
print("result: {}".format(np.argmax(y2,axis=1))) # result: [1 2]
最后定义一个三维数组来进行测试:
# 定义一个三维数组
y3 = np.array([[[1, 9, 3],
[7, 8, 9]],
[[2, 6, 3],
[7, 6, 9]],
])
print("result: {}".format(np.argmax(y3, axis=None))) # result: 1返回最大值所在的索引,类型为整型,如果有多个相同的最大值,则返回第一个
print("result: {}".format(np.argmax(y3, axis=0))) # result: [[1 0 0][0 0 0]] z最大的
print("result: {}".format(np.argmax(y3, axis=1))) # result: [[1 0 1] [1 0 1]]y最大的
print("result: {}".format(np.argmax(y3, axis=2))) # result: [[1 2 ] [1 2]] x最大的
总结上述实例,我们可以总结出如下规律:
当axis=None时,我们将n维数组降为一维数组,取该数组里面最大值的索引,若存在多个最大值则返回第一个最大值所在的索引,所以返回的是一个0维的整型数字;
例如上述的三维数组被看成[1, 9, 3,7, 8, 9,2, 6, 3,7, 6, 9],第一个最大值所在索引是1
当axis=0时,取第0维中的每个元素的对应位置取最大值索引,若存在多个最大值则返回第一个最大值所在的索引,返回的是n-1维数组;
例如上述的三维数组在第0维方向上,我们的作用对象是[[1, 9, 3],[7, 8, 9]]和[[2, 6, 3],[7, 6, 9]],此时作用后则变成了 [[1 0 0][0 0 0]]
当axis=1时,取第1维中的每个元素的对应位置取最大值索引,若存在多个最大值则返回第一个最大值所在的索引,返回的是n-1维数组;
例如上述的三维数组在第1维方向的,我们的第1个作用对象是[1, 9, 3]和[7, 8, 9],其作用后是[1 0 1];第2个作用对象是[2, 6, 3],[7, 6, 9],其作用后是[1 0 1],则最终结果变成了 [[1 0 1][0 0 1]]
当axis=2时,取第2维中的每个元素的对应位置取最大值索引,若存在多个最大值则返回第一个最大值所在的索引,返回的是n-1维数组;
例如上述的三维数组在第2维方向的,则我们的第1个作用对象是[1, 9, 3],其作用结果是1;第1个作用对象是[7, 8, 9],其作用结果是2,此时合起来是[1,2];第3个作用对象是[2, 6, 3],其作用结果是1;第4个作用对象是[7, 6, 9],其作用结果是2,此时合起来是[1,2];;则最终结果变成了 [[1 2][1 2]]
当axis>多维数组的秩-1时,则报错:numpy.AxisError: axis 1 is out of bounds for array of dimension 1
二、tf.argmax()的使用
tf.argmax()是TensorFlow的一个常见API,也是经常用来获取数组中的最大值所在的索引,其内部也是用numpy来实现的。所以使用该API要先导入该库:
import tensorflow as tf
该API的完全定义如下:
argmax(input,axis=None,name=None,dimension=None,output_type=dtypes.int64)# 第1个参数是输入的tf张量;第2个参数是所获取的轴,取值为整数0,1等;第3个参数是返回的tf张量的名字;剩下的参数一般不常用
该API返回的是tf张量,这一点和numpy返回np数组不同。
因为返回的是tf张量,所以直接输出tf张量,只会查看该张量对象的一些基本信息,例如:
print("result: {}".format(tf.argmax(y1))) # result: Tensor("ArgMax_2:0", shape=(), dtype=int64)
在TensorFlow中要用会话Session来输出tf张量,所以正确的打印如下:
with tf.Session() as sess:
result = sess.run(tf.argmax(y1, 0))
print("result: {}".format(result))# result: 5
其他用法和上述的np.argmax()相同。