tensorflow微调vgg16 程序代码汇总

一、【深度学习图像识别课程】tensorflow迁移学习系列:VGG16花朵分类

转自:https://blog.csdn.net/weixin_41770169/article/details/80330581

花朵数据库介绍

种类5种:daisy雏菊,dandelion蒲公英,rose玫瑰,sunflower向日葵,tulips郁金香

数量:     633,            898,                   641,          699,                799             

总数量:3670

实战:VGGNet实现花朵分类

1、读入VGG16模型
[python] view plain copy
  1. from urllib.request import urlretrieve  
  2. from os.path import isfile, isdir  
  3. from tqdm import tqdm  
  4.   
  5. vgg_dir = 'tensorflow_vgg/'  
  6. # Make sure vgg exists  
  7. if not isdir(vgg_dir):  
  8.     raise Exception("VGG directory doesn't exist!")  
  9.   
  10. class DLProgress(tqdm):  
  11.     last_block = 0  
  12.   
  13.     def hook(self, block_num=1, block_size=1, total_size=None):  
  14.         self.total = total_size  
  15.         self.update((block_num - self.last_block) * block_size)  
  16.         self.last_block = block_num  
  17. if not isfile(vgg_dir + "vgg16.npy"):  
  18.     with DLProgress(unit='B', unit_scale=True, miniters=1, desc='VGG16 Parameters') as pbar:  
  19.         urlretrieve(  
  20.             'https://s3.amazonaws.com/content.udacity-data.com/nd101/vgg16.npy',  
  21.             vgg_dir + 'vgg16.npy',  
  22.             pbar.hook)  
  23. else:  
  24.     print("Parameter file already exists!")  

下载了如下标亮文件:vgg16.npy


2、读入图像库
[python] view plain copy
  1. import tarfile  
  2.   
  3. dataset_folder_path = 'flower_photos'  
  4.   
  5. class DLProgress(tqdm):  
  6.     last_block = 0  
  7.   
  8.     def hook(self, block_num=1, block_size=1, total_size=None):  
  9.         self.total = total_size  
  10.         self.update((block_num - self.last_block) * block_size)  
  11.         self.last_block = block_num  
  12.   
  13. if not isfile('flower_photos.tar.gz'):  
  14.     with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Flowers Dataset') as pbar:  
  15.         urlretrieve(  
  16.             'http://download.tensorflow.org/example_images/flower_photos.tgz',  
  17.             'flower_photos.tar.gz',  
  18.             pbar.hook)  
  19.   
  20. if not isdir(dataset_folder_path):  
  21.     with tarfile.open('flower_photos.tar.gz') as tar:  
  22.         tar.extractall()  
  23.         tar.close()  


下载如下高亮文件:flower_photos.tar.gz



3、卷积代码

参考的源码:[html] view plain cop


[html] view plain copy
  1. self.conv1_1 = self.conv_layer(bgr, "conv1_1")  
  2. self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2")  
  3. self.pool1 = self.max_pool(self.conv1_2, 'pool1')  
  4.   
  5. self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")  
  6. self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2")  
  7. self.pool2 = self.max_pool(self.conv2_2, 'pool2')  
  8.   
  9. self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")  
  10. self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2")  
  11. self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3")  
  12. self.pool3 = self.max_pool(self.conv3_3, 'pool3')  
  13.   
  14. self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")  
  15. self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2")  
  16. self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3")  
  17. self.pool4 = self.max_pool(self.conv4_3, 'pool4')  
  18.   
  19. self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")  
  20. self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2")  
  21. self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3")  
  22. self.pool5 = self.max_pool(self.conv5_3, 'pool5')  
  23.   
  24. self.fc6 = self.fc_layer(self.pool5, "fc6")  
  25. self.relu6 = tf.nn.relu(self.fc6)  
  26.   
  27. with tf.Session() as sess:  
  28.     vgg = vgg16.Vgg16()  
  29.     input_ = tf.placeholder(tf.float32, [None, 224, 224, 3])  
  30.     with tf.name_scope("content_vgg"):  
  31.         vgg.build(input_)  
  32.   
  33. feed_dict = {input_: images}  
  34. codes = sess.run(vgg.relu6, feed_dict=feed_dict)  

tensorflow中vgg_16采用的上述结构。本项目代码如下:

[python] view plain copy
  1. import os  
  2.   
  3. import numpy as np  
  4. import tensorflow as tf  
  5.   
  6. from tensorflow_vgg import vgg16  
  7. from tensorflow_vgg import utils  
[python] view plain copy
  1. data_dir = 'flower_photos/'  
  2. contents = os.listdir(data_dir)  
  3. classes = [each for each in contents if os.path.isdir(data_dir + each)]  

