来源https://sodocumentation.net/tensorflow/topic/2511/tensor-indexing
为Tensor指定非连续的部分分片,并提取相关特征值
使用到的函数为tf.gather()
给出对应的index,以及对象input,使用tf.gather(input,index)取出对应的tensor
示例
# data is [[0, 1, 2, 3, 4, 5],
# [6, 7, 8, 9, 10, 11],
# ...
# [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
params = tf.constant(data)
indices = tf.constant([0, 3])
selected = tf.gather(params, indices)
结果
[[ 0 1 2 3 4 5]
[18 19 20 21 22 23]]
可以根据不同的需求构建index,获取不同shape的tensor,更多示例可参见上述链接