其实tensorflow里面如何改变tensor的值一直是很令人苦恼的事儿,不过我最近发现 tf.assign 函数可以对 tensor 进行整体的赋值。
调用方式
import tensorflow as tf
w = tf.Variable(initial_value=[[1,1], [1,1]], dtype = tf.float32)
update = tf.assign(w, [[1,2],[1,2]])
x = w * 2
注意点 – 和控制流搭配使用
先看一个实例,如果直接输出 tensor w 的值:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(x))
然而输出的 w 的值是:
这就证明了似乎没有执行 update = tf.assign(w, [[1,2],[1,2]]) 这一步。如何确保一定执行了这一步呢?
显式确保调用 update:控制流
with tf.control_dependencies([update]):
x = w * 2
这样就能够保证在执行 x = w * 2 之前一定执行了 update 那一步。这样再执行就会得到正常的结果了: