Tensorflow–tf.assign()详解

tf中的tf.assign()函数非常容易被人误解,如果不是非常透彻理解tensorflow graph和op的概念的话,一不小心就会计算错误。

下面先来看一下源代码中的注释:

def assign(ref, value, validate_shape=None, use_locking=None, name=None):
  """Update 'ref' by assigning 'value' to it.
 
  This operation outputs a Tensor that holds the new value of 'ref' after
    the value has been assigned. This makes it easier to chain operations
    that need to use the reset value.
 
  Args:
    ref: A mutable `Tensor`.
      Should be from a `Variable` node. May be uninitialized.
    value: A `Tensor`. Must have the same type as `ref`.
      The value to be assigned to the variable.
    validate_shape: An optional `bool`. Defaults to `True`.
      If true, the operation will validate that the shape
      of 'value' matches the shape of the Tensor being assigned to.  If false,
      'ref' will take on the shape of 'value'.
    use_locking: An optional `bool`. Defaults to `True`.
      If True, the assignment will be protected by a lock;
      otherwise the behavior is undefined, but may exhibit less contention.
    name: A name for the operation (optional).
 
  Returns:
    A `Tensor` that will hold the new value of 'ref' after
      the assignment has completed.
  """
  if ref.dtype._is_ref_dtype:
    return gen_state_ops.assign(
        ref, value, use_locking=use_locking, name=name,
        validate_shape=validate_shape)
  return ref.assign(value)

A `Tensor` that will hold the new value of ‘ref’ after the assignment has completed. 这一句话是重点。

# --*== UTF-8 --*--
import tensorflow as tf
 
 
def test_1():
    a = tf.Variable([10, 20])
    b = tf.assign(a, [20, 30])
    c = a + [10, 20]
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print("test_1 run a : ",sess.run(a)) # => [10 20] 
        print("test_1 run c : ",sess.run(c)) # => [10 20]+[10 20] = [20 40] 因为b没有被run所以a还是[10 20]
        print("test_1 run b : ",sess.run(b)) # => ref:a = [20 30] 运行b,对a进行assign
        print("test_1 run a again : ",sess.run(a)) # => [20 30] 因为b被run过了,所以a为[20 30]
        print("test_1 run c again : ",sess.run(c)) # => [20 30] + [10 20] = [30 50] 因为b被run过了,所以a为[20,30], 那么c就是[30 50]
 
 
def test_2():
    a = tf.Variable([10, 20])
    b = tf.assign(a, [20, 30])
    c = b + [10, 20]
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run(a)) # => [10 20] 
        print(sess.run(c)) # => [30 50] 运行c的时候,由于c中含有b,所以b也被运行了
        print(sess.run(a)) # => [20 30]
 
 
def main():
    test_1()
    test_2()
 
 
if __name__ == '__main__()':
    main()

只有把上面的test_1和test_2搞明白了,才能说是真正理解了tf.assign()这个操作。

总结
  • assign未被执行,ref值不更新
  • assign_add 、assign_sub 也是一样的
  • assign_add(加后分配值给x,如x=x+1/x-=1)
  • assign_sub(减后分配值给x,x=x-1/x-=1)
  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值