TensorFlow改变张量特定切片中元素的值

问题

使用TensorFlow时,有时我们希望像使用numpy那样改变张量中某些元素的值,通过使用类似下面的代码:

np_array[...] = some_value  # usually succeed
tf_tensor[...] = some_value  # usually fail

然而在TensorFlow中,Tensor对象不支持切片赋值,因此类似于上述第二行代码通常失败。这是TensorFlow的计算图逻辑决定的。在TensorFlow计算图中,Tensor是不同操作(OP)节点之间的数据流,其本质是获得该张量的操作(tf.Tensor.op)而并非数据,内存中并不存在某一区域保存某个张量的值。因此,像使用numpy那样改变张量本身某些元素的值从逻辑上是无法实现的。
然而,我们通常只是希望通过更新张量某些元素的值获取一个新的张量用于后续的计算,这样,我们便可以通过构造一个操作来实现这一过程。该操作以原张量为输入,以特定元素值更新后的张量为输出。TensorFlow官方API提供了scatter_nd_update(TensorFlow 1.x)和tensor_scatter_nd_update(TensorFlow 2.x)等函数,通过给出更新元素位置来实现上述功能。具体使用方法请参考官方API,这里不再赘述。然而,这些函数并不能支持任意切片,因此可考虑手动实现对张量任意切片的值更新。

实现

首先导入TensorFlow。

import tensorflow as tf

可以考虑以下思路实现对张量切片部分元素的修改(这样会得到一个新的张量,而并非改变原张量的“值”)。对于一个张量tensor,构造两个同形状的索引张量和二值张量,其中,索引张量用于标示更改元素的唯一索引,二值张量用于标示更改元素的位置。具体代码如下所示:

def tensor_slice_update(tensor, updates,
                        begin, end, 
                        strides=None, begin_mask=0, end_mask=0,
                        ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0,
                        name=None):
    """Create a tensor by updating values of a tensor slice.
    Arguments:
        tensor: A `Tensor` like, the tensor to be assigned value to.
        updates: A `Tensor` like, the value to update. Must has the same shape and data type
            with the tensor slice.
        begin: See `tf.strided_slice`.
        end:  See `tf.strided_slice`.
        strides:  See `tf.strided_slice`.
        begin_mask:  See `tf.strided_slice`.
        end_mask:  See `tf.strided_slice`.
        ellipsis_mask:  See `tf.strided_slice`.
        new_axis_mask:  See `tf.strided_slice`.
        shrink_axis_mask:  See `tf.strided_slice`.
        name: A `str`, `OP` name.
    Returns:
        A `Tensor` with the same shape with `tensor`.
    """
    with tf.name_scope(name or 'tensor_slice_update'):
        tensor = tf.convert_to_tensor(tensor)
        tensor_shape = tf.shape(tensor)
        tensor_size = tf.size(tensor)
        updates = tf.convert_to_tensor(updates, dtype=tensor.dtype)

        index_slice = tf.strided_slice(
            tf.reshape(tf.range(tensor_size), tensor_shape),
            begin, end, strides=strides,
            begin_mask=begin_mask, end_mask=end_mask,
            ellipsis_mask=ellipsis_mask, new_axis_mask=new_axis_mask,
            shrink_axis_mask=shrink_axis_mask)

        flattened_index_slice = tf.expand_dims(tf.keras.backend.flatten(index_slice), -1)
        flattened_value = tf.keras.backend.flatten(updates)
        scattered_value = tf.scatter_nd(flattened_index_slice, flattened_value, (tensor_size,))
        padded_value = tf.reshape(scattered_value, tensor_shape)

        flattened_ones = tf.ones_like(flattened_value, dtype=tf.bool)
        scattered_ones = tf.scatter_nd(flattened_index_slice, flattened_ones, (tensor_size,))
        padded_ones = tf.reshape(scattered_ones, tensor_shape)

        updated_tensor = tf.where(padded_ones, padded_value, tensor)
    return updated_tensor

