Saving and Restoring Variables
tf.train.Saver
tf.train.Saver(var_lise=None)
tf.get_variable()/tf.variable_scope
tf.get_variable(name, shape=None, initializer=None)
tf.variable_scope(name_or_scope, reuse=None)
a = tf.get_variable('a', shape=[2,5], initializer=tf.truncated_normal_initializer())
a = tf.Variable(tf.truncated_normal(shape=[2,5]), name='a')
以上两个定义等价。
tf.get_variable()与tf.Variable()最大区别在于参数name,前者是必填参数,后者是可选参数。
tf.variable_scope()可以控制tf.get_variable()的功能,当reuse=True生成上下文管理器时,tf.get_variable()根据name的值直接获取已创建的变量,如果变量不存在(没有name那个值的变量)则会报错。当reuse=None或reuse=False创建上下文管理器,tf.get_variable创建新的变量,如果有同名变量(name一样),则会报错。
with tf.variable_scope('layer1'):
w = tf.get_variable('w', shape=[2,5], initializer=tf.truncated_normal_initializer())
# w的name为'layer1/w',而下面创建的w1的name也为
# 'layer1/w',所以会报错。
with tf.variable_scope('layer1'):
w1 = tf.get_variable('w', shape=[2,5])
with tf.variable_scope('layer1', reuse=True):
w2 = tf.get_variable('w', [2,5])
print(w2 == w) # => True
# w2和w指向的是同一个东西,即'layer1/w'.
with tf.variable_scope('layer1', reuse=True):
w3 = tf.get_variable('b', [2,5])
# 由于命名空间中没有'layer1/b',且reuse=True,tf.get_variable是获取非创建新变量,故报错。
with tf.variable_scope('layer1', reuse=True):
w4 = tf.get_variable('w', [2])
# 虽然命名空间中有'layer1/w',但是shape不一样,故报错。
with tf.variable_scope('layer1', reuse=True):
w5 = tf.get_variable('w', [2, 5], initializer=tf.constant_intializer())
# 虽然此时创建的initializer不一样,但不会报错,并且此时的initializer没有意义,依然是之前w的initializer.
tf.variable_scope()嵌套使用时,reuse参数的取值见如下代码,
with tf.variable_scope('layer1', reuse=True):
print(tf.get_variable_scope().reuse) # => True
with tf.variable_scope('layer11'):
print(tf.get_variable_scope().reuse)
# 不指定reuse,内层的reuse取值会应用外一层的,此时输出为True。若外层reuse=FALSE,此时输出False。
tf.variable_scope()生成的上下文管理器去创建Tensorflow的命名空间还有一个用法,见下代码。
with tf.variable_scope("", reuse=True):
w9 = tf.get_variable('layer1/w2', [2,5])
# 创建一个空的命名空间,w9可以用来获得其他命名空间中的变量'layer1/w2'.
tf.Variable()
tf.Variable(initial_value=None, trainable=True, name=None)
a = tf.Variable(...,trainable=False,...)
参数trainable=False表示不将变量a加入到tf.GrapsKeys.TRAINABLE_VARIABLES集合中,即变量a不会出现在tf.trainable_variables()函数返回的变量列表中。