tensorflow 的argmax

Tensorflow 的 argmax 接口可以返回一阶以上张量最大值所对应的分量索引。

比如 tensorflow.argmax ( [1,2,3,10,1])  返回 10对应的索引3 。

对于超过一阶的张量,需要指定要搜索的是第几维的元素,这个维是以0开始的。比如

对于a = [[[10.0,25.0,3.0,4.0] , ]  这样一个张量,想找最里层的元素的最大值,可以tensorflow.argmax ( a , 2 ) 来获取。

下面是分别对三阶和一阶张量找最大值的例子。

import tensorflow as tf


def findMaxFromRank3() :
    """
    找出一个三阶张量第三维(索引是2)的最大值
    """
    a =tf.Variable( [[[10.0,25.0,3.0,4.0] , [10.0,251.0,35.0,4.0]] , [[100.0,25.0,3.0,4.0] , [10.0,250.0,3500.0,4.0]] ] )
    b = tf.argmax(  a  , 2 )
    se = tf.Session()
    init = tf.global_variables_initializer()
    se.run( init ) 
    r = se.run( b )
    ar = se.run( a )
    se.close()
    print( r )
    print ( type (r )  )
    print( ar[0][0][r[0][0]] , "," , ar[0][1][ r[0][1] ] )
    print( ar[1][0][r[1][0]] , "," , ar[1][1][ r[1][1] ] )


def findMaxFromRank1() :
    """
    找出一个一阶张量第一维(索引是0)的最大值
    """
    a =tf.Variable( [10.0,250.0,3510.0,4.0 ] )
    b = tf.argmax(  a  , 0 )
    se = tf.Session()
    init = tf.global_variables_initializer()
    se.run( init ) 
    r = se.run( b )
    ar = se.run( a )
    se.close()
    print( r )
    print ( type (r )  )
    print( ar[r] )
    


findMaxFromRank3()
findMaxFromRank1()  

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值