共享变量
1. tf.name_scope
和tf.variable_scope
的区别
- 对于使用tf.Variable来说,
tf.name_scope
和tf.variable_scope
功能一样,都是给变量加前缀,相当于分类管理,模块化。 - 对于tf.get_variable来说,tf.name_scope对其无效,也就是说tf认为当你使用tf.get_variable时,你只归属于tf.variable_scope来管理共享与否。
with tf.name_scope('name_sp1') as sp1:
with tf.variable_scope('var_sp2') as sp2:
with tf.name_scope('name_sp3') as sp3:
a = tf.Variable('a')
b = tf.get_variable(name='b', shape=[2], dtype=tf.float32)
print(a.name) # name_sp1/var_sp2/name_sp3/a:0
print(b.name) # var_sp2/b:0
2. tf.Variable
和tf.get_variable
的区别
import tensorflow as tf
with tf.name_scope("s1"):
initializer = tf.constant_initializer(value=1)
var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32, initializer=initializer)
var2 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
var21 = tf.Variable(name='var2', initial_value=[2.1], dtype=tf.float32)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print(var1.name) # var1:0
print(var2.name) # s1/var2:0
print(var21.name) # s1/var2_1:0
使用 tf.Variable()
定义的时候, 虽然 name
都一样, 但是为了不重复变量名, Tensorflow 输出的变量名并不是一样的. 所以, 本质上 var2
, var21
并不是一样的变量. 而另一方面, 使用tf.get_variable()
定义的变量不会被tf.name_scope()
当中的名字所影响.
3. 实现共享变量
实现共享变量,我们就要使用tf.variable_scope()
和tf.get_variable()
import tensorflow as tf
with tf.variable_scope('var_scp2') as scp2:
a = tf.get_variable(name='a', shape=[2], dtype=tf.float32)
scp2.reuse_variables() # 必须使用,否则会报错
b = tf.get_variable(name='a')
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(a.name) # var_scp2/a:0
print(b.name) # var_scp2/a:0
可以看出,本质上,变量a
,b
是同一个变量,达到参数共享的目的。