问题
使用TensorFlow时,有时我们希望像使用numpy那样改变张量中某些元素的值,通过使用类似下面的代码:
np_array[...] = some_value # usually succeed
tf_tensor[...] = some_value # usually fail
然而在TensorFlow中,Tensor对象不支持切片赋值,因此类似于上述第二行代码通常失败。这是TensorFlow的计算图逻辑决定的。在TensorFlow计算图中,Tensor是不同操作(OP)节点之间的数据流,其本质是获得该张量的操作(tf.Tensor.op
)而并非数据,内存中并不存在某一区域保存某个张量的值。因此,像使用numpy那样改变张量本身某些元素的值从逻辑上是无法实现的。
然而,我们通常只是希望通过更新张量某些元素的值获取一个新的张量用于后续的计算,这样,我们便可以通过构造一个操作来实现这一过程。该操作以原张量为输入,以特定元素值更新后的张量为输出。TensorFlow官方API提供了scatter_nd_update
(TensorFlow 1.x)和tensor_scatter_nd_update
(TensorFlow 2.x)等函数,通过给出更新元素位置来实现上述功能。具体使用方法请参考官方API,这里不再赘述。然而,这些函数并不能支持任意切片,因此可考虑手动实现对张量任意切片的值更新。
实现
首先导入TensorFlow。
import tensorflow as tf
可以考虑以下思路实现对张量切片部分元素的修改(这样会得到一个新的张量,而并非改变原张量的“值”)。对于一个张量tensor
,构造两个同形状的索引张量和二值张量,其中,索引张量用于标示更改元素的唯一索引,二值张量用于标示更改元素的位置。具体代码如下所示:
def tensor_slice_update(tensor, updates,
begin, end,
strides=None, begin_mask=0, end_mask=0,
ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0,
name=None):
"""Create a tensor by updating values of a tensor slice.
Arguments:
tensor: A `Tensor` like, the tensor to be assigned value to.
updates: A `Tensor` like, the value to update. Must has the same shape and data type
with the tensor slice.
begin: See `tf.strided_slice`.
end: See `tf.strided_slice`.
strides: See `tf.strided_slice`.
begin_mask: See `tf.strided_slice`.
end_mask: See `tf.strided_slice`.
ellipsis_mask: See `tf.strided_slice`.
new_axis_mask: See `tf.strided_slice`.
shrink_axis_mask: See `tf.strided_slice`.
name: A `str`, `OP` name.
Returns:
A `Tensor` with the same shape with `tensor`.
"""
with tf.name_scope(name or 'tensor_slice_update'):
tensor = tf.convert_to_tensor(tensor)
tensor_shape = tf.shape(tensor)
tensor_size = tf.size(tensor)
updates = tf.convert_to_tensor(updates, dtype=tensor.dtype)
index_slice = tf.strided_slice(
tf.reshape(tf.range(tensor_size), tensor_shape),
begin, end, strides=strides,
begin_mask=begin_mask, end_mask=end_mask,
ellipsis_mask=ellipsis_mask, new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask)
flattened_index_slice = tf.expand_dims(tf.keras.backend.flatten(index_slice), -1)
flattened_value = tf.keras.backend.flatten(updates)
scattered_value = tf.scatter_nd(flattened_index_slice, flattened_value, (tensor_size,))
padded_value = tf.reshape(scattered_value, tensor_shape)
flattened_ones = tf.ones_like(flattened_value, dtype=tf.bool)
scattered_ones = tf.scatter_nd(flattened_index_slice, flattened_ones, (tensor_size,))
padded_ones = tf.reshape(scattered_ones, tensor_shape)
updated_tensor = tf.where(padded_ones, padded_value, tensor)
return updated_tensor
下面对上述代码逻辑进行解释。
该函数实现了对张量tensor
的某个切片(通过begin
、end
、…、shrink_axis_mask
等8个参数,具体含义参考tf.strided_slice
)以value
进行赋值,得到一个新的张量。(1)使用tf.range
和tf.reshape
函数构造一个与tensor
同形状的索引张量,每个元素给出该位置的唯一ID;(2)使用tf.strided_slice
对索引张量取切片index_slice
,表示切片位置的唯一ID,为实现赋值,赋值张量updates
应当与index_slice
具有相同的形状;(3)使用tf.keras.backend.flatten
将index_slice
与updates
拉平并进行简单的变形,从而可以作为tf.scatter_nd
的输入,将updates
根据索引分散到应当更新的位置;(4)类似上述过程构造一个二值张量padded_ones
,用以标识张量在各个位置的元素是否应当更新;(5)最后,使用tf.where
获得更新后的张量。
此外,我们有时候需要仅更新张量某一个元素的值,那么使用上述函数明显比较麻烦。可以根据上述函数完成仅更新张量某一个元素值的函数(当然,也可以基于官方API提供的scatter函数实现,这里略去):
def tensor_element_update(tensor, update, index, name=None):
"""Create a tensor by updating the value of an element.
Arguments:
tensor: A `Tensor` like with known rank, the tensor to be assigned value to.
update: A `Tensor` scalar like, the value to update. Must has the same data type with `tensor`.
index: An 1-D `int` `Tensor` like, representing the element position to assign.
name: A `str`, `OP` name.
Returns:
A `Tensor` with the same shape with `tensor`.
"""
with tf.name_scope(name or 'tensor_element_update'):
tensor = tf.convert_to_tensor(tensor)
update = tf.convert_to_tensor(update, dtype=tensor.dtype)
index = tf.convert_to_tensor(index, dtype=tf.int64)
begin = index
end = index + tf.ones_like(index)
shrink_axis_mask = (1 << len(tensor.shape) + 1) - 1
updated_tensor = tensor_slice_update(tensor, update, begin, end, shrink_axis_mask=shrink_axis_mask)
return updated_tensor
测试
我们可以测试一下函数的效果。
a = tf.constant(0., shape=(3, 4, 5))
# Be careful of the tensor shape of `updates`.
ans_1 = tensor_slice_update(a, [[[1]]], (1, 2, 3), (2, 3, 4))
# Refer to `tf.strided_slice` and `tensorflow.python.ops.gen_array_ops.strided_slice` to specify a correct mask.
ans_2 = tensor_slice_update(a, 1, (1, 2, 3), (2, 3, 4), shrink_axis_mask=7)
ans_3 = tensor_slice_update(a, [[1, 2], [3, 4]], (0, 1, 1), (1, 3, 3), shrink_axis_mask=1)
ans_4 = tensor_element_update(a, tf.constant(1.), (1, 2, 3))
if tf.__version__ >= '2.0.0':
print('ans_1:\n', ans_1.numpy())
print('ans_2:\n', ans_2.numpy())
print('ans_3:\n', ans_3.numpy())
print('ans_4:\n', ans_4.numpy())
else:
sess = tf.Session()
ans_1, ans_2, ans_3, ans_4 = sess.run((ans_1, ans_2, ans_3, ans_4))
print('ans_1:\n', ans_1)
print('ans_2:\n', ans_2)
print('ans_3:\n', ans_3)
print('ans_4:\n', ans_4)
输出如下:
ans_1:
[[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]]
ans_2:
[[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]]
ans_3:
[[[0. 0. 0. 0. 0.]
[0. 1. 2. 0. 0.]
[0. 3. 4. 0. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]]
ans_4:
[[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]]
可以看到,我们能够对张量切片进行更新了。对于特定的应用,函数tensor_slice_update
可能略显繁琐,不妨基于该函数定制自己的函数。
总结
- 然而在TensorFlow中,Tensor对象不支持切片赋值,但可以通过更新张量切片的值获取一个新的张量。
- 优先选用官方API,
scatter_nd_update
(TensorFlow 1.x)和tensor_scatter_nd_update
(TensorFlow 2.x)等函数。 - 基于
tf.stride_slice
和tf.scatter_nd
实现了对张量任意切片进行更新的函数tensor_slice_update
和轻量级的单个元素更新函数tensor_element_update
。 - 根据你的需求进行定制。
对于BUGs和编辑过程中的错误,望大家批评指正!