这次用slim搭个稍微大一点的网络VGG16,VGG16和VGG19实际上差不多,所以本例程的代码以VGG16来做5类花的分类任务。
VGG网络相比之前的LeNet,AlexNet引入如下几个特点:
1. 堆叠3×3的小卷积核替代了5×5,7×7的大卷积核。
虽然5×5的卷积核感受野大,但是参数多。2个3×3的卷积堆叠感受野等同于5×5,并且进行了2次非线性变换。总结一下:相比于大卷积核,小卷积核的堆叠一方面减少了参数; 另一方面进行了更多的非线性映射,增加了网络表达能力。
2.网络层数加深。我们先不谈深层网络难以训练又或者梯度弥散等缺点,在特征的抽象化或者网络的表达能力范畴上,深层网络比浅层网络更加能够拟合数据的分布。
3.VGG网络的原作还引入了数据增广,图像预处理等trick。
开始贴代码阶段,工程分为三个文件:
vgg.py: 搭建16层的VGG网络。
import tensorflow as tf
import tensorflow.contrib.slim as slim
def build_vgg(rgb, num_classes, keep_prob, train=True):
with slim.arg_scope([slim.conv2d, slim.fully_connected], activation_fn=tf.nn.relu, normalizer_fn=slim.batch_norm):
# block_1
net = slim.repeat(rgb, 2, slim.conv2d, 64, [3, 3], padding='SAME', scope='conv1')
net = slim.max_pool2d(net, [2, 2], scope='pool1')
# block_2
net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], padding='SAME', scope='conv2')
net = slim.max_pool2d(net, [2, 2], scope='pool2')
# block_3
net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], padding='SAME', scope='conv3')
net = slim.max_pool2d(net, [2, 2], scope='pool3')
# block_4
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], padding='SAME', scope='conv4')
net = slim.max_pool2d(net, [2, 2], scope='pool4')
# block_5
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], padding='SAME', scope='conv5')
net = slim.max_pool2d(net, [2, 2], scope='pool5')
# flatten
feature_shape = net.get_shape()
flattened_shape = feature_shape[1].value * feature_shape[2].value * feature_shape[3].value
pool5_flatten = tf.reshape(net, [-1, flattened_shape])
# fc6
net = slim.fully_connected(pool5_flatten, 4096, scope='fc6')
if train:
net = slim.dropout(net, keep_prob=keep_prob, scope='dropout6')
# fc7
net = slim.fully_connected(net, 4096, scope='fc7')
if train:
net = slim.dropout(net, keep_prob=keep_prob, scope='dropout7')
# fc8
net = slim.fully_connected(net, num_classes, activation_fn=tf.nn.softmax, scope='fc8')
return net
tfrecords.py:用于数据的编码和解码,本例程不同与之前的文章采用feed_dict向网络喂数据,而是使用tensorflow自己的TFRecord结构编码数据集。
import tensorflow as tf
import numpy as np
import os
import glob
from PIL import Image
path_tfrecord = '/home/danny/chenwei/CSDN_blog/VGG/tfrecords/'
def convert_to_tfrecord(images, labels, filename):
print("Converting data into %s ..." % filename)
writer = tf.python_io.TFRecordWriter(path_tfrecord + filename)
for index, img in enumerate(images):
img_raw = Image.open(img)
if img_raw.mode != "RGB":
continue
img_raw = img_raw.resize((256, 256))
img_raw = img_raw.tobytes()
label = int(labels[index])
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
"img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
}))
writer.write(example.SerializeToString())
writer.close()
def read_and_decode(filename, is_train=None):
filename_queue = tf.train.string_input_producer([filename], num_epochs=400)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string),
})
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [256, 256, 3])
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
if is_train == True:
img = tf.random_crop(img, [224, 224, 3])
img = tf.image.random_flip_left_right(img)
img = tf.image.random_brightness(img, max_delta=63)
img = tf.image.random_contrast(img, lower=0.2, upper=1.8)
img = tf.image.per_image_standardization(img)
else:
img = tf.image.resize_image_with_crop_or_pad(img, 224, 224)
img = tf.image.per_image_standardization(img)
label = tf.cast(features['label'], tf.int32)
return img, label
def get_file(path):
cate = [path+x for x in os.listdir(path) if os.path.isdir(path+x)]
images = []
labels = []
for idx, folder in enumerate(cate):
for img in glob.glob(folder+'/*.jpg'):
print('reading the images:%s' % (img))
images.append(img)
labels.append(idx)
image_list = np.asarray(images, np.string_)
label_list = np.asarray(labels, np.int32)
# shuffle
num_example = image_list.shape[0]
arr = np.arange(num_example)
np.random.shuffle(arr)
image_list = image_list[arr]
label_list = label_list[arr]
# divide train_data and val_data
num_example = image_list.shape[0]
split = np.int(num_example * 0.8)
train_images = image_list[:split]
train_labels = label_list[:split]
val_images = image_list[split:]
val_labels = label_list[split:]
return train_images, train_labels, val_images, val_labels
if __name__ == '__main__':
train_images, train_labels, val_images, val_labels = get_file('/home/danny/chenwei/CSDN_blog/VGG/datasets/')
convert_to_tfrecord(images=train_images, labels=train_labels, filename="train.tfrecords")
convert_to_tfrecord(images=val_images, labels=val_labels, filename="test.tfrecords")
train.py:用于训练的文件,与之间不同之处在于使用队列的方式多线程取数据进行训练。
# -*- coding: utf-8 -*-
import tensorflow as tf
from utils.tfrecords import *
from model.vgg import *
tf.app.flags.DEFINE_integer('num_classes', 5, 'classification number.')
tf.app.flags.DEFINE_integer('crop_width', 256, 'width of input image.')
tf.app.flags.DEFINE_integer('crop_height', 256, 'height of input image.')
tf.app.flags.DEFINE_integer('channels', 3, 'channel number of image.')
tf.app.flags.DEFINE_integer('batch_size', 2, 'num of each batch')
tf.app.flags.DEFINE_integer('num_epochs', 400, 'number of epoch')
tf.app.flags.DEFINE_bool('continue_training', False, 'whether is continue training')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate')
tf.app.flags.DEFINE_string('dataset_path', './datasets/', 'path of dataset')
tf.app.flags.DEFINE_string('checkpoints', './checkpoints/model.ckpt', 'path of checkpoints')
tf.app.flags.DEFINE_string('train_tfrecords', '/home/danny/chenwei/CSDN_blog/VGG/tfrecords/train.tfrecords', 'train tfrecord')
tf.app.flags.DEFINE_string('test_tfrecords', '/home/danny/chenwei/CSDN_blog/VGG/tfrecords/test.tfrecords', 'test tfrecord')
FLAGS = tf.app.flags.FLAGS
def main(_):
# data process
train_images, train_labels = read_and_decode(FLAGS.train_tfrecords, True)
val_images, val_labels = read_and_decode(FLAGS.test_tfrecords, False)
train_labels = tf.one_hot(indices=tf.cast(train_labels, tf.int32), depth=FLAGS.num_classes)
train_images_batch, train_labels_batch = tf.train.shuffle_batch([train_images, train_labels], batch_size=FLAGS.batch_size, capacity=20000, min_after_dequeue=3000, num_threads=16) # 这里设置线程数
val_labels = tf.one_hot(indices=tf.cast(val_labels, tf.int32), depth=FLAGS.num_classes)
val_images_batch, val_labels_batch = tf.train.shuffle_batch([val_images, val_labels], batch_size=FLAGS.batch_size, capacity=20000, min_after_dequeue=3000, num_threads=16) # 这里设置线程数
# define network input
input = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.crop_height, FLAGS.crop_width, FLAGS.channels], name='input')
output = tf.placeholder(tf.int32, shape=[FLAGS.batch_size, FLAGS.num_classes], name='output')
# control GPU resource utilization
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
# build network
logits = build_vgg(input, FLAGS.num_classes, 0.5, True)
# loss
cross_entropy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=output))
regularization_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
loss = cross_entropy_loss + regularization_loss
# optimizer
train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(loss)
# calculate correct
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(output, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
with sess.as_default():
# init all paramters
saver = tf.train.Saver(max_to_keep=1000)
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
# restore weight
if FLAGS.continue_training:
saver.restore(sess, FLAGS.checkpoints)
# begin training
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
epoch = 0
try:
while not coord.should_stop():
# begin training
train_images, train_labels = sess.run([train_images_batch, train_labels_batch])
_, err, acc = sess.run([train_op, loss, accuracy], feed_dict={input: train_images, output: train_labels})
print("[Train] Step: %d, loss: %.4f, accuracy: %.4f%%" % (epoch, err, acc))
epoch += 1
if epoch % 10 == 0 or (epoch + 1) == FLAGS.num_epochs:
val_images, val_labels = sess.run([val_images_batch, val_labels_batch])
val_err, val_acc = sess.run([loss, accuracy], feed_dict={input:val_imagesh, output: val_labels})
print("[validation] Step: %d, loss: %.4f, accuracy: %.4f%%" % (epoch, val_err, val_acc))
if (epoch + 1) == FLAGS.num_epochs:
checkpoint_path = FLAGS.checkpoints
saver.save(sess, save_path=checkpoint_path, global_step=epoch)
except tf.errors.OutOfRangeError:
print('Done training -- epoch limited reached')
finally:
coord.request_stop()
coord.join(threads)
sess.close()
if __name__ == '__main__':
tf.app.run()
训练结果:大约在96%左右
[Train] Step: 19985, loss: 1.1098, accuracy: 1.0000%
[Train] Step: 19986, loss: 1.1302, accuracy: 1.0000%
[Train] Step: 19987, loss: 1.1232, accuracy: 1.0000%
[Train] Step: 19988, loss: 1.1299, accuracy: 1.0000%
[Train] Step: 19989, loss: 1.1220, accuracy: 1.0000%
[validation] Step: 19990, loss: 1.1634, accuracy: 0.9688%