Variable
两个函数Variable和get_variable函数都是获取变量,具体用哪个?
tf.Variable(initial_value=None, trainable=True, collections=None, validate_shape=True, caching_device=None, name=None, variable_def=None, dtype=None, expected_shape=None, import_scope=None, constraint=None, use_resource=None, synchronization=VariableSynchronization.AUTO, aggregation=VariableAggregation.NONE)
tf.get_variable(name, shape=None, dtype=None, initializer=None, regularizer=None, trainable=None, collections=None, caching_device=None, partitioner=None, validate_shape=True, use_resource=None, custom_getter=None, constraint=None, synchronization=VariableSynchronization.AUTO, aggregation=VariableAggregation.NONE)
区别:使用tf.Variable时,如果检测到命名冲突,系统会自己处理,使用tf.get_variable()时,命名冲突系统会报错,除非用户tf.variable_scope的reuse=True进行共享变量。当我们需要共享变量的时候,需要使用tf.get_variable()。
import tensorflow as tf
with tf.variable_scope("scope1"):
w1 = tf.get_variable("w1", shape=[])
w2 = tf.Variable(0.0, name="w2")
w2_1 = tf.Variable(0.0, name="w2")
with tf.variable_scope("scope1", reuse=True):
w1_p = tf.get_variable("w1", shape=[])
w2_p = tf.Variable(1.0, name="w2")
print(w2.name)
print(w2_1.name)
print(w1 is w1_p, w2 is w2_p)
#输出
##输出
#w_2:0
#w_2_1:0
#True False
tf.Variable() 每次都在创建新对象,reuse=True 对它无意义。对于get_variable(),如果已经创建的变量对象,就把那个对象返回,如果没有创建变量对象的话,就创建一个新的。
reuse的还可以:
with tf.variable_scope('V1') as scope:
a1 = tf.get_variable(name='a1', shape=[1], initializer=tf.constant_initializer(1))
scope.reuse_variables()
Scope
TF中有两种作用域类型:
命名域 (name scope),通过tf.name_scope 或 tf.op_scope创建;
变量域 (variable scope),通过tf.variable_scope 或 tf.variable_op_scope创建;
这两种作用域,对于使用tf.Variable()方式创建的变量,具有相同的效果,都会在变量名称前面,加上域名称。对于通过tf.get_variable()方式创建的变量,只有variable scope名称会加到变量名称前面,而name scope不会作为前缀。
with tf.variable_scope('V1',reuse=None):
a1 = tf.get_variable(name='a1', shape=[1], initializer=tf.constant_initializer(1))
a2 = tf.Variable(tf.random_normal(shape=[2,3], mean=0, stddev=1), name='a2')
with tf.variable_scope('V2',reuse=True):
a3 = tf.get_variable(name='a1', shape=[1],initializer=tf.constant_initializer(1))
a4 = tf.Variable(tf.random_normal(shape=[2,3], mean=0, stddev=1), name='a2')
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print (a1.name)
print (a2.name)
print (a3.name)
print (a4.name)
#输出:
#V1/a1:0
#V1_14/a2:0
#V2/a1:0
#V2_2/a2:0 在tf.name_scope()中则没有resuse这个参数,无法实现这种操作。
with tf.name_scope("my_name_scope"):
v1 = tf.get_variable("var1", [1], dtype=tf.float32)
v2 = tf.Variable(1, name="var2", dtype=tf.float32)
a = tf.add(v1, v2)
print(v1.name)
print(v2.name)
print(a.name)
输出:
var1:0
my_name_scope/var2:0
my_name_scope/Add:0
总结:
tf.variable_scope() adds a prefix to the names of all variables (no matter how you create them), ops, constants. On the other hand tf.name_scope() ignores variables created with tf.get_variable() because it assumes that you know which variable and in which scope you wanted to use.
经验上,使用get_variable + variable_scope会比较多。
参考资料:
共享变量