一、变量初始化
1.1 tf.global_variables_initializer()
如果计算图中有tf.Variable()或者tf.get_variable()
,则需要通过tf.global_variables_initializer()对变量进行初始化。 如果只有tf.Constant
则无需variables_initializer()。
注意:tf.global_variables_initializer()是在模型加载后调用的.
with tf.Graph().as_default():
[...模型定义...]
“预训练模型加载.”
tvars = tf.trainable_variables()
assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, FLAGS.checkpoint_path)
tf.train.init_from_checkpoint(FLAGS.checkpoint_path, assignment_map)
# init是在预训练模型加载后.
init = tf.global_variables_initializer()
1.2 tf.get_variable()
网络中声明可训练变量,可以这样声明,并初始化。
X = tf.get_variable(shape=[1, 8, 4], dtype=tf.float32, name="input",
initializer=tf.contrib.layers.xavier_initializer())
1.3 tf.Variable()
这个好像没法initializer.
二、查看变量
1. tf.trainable_variables()
查看所有可训练的变量.
2. tf.all_variables()
查看所有变量.
3. tf.global_variables()
查看全局变量.