tf.gather

参考  tf.gather - 云+社区 - 腾讯云

tf.gather(
    params,
    indices,
    validate_indices=None,
    name=None,
    axis=0
)

根据索引从params坐标轴中收集切片。indices是任何维度(通常是0-维或1-维)的整数张量。产生一个带有形状参数的输出张量,其中:  params.shape[:axis] + indices.shape + params.shape[axis + 1:]。

# Scalar indices (output is rank(params) - 1).
output[a_0, ..., a_n, b_0, ..., b_n] =
   params[a_0, ..., a_n, indices, b_0, ..., b_n]

# Vector indices (output is rank(params)).
output[a_0, ..., a_n, i, b_0, ..., b_n] =
   params[a_0, ..., a_n, indices[i], b_0, ..., b_n]

# Higher rank indices (output is rank(params) + rank(indices) - 1).
output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] =
   params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n]

                

注意,在CPU上,如果发现一个out of bound索引,将返回一个错误。在GPU上,如果发现一个out of bound索引,则在相应的输出值中存储一个0。

参数:

  • params:  一个张量。用来收集值的张量。秩必须至少是axis + 1
  • indices:  一个张量。必须是下列类型之一:int32、int64。指数张量。必须在range [0, params.shape[axis]]中
  • axis:  张量,必须是下列类型之一:int32、int64。以参数为单位的轴,用来收集指标。默认为第一个维度。支持负索引
  • name:  操作的名称(可选)

返回值:

  • 具有与params相同的类型的张量。

例:

import tensorflow as tf

a = tf.Variable([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]])
index_a = tf.Variable([0, 2])

b = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
index_b = tf.Variable([2, 4, 6, 8])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.gather(a, index_a)))
    print(sess.run(tf.gather(b, index_b)))


Output:
--------------------
[[ 1  2  3  4  5]
 [11 12 13 14 15]]
[3 5 7 9]
--------------------

原链接:  https://tensorflow.google.cn/versions/r1.9/api_docs/python/tf/gather?hl=en

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Wanderer001

ROIAlign原理

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值