Tensorflow:tf.assign()函数的使用方法及易错点

引言:
当大家在使用tf.assign()这个函数时,如果不是很了解这个函数的用法,很容易出错,而且似乎对应不同的tf版本其操作结果也会有细微的差别,本文是基于1.9.0版本的tf进行描述的,对于更新的版本而言应该结论是一样的,但对于比较旧的版本,可能就会有细微差别。


首先我们看一下源码中的返回值说明:

update = tf.assign(ref, new_value)    # 平时的使用写法
--------------------------------------------------------------------
Returns:
    A `Tensor` that will hold the new value of 'ref' after
      the assignment has completed.

也就是说,只有当这个赋值被完成时,该旧值ref才会被修改成new_value。不过这样描述还是太抽象了,那到底什么叫赋值被完成呢?下面我给大家放两个简单的例子,帮助大家理解

import tensorflow as tf 

ref_a = tf.Variable(tf.constant(1))
ref_b = tf.Variable(tf.constant(2))
update = tf.assign(ref_a, 10)
ref_sum = tf.add(ref_a, ref_b)

with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	print(sess.run(ref_sum))
------------------------------------------------------
输出结果:3

然后你就会感到奇怪,这里与往常的直觉不一样,理论上ref_a应该已经被修改为10了?带着疑问,我们看第二个例子

import tensorflow as tf 

ref_a = tf.Variable(tf.constant(1))
ref_b = tf.Variable(tf.constant(2))
update = tf.assign(ref_a, 10)
ref_sum = tf.add(ref_a, ref_b)

with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	sess.run(update)  # 唯一修改的地方
	print(sess.run(ref_sum))
------------------------------------------------------
输出结果:12

看到这里,是不是大家就明白了。所谓的赋值被完成其实指得是需要对tf.assign()函数的返回值执行一下sess.run()操作后,才能保证正常更新。

在明白了这个易错的地方后,我再介绍两种方法,来达到同样的目的。


方法一:采用ref_a = tf.assign(ref_a, 10)操作,我们看一下代码和运行结果

import tensorflow as tf 

ref_a = tf.Variable(tf.constant(1))
ref_b = tf.Variable(tf.constant(2))
ref_a = tf.assign(ref_a, 10)
ref_sum = tf.add(ref_a, ref_b)

with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	print(sess.run(ref_sum))
------------------------------------------------------
输出结果:12

事实上,tf.assign(ref, new_value)函数返回的结果就是参数中的new_value,因此我们只需要用ref来接收返回值也可以达到直接更新的效果

方法二:使用tf.control_dependencies()函数,我们也同样来看一下代码和结果

import tensorflow as tf 

ref_a = tf.Variable(tf.constant(1))
ref_b = tf.Variable(tf.constant(2))
update = tf.assign(ref_a, 10)

with tf.control_dependencies([update]):
	ref_sum = tf.add(ref_a, ref_b)

with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	print(sess.run(ref_sum))
------------------------------------------------------
输出结果:12

可以发现,结果也为我们预期想要达到的效果,该函数保证其辖域中的操作必须要在该函数所传递的参数中的操作完成后再进行。简单地说,就是实际在运行时,会先执行该函数传递的参数update,再执行其辖域中的操作ref_sum = tf.add(ref_a, ref_b)


如果觉得我有地方讲的不好的或者有错误的欢迎给我留言,谢谢大家阅读(点个赞我可是会很开心的哦)~

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值