Several simple examples showing the usage of scatter_nd_update is provided in the tensor flow official document( accessible via https://www.usetensorflow.com/api_docs/python/tf/scatter_nd_update). However, this example only shows its usage on 1 dimensional tensor. It cost me quite a time to use it on multi dimensional tensor. Meanwhile, few examples about its usage on multi dimensional tensor can be found on the web. Following shows three examples I have successfully finished. Before that, first shows the example from the tensor flow official document.
Example 1:( from tensorflow documentation)
For example, say we want to update 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that update would look like this:
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1] ,[7]])
updates = tf.constant([9, 10, 11, 12])
update = tf.scatter_nd_update(ref, indices, updates)
with tf.Session() as sess:
print sess.run(update)
The resulting update to ref would look like this:
[1, 11, 3, 10, 9, 6, 7, 12]
Next are two examples written by me.
Example 2:
>>> ref = tf.Variable(tf.ones([2,3],tf.int32))
>>> updates = tf.constant([[0,0,0]])
>>> update = tf.scatter_nd_update(ref,[[0]],updates)
>>> init = tf.global_variables_initializer()
>>> sess.run(init)
>>> sess.run(update)
array([[0, 0, 0],
[1, 1, 1]], dtype=int32)
Example 3:
>>> ref = tf.Variable(tf.ones([2,3,3],tf.int32))
>>> indices = tf.constant([[0,1]])
#>>> updates = tf.constant([0,0,0]) #wrong
>>> updates = tf.constant([[0,0,0]])#correct
>>> update = tf.scatter_nd_update(ref,indices,updates)
>>> init = tf.global_variables_initializer()
>>> sess.run(init)
>>> print(ref.eval())
[[[1 1 1]
[1 1 1]
[1 1 1]]
[[1 1 1]
[1 1 1]
[1 1 1]]]
>>> sess.run(update)
array([[[1, 1, 1],
[0, 0, 0],
[1, 1, 1]],
[[1, 1, 1],
[1, 1, 1],
[1, 1, 1]]], dtype=int32)
Example 4:
>>> updates = tf.constant([0])
>>> indices = tf.constant([[1,0,1]])
>>> init = tf.global_variables_initializer()
>>> sess.run(init)
>>> update = tf.scatter_nd_update(ref,indices,updates)
>>> sess.run(update)
array([[[1, 1, 1],
[1, 1, 1],
[0, 0, 0]],
[[1, 0, 1],
[1, 1, 1],
[1, 1, 1]]], dtype=int32)