- tf官方定义的top_k函数
def top_k(input, k=1, sorted=True, name=None): # pylint: disable=redefined-builtin
"""
找到张量(k>1)(向量,k=1)中每行最大的k个值,并返回值和对应的索引。
Args:
input: 1-D 向量 or 高维张量.
k: 0-D `int32` `Tensor`. 最大的k个值
sorted: 如果 true,从大到小排序;反之,从小到大。
name: 可选择的操作的名字
Returns:
values: 最大的k个值
indices: 值对应的索引
"""
return gen_nn_ops.top_kv2(input, k=k, sorted=sorted, name=name)
- 例子
# 选出每一行的最大的前两个数字
# 返回的是最大的k个数字,同时返回的是最大的k个数字在最后的一个维度的下标
import tensorflow as tf
import numpy as np
a = tf.constant(np.random.rand(4, 3))
b = tf.nn.top_k(a, k=2)
with tf.Session() as sess:
print(sess.run(a))
print(sess.run(b))
print结果
[[0.86173437 0.53670843 0.40364604]
[0.08836313 0.00513431 0.08697128]
[0.68419199 0.44204247 0.93483738]
[0.28708127 0.08542102 0.9693912 ]]
TopKV2(values=array([[0.86173437, 0.53670843],
[0.08836313, 0.08697128],
[0.93483738, 0.68419199],
[0.9693912 , 0.28708127]]), indices=array([[0, 1],
[0, 2],
[2, 0],
[2, 0]]))