Tensorflow深度学习之三十三:tf.scatter_update

一、tf.scatter_update

tf.scatter_update(
    ref,
    indices,
    updates,
    use_locking=True,
    name=None
)

   Applies sparse updates to a variable reference.
   将稀疏更新应用于变量引用。

   This operation computes
   该函数的计算过程如下:

	# Scalar indices
	ref[indices, ...] = updates[...]

	# Vector indices (for each i)
	ref[indices[i], ...] = updates[i, ...]

	# High rank indices (for each i, ..., j)
	ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]

   This operation outputs ref after the update is done. This makes it easier to chain operations that need to use the reset value.
   更新完成后,此操作输出ref。 这样可以更容易地链接需要使用重置值的操作。

   If values in ref is to be updated more than once, because there are duplicate entries in indices, the order at which the updates happen for each value is undefined.
   如果ref中的值要多次更新,因为索引中存在重复条目,则每个值的更新发生顺序是不确定的。

   Requires updates.shape = indices.shape + ref.shape[1:].
图片来源自Tensor Flow官方网站

二、参数

参数
refA Variable.一个Variable
indicesA Tensor. Must be one of the following types: int32, int64. A tensor of indices into the first dimension of ref.一个Tensor,必须为以下的数据类型: int32int64表示在ref的第一维中的索引的张量。
updatesA Tensor. Must have the same type as ref. A tensor of updated values to store in ref.一个Tensor,必须和ref拥有相同的数据类型。表示一个要存储在ref中的更新值的张量。
use_lockingAn optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.一个可选的bool值。 默认为True。 如果为True,则分配将受锁保护; 否则行为未定义,但可能表现出较少的争用。
nameA name for the operation (optional).名称,可选。

返回值
   Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.
   与ref有相同数据类型。 返回以方便那些在更新完成后要使用更新值的操作。(注:返回的是一个Variable,而不是和tf.scatter_nd_update一样返回是一个Tensor。!!)

三、代码
   该函数的是给定需要待更新的矩阵的第一维索引和需要更新的数据然后根据这些数据进行更新。
   根据这个公式updates.shape = indices.shape + ref.shape[1:],可以看出该函数和tf.scatter_nd_update 函数最大的区别,前者只作用于矩阵的第一维,后者可以作用于矩阵的任意多个维度。
   综上,理解了上述的公式,这个函数就很容易理解了。

import tensorflow as tf
import tensorflow.contrib.eager as tfe

tf.enable_eager_execution()

ref = tfe.Variable(initial_value=[[0, 0, 0, 0], [0, 0, 0, 0]])
indices = tf.constant([0])
updates = tf.constant([[1, 98, 20, 102]])
update = tf.scatter_update(ref, indices, updates)

print(update)

   结果如下:

<tf.Variable '' shape=(2, 4) dtype=int32, numpy=
array([[  1,  98,  20, 102],
       [  0,   0,   0,   0]])>

   三维矩阵更新:

import tensorflow as tf
import tensorflow.contrib.eager as tfe
import numpy as np

tf.enable_eager_execution()

ref = tfe.Variable(np.zeros(shape=[4, 4, 3], dtype=np.float32))
indices = tf.constant([0])
updates = tf.constant([np.random.random(size=[4, 3])], dtype=tf.float32)
update = tf.scatter_update(ref, indices, updates)

print(update)

   结果如下:

<tf.Variable '' shape=(4, 4, 3) dtype=float32, numpy=
array([[[0.13318056, 0.83603495, 0.8232899 ],
        [0.02964316, 0.8545541 , 0.27696434],
        [0.4880769 , 0.23017927, 0.64292145],
        [0.07073301, 0.10755321, 0.347981  ]],

       [[0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ]]], dtype=float32)>
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值