variable共享
tf的Variable是什么就不说了,我们只说下Variable的复用,如下例所示,我们要为图像建立两层卷积网络:
def conv_relu(input, kernel_shape, bias_shape):
# Create variable named "weights".
weights = tf.Variable(tf.random_normal([5, 5, 32, 32]),
name="conv1_weights")
# Create variable named "biases".
biases = tf.Variable(tf.zeros([32]), name="conv1_biases")
conv = tf.nn.conv2d(input_images, conv1_weights,
strides=[1, 1, 1, 1], padding='SAME')
return tf.nn.relu(conv + biases)
def my_image_filter(input_images):
relu1 = conv_relu(input_images, [5, 5, 32, 32], [32])
relu2 = conv_relu(relu1, [5, 5, 32, 32], [32])
如果有两个图像需要处理,那么我们调用:
# 创建一套参数
result1 = my_image_filter(image1)
# 在此创建一套参数
result2 = my_image_filter(image2)
我们知道同样作用同样维度的参数我们分别在result1和result2中创建了一次,这样对内存的占用就多了两倍!如何在处理多个图像时仍然使用一套参数呢(也就是variable共享)??一种选择是建立一套参数的词典,使用唯一的key来确定参数,但这样比较麻烦,并且也不符合封装的思想。有没有一种方法,使得我们可以在代码内部就可以分配给Variable一个类似于key功能的元素? 这时候我们就需要引入get_variable的name属性和variable_scope。
get_variable
我们知道,tf.Variable的name属性是和其对象本身一一对应的。也就是说,如果确定Variable的名字,也就可以确定Variable对象本身了。可能有人会问,那使用tf.Varible(name='xxx')创建Variable对象时,不是给Variable确定的名字了吗?调用相同的创建语句时,岂不是使用相同的名字又创建了一个不同的Variable?看了下面的代码就明白了:
conv1_weights = tf.Variable(tf.random_normal([5, 5, 32, 32]),
name="conv_weights")
print (conv1_weights.name)
conv1_weights = tf.Variable(tf.random_normal([5, 5, 32, 32]),
name="conv_weights")
print (conv1_weights.name)
# 输出:
# conv_weights:0
# conv_weights_1:0
看到没,同样的创建语句,name参数设置也一样,实际的名字属性却不同!也就是说,tf.Variable()创建方法在会自动在相同name属性下面加序号以方便区别!
如果要使用自己设立一个固定的name属性当如何?
v = tf.get_variable("v", [1])
print (v.name)
v = tf.get_variable("v", [1])
print (v.name)
# 输出:v:0
# ValueError
可以看出,tf.get_variable()可以实现这样的功能。
variable_scope
我们知道参数的数量常常是巨大的,给参数命名很耗时(给参数起很长的名字并且打出来也很耗时),借助variable_scope我们可以在当前scope下面给所有的variable加上前缀“scope”,这无疑带来了。我们可以看下格式:
with tf.variable_scope("foo"):
v = tf.get_variable("w", [1])
print (v.name)
# 输出:foo/w:0
reuse参数,tf.get_variable_scope().reuse_variables()以及
回到我们最初的问题--variable共享,怎么在多次调用代码段的同时保持使用同一套Variable呢??我们有两种方式,一种是设置variable_scope的参数reuse,
1. reuse==Ture,该scope下强制使用曾经创建过的Variable。一旦发现get_variable的name不是已经创建的或者shape和已经创建的过的variable有不同,则报错。
2. reuse==tf.AUTO_REUSE,如果name已经存在,则复用。如果不存在,则创建。
3. reuse==tf.False,默认为False,该情况下不复用,换句话说,如果name重复了会报错。
另外 tf.get_variable_scope().reuse_variables()和reuse = True的功能是类似的。
参考文献: