tf.argmax()
作用:寻找张量中的最大值,并返回起下标,主要参数为:input, axis
input:输入的张量
axis:决定了在第几纬度找最大值,例如,axis = 0,决定了第一纬度找最大值,axis = 1 ,第二维度找最大值(此时第一维度的大小不变),axis 不可以超过张量的最大维度(shape = (1,5),那么axis 选择只有0和1,shape为(1,2,5),那么axis只有选择0,1,2,依次类推)
返回:如果axis比较的时最后一个纬度,那么仅返回最后维度的下标,之前的属性不变,如果axis比较的非最后一个纬度,那么返回值为 之前纬度属性不变,比较的纬度消失,之后的纬度也不变,如:
shape(2,3,5)
axis = 0, --> new shape = (3,5)
axis = 1, --> new shape = (2,5)
axis = 2, --> new shape = (2,3)
shape(1,5),axis = 0,代码:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tensor = sess.run(tf.random_uniform((1,5)))
print(tensor)
print(sess.run(tf.argmax(tensor,0)))
结果,因为第一维度只有属性为1,因此,当只有一个子项即第二维数据进行比对时,结果[0,0,0,0,0]:
[[0.5847366 0.554988 0.20085955 0.9588026 0.55059564]]
[0 0 0 0 0]
shape(2,5),axis = 0,代码:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tensor = sess.run(tf.random_uniform((2,5)))
print(tensor)
print(sess.run(tf.argmax(tensor,0)))
结果,第一维度属性为2,有两个子项,因此,当axis=0时,第一维度数据进行互相比较,结果为[1 1 1 0 1],返回的是一组对应维度数据的比较后的下标,如0.32242548比0.8556894小,因此第一个下标为1,依次比较类推:
[[0.32242548 0.35847282 0.5762752 0.91375875 0.2024641 ]
[0.8556894 0.3678167 0.64506817 0.04399645 0.58732665]]
[1 1 1 0 1]
当shape为(3,5)时,axis = 0,结果如下:
[[0.5915636 0.8889439 0.72812104 0.94132113 0.0546453 ]
[0.8918692 0.19301915 0.22154689 0.04286265 0.52505565]
[0.8999599 0.02016377 0.2340877 0.22576177 0.31773365]]
[2 0 0 0 1]