一文搞懂TensorFlow的变量创建
在TensorFlow中创建变量,主要使用的函数是:tf.Variable()和tf.get_variable()。这两个函数都比较常用,看懂这两个函数后,在调试代码时会更顺利一些。下面简单介绍一下这两个函数。
两个函数的简单介绍
(1)tf.Variable(initial value, [name]) ([] - 表示可选)
功能:生成一个新的初始值为initial value的变量。根据函数形参的设置,必须指定变量初始值。
(2)tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None, regularizer=None, trainable=True, collections=None)
功能:获取已存在的变量(名字,shape等完全一致),如果不存在该变量,则新建一个。根据该函数形参的设置,必须指定变量名,但可以不指定确切的初始值,可以使用各种初始化方法。
两者的重要区别
(1) tf.get_variable()的初始化更方便;
(2) tf.get_variable()共享变量更方便,而tf.Variable()则会新建一个变量。
(3) 一般而言,使用tf.get_variable()需要结合tf.variable_scope()和reuse的设置,而tf.Variable()不需要。
针对上面的三点,给出相应的例子说明。
(1)初始化方式对比
initial = tf.truncated_normal([1], stddev=0.02) # 具体的初始值
initial_2 = tf.contrib.layers.xavier_initializer()
vv1 = tf.get_variable("vv1", initializer=initial)
vv2 = tf.get_variable("vv2", [1], initializer=initial_2)
# xx1 = tf.Variable(initial_2) # 会报错,因为没有shape
xx2 = tf.Variable(initial)
print(vv1)
print(vv2)
print(xx2)
显然,tf.get_variable()的初始化方式更为灵活。
(2)、(3) 变量创建以及共享。tf.Variable()难以共享,tf框架会将设置相同的name的Variable按一定的规则修改不同的name。而tf.get_variable()则会首先检测是否相同的name,没有才会创建新的变量,同时结合variable_scope可实现变量的共享(即在相同的scope中多次使用该变量,值可传递)。
注:若允许变量共享,则需设置reuse=True,否则重复使用会报错。
with tf.variable_scope("scope1"):
v = tf.get_variable("v", [1])
with tf.variable_scope("scope1", reuse=True):
b = tf.get_variable("v")
e = tf.Variable([1.0], name="v")
d = tf.Variable([1.0], name="v") # 每次系统都会新建一个变量,这样测试的时候,有相同的name则难以测试
assert v == b # 没有报错,证明v和b是完全一样的对象
print(v)
print(e)
print(d) # 打印的结果显示d的name并不是'v',而是被修改为'v_1'
当测试时有使用到name的话,而Variable创建的变量的name存在重复,此时系统会自动修改name,而用户不知道具体修改之后的name,则测试时直接使用自己设置的name,就会很容易出错。
实战应用
def weight_variable(shape, stddev=0.02, name=None):
initial = tf.truncated_normal(shape, stddev=stddev)
if name is None:
return tf.Variable(initial)
else:
return tf.get_variable(name, initializer=initial)
结语
综上所述,如果需要快速实现,而且网络结构不大,则可使用tf.Variable();如果使网络结构可读性更强(使用variable_scope),需要共享变量的,则可使用tf.get_variable()。