Keras 入门课6 -- 使用Inception V3模型进行迁移学习

Keras 入门课6:使用Inception V3模型进行迁移学习

本系列课程代码,欢迎star:
https://github.com/tsycnh/Keras-Tutorials

keras 请使用2.1.2版

深度学习可以说是一门数据驱动的学科,各种有名的CNN模型,无一不是在大型的数据库上进行的训练。像ImageNet这种规模的数据库,动辄上百万张图片。对于普通的机器学习工作者、学习者来说,面对的任务各不相同,很难拿到如此大规模的数据集。同时也没有谷歌,Facebook那种大公司惊人的算力支持,想从0训练一个深度CNN网络,基本是不可能的。但是好在已经训练好的模型的参数,往往经过简单的调整和训练,就可以很好的迁移到其他不同的数据集上,同时也无需大量的算力支撑,便能在短时间内训练得出满意的效果。这便是迁移学习。究其根本,就是虽然图像的数据集不同,但是底层的特征却是有大部分通用的。

迁移学习主要分为两种

  • 第一种即所谓的transfer learning,迁移训练时,移掉最顶层,比如ImageNet训练任务的顶层就是一个1000输出的全连接层,换上新的顶层,比如输出为10的全连接层,然后训练的时候,只训练最后两层,即原网络的倒数第二层和新换的全连接输出层。可以说transfer learning将底层的网络当做了一个特征提取器来使用。
  • 第二种叫做fine tune,和transfer learning一样,换一个新的顶层,但是这一次在训练的过程中,所有的(或大部分)其它层都会经过训练。也就是底层的权重也会随着训练进行调整。

一个典型的迁移学习过程是这样的。首先通过transfer learning对新的数据集进行训练,训练过一定epoch之后,改用fine tune方法继续训练,同时降低学习率。这样做是因为如果一开始就采用fine tune方法的话,网络还没有适应新的数据,那么在进行参数更新的时候,比较大的梯度可能会导致原本训练的比较好的参数被污染,反而导致效果下降。

本课,我们将尝试使用谷歌提出的Inception V3模型来对一个花朵数据集进行迁移学习的训练。

数据集为17种不同的花朵,每种有80张样本,一共1360张图像,属于典型的小样本集。数据下载地址:http://www.robots.ox.ac.uk/~vgg/data/flowers/17/
官方没有给出图像对应的label,我写了一段代码,把每张图像加上标签,同时,Keras对于数据的格式要求如下:
我写了一个脚本来做转换
https://gist.github.com/tsycnh/1b35103adec1ad2be5090c486354859f

2018年09月02日更新:
花朵命名按顺序命名为flower_A, flower_B, … , flower_Q。

data/
    train/
        class1/
            img1
            img2
            ...
        class2/
            img1
            ...
    validation/
        class1/
            img1
            img2
            ...
        class2/
            img1
            ...
    test/
        class1/
            img1
            img2
            ...
        class2/
            img1
            ...

这个脚本我将训练集划分为800张,验证集和测试集分别为260张,图片顺序做了随机打乱

如果你懒得自己转换,我已经把处理好的数据进行上传,直接下载即可:https://download.csdn.net/download/tsyccnh/10641502

请注意,这里的花朵识别仍属于最简单的单分类任务,样张如下


这里写图片描述

from keras.preprocessing.image import ImageDataGenerator
from keras.applications.inception_v3 import InceptionV3,preprocess_input
from keras.layers import GlobalAveragePooling2D,Dense
from keras.models import Model
from keras.utils.vis_utils import plot_model
from keras.optimizers import Adagrad
# 数据准备
train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_inp
  • 28
    点赞
  • 142
    收藏
    觉得还不错? 一键收藏
  • 59
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 59
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值