关于tf.Variable()与tf.get_variable()的区别,很多博客都在罗列传入参数的不同,然后推荐使用tf.get_variable(),因为更适合复用。本人才疏学浅,看不出参数不同怎么就会影响复用。通过阅读这篇博客,加以自己的理解,阐述两者的区别和为何tf.get_variable()更适合参数复用。
tf.Variable()
对于tf.Variable()定义的变量,如x = tf.Variable(3, name="x")
。当多次调用x时,系统会自动给变量编号x,x_1,x_2
来解决命名冲突。例如,定义如下的加法:
def add_function():
x = tf.Variable(3, name="x_scalar")
y = tf.Variable(2, name="y_scalar")
addition = tf.add(x, y, name="add_function")
print("=== checking Variables ===")
print("x:", x, "\ny:", y, "\n")
return addition
并分两次调用该函数:
result1 = add_function()
result2 = add_function()
得到的结果是:
=== checking Variables ===
x: <tf.Variable 'x_scalar:0' shape=() dtype=int32_ref>
y: <tf.Variable 'y_scalar:0' shape=() dtype=int32_ref>
=== checking Variables ===
x: <tf.Variable 'x_scalar_1:0' shape=() dtype=int32_ref>
y: <tf.Variable 'y_scalar_1:0' shape=() dtype=int32_ref>
可以看到x第一次变量名是x_scalar:0
,第二次是x_scalar_1:0
。其中:
前是系统赋予的变量名,也就是说,在系统里创建了两个变量x_scalar
和x_scalar_1
,如果调用该函数100次,那么总共会创建100个x变量:x_scalar - x_scalar_99
这显然太占用内存了,他们分明指的都是一个变量,如果能复用就能省很大空间。
tf.Variable()能复用吗?可以就是太麻烦了
办法是先用一个dict存储这些变量,然后定义的函数再去调用,如下:
# 先定义好变量dict
variables_dict = {"x_scalar": tf.Variable(3, name="x_scalar"), "y_scalar": tf.Variable(2, name="y_scalar")}
# 再定义函数
def add_function(x, y):
addition = tf.add(x, y, name="add_function")
print("=== checking Variables ===")
print("x:", x, "\ny:", y, "\n")
return addition
# 通过调用来复用变量
result1 = add_function(variables_dict["x_scalar"], variables_dict["y_scalar"])
result2 = add_function(variables_dict["x_scalar"], variables_dict["y_scalar"])
有结果:
=== checking Variables ===
x: <tf.Variable 'x_scalar_2:0' shape=() dtype=int32_ref>
y: <tf.Variable 'y_scalar_2:0' shape=() dtype=int32_ref>
=== checking Variables ===
x: <tf.Variable 'x_scalar_2:0' shape=() dtype=int32_ref>
y: <tf.Variable 'y_scalar_2:0' shape=() dtype=int32_ref>
可以看到,此时的变量x和y复用了,都为变量x_scalar_2
和y_scalar_2
。然而,这还是太麻烦了,也不优雅,需要先定义dict,再多次调用dict,这种大量的重复一定是不能忍的。tf提供了一种优雅解决的方式,就是tf.get_variable()
tf.get_variable()
首先它与tf.Variable()最大的不同是,tf.Variable()重复调用一个定义的变量时,系统会自动编号解决重复命名的问题,但是对于 tf.get_variable()来说,这种方式是会报错的,如下:
x = tf.get_variable("x_scalar", [])
y = tf.get_variable("x_scalar", [])
当把变量都取为x_scalar
时会出现如下的报错:
ValueError: Variable x_scalar already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope?
所以,当需要共享变量时,需要在定义时写上reuse=True
,并且是在它所在的scope
定义,而非对这个变量定义,如下,当对这个变量定义reuse
,会报错:
x = tf.get_variable("x_scalar", [] , reuse = True)
y = tf.get_variable("x_scalar", [] , reuse = True)
TypeError: get_variable() got an unexpected keyword argument 'reuse'
当对这个scope
定义reuse
,那么就可以得到正确的定义:
x = tf.get_variable("test_scalar", [] )
tf.get_variable_scope().reuse_variables()
y = tf.get_variable("test_scalar", [] )
scope的定义有点像文件夹,来指明一个变量的所属,允许嵌套定义,通常长成这个样子xxx/xxx/var
。例如:
with tf.variable_scope("test", reuse=tf.AUTO_REUSE):
v = tf.get_variable("v_var", [2]) # Variable scope "scope/variable name" -> test/v_var
print("=== variable scope ===") # === variable scope ===
print("v.name:", v.name) # v.name: test/v_var:0
print("v.op.name:", v.op.name) # v.op.name: test/v_var
得到如下结果:
=== variable scope ===
v.name: test/v_var:0
v.op.name: test/v_var
回到最开始的例子,用tf.get_variable()配合tf.variable_scope()轻轻松松复用:
def add_function():
with tf.variable_scope("test" , reuse = tf.AUTO_REUSE):
x = tf.get_variable("x_scalar1", [2])
y = tf.get_variable("y_scalar1", [2])
addition = tf.add(x, y, name="add_function1")
print("=== checking Variables ===")
print("x:", x, "\ny:", y)
return addition
result1 = add_function()
result2 = add_function()
得到结果:
=== checking Variables ===
x: <tf.Variable 'test/x_scalar1:0' shape=(2,) dtype=float32_ref>
y: <tf.Variable 'test/y_scalar1:0' shape=(2,) dtype=float32_ref>
=== checking Variables ===
x: <tf.Variable 'test/x_scalar1:0' shape=(2,) dtype=float32_ref>
y: <tf.Variable 'test/y_scalar1:0' shape=(2,) dtype=float32_ref>