前面我们介绍了猫狗大战的数据集的整理已经转换成Tensorflow专用格式https://blog.csdn.net/nvidiacuda/article/details/83413837
这一篇介绍一下VGG16模型的修改
Step 1: 对模型的修改
首先是对模型的修改(VGG16_model.py文件),在这里原先的输出结果是对1000个不同的类别进行判定,而在此是对2个图像,也就是猫和狗的判断,因此首先第一步就是修改输出层的全连接数据。
def fc_layers(self):
self.fc6 = self.fc("fc6", self.pool5, 4096,trainable=False)#语句变动
self.fc7 = self.fc("fc7", self.fc6, 4096,trainable=False)#语句变动
self.fc8 = self.fc("fc8", self.fc7, 2)
这里是最后一层的输出通道被设置成2,而对于其他部分,定义创建卷积层和全连接层的方法则无需做出太大改动。
def conv(self,name, input_data, out_channel):
in_channel = input_data.get_shape()[-1]
with tf.variable_scope(name):
kernel = tf.get_variable("weights", [3, 3, in_channel, out_channel], dtype=tf.float32, trainable=False) #语句变动
biases = tf.get_variable("biases", [out_channel], dtype=tf.float32, trainable=False) #语句变动
conv_res = tf.nn.conv2d(input_data, kernel, [1, 1, 1, 1], padding="SAME")
res = tf.nn.bias_add(conv_res, biases)
out = tf.nn.relu(res, name=name)
self.parameters += [kernel, biases]
return out
def fc(self, name, input_data, out_channel, trainable=True):
shape = input_data.get_shape().as_list()
if len(shape) == 4:
size = shape[-1] * shape[-2] * shape[-3]
else:size = shape[1]
input_data_flat = tf.reshape(input_data,[-1,size])
with tf.variable_scope(name):
weights = tf.get_variable(name="weights",shape=[size,out_channel],dtype=tf.float32,trainable=trainable) #语句变动
biases = tf.get_variable(name="biases",shape=[out_channel],dtype=tf.float32, trainable=trainable) #语句变动
res = tf.matmul(input_data_flat,weights)
out = tf.nn.relu(tf.nn.bias_add(res,biases))
self.parameters += [weights, biases]
return out
在这里读者可能已经注意到,在介绍全连接层的修改时,就有一个额外的输入参数:
trainable=False
而在卷积层和全连接层的定义中,也添加了这个参数:
def fc(self, name, input_data, out_channel, trainable=True):
直接的解释就是,在进行Finetuning对模型重新训练时,对于部分不需要训练的层可以通过设置trainable=False来确保其在训练过程中不会被修改权值。
下面还有一个非常重要的函数是VGGNet权重的载入,前文已经有所介绍,具体如下:
def load_weights(self, weight_file, sess):
weights = np.load(weight_file)
keys = sorted(weights.keys())
for i, k in enumerate(keys):
if i not in [30,31]:
sess.run(self.parameters[i].assign(weights[k]))
print("-----------all done---------------")
可以看到,这里使用了一个if函数对序号进行剔除,即对于最后一层的权重不要载入。
完整代码:VGG16_model.py文件
import numpy as np
import tensorflow as tf
class vgg16:
def __init__(self, imgs):
self.parameters = []
self.imgs = imgs
self.convlayers()
self.fc_layers()
self.probs = self.fc8
def saver(self):
return tf.train.Saver()
def maxpool(self,name,input_data, trainable):
out = tf.nn.max_pool(input_data,[1,2,2,1],[1,2,2,1],padding="SAME",name=name)
return out
def conv(self,name, input_data, out_channel, trainable):
in_channel = input_data.get_shape()[-1]
with tf.variable_scope(name):
kernel = tf.get_variable("weights", [3, 3, in_channel, out_channel], dtype=tf.float32,trainable=False)
biases = tf.get_variable("biases", [out_channel], dtype=tf.float32,trainable=False)
conv_res = tf.nn.conv2d(input_data, kernel, [1, 1, 1, 1], padding="SAME")
res = tf.nn.bias_add(conv_res, biases)
out = tf.nn.relu(res, name=name)
self.parameters += [kernel, biases]
return out
def fc(self,name,input_data,out_channel,trainable = True):
shape = input_data.get_shape().as_list()
if len(shape) == 4:
size = shape[-1] * shape[-2] * shape[-3]
else:size = shape[1]
input_data_flat = tf.reshape(input_data,[-1,size])
with tf.variable_scope(name):
weights = tf.get_variable(name="weights",shape=[size,out_channel],dtype=tf.float32,trainable = trainable)
biases = tf.get_variable(name="biases",shape=[out_channel],dtype=tf.float32,trainable = trainable)
res = tf.matmul(input_data_flat,weights)
out = tf.nn.relu(tf.nn.bias_add(res,biases))
self.parameters += [weights, biases]
return out
def convlayers(self):
# zero-mean input
#conv1
self.conv1_1 = self.conv("conv1re_1",self.imgs,64,trainable=False)
self.conv1_2 = self.conv("conv1_2",self.conv1_1,64,trainable=False)
self.pool1 = self.maxpool("poolre1",self.conv1_2,trainable=False)
#conv2
self.conv2_1 = self.conv("conv2_1",self.pool1,128,trainable=False)
self.conv2_2 = self.conv("convwe2_2",self.conv2_1,128,trainable=False)
self.pool2 = self.maxpool("pool2",self.conv2_2,trainable=False)
#conv3
self.conv3_1 = self.conv("conv3_1",self.pool2,256,trainable=False)
self.conv3_2 = self.conv("convrwe3_2",self.conv3_1,256,trainable=False)
self.conv3_3 = self.conv("convrew3_3",self.conv3_2,256,trainable=False)
self.pool3 = self.maxpool("poolre3",self.conv3_3,trainable=False)
#conv4
self.conv4_1 = self.conv("conv4_1",self.pool3,512,trainable=False)
self.conv4_2 = self.conv("convrwe4_2",self.conv4_1,512,trainable=False)
self.conv4_3 = self.conv("conv4rwe_3",self.conv4_2,512,trainable=False)
self.pool4 = self.maxpool("pool4",self.conv4_3,trainable=False)
#conv5
self.conv5_1 = self.conv("conv5_1",self.pool4,512,trainable=False)
self.conv5_2 = self.conv("convrwe5_2",self.conv5_1,512,trainable=False)
self.conv5_3 = self.conv("conv5_3",self.conv5_2,512,trainable=False)
self.pool5 = self.maxpool("poorwel5",self.conv5_3,trainable=False)
def fc_layers(self):
self.fc6 = self.fc("fc6", self.pool5, 4096,trainable=False)
self.fc7 = self.fc("fc7", self.fc6, 4096,trainable=False)
self.fc8 = self.fc("fc8", self.fc7, 2)
def load_weights(self, weight_file, sess):
weights = np.load(weight_file)
keys = sorted(weights.keys())
for i, k in enumerate(keys):
if i not in [30,31]:
sess.run(self.parameters[i].assign(weights[k]))
print("-----------all done---------------")
可以看到,对于每个卷积层和全连接层中,不需要训练的权重全部被设置为trainable=False。