tf.gather, tf.gather_nd和tf.slice

∇ \nabla 联系方式:

e-mail: FesianXu@gmail.com

QQ: 973926198

github: https://github.com/FesianXu

知乎专栏: 计算机视觉/计算机图形理论与应用

微信公众号
qrcode


tensorflowpytorch中,都有一些函数是用于对张量进行索引的,因为pytorch是基于动态图的,它的张量可以动态的进行操作,不需要在对话(session)中进行,因此可能采用这些方法的必要性没有那么足,但是因为tensorflow是基于静态图的,因此如果在构建计算图的过程中需要对张量进行索引,一般是不能进行简单的slice的,因此引入这些函数就很有必要了,我们这里集中介绍下这些函数。

tf.gather

tf.gather(
    params,  # 需要被索引的张量
    indices,  # 索引
    validate_indices=None,
    name=None,
    axis=0
)

其作用很简单,就是根据提供的indicesaxis这个轴上对params进行索引,拼接成一个新的张量,其示意图如下所示:
gather_tf
其中的indices可以是标量张量,向量张量或者是更高阶的张量,但是其元素都必须是整数类型,比如int32,int64等,而且注意检查不要越界了,因为如果越界了,如果使用的CPU,则会报错,如果在GPU上进行操作的,那么相应的输出值将会被置为0,而不会报错,因此认真检查是否越界。
(Note that on CPU, if an out of bound index is found, an error is returned. On GPU, if an out of bound index is found, a 0 is stored in the corresponding output value.)

给个代码例子:

params = tf.constant(np.random.normal(size=(5,6)))
index = tf.constant([1,3, 6])
out = tf.gather(params, index)
with tf.Session() as sess:
    print(sess.run(out).shape)
    print(out.eval())
    print('-----------------------------------------')
    print(params.eval())

# ------ output
(3, 6)
[[ 0.33662994 -0.33750725  1.5710436   0.78380586 -0.39254751  0.15269514]
 [ 0.35278309  0.59719792 -0.86332759 -0.88997637  0.18247123 -1.46388103]
 [ 0.          0.          0.          0.          0.          0.        ]]
-----------------------------------------
[[ 0.98987616 -1.62505873  0.86819471  2.3316706  -1.17362956  0.57264237]
 [ 0.33662994 -0.33750725  1.5710436   0.78380586 -0.39254751  0.15269514]
 [ 0.62367272  0.83599086  1.60433217 -1.4376806   0.61179675  2.07843436]
 [ 0.35278309  0.59719792 -0.86332759 -0.88997637  0.18247123 -1.46388103]
 [-1.76421914  0.85159247  0.18397565  0.69460119 -0.0651779   0.14643597]]

可以看到越界的地方全部置为了0。

params = tf.constant(np.random.normal(size=(5,6)))
index = tf.constant([[[0,0],[2,0]]])
out = tf.gather(params, index)
with tf.Session() as sess:
    print(sess.run(out).shape)
    print(index.eval().shape)
    print(out.eval())
    print('-----------------------------------------')
    print(params.eval())

# ----- output
(1, 2, 2, 6)
(1, 2, 2)
[[[[-1.52187772 -0.86557915  1.15487405  0.94154816 -0.55900061
     0.28165203]
   [-1.52187772 -0.86557915  1.15487405  0.94154816 -0.55900061
     0.28165203]]

  [[-0.18865467 -0.00902652  1.03637009 -1.14683883  0.20450098
     0.11977745]
   [-1.52187772 -0.86557915  1.15487405  0.94154816 -0.55900061
     0.28165203]]]]
-----------------------------------------
[[-1.52187772 -0.86557915  1.15487405  0.94154816 -0.55900061  0.28165203]
 [-1.55437505  0.1800381  -0.06461522  0.85310958 -1.25240021 -0.64335453]
 [-0.18865467 -0.00902652  1.03637009 -1.14683883  0.20450098  0.11977745]
 [-0.46992654 -0.07142709  0.35710407  0.24846814 -0.13490197 -1.31201887]
 [-0.24019091  0.54458599  0.3623213  -0.10113704 -1.15469468 -0.33837456]]

如果索引indices为矩阵或者更高阶的张量,其索引出来的尺寸就根据索引而定了,如上例子所示。

tf.gather_nd

tf.gather_nd类似于tf.gather,不过后者只能在一个维度上进行索引,而前者可以在多个维度上进行索引,其API为:

tf.gather_nd(
    params, # 待索引输入张量
    indices, # 索引
    name=None
)

例子如:

params = tf.constant(np.random.normal(size=(5,6)))
index = tf.constant([[[0,0],[2,0]]])
out = tf.gather_nd(params, index)
with tf.Session() as sess:
    print(sess.run(out).shape)
    print(index.eval().shape)
    print(out.eval())
    print('-----------------------------------------')
    print(params.eval())
# -------------- output
(1, 2)
(1, 2, 2)
[[-0.20841454 -0.12850639]]
-----------------------------------------
[[-0.20841454  0.13953242  2.26913464  1.60226729  1.51834021 -0.41044839]
 [-1.3215913  -0.15647683  0.17277501 -0.60070571  0.80504465  0.85216738]
 [-0.12850639 -0.34113574  0.16567085 -1.61315021 -0.29158798  0.03378417]
 [ 0.67928455  0.16899565  1.0065189  -1.05949544  1.87377367  0.71542472]
 [-0.2858821  -1.29338336 -0.41112208 -0.21772644  0.17873804 -0.0703971 ]]

tf.slice

用来进行切片操作,实现在python中的a[:,2:3,5:6]类似的操作,例子如:

t = tf.constant([[[1, 1, 1], [2, 2, 2]],
                 [[3, 3, 3], [4, 4, 4]],
                 [[5, 5, 5], [6, 6, 6]]])
# shape = (3,2,3)
tf.slice(t, [1, 0, 0], [1, 1, 3])  # [[[3, 3, 3]]], shape (1,1,3)
tf.slice(t, [1, 0, 0], [1, 2, 3])  # [[[3, 3, 3],
                                   #   [4, 4, 4]]] shape (1,2,3)
tf.slice(t, [1, 0, 0], [2, 1, 3])  # [[[3, 3, 3]],
                                   #  [[5, 5, 5]]] shape (2,1,3)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

FesianXu

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值