返回最内层一维(也就是最后一维)的前k个最大的元素,以及它所对应的索引。返回的值除了最后一维维度为k之外,其它维度维持原样。
#coding:utf-8
import tensorflow as tf
import numpy as np
data = np.array([[[1, 4, 1], [5, 2, 8]],
[[3, 2, 3], [9, 4, 4]],
[[1, 5, 5], [6, 7, 6]]])
data=tf.constant(data)
top_k=tf.nn.top_k(data,2)
print top_k.values
with tf.Session() as sess:
print sess.run(top_k)
print sess.run(top_k.values)
print sess.run(top_k.indices[0,0,0])
结果:
Tensor("TopKV2:0", shape=(3, 2, 2), dtype=int64)
TopKV2(values=array([[[4, 1],
[8, 5]],
[[3, 3],
[9, 4]],
[[5, 5],
[7, 6]]]), indices=array([[[1, 0],
[2, 0]],
[[0, 2],
[0, 1]],
[[1, 2],
[1, 0]]], dtype=int32))
[[[4 1]
[8 5]]
[[3 3]
[9 4]]
[[5 5]
[7 6]]]
1