有时我们会遇到tensor域下的数组排序,比如按照一定规则对输入排序。
import tensorflow as tf
import numpy as np
a = tf.placeholder(tf.int32, shape=(3,2))
# bb = tf.constant(a) # the array
reordered = tf.gather(a, tf.nn.top_k(a[:, 0], k=3).indices) # 按照输入的第一个维度排序,选取top3的值
value_ = tf.nn.top_k(a[:, 0], k=3).values
indices_ = tf.nn.top_k(a[:, 0], k=3).indices
'''
tf.nn.top_k(
input,
k=1,
sorted=True,
name=None
)
top_k返回值:
top_k(...).values: The k largest elements along each last dimensional slice.(返回对应的值)
top_k(...).indices: The indices of values within the last dimension of input(返回索引)
-----------------
tf.gather(
params,
indices,
validate_indices=None,
name=None,
axis=0
)
根据对应索引indices把对应元素取出来
'''
feed_dict = {a:np.array([[1, 2], [3, 4], [2, 2]])}
sess = tf.Session()
_in, out, v, i = sess.run([a, reordered, value_, indices_], feed_dict=feed_dict)
print('in:\n',_in, '\nout:\n', out, '\nvalue:\n', v, '\nindices:\n', i)
'''
>>>in:
[[1 2]
[3 4]
[2 2]]
out:
[[3 4]
[2 2]
[1 2]]
value:
[3 2 1]
indices:
[1 2 0]
'''