tf.scatter_update和tf.batch_scatter_update

tf.scatter_update

函数定义:

tf.scatter_update(
    ref,
    indices,
    updates,
    use_locking=True,
    name=None
)

需要说明的是,updates.shape = [*indices.shape, *ref.shape[1:]], upadtes的shape不一定与ref的shape相等。

测试:

a = tf.Variable([[1, 2, 3, 4], [5, 6, 7, 8]])
indices = [[0, 1], [1, 0]]
updates = [[[1, 1, 1, 1], [2, 3, 4, 5]], [[2, 2, 2, 2], [3, 3, 3, 3]]]
b = tf.scatter_update(a, indices, updates)

sess = tf.InteractiveSession()
print(sess.run([b]))

结果:

[array([[3, 3, 3, 3],
        [2, 2, 2, 2]])]

结果说明:

重点是解读indices的含义

indices的值指定ref中被替换的对象,以上面测试为例,indices中的0,1分别指定a中的a[0]、a[1]将被替换。

indices的值对应的index指定ref中的相应替代值为update[index],以上为例,indices[0][0]为1,则a[1]将被替换为updates[0][0]。

此外indices中重复出现的值将被多次替换,至于结果是不确定的。


tf.batch_scatter_update

函数定义:

与tf.scatter_update相同

测试:

d = tf.Variable([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
indices = [[1, 1], [1, 0]]
updates = [[[1, 1], [2, 2]], [[3, 3], [4, 4]]]
e = tf.batch_scatter_update(d, indices, updates)
sess.run(tf.global_variables_initializer())
print(sess.run([e]))

结果:

[array([[[0, 1],
         [2, 2]],
 
        [[4, 4],
         [3, 3]]])]

结果说明:

tf.batch_scatter_update与tf.scatter_update类似,只是在进行值替换时,tf.scatter_update中ref替换对象由indices的值指定,而在tf.batch_scatter_update中由indices的值和对应的index[:-1]共同指定。

以上面为例 :indices[0][1]为1,则d[0][1]]替换为updates[0][1],其中d[0][1]中的1是indices[0][1]的值。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值