运行环境 : python 3.6.0
第三方库 : tensorflow 1.9.0
用 tensorflow 做 CNN_TEXT 文本分类时,看到这个API,然后去官网查了一下,再看了一下别的资料,算是明白它的处理方式了。
tf.argmax 可以认为就是 np.argmax 。应该说 tensorflow 使用 numpy 实现的这个 API 。简单的说 , tf.argmax 就是返回最大的那个值所对应的下标 。
tf.argmax 有个很重要的参数 axis , 像这样 tf.argmax(test, 0) 和 tf.argmax(test, 0) 有什么区别呢?
看示例 :
# -*- encoding: utf-8 -*-
import numpy as np
test = np.array([
[1, 2, 3],
[2, 3, 4],
[5, 4, 3],
[8, 7, 2]])
print("axis=0", np.argmax(test, 0))
print("axis=1", np.argmax(test, 1))
# 输出
"""
axis=0 [3 3 1]
axis=1 [2 2 0 0]
"""
有明白这是啥意思没 ?
答案是这样的 :
当 axis=0 的时候 , 这时候是求出 每一列 最大值 所对应的 索引数字 , 然后 将 这几个数字 输出一个 列表 并返回
当 axis=1 的时候 , 这时候是求出 每一行 最大值 所对应的 索引数字 , 然后 将 这几个数字 输出一个 列表 并返回
- axis = 0 :
- axis = 1 :
以上情况为数组长度一致的情况 , 如果数组禅古不一致 , axis 最大值 为 最小的数组长度 -1 , 超过则报错。
当不一致的时候 , axis = 0 的比较也就变成了每个数组的和的比较 。