Keras对多维Tensor的argmax()解析

基础理论

argmax中的axis参数表示在该维度上比较各元素。并且,张量各维度对换,不影响在该维度取argmax()的结果。

a = tf.constant([[[1, 2, 3], [3, 2, 2]], [[10, 11, 12], [4, 5, 6]]])  # a是个2*2*3的tensor
b = tf.argmax(a, axis=1, output_type=tf.int32)
at = tf.transpose(a, [0, 2, 1])  # 将DIM1和DIM2对换,at变成了2*3*2
c = tf.argmax(at, axis=2, output_type=tf.int32)

with tf.Session() as sess:
    print(sess.run(b))
    print(sess.run(c))
print("")

输出结果

[[1 0 0]
 [0 0 0]]
[[1 0 0]
 [0 0 0]]

tf.argmax(a, axis=1)相当于是在a的DIM1上比较,也就是1和3,2和2,3和2,以及10和4,11和5,12和6比较。如果改成tf.argmax(a, axis=0),相当于是a在DIM0上比较,也就是1和10,2和11,3和12,以此类推。

应用场景

比如,目前有分子特征张量input,维度为SampNum × AtomNum × FeatNum,那么,argmax(input, axis=1)将得到维度为SampNum × FeatNum的Tensor,其元素表示各样本分子的各种向量值表征、同种向量的最大者所对应的原子id。
同样的,再来一个,argmax(input, axis=2)将得到维度为SampNum × AtomNum的Tensor,其元素表示各样本分子的各原子的FeatNum种特征中,最大的特征值所对应的特征id。

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值