一般用法见:https://blog.csdn.net/zlrai5895/article/details/80551056
简单地说,scatter_nd
根据indices来将updates中的元素“散布”到形状为shape并初始化为0的tensor中(称其为output),所以一般来说,shape
的某个维度大小会比updates
的对应维度的大小更大(小的tensor散布到大的tensor中),如下图所示:
然而在某个Github项目中见到一个特殊用法,shape
的某个维度大小比updates
的对应维度的大小更小,此时很难理解怎么“scatter(散开)”updates中的元素到一个更小的tensor中。如下面的例子,将1到12分散到只有4个位置的tensor中(解释见后文)。
indices = tf.constant([[0],[1],[2],[3],[0],[1],[2],[3],[0],[1],[2],[3]])
updates = tf.constant([1,2,3,4,5,6,7,8,9,10,11,12])
shape = tf.constant([4])
scatter = tf.scatter_nd(indices, updates, shape)
with tf.Session() as sess:
print(sess.run(scatter))
# output is [15 18 21 24], 15=1+5+9, 18=2+6+10 and so on
实际上,由于updates的维度更大,故会有多个元素会被分到shape的同一个位置,而最后的结果是将分到同一个位置的tensor加起来,作为该位置的最终结果。例如,上面例子中,updates中的1
、5
、9
对应indices的下标都是0,因此都被分到了结果中的位置0,因此15=1+5+9才是该位置的最终结果。同理,结果中下标为1的18是由indices中下标为1对应的updates中的元素的求和,即18=2+6+10。其他同理。