tf.Variable()与tf.get_variable()的区别

关于tf.Variable()与tf.get_variable()的区别,很多博客都在罗列传入参数的不同,然后推荐使用tf.get_variable(),因为更适合复用。本人才疏学浅,看不出参数不同怎么就会影响复用。通过阅读这篇博客,加以自己的理解,阐述两者的区别和为何tf.get_variable()更适合参数复用。

tf.Variable()

对于tf.Variable()定义的变量,如x = tf.Variable(3, name="x")。当多次调用x时,系统会自动给变量编号x,x_1,x_2来解决命名冲突。例如,定义如下的加法:

def add_function():
    x = tf.Variable(3, name="x_scalar")
    y = tf.Variable(2, name="y_scalar")
    addition = tf.add(x,  y, name="add_function")
    print("=== checking Variables ===")
    print("x:", x, "\ny:", y, "\n")
    return addition

并分两次调用该函数:

result1 = add_function()
result2 = add_function()

得到的结果是:

=== checking Variables ===
x: <tf.Variable 'x_scalar:0' shape=() dtype=int32_ref> 
y: <tf.Variable 'y_scalar:0' shape=() dtype=int32_ref> 

=== checking Variables ===
x: <tf.Variable 'x_scalar_1:0' shape=() dtype=int32_ref> 
y: <tf.Variable 'y_scalar_1:0' shape=() dtype=int32_ref>

可以看到x第一次变量名是x_scalar:0,第二次是x_scalar_1:0。其中:前是系统赋予的变量名,也就是说,在系统里创建了两个变量x_scalarx_scalar_1,如果调用该函数100次,那么总共会创建100个x变量:x_scalar - x_scalar_99这显然太占用内存了,他们分明指的都是一个变量,如果能复用就能省很大空间。

tf.Variable()能复用吗?可以就是太麻烦了
办法是先用一个dict存储这些变量,然后定义的函数再去调用,如下:

# 先定义好变量dict 
variables_dict = {"x_scalar": tf.Variable(3, name="x_scalar"),  "y_scalar": tf.Variable(2, name="y_scalar")}
# 再定义函数                
def add_function(x, y):
    addition = tf.add(x,  y, name="add_function")
    print("=== checking Variables ===")
    print("x:", x, "\ny:", y, "\n")
    return addition
# 通过调用来复用变量
result1 = add_function(variables_dict["x_scalar"], variables_dict["y_scalar"])
result2 = add_function(variables_dict["x_scalar"], variables_dict["y_scalar"]) 

有结果:

=== checking Variables ===
x: <tf.Variable 'x_scalar_2:0' shape=() dtype=int32_ref> 
y: <tf.Variable 'y_scalar_2:0' shape=() dtype=int32_ref> 

=== checking Variables ===
x: <tf.Variable 'x_scalar_2:0' shape=() dtype=int32_ref> 
y: <tf.Variable 'y_scalar_2:0' shape=() dtype=int32_ref>

可以看到,此时的变量x和y复用了,都为变量x_scalar_2y_scalar_2。然而,这还是太麻烦了,也不优雅,需要先定义dict,再多次调用dict,这种大量的重复一定是不能忍的。tf提供了一种优雅解决的方式,就是tf.get_variable()

tf.get_variable()

首先它与tf.Variable()最大的不同是,tf.Variable()重复调用一个定义的变量时,系统会自动编号解决重复命名的问题,但是对于 tf.get_variable()来说,这种方式是会报错的,如下:

x = tf.get_variable("x_scalar", [])
y = tf.get_variable("x_scalar", [])

当把变量都取为x_scalar时会出现如下的报错:

ValueError: Variable x_scalar already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? 

所以,当需要共享变量时,需要在定义时写上reuse=True,并且是在它所在的scope定义,而非对这个变量定义,如下,当对这个变量定义reuse,会报错:

x = tf.get_variable("x_scalar", [] , reuse = True)
y = tf.get_variable("x_scalar", [] , reuse = True)
TypeError: get_variable() got an unexpected keyword argument 'reuse'

当对这个scope定义reuse,那么就可以得到正确的定义:

x = tf.get_variable("test_scalar", [] )
tf.get_variable_scope().reuse_variables()
y = tf.get_variable("test_scalar", [] )

scope的定义有点像文件夹,来指明一个变量的所属,允许嵌套定义,通常长成这个样子xxx/xxx/var。例如:

with tf.variable_scope("test", reuse=tf.AUTO_REUSE):
    v = tf.get_variable("v_var", [2]) # Variable scope "scope/variable name" -> test/v_var
print("=== variable scope ===") # === variable scope ===
print("v.name:", v.name) # v.name: test/v_var:0
print("v.op.name:", v.op.name) # v.op.name: test/v_var

得到如下结果:

=== variable scope ===
v.name: test/v_var:0
v.op.name: test/v_var

回到最开始的例子,用tf.get_variable()配合tf.variable_scope()轻轻松松复用:

def add_function():
    with tf.variable_scope("test" , reuse = tf.AUTO_REUSE):
        x = tf.get_variable("x_scalar1", [2])
        y = tf.get_variable("y_scalar1", [2])
        addition = tf.add(x,  y, name="add_function1")
    print("=== checking Variables ===")
    print("x:", x, "\ny:", y)
    return addition
    
result1 = add_function()
result2 = add_function()

得到结果:

=== checking Variables ===
x: <tf.Variable 'test/x_scalar1:0' shape=(2,) dtype=float32_ref> 
y: <tf.Variable 'test/y_scalar1:0' shape=(2,) dtype=float32_ref>
=== checking Variables ===
x: <tf.Variable 'test/x_scalar1:0' shape=(2,) dtype=float32_ref> 
y: <tf.Variable 'test/y_scalar1:0' shape=(2,) dtype=float32_ref>
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值