tf.argmax返回的是最大值的坐标,它与np的argmax是一样的。
参数
- input:输入Tensor
- axis:0表示按列,1表示按行
- name:名称
- dimension:和axis功能一样,默认axis取值优先。新加的字段
关于返回shape比较容易混淆,可以简单地记为axis指向输入input的shape的第几维,返回值的shape就是输入shape减去这一维
如
test = np.array([[1, 2, 3], [2, 3, 4], [5, 4, 3], [8, 7, 2]])
np.argmax(test, 0) #输出:array([3, 3, 1])
np.argmax(test, 1) #输出:array([2, 2, 0, 0])
test的shape为(4,3),当axis=0时,输出为(3);当axis=1时,输出为(4)
在语义分割的输出层中一般会讲dimension设置为最后一维。
例:当前一层为[None,224,224,10]时,10的意义是语义分割的类别数。可以用tf.argmax(input, dimension=3)来转化,此时输出shape为[None,224,224],数值为0~9。