将图像批量batches通过VGG模型,将输出作为新的输入:

[python] view plain copy
  1. # Set the batch size higher if you can fit in in your GPU memory  
  2. batch_size = 10  
  3. codes_list = []  
  4. labels = []  
  5. batch = []  
  6.   
  7. codes = None  
  8.   
  9. with tf.Session() as sess:  
  10.     vgg = vgg16.Vgg16()  
  11.     input_ = tf.placeholder(tf.float32, [None2242243])  
  12.     with tf.name_scope("content_vgg"):  
  13.         vgg.build(input_)  
  14.   
  15.     for each in classes:  
  16.         print("Starting {} images".format(each))  
  17.         class_path = data_dir + each  
  18.         files = os.listdir(class_path)  
  19.         for ii, file in enumerate(files, 1):  
  20.             # Add images to the current batch  
  21.             # utils.load_image crops the input images for us, from the center  
  22.             img = utils.load_image(os.path.join(class_path, file))  
  23.             batch.append(img.reshape((12242243)))  
  24.             labels.append(each)  
  25.               
  26.             # Running the batch through the network to get the codes  
  27.             if ii % batch_size == 0 or ii == len(files):  
  28.                 images = np.concatenate(batch)  
  29.   
  30.                 feed_dict = {input_: images}  
  31.                 codes_batch = sess.run(vgg.relu6, feed_dict=feed_dict)  
  32.                   
  33.                 # Here I'm building an array of the codes  
  34.                 if codes is None:  
  35.                     codes = codes_batch  
  36.                 else:  
  37.                     codes = np.concatenate((codes, codes_batch))  
  38.                   
  39.                 # Reset to start building the next batch  
  40.                 batch = []  
  41.                 print('{} images processed'.format(ii))  



4、模型建立和测试
图像处理代码和标签:
[python] view plain copy
  1. # read codes and labels from file  
  2. import csv  
  3.   
  4. with open('labels') as f:  
  5.     reader = csv.reader(f, delimiter='\n')  
  6.     labels = np.array([each for each in reader if len(each) > 0]).squeeze()  
  7. with open('codes') as f:  
  8.     codes = np.fromfile(f, dtype=np.float32)  
  9.     codes = codes.reshape((len(labels), -1))  
4.1 图像预处理
[python] view plain copy
  1. from sklearn.preprocessing import LabelBinarizer  
  2.   
  3. lb = LabelBinarizer()  
  4. lb.fit(labels)  
  5.   
  6. labels_vecs = lb.transform(labels)  

对标签进行one-hot编码:daisy雏菊  dandelion蒲公英  rose玫瑰  sunflower向日葵 tulips郁金香

                    daisy雏菊        1                0                        0                 0                     0

         dandelion蒲公英        0                1                        0                 0                     0

                     rose玫瑰        0                0                        1                 0                     0

          sunflower向日葵        0                0                        0                 1                     0

                 tulips郁金香        0                0                        0                 0                     1

随机拆分数据集(之前那种直接把集中的部分图像拿出来验证/测试不管用,这里的数据集是每个种类集中放的,如果直接拿出其中的一部分,会导致验证集或者测试集是同一种花)。scikit-learn中的函数StratifiedShuffleSplit可以做到。我们这里,随机拿出20%的图像用来验证和测试,然后验证集和测试集再各占一半。

[python] view plain copy
  1. from sklearn.model_selection import StratifiedShuffleSplit  
  2.   
  3. ss = StratifiedShuffleSplit(n_splits=1, test_size=0.2)  
  4.   
  5. train_idx, val_idx = next(ss.split(codes, labels))  
  6.   
  7. half_val_len = int(len(val_idx)/2)  
  8. val_idx, test_idx = val_idx[:half_val_len], val_idx[half_val_len:]  
  9.   
  10. train_x, train_y = codes[train_idx], labels_vecs[train_idx]  
  11. val_x, val_y = codes[val_idx], labels_vecs[val_idx]  
  12. test_x, test_y = codes[test_idx], labels_vecs[test_idx]  
[python] view plain copy
  1. print("Train shapes (x, y):", train_x.shape, train_y.shape)  
  2. print("Validation shapes (x, y):", val_x.shape, val_y.shape)  
  3. print("Test shapes (x, y):", test_x.shape, test_y.shape)  

总数量:3670,则训练图像:3670*0.8=2936,验证图像:3670*0.2*0.5=367,测试图像:3670*0.2*0.5=367。


4.2 层

在上述vgg的基础上,增加一个256个元素的全连接层,最后加上一个softmax层,计算交叉熵进行最后的分类。

