Tensorflow 学习笔记之 共享变量(Sharing Variables)
最近两年,谷歌撑腰的深度学习框架Tensorflow发展地如日中天,虽然17年pytorch的出现略微“打压”了一些TF的势头,但TF在深度学习界的地位还是难以撼动的,github上TF的收藏量一直稳在深度学习中前二的位置。个人在4月份开始接触TF,写分类、超分辨网络不亦乐乎。然而,最近从越来越多的TF github项目中看到了人们都在使用一个叫“共享变量”的机制管理变量,已经基本学会简单TF语法的我,今天决定好好研究一下这个功能。
变量管理的问题
设想你要写一个分类网络,结构是“卷积->ReLU->Pooling->卷积->ReLU->Pooling->展平->全连接->ReLU->全连接->Softmax”。由于网络实在太简单了,写起来完全不需要过多的思考。可能你是这么写的(例子出自TF官网:http://tensorflow.org/tutorials/mnist/pros/index.html):
def weight_variable(shape):
return tf.Variable(tf.truncated_normal(shape, stddev=0.1))
def bias_variable(shape):
return tf.Variable(tf.constant(0.1, shape=shape))
W_conv1 = weight_variable([5, 5, 3, 32])
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(tf.nn.conv2d(...))
h_pool1 = tf.nn.max_pool(h_conv1,...)
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(tf.nn.conv2d(...))
h_pool2 = tf.nn.max_pool(h_conv2,...)
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_flat = tf.reshape(...)
h_fc1 = tf.nn.relu(tf.matmul(...))
从中应该可以看出来,如果需要添加卷积层或全连接层,需要额外定义相应的权重w和偏置b。因此就有了[W_conv1,b_conv1,W_conv2,b_conv2,…]这一串变量信息。
那么问题来了,如果让你写一个19层的VGG网络,甚至是上百层的Resnet呢?这种定义方法显然是行不通的,等手动把[W_conv1,b_conv1,W_conv2,b_conv2,…]这些东西输完,估计也对TF丧失兴趣了。你可能会想到这种循环的方法:
def layer(shape, ...):
w = tf.Variable(tf.truncated_normal(shape, stddev=0.1))
b = tf.Variable(tf.constant(0.1, shape=shape))
return tf.nn.relu(tf.nn.conv2d(...))
for i in range(19):
...
x = layer(shape, ...)
...
这样的确就不用一个个写[w1,w2,w3,w4,….]这些变量了,从某种程度上来看确实解放了双手。但是,如果我现在想读取第8个卷积层中w和b的数值,有没有什么简单的方法呢?再或者我想把这个网络中的参数转移到另