附上代码加数据地址 https://github.com/Liuyubao/transfer-learning ,欢迎参考。
一、Inception-V3模型
1.1 详细了解模型可参考以下论文:
[v1] Going Deeper with Convolutions, 6.67% test error
http://arxiv.org/abs/1409.4842
[v2] Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift, 4.8% test error
http://arxiv.org/abs/1502.03167
[v3] Rethinking the Inception Architecture for Computer Vision, 3.5% test error
http://arxiv.org/abs/1512.00567
[v4] Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning, 3.08% test error
http://arxiv.org/abs/1602.07261
1.2 CNN结构演化及Inception-V3简介
v3一个最重要的改进是分解(Factorization),将7x7分解成两个一维的卷积(1x7,7x1),3x3也是一样(1x3,3x1),这样的好处,既可以加速计算(多余的计算能力可以用来加深网络),又可以将1个conv拆成2个conv,使得网络深度进一步增加,增加了网络的非线性,还有值得注意的地方是网络输入从224x224变为了299x299,更加精细设计了35x35/17x17/8x8的模块。
二、迁移学习
2.1 What is transfer learning?
在深度学习中,所谓的迁移学习是将一个问题A上训练好的模型通过简单的调整使其适应一个新的问题B。在实际使用中,往往是完成问题A的训练出的模型有更完善的数据,而问题B的数据量偏小。而调整的过程根据现实情况决定,可以选择保留前几层卷积层的权重,以保留低级特征的提取;也可以保留全部的模型,只根据新的任务改变其fc层。
2.2 What can transfer learning do?
那么对于不同的任务,为什么不同的模型间可以做迁移呢?上面提到了,被迁移的模型往往是使用大量样本训练出来的,比如Google提供的Inception V3网络模型使用ImageNet数据集训练,而ImageNet中有120万标注图片,然后在实际应用中,很难收集到如此多的样本数据。而且收集的过程需要消耗大量的人力无力(其实深度学习解决实际问题时,最好费时间的往往不是训练的过程,而是数据标记的过程),所以一般情况下来说,问题B的数据量是较少的。
所以,同样一个模型在使用大样本很好的解决了问题A,那么有理由相信该模型中训练处的权重参数能够能够很好的完成特征提取任务(最起码前几层是这样),所以既然已经有了这样一个模型,那就拿过来用吧。
所以迁移学习具有如下优势:
- 更短的训练时间
- 更快的收敛速度
- 更精准的权重参数。
但是一般情况下如果任务B的数据量是足够的,那么迁移来的模型效果会不如训练的到,但是此时起码可以将底层的权重参数作为初始值来重新训练。
2.3 代码实现 Tensorflow 对 Inception-V3 进行迁移
(1)前期数据、模型准备
谷歌提供的训练好的Inception-v3模型:
https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip
解压后有两个文件,使用的是 .pb 文件
案例使用的数据集:
http://download.tensorflow.org/example_images/flower_photos.tgz
数据集文件解压后,包含5个子文件夹,子文件夹的名称为花的名称,代表了不同的类别。平均每一种花有734张图片,图片是RGB色彩模式,大小也不相同。
(2)导入相关工具包
# -*- coding: utf-8 -*-
"""
Created on May 31 2018
@author: 柳玉豹
"""
import glob
import os.path
import random
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
(3)模型和样本路径的设置
#模型和样本路径的设置
#inception-V3瓶颈层节点个数
BOTTLENECK_TENSO