下面对上述代码逻辑进行解释。
该函数实现了对张量tensor的某个切片(通过beginend、…、shrink_axis_mask等8个参数,具体含义参考tf.strided_slice)以value进行赋值,得到一个新的张量。(1)使用tf.rangetf.reshape函数构造一个与tensor同形状的索引张量,每个元素给出该位置的唯一ID;(2)使用tf.strided_slice对索引张量取切片index_slice,表示切片位置的唯一ID,为实现赋值,赋值张量updates应当与index_slice具有相同的形状;(3)使用tf.keras.backend.flattenindex_sliceupdates拉平并进行简单的变形,从而可以作为tf.scatter_nd的输入,将updates根据索引分散到应当更新的位置;(4)类似上述过程构造一个二值张量padded_ones,用以标识张量在各个位置的元素是否应当更新;(5)最后,使用tf.where获得更新后的张量。
此外,我们有时候需要仅更新张量某一个元素的值,那么使用上述函数明显比较麻烦。可以根据上述函数完成仅更新张量某一个元素值的函数(当然,也可以基于官方API提供的scatter函数实现,这里略去):

def tensor_element_update(tensor, update, index, name=None):
    """Create a tensor by updating the value of an element.
    Arguments:
        tensor: A `Tensor` like with known rank, the tensor to be assigned value to.
        update: A `Tensor` scalar like, the value to update. Must has the same data type with `tensor`.
        index: An 1-D `int` `Tensor` like, representing the element position to assign.
        name: A `str`, `OP` name.
    Returns:
        A `Tensor` with the same shape with `tensor`.
    """
    with tf.name_scope(name or 'tensor_element_update'):
        tensor = tf.convert_to_tensor(tensor)
        update = tf.convert_to_tensor(update, dtype=tensor.dtype)
        index = tf.convert_to_tensor(index, dtype=tf.int64)
        begin = index
        end = index + tf.ones_like(index)
        shrink_axis_mask = (1 << len(tensor.shape) + 1) - 1
        updated_tensor = tensor_slice_update(tensor, update, begin, end, shrink_axis_mask=shrink_axis_mask)
    return updated_tensor

测试

我们可以测试一下函数的效果。

a = tf.constant(0., shape=(3, 4, 5))

# Be careful of the tensor shape of `updates`.
ans_1 = tensor_slice_update(a, [[[1]]], (1, 2, 3), (2, 3, 4))

# Refer to `tf.strided_slice` and `tensorflow.python.ops.gen_array_ops.strided_slice` to specify a correct mask.
ans_2 = tensor_slice_update(a, 1, (1, 2, 3), (2, 3, 4), shrink_axis_mask=7)

ans_3 = tensor_slice_update(a, [[1, 2], [3, 4]], (0, 1, 1), (1, 3, 3), shrink_axis_mask=1)
ans_4 = tensor_element_update(a, tf.constant(1.), (1, 2, 3))

if tf.__version__ >= '2.0.0':
    print('ans_1:\n', ans_1.numpy())
    print('ans_2:\n', ans_2.numpy())
    print('ans_3:\n', ans_3.numpy())
    print('ans_4:\n', ans_4.numpy())
else:
    sess = tf.Session()
    ans_1, ans_2, ans_3, ans_4 = sess.run((ans_1, ans_2, ans_3, ans_4))
    print('ans_1:\n', ans_1)
    print('ans_2:\n', ans_2)
    print('ans_3:\n', ans_3)
    print('ans_4:\n', ans_4)

输出如下:

ans_1:
[[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]]
ans_2:
[[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]]
ans_3:
[[[0. 0. 0. 0. 0.]
[0. 1. 2. 0. 0.]
[0. 3. 4. 0. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]]
ans_4:
[[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]]

可以看到,我们能够对张量切片进行更新了。对于特定的应用,函数tensor_slice_update可能略显繁琐,不妨基于该函数定制自己的函数。

总结

  1. 然而在TensorFlow中,Tensor对象不支持切片赋值,但可以通过更新张量切片的值获取一个新的张量。
  2. 优先选用官方API,scatter_nd_update(TensorFlow 1.x)和tensor_scatter_nd_update(TensorFlow 2.x)等函数。
  3. 基于tf.stride_slicetf.scatter_nd实现了对张量任意切片进行更新的函数tensor_slice_update和轻量级的单个元素更新函数tensor_element_update
  4. 根据你的需求进行定制。

对于BUGs和编辑过程中的错误,望大家批评指正!

  • 4
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值