Tensorflow学习笔记-变量管理
当一个神经网络比较复杂、参数比较多时,就比较需要一个比较好的方式来传递和管理这些参数。而Tensorflow提供了通过变量名称来创建或者获取变量的机制。通过这个机制,可以在不同的函数中直接通过变量的名称来使用变量,而不需要将变量通过参数进行传递。
关于变量管理使用,请参考LeNet5模型的实现:http://blog.csdn.net/lovelyaiq/article/details/78631593。
实现这种机制的函数是tf.variable_scope和tf.get_variable。
Tenssorflow创建变量时,一般都是通过tf.Variable,但这和tf.get_variable是等价的。例如:
v= tf.Variable(tf.constant(1.0,shape=[1]),name='v')
v = tf.get_variable('v',shape=[1],initializer=tf.constant_initializer(1.0))
tf.get_variable的函数定义如下:
def get_variable(name,#这个是必须要填写的
shape=None,
dtype=None,
initializer=None,
regularizer=None,
trainable=True,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
custom_getter=None):
tf.get_variable的第一个参数是必须要填写的,那就是变量的名称。 tf.get_variable的作用不仅在于创建变量,它还可以通过变量名称获取变量,但在创建变量时,如果创建失败(变量已经存在),那么就会报错。因此,就需要一个上下文管理器tf.variable_scope,可以更好的管理变量,并且可以增加程序的可读性。
参数reuse=True时,通过tf.get_variable只能获取已经存在的变量名。
with tf.variable_scope('foo'):
w = tf.get_variable('w', shape=[1], initializer=tf.constant_initializer(1.0))
# reuse=True,如果获取的变量不存在,则会报错。
# reuse=None,如果获取的变量存在,则会报错。
# ValueError: Variable foo/w already exists, disallowed
with tf.variable_scope('foo',reuse=None):
w = tf.get_variable('w', shape=[1])
w1 = tf.get_variable('w1', shape=[1],initializer=tf.constant_initializer(1.0)
上下文管理器也可以进行嵌套。
with tf.variable_scope('root'):
# False
print(tf.get_variable_scope().reuse)
with tf.variable_scope('foo',reuse=True):
# True
print(tf.get_variable_scope().reuse)
with tf.variable_scope('bar'):
# True
print(tf.get_variable_scope().reuse)
# False
print(tf.get_variable_scope().reuse)
上下文管理器在嵌套的时候,变量名称的前面会加上上下文管理器的名称。
v1 = tf.Variable('v',[1])
# Variable:0
print(v1.name)
with tf.variable_scope('foo'):
v2 = tf.get_variable('v',[1])
# foo/v:0
print(v2.name)
with tf.variable_scope('foo'):
with tf.variable_scope('bar'):
v3 = tf.get_variable('v',[1])
# foo/bar/v:0
print(v3.name)
v4 = tf.get_variable('v1',[1])
# foo/v1:0
print(v4.name)
with tf.variable_scope('',reuse=True):
v5 = tf.get_variable('foo/bar/v',[1])
# True
print(v5 == v3)
v6 = tf.get_variable('foo/v1',[1])
# True
print(v6 == v4)
TensorFlow中有两个变量管理器,另外一个是name_scope,它与variable_scope的区别,请参考:http://blog.csdn.net/u012436149/article/details/53081454