tf实现用二维的索引从二维数组获取对应值 tf.gather_nd

a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
inds = tf.constant([[0, 2], [2, 1], [1, 1]])

#目的是实现 从[1,2,3]获取index为[0,2]的值也就是[1,3]作为第一行,
从[4,5,6]获取index为[2,1]的值也就是[6,5]作为第二行, 
从[7,8,9]获取index[1,1]的值作为第三行,也就是输出是
[[1 3]
 [6 5]
 [8 8]]






这种需求应该很常见,但是想通过look_up_table好像不行,以及想通过tf.gather_fn似乎可以但是也不好写

本文提供一种写法:

import tensorflow as tf

def gather_batch(v, inds):
    return tf.gather(v, inds)

def test2():
    a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    inds = tf.constant([[0, 2], [2, 1], [1, 1]])
    vs = tf.map_fn(fn=lambda x: gather_batch(x[:3], x[3:]), elems=tf.concat([a, inds], 1))

    with tf.Session() as sess:
        print(sess.run(vs))
 

if __name__ == '__main__':
    # test1()
    test2()

 

但是上面写法还是用了循环 会很慢 所以更好写法

def test3():
    a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    inds = tf.constant([[0, 2], [2, 1], [1, 1]])
    batch_size = inds.shape[0]
    cnt = inds.shape[1]
    left_inds = tf.tile(
        tf.expand_dims(tf.range(batch_size), 1),
        [1, cnt]
    )
    ind = tf.squeeze(
        tf.stack(
            [
                tf.expand_dims(left_inds, 2),
                tf.expand_dims(inds, 2),
            ],
            2
        )
        ,-1
    )

    vs = tf.gather_nd(a, ind)
    with tf.Session() as sess:
        # print(sess.run(ind))
        print(sess.run(vs))

 

 

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值