TensorFlow:top_k()和区别in_top_k()

      从字面上可以大致了解这两个函数的自用,但具体的作用,还需要查看源码,及编程实现,这样掌握和了解的比较透彻。

in_top_k(predictions, targets, k, name=None):
r"""Says whether the targets are in the top `K` predictions.
     首先说作用:返回一个布尔向量,说明目标值是否存在于预测值之中。

  参数:predicitions:输入的输入tensor,数据类型必须是以下之一:float32、float64、int32、int64、uint8、int16、int8。

              targets:tensor,数据类型是 int32 。每行目标值所在的位置,如果predicitions某行的最大值位置为n, n==targets,则该行的返回值为True

              k: 最大值的个数,k值关系返回矩阵的结果。如果k=1,最大值的位置是否在targets处。

例如:

    input = tf.constant(np.random.rand(3,4), tf.float32)
    k = 1 
    output = tf.nn.in_top_k(input, [3,3,3], k)#每一行的最大值都在第3列(0为第一列)
    with tf.Session() as sess:
        print(sess.run(input))
        print(sess.run(output))
输出:

[[ 0.46714601  0.92652822  0.16808732  0.44906664]#最大值在第1列,返回为false
 [ 0.03874864  0.55331773  0.32944077  0.84536946]#最大值在第3列,返回为false
 [ 0.80283058  0.63945484  0.07212774  0.27699497]]最大值在第1列,返回为false
[False  True False]
    如果k=3呢?

[[ 0.10950958  0.09272877  0.65265322  0.49682239]#最大值在第二列,第二个次大值在第3列,返回true
 [ 0.70769322  0.00581258  0.40589932  0.7010119 ]
 [ 0.18922156  0.57137531  0.14654963  0.26083347]]
[ True  True  True]

top_k(input, k=1, sorted=True, name=None):
"""Finds values and indices of the `k` largest entries for the last dimension.
   作用:返回 input 中每行最大k 个数的值,并且返回它们所在位置的索引。

   参数:input:输入的输入tensor,数据类型必须是以下之一:float32、float64、int32、int64、uint8、int16、int8。

例如:

    input = tf.constant(np.random.rand(3,4), tf.float32)
    k = 1  #targets对应的索引是否在最大的前k(2)个数据中
    output = tf.nn.top_k(input, k)
    with tf.Session() as sess:
        print(sess.run(input))
        print(sess.run(output))
   输出:

TopKV2(values=array([[ 0.87421292],
       [ 0.96415848],
       [ 0.54568386]], dtype=float32), indices=array([[3],
       [1],
       [2]]))#每一行的最大值,与最大值所在的位置。
如果k=2

[[ 0.98679858  0.09883292  0.19342254  0.20967487]
 [ 0.12573749  0.60547918  0.54529655  0.08391853]
 [ 0.80146015  0.38433447  0.68723434  0.04177354]]
TopKV2(values=array([[ 0.98679858,  0.20967487],
       [ 0.60547918,  0.54529655],
       [ 0.80146015,  0.68723434]], dtype=float32), indices=array([[0, 3],
       [1, 2],
       [0, 2]]))#每行中,前两个最大值,及它们所在的位置。




  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值