一、tf.scatter_update
tf.scatter_update(
ref,
indices,
updates,
use_locking=True,
name=None
)
Applies sparse updates to a variable reference.
将稀疏更新应用于变量引用。
This operation computes
该函数的计算过程如下:
# Scalar indices
ref[indices, ...] = updates[...]
# Vector indices (for each i)
ref[indices[i], ...] = updates[i, ...]
# High rank indices (for each i, ..., j)
ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]
This operation outputs ref
after the update is done. This makes it easier to chain operations that need to use the reset value.
更新完成后,此操作输出ref。 这样可以更容易地链接需要使用重置值的操作。
If values in ref
is to be updated more than once, because there are duplicate entries in indices
, the order at which the updates happen for each value is undefined.
如果ref
中的值要多次更新,因为索引中存在重复条目,则每个值的更新发生顺序是不确定的。
Requires updates.shape = indices.shape + ref.shape[1:]
.
二、参数
参数 | ||
---|---|---|
ref | A Variable . | 一个Variable |
indices | A Tensor . Must be one of the following types: int32 , int64 . A tensor of indices into the first dimension of ref . | 一个Tensor ,必须为以下的数据类型: int32 , int64 。表示在ref 的第一维中的索引的张量。 |
updates | A Tensor . Must have the same type as ref . A tensor of updated values to store in ref . | 一个Tensor ,必须和ref 拥有相同的数据类型。表示一个要存储在ref 中的更新值的张量。 |
use_locking | An optional bool . Defaults to True . If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. | 一个可选的bool 值。 默认为True 。 如果为True,则分配将受锁保护; 否则行为未定义,但可能表现出较少的争用。 |
name | A name for the operation (optional). | 名称,可选。 |
返回值
Same as ref
. Returned as a convenience for operations that want to use the updated values after the update is done.
与ref
有相同数据类型。 返回以方便那些在更新完成后要使用更新值的操作。(注:返回的是一个Variable
,而不是和tf.scatter_nd_update
一样返回是一个Tensor
。!!)
三、代码
该函数的是给定需要待更新的矩阵的第一维索引和需要更新的数据然后根据这些数据进行更新。
根据这个公式updates.shape = indices.shape + ref.shape[1:]
,可以看出该函数和tf.scatter_nd_update
函数最大的区别,前者只作用于矩阵的第一维,后者可以作用于矩阵的任意多个维度。
综上,理解了上述的公式,这个函数就很容易理解了。
import tensorflow as tf
import tensorflow.contrib.eager as tfe
tf.enable_eager_execution()
ref = tfe.Variable(initial_value=[[0, 0, 0, 0], [0, 0, 0, 0]])
indices = tf.constant([0])
updates = tf.constant([[1, 98, 20, 102]])
update = tf.scatter_update(ref, indices, updates)
print(update)
结果如下:
<tf.Variable '' shape=(2, 4) dtype=int32, numpy=
array([[ 1, 98, 20, 102],
[ 0, 0, 0, 0]])>
三维矩阵更新:
import tensorflow as tf
import tensorflow.contrib.eager as tfe
import numpy as np
tf.enable_eager_execution()
ref = tfe.Variable(np.zeros(shape=[4, 4, 3], dtype=np.float32))
indices = tf.constant([0])
updates = tf.constant([np.random.random(size=[4, 3])], dtype=tf.float32)
update = tf.scatter_update(ref, indices, updates)
print(update)
结果如下:
<tf.Variable '' shape=(4, 4, 3) dtype=float32, numpy=
array([[[0.13318056, 0.83603495, 0.8232899 ],
[0.02964316, 0.8545541 , 0.27696434],
[0.4880769 , 0.23017927, 0.64292145],
[0.07073301, 0.10755321, 0.347981 ]],
[[0. , 0. , 0. ],
[0. , 0. , 0. ],
[0. , 0. , 0. ],
[0. , 0. , 0. ]],
[[0. , 0. , 0. ],
[0. , 0. , 0. ],
[0. , 0. , 0. ],
[0. , 0. , 0. ]],
[[0. , 0. , 0. ],
[0. , 0. , 0. ],
[0. , 0. , 0. ],
[0. , 0. , 0. ]]], dtype=float32)>