1、tensorflow中所有的定义都只是声明,只有在session中run的时候,才会被执行。
谨记:对于模型中所有的参数都必须要使用variable来定义。可以使用tf.truncated_normal()来定义随机初始话,但是必须将随机初始化的值赋给variable。不然,每次需要访问参数的时候,都会驱动tf.truncated_normal()。
正确的写法:
import numpy as np
import tensorflow as tf
sess = tf.Session()
params = tf.Variable(tf.truncated_normal([4, 5]))
indices = tf.constant([2, 0])
output = tf.gather(params, indices)
sess.run(tf.global_variables_initializer())
print (sess.run(params))
print (sess.run(output))
sess.close()
错误的写法:
import numpy as np
import tensorflow as tf
sess = tf.Session()
params = tf.truncated_normal([4, 5])
indices = tf.constant([2, 0])
output = tf.gather(params, indices)
print (sess.run(params))
print (sess.run(output))
sess.close()
说明:param其实也只是生成随机数的操作,这个操作被驱动了2次,一次是sess.run(params),,一次是sess.run(output)。