[python] view plain copy
  1. inputs_ = tf.placeholder(tf.float32, shape=[None, codes.shape[1]])  
  2. labels_ = tf.placeholder(tf.int64, shape=[None, labels_vecs.shape[1]])  
  3.   
  4. fc = tf.contrib.layers.fully_connected(inputs_, 256)  
  5.       
  6. logits = tf.contrib.layers.fully_connected(fc, labels_vecs.shape[1], activation_fn=None)  
  7. cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=labels_, logits=logits)  
  8. cost = tf.reduce_mean(cross_entropy)  
  9.   
  10. optimizer = tf.train.AdamOptimizer().minimize(cost)  
  11.   
  12. predicted = tf.nn.softmax(logits)  
  13. correct_pred = tf.equal(tf.argmax(predicted, 1), tf.argmax(labels_, 1))  
  14. accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))  

4.3 训练:batches和epoches
[python] view plain copy
  1. def get_batches(x, y, n_batches=10):  
  2.     """ Return a generator that yields batches from arrays x and y. """  
  3.     batch_size = len(x)//n_batches  
  4.       
  5.     for ii in range(0, n_batches*batch_size, batch_size):  
  6.         # If we're not on the last batch, grab data with size batch_size  
  7.         if ii != (n_batches-1)*batch_size:  
  8.             X, Y = x[ii: ii+batch_size], y[ii: ii+batch_size]   
  9.         # On the last batch, grab the rest of the data  
  10.         else:  
  11.             X, Y = x[ii:], y[ii:]  
  12.         # I love generators  
  13.         yield X, Y  
[python] view plain copy
  1. epochs = 10  
  2. iteration = 0  
  3. saver = tf.train.Saver()  
  4. with tf.Session() as sess:  
  5.       
  6.     sess.run(tf.global_variables_initializer())  
  7.     for e in range(epochs):  
  8.         for x, y in get_batches(train_x, train_y):  
  9.             feed = {inputs_: x,  
  10.                     labels_: y}  
  11.             loss, _ = sess.run([cost, optimizer], feed_dict=feed)  
  12.             print("Epoch: {}/{}".format(e+1, epochs),  
  13.                   "Iteration: {}".format(iteration),  
  14.                   "Training loss: {:.5f}".format(loss))  
  15.             iteration += 1  
  16.               
  17.             if iteration % 5 == 0:  
  18.                 feed = {inputs_: val_x,  
  19.                         labels_: val_y}  
  20.                 val_acc = sess.run(accuracy, feed_dict=feed)  
  21.                 print("Epoch: {}/{}".format(e, epochs),  
  22.                       "Iteration: {}".format(iteration),  
  23.                       "Validation Acc: {:.4f}".format(val_acc))  
  24.     saver.save(sess, "checkpoints/flowers.ckpt")  



验证集的正确率达到90%,很高了已经。


4.4 测试
[python] view plain copy
  1. with tf.Session() as sess:  
  2.     saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))  
  3.       
  4.     feed = {inputs_: test_x,  
  5.             labels_: test_y}  
  6.     test_acc = sess.run(accuracy, feed_dict=feed)  
  7.     print("Test accuracy: {:.4f}".format(test_acc))  


[python] view plain copy
  1. %matplotlib inline  
  2.   
  3. import matplotlib.pyplot as plt  
  4. from scipy.ndimage import imread  
[python] view plain copy
  1. test_img_path = 'flower_photos/roses/10894627425_ec76bbc757_n.jpg'  
  2. test_img = imread(test_img_path)  
  3. plt.imshow(test_img)  

[python] view plain copy
  1. with tf.Session() as sess:  
  2.     input_ = tf.placeholder(tf.float32, [None2242243])  
  3.     vgg = vgg16.Vgg16()  
  4.     vgg.build(input_)  


[python] view plain copy
  1. with tf.Session() as sess:  
  2.     img = utils.load_image(test_img_path)  
  3.     img = img.reshape((12242243))  
  4.   
  5.     feed_dict = {input_: img}  
  6.     code = sess.run(vgg.relu6, feed_dict=feed_dict)  
  7.           
  8. saver = tf.train.Saver()  
  9. with tf.Session() as sess:  
  10.     saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))  
  11.       
  12.     feed = {inputs_: code}  
  13.     prediction = sess.run(predicted, feed_dict=feed).squeeze()  
[python] view plain copy
  1. plt.imshow(test_img)  

[python] view plain copy
  1. plt.barh(np.arange(5), prediction)  
  2. _ = plt.yticks(np.arange(5), lb.classes_)  

上图的花最有可能是Rose,有小概率是Tulips。




阅读更多
文章标签: tensorflow vgg16 微调
个人分类: 深度学习
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页

关闭
关闭
关闭