examples of scatter_nd_update

Several simple examples showing the usage of scatter_nd_update is provided in the tensor flow official document( accessible via https://www.usetensorflow.com/api_docs/python/tf/scatter_nd_update). However, this example only shows its usage on 1 dimensional tensor. It cost me quite a time to use it on multi dimensional tensor. Meanwhile, few examples about its usage on multi dimensional tensor can be found on the web. Following shows three examples I have successfully finished. Before that, first shows the example from the tensor flow official document.

Example 1:( from tensorflow documentation)

For example, say we want to update 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that update would look like this:

ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1] ,[7]])
updates = tf.constant([9, 10, 11, 12])
update = tf.scatter_nd_update(ref, indices, updates)
with tf.Session() as sess:
  print sess.run(update)

The resulting update to ref would look like this:

[1, 11, 3, 10, 9, 6, 7, 12]

Next are two examples written by me.

Example 2:

>>> ref = tf.Variable(tf.ones([2,3],tf.int32)) 
>>> updates = tf.constant([[0,0,0]])
>>> update = tf.scatter_nd_update(ref,[[0]],updates) 
>>> init = tf.global_variables_initializer()
>>> sess.run(init)
>>> sess.run(update)
array([[0, 0, 0],
   [1, 1, 1]], dtype=int32)

Example 3:

>>> ref = tf.Variable(tf.ones([2,3,3],tf.int32))
>>> indices = tf.constant([[0,1]])
#>>> updates = tf.constant([0,0,0]) #wrong
>>> updates = tf.constant([[0,0,0]])#correct
>>> update = tf.scatter_nd_update(ref,indices,updates) 
>>> init = tf.global_variables_initializer()
>>> sess.run(init)
>>> print(ref.eval())
[[[1 1 1]
  [1 1 1]
  [1 1 1]]

 [[1 1 1]
  [1 1 1]
  [1 1 1]]]
>>> sess.run(update)
array([[[1, 1, 1],
    [0, 0, 0],
    [1, 1, 1]],

   [[1, 1, 1],
    [1, 1, 1],
    [1, 1, 1]]], dtype=int32)

Example 4:

>>> updates = tf.constant([0])
>>> indices = tf.constant([[1,0,1]])
>>> init = tf.global_variables_initializer()
>>> sess.run(init)
>>> update = tf.scatter_nd_update(ref,indices,updates)
>>> sess.run(update)
array([[[1, 1, 1],
    [1, 1, 1],
    [0, 0, 0]],

   [[1, 0, 1],
    [1, 1, 1],
    [1, 1, 1]]], dtype=int32)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值