函数原型
tf.gather(
params, indices, validate_indices=None, axis=None, batch_dims=0, name=None
)
函数说明
根据索引从参数轴上收集切片。
参数params表示用来收集数值的一个张量。
参数indices表示待收集张量的索引。
参数axis表示索引所对应的轴。(官网说不推荐使用,而应多用参数batch_dims)
参数batch_dims表示批处理的维度。
函数使用
1、一维张量的切片
>>> params = [1, 3, 5, 6, 8]
>>> indices = [0, 1, 2]
>>> tf.gather(params, indices)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 3, 5])>
2、二维张量的切片
>>> params = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> params
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])>
>>> indices = tf.constant([0, 2])
>>> tf.gather(params, indices)
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[1, 2, 3],
[7, 8, 9]])>
>>> params = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> indices = tf.constant([[0, 1], [1, 2], [0, 2]])
>>> tf.gather(params, indices, batch_dims=1)
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[1, 2],
[5, 6],
[7, 9]])>
3、三维张量的切片
>>> params = tf.constant([[[1, 1], [2, 2], [3, 3]],
[[4, 4], [5, 5], [6, 6]],
[[7, 7], [8, 8], [9, 9]]])
>>> indices = tf.constant([0, 2])
>>> tf.gather(params, indices)
<tf.Tensor: shape=(2, 3, 2), dtype=int32, numpy=
array([[[1, 1],
[2, 2],
[3, 3]],
[[7, 7],
[8, 8],
[9, 9]]])>
>>> params = tf.constant([[[1, 1], [2, 2], [3, 3]],
[[4, 4], [5, 5], [6, 6]],
[[7, 7], [8, 8], [9, 9]]])
>>> indices = tf.constant([[0, 1], [1, 2], [0, 2]])
>>> indices
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[0, 1],
[1, 2],
[0, 2]])>
>>> tf.gather(params, indices)
<tf.Tensor: shape=(3, 2, 3, 2), dtype=int32, numpy=
array([[[[1, 1],
[2, 2],
[3, 3]],
[[4, 4],
[5, 5],
[6, 6]]],
[[[4, 4],
[5, 5],
[6, 6]],
[[7, 7],
[8, 8],
[9, 9]]],
[[[1, 1],
[2, 2],
[3, 3]],
[[7, 7],
[8, 8],
[9, 9]]]])>
>>> params = tf.constant([[[1, 1], [2, 2], [3, 3]],
[[4, 4], [5, 5], [6, 6]],
[[7, 7], [8, 8], [9, 9]]])
>>> indices = tf.constant([[0, 1], [1, 2], [0, 2]])
# 当设置batch_dims=1时表示进行批处理的维度为1
# 此时再根据索引进行切片,注意有batch_dims和没有batch_dims的区别。
>>> tf.gather(params, indices, batch_dims=1)
<tf.Tensor: shape=(3, 2, 2), dtype=int32, numpy=
array([[[1, 1],
[2, 2]],
[[5, 5],
[6, 6]],
[[7, 7],
[9, 9]]])>