今天将tf.Variable和tf.get_variable变量的使用记录一下,在实现gan时着实踩了很深的坑,总是效果不好,也没有报错,结果发现是共享权重没有处理好,最后终于整好了,贼开心呢,希望大家能够避免踩坑。
1、tf.Variable的使用
调用方式:
weights = tf.Variable(tf.constant(0.1, shape = shape), name = "weights")
2、tf.get_variable的使用
调用方式:
weights = tf.get_variable("weights", shape,
initializer = tf.truncated_normal_initializer(stddev = 0.1))
3、两者区别
3.1
tf.Variable,当重复调用时,它会自动创建新的变量名:
def test():
#在layer1命名空间内创建变量,默认reuse=False
with tf.variable_scope('D_layer1'):
weights1 = tf.Variable(tf.constant(0.1, shape = [5]), name = "weights")
name1 = weights1.name
#在layer2命名空间内创建变量,默认reuse=False
with tf.variable_scope('D_layer2'):
weights2 = tf.Variable(tf.constant(0.1, shape = [5]), name = "weights")
name2 = weights2.name
return name1, name2
tf.variable_scope(‘D_layer1’)会创建一个名为D_layer1的空间名,其下的所有变量名是在它的子空间来命名,如上函数,重复调用结果如下:
import variabletest
import tensorflow as tf
import numpy as np
name11, name12 = variabletest.test()
name21, name22 = variabletest.test()
print(name11)
print(name12)
print(name21)
print(name22)
D_layer1/weights:0
D_layer2/weights:0
D_layer1_1/weights:0
D_layer2_1/weights:0
3.2、实现共享变量
tf.get_variable,当重复调用时,它会自动创建新的变量名:
def test(reuse):
#在layer1命名空间内创建变量,默认reuse=False
with tf.variable_scope('D_layer1', reuse = reuse):
weights1 = tf.get_variable("weights", [5], initializer = tf.truncated_normal_initializer(stddev = 0.1))
name1 = weights1.name
#在layer2命名空间内创建变量,默认reuse=False
with tf.variable_scope('D_layer2', reuse = reuse):
weights2 = tf.get_variable("weights", [5], initializer = tf.truncated_normal_initializer(stddev = 0.1))
name2 = weights2.name
return name1, name2
在类似gan网络中,我们需要共享权重,这样就会多次调用同一个前向传播的函数,但是若使用tf.Variable达不到共享权重的目的,除非将tf.Variable放置主函数中,但是这样封装性不好,所以就可以使用tf.get_variable,配合tf.variable_scope一起使用,结果如下:
import variabletest
import tensorflow as tf
import numpy as np
name11, name12 = variabletest.test(False)
name21, name22 = variabletest.test(True)
print(name11)
print(name12)
print(name21)
print(name22)
D_layer1/weights:0
D_layer2/weights:0
D_layer1/weights:0
D_layer2/weights:0
当tf.variable_scope的reuse设置为False时,他会自动创建新的变量,当为True时,他会从已有的变量中查询并使用。从而上述代码即可完成变量的共享使用。。。