假设您想交换第二维中的元素,保留第一维的顺序与否。
import tensorflow as tf
sess = tf.InteractiveSession()
def prepare_fd(fd_indices, sd_dims):
fd_indices = tf.expand_dims(fd_indices, 1)
fd_indices = tf.tile(fd_indices, [1, sd_dims])
return fd_indices
# define the updates
updates = tf.constant([[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34]])
sd_dims = tf.shape(updates)[1]
sd_indices = tf.constant([[1, 0, 2, 3], [0, 2, 1, 3], [0, 1, 3, 2]])
fd_indices_range = tf.range(0, limit=tf.shape(updates)[0])
fd_indices_custom = tf.constant([2, 0, 1])
# define the indices
indices1 = tf.stack((prepare_fd(fd_indices_range, sd_dims), sd_indices), axis=2)
indices2 = tf.stack((prepare_fd(fd_indices_custom, sd_dims), sd_indices), axis=2)
# define the shape
shape = tf.shape(updates)
scatter1 = tf.scatter_nd(indices1, updates, shape)
scatter2 = tf.scatter_nd(indices2, updates, shape)
print(scatter1.eval())
# array([[12, 11, 13, 14],
# [21, 23, 22, 24],
# [31, 32, 34, 33]], dtype=int32)
print(scatter2.eval())
# array([[21, 23, 22, 24],
# [31, 32, 34, 33],
# [12, 11, 13, 14]], dtype=int32)
可能这个例子有帮助。