Tensorflow–tf.assign()详解

Tensorflow–tf.assign()详解

转自:http://www.soaringroad.com/?p=194

tf.assgin()函数很具有被误解的潜质,如果不是非常透彻理解tensorflow graph 和 op 的概念的话,一不小心就会计算错误……. 先来看下源代码:

 
  1. def assign(ref, value, validate_shape=None, use_locking=None, name=None):

  2. """Update 'ref' by assigning 'value' to it.

  3.  
  4. This operation outputs a Tensor that holds the new value of 'ref' after

  5. the value has been assigned. This makes it easier to chain operations

  6. that need to use the reset value.

  7.  
  8. Args:

  9. ref: A mutable `Tensor`.

  10. Should be from a `Variable` node. May be uninitialized.

  11. value: A `Tensor`. Must have the same type as `ref`.

  12. The value to be assigned to the variable.

  13. validate_shape: An optional `bool`. Defaults to `True`.

  14. If true, the operation will validate that the shape

  15. of 'value' matches the shape of the Tensor being assigned to. If false,

  16. 'ref' will take on the shape of 'value'.

  17. use_locking: An optional `bool`. Defaults to `True`.

  18. If True, the assignment will be protected by a lock;

  19. otherwise the behavior is undefined, but may exhibit less contention.

  20. name: A name for the operation (optional).

  21.  
  22. Returns:

  23. A `Tensor` that will hold the new value of 'ref' after

  24. the assignment has completed.

  25. """

  26. if ref.dtype._is_ref_dtype:

  27. return gen_state_ops.assign(

  28. ref, value, use_locking=use_locking, name=name,

  29. validate_shape=validate_shape)

  30. return ref.assign(value)

A `Tensor` that will hold the new value of ‘ref’ after the assignment has completed. 只有当assign()被执行了才会返回新值 下面两个例子看一下就明白了:

 
  1. # --*== UTF-8 --*--

  2. import tensorflow as tf

  3.  
  4.  
  5. def test_1():

  6. a = tf.Variable([10, 20])

  7. b = tf.assign(a, [20, 30])

  8. c = a + [10, 20]

  9. with tf.Session() as sess:

  10. sess.run(tf.global_variables_initializer())

  11. print("test_1 run a : ",sess.run(a)) # => [10 20]

  12. print("test_1 run c : ",sess.run(c)) # => [10 20]+[10 20] = [20 40] 因为b没有被run所以a还是[10 20]

  13. print("test_1 run b : ",sess.run(b)) # => ref:a = [20 30] 运行b,对a进行assign

  14. print("test_1 run a again : ",sess.run(a)) # => [20 30] 因为b被run过了,所以a为[20 30]

  15. print("test_1 run c again : ",sess.run(c)) # => [20 30] + [10 20] = [30 50] 因为b被run过了,所以a为[20,30], 那么c就是[30 50]

  16.  
  17.  
  18. def test_2():

  19. a = tf.Variable([10, 20])

  20. b = tf.assign(a, [20, 30])

  21. c = b + [10, 20]

  22. with tf.Session() as sess:

  23. sess.run(tf.global_variables_initializer())

  24. print(sess.run(a)) # => [10 20]

  25. print(sess.run(c)) # => [30 50] 运行c的时候,由于c中含有b,所以b也被运行了

  26. print(sess.run(a)) # => [20 30]

  27.  
  28.  
  29. def main():

  30. test_1()

  31. test_2()

  32.  
  33.  
  34. if __name__ == '__main__()':

  35. main()

如果把上面两个test弄明白了,那就真的理解了assign的操作了

总结

assign未被执行,ref值不更新

assign_add 、assign_sub 也是一样的

assign_add(加后分配值给x,如x=x+1/x-=1)

assign_sub(减后分配值给x,x=x-1/x-=1)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值