基础理论
argmax中的axis参数表示在该维度上比较各元素。并且,张量各维度对换,不影响在该维度取argmax()的结果。
a = tf.constant([[[1, 2, 3], [3, 2, 2]], [[10, 11, 12], [4, 5, 6]]]) # a是个2*2*3的tensor
b = tf.argmax(a, axis=1, output_type=tf.int32)
at = tf.transpose(a, [0, 2, 1]) # 将DIM1和DIM2对换,at变成了2*3*2
c = tf.argmax(at, axis=2, output_type=tf.int32)
with tf.Session() as sess:
print(sess.run(b))
print(sess.run(c))
print("")
输出结果
[[1 0 0]
[0 0 0]]
[[1 0 0]
[0 0 0]]
tf.argmax(a, axis=1)相当于是在a的DIM1上比较,也就是1和3,2和2,3和2,以及10和4,11和5,12和6比较。如果改成tf.argmax(a, axis=0),相当于是a在DIM0上比较,也就是1和10,2和11,3和12,以此类推。
应用场景
比如,目前有分子特征张量input,维度为SampNum × AtomNum × FeatNum,那么,argmax(input, axis=1)将得到维度为SampNum × FeatNum的Tensor,其元素表示各样本分子的各种向量值表征、同种向量的最大者所对应的原子id。
同样的,再来一个,argmax(input, axis=2)将得到维度为SampNum × AtomNum的Tensor,其元素表示各样本分子的各原子的FeatNum种特征中,最大的特征值所对应的特征id。