indices里的二维数组表明更新的位置,如[[0,0,0], [0,2,1]]指更新第1面第一行第一个和第一面第三行第二个元素,updates里的列表中的元素顺序即为更改位置的对应元素.更改位置表示几维,updates列表里的元素就应该设置为几维.
1
indices = tf.constant([[0,0,0], [0,2,1]])
updates = tf.constant([5,6])
shape = tf.constant([4, 4, 4])
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter)
tf.Tensor(
[[[5 0 0 0]
[0 0 0 0]
[0 6 0 0]
[0 0 0 0]]
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]], shape=(4, 4, 4), dtype=int32)
2
indices = tf.constant([[0,0], [0,2]])
updates = tf.constant([[5,5,5,5],[6,6,6,6]])
shape = tf.constant([4, 4, 4])
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter)
tf.Tensor(
[[[5 5 5 5]
[0 0 0 0]
[6 6 6 6]
[0 0 0 0]]
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]], shape=(4, 4, 4), dtype=int32)
3
indices = tf.constant([[0], [2]])
updates = tf.constant([[[5,5,5,5],[6,6,6,6],
[7,7,7,7],[8,8,8,8]],
[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]]])
shape = tf.constant([4, 4, 4])
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter)
tf.Tensor(
[[[5 5 5 5]
[6 6 6 6]
[7 7 7 7]
[8 8 8 8]]
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]
[[5 5 5 5]
[6 6 6 6]
[7 7 7 7]
[8 8 8 8]]
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]], shape=(4, 4, 4), dtype=int32)