tf.gather(
params,
indices,
validate_indices=None,
name=None,
axis=0
)
根据索引从params坐标轴中收集切片。indices
是任何维度(通常是0-维或1-维)的整数张量。产生一个带有形状参数的输出张量,其中: params.shape[:axis] + indices.shape + params.shape[axis + 1:]。
# Scalar indices (output is rank(params) - 1).
output[a_0, ..., a_n, b_0, ..., b_n] =
params[a_0, ..., a_n, indices, b_0, ..., b_n]
# Vector indices (output is rank(params)).
output[a_0, ..., a_n, i, b_0, ..., b_n] =
params[a_0, ..., a_n, indices[i], b_0, ..., b_n]
# Higher rank indices (output is rank(params) + rank(indices) - 1).
output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] =
params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n]
注意,在CPU上,如果发现一个out of bound索引,将返回一个错误。在GPU上,如果发现一个out of bound索引,则在相应的输出值中存储一个0。
参数:
params
: 一个张量。用来收集值的张量。秩必须至少是axis
+ 1indices
: 一个张量。必须是下列类型之一:int32、int64。指数张量。必须在range [0, params.shape[axis]]中axis
: 张量,必须是下列类型之一:int32、int64。以参数为单位的轴,用来收集指标。默认为第一个维度。支持负索引- name: 操作的名称(可选)
返回值:
- 具有与params相同的类型的张量。
例:
import tensorflow as tf
a = tf.Variable([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]])
index_a = tf.Variable([0, 2])
b = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
index_b = tf.Variable([2, 4, 6, 8])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.gather(a, index_a)))
print(sess.run(tf.gather(b, index_b)))
Output:
--------------------
[[ 1 2 3 4 5]
[11 12 13 14 15]]
[3 5 7 9]
--------------------
原链接: https://tensorflow.google.cn/versions/r1.9/api_docs/python/tf/gather?hl=en