基于VGG16预训练网络特征提取在小型训练集上的应用(kaggle - 猫狗分类)(《python深度学习》)

0 前言

     在之前的例子中,我们采用一个从头开始训练的卷积神经网络在训练样本数共2000个的kaggle猫狗训练集上训练,哪怕采用了4组Conv2D和MaxPooling组合、随机图片变换、全连接输入层高达50%的dropout以及多达100次的循环最后通过绘图分析发现还是不可避免的较早的出现过拟合,最后在测试集上的正确率仅仅为75%。归根结底是因为样本数太少,即使通过图片的随机变换仍然只是在训练集上抓取特征,模型的泛化能力太低。因此在样本数量不能提高的情况下我们常采用预训练网络来行之有效的构造模型。本例采用的预训练模型架构是VGG16。

1 预训练网络——VGG16

     VGG16是2014年开发的在ImageNet(多达140万张标记图像,1000种不同类别)训练好的大型卷积神经网络,因为ImageNet1000种分类包含猫狗的分类,因此可以认为基于此训练的VGG16能在本例上有良好的表现。

     要想将VGG16应用到本例我们需要的是其已经训练好的卷积核,至于后面的全连接层丢弃不用,而卷积神经网络之所以能起到辨识的能力重点就在于全连接层之前的卷积核。全连接层分类器往往针对的是模型训练的类别,其中仅包括某个类别在整体中出现的概率。此外,密集连接层忽略了图像的位置信息,而这一信息是由卷积特征图描述的,如果位置对于问题很重要,那么密集连接层特征很大程度上是没用的。因此我们的任务就是将VGG16的卷积核移植到我们自行设计的全连接层上,移植的方法有两种:特征提取与微调模型,本例采用的是特征提取。

2 VGG16特征提取

2.1 特征提取简介

   所谓的特征提取就是将要训练的数据输入到之前学习过的模型中,然后从中提取有用的特征,随后将这些特征输入到新的分类器进行训练。这里需要注意的是训练的仅仅是后面的分类模型也就是全连接层的参数,而卷积核的数据是不参与学习的,示意图如下:

2.2 keras实例化

     VGG16模型内置于Keras中输入如下代码导入:

from keras.applications import VGG16

  此外还要对输入类型等实例化:

conv_base = VGG16(
    weights = 'imagenet',
    include_top = False,
    input_shape = (150,150,3))

这里的参数含义为:

  1. weights 指定模型的初始化权重检查点。
  2. include_top True为连接密集连接层,本例设置为False。
  3. input_shape 指定输入数据类型。

查看一下实例化后卷积基模型:

可见最后一层的数据类型为(4,4,512),在本层后面添加一个全连接层的分类器。

3 全连接层分类器的添加与编译

3.1 添加方法综述

    添加全连接层分类器有两种方法:

  1. 在数据集上独立先运行卷积基,将特征结果(4,4,512)以numpy数组的形式保存,随后再输入到全连接层。
  2. 在全连接层顶部添加VGG模型(模型数据锁住)。

   两种方法各有利弊,总的来说,第一种方法的优点是速度快,计算代价小。因为i在卷积神经网络中最为费时的就是卷积计算的过程,没有一个好的GPU协同计算计算的时间将非常大。而第一种方法只过一遍卷积计算,随后的全连接层的计算就十分的快了。然而,正是因为只过了一遍的卷积计算没有经过数据加强,过拟合的风险就比较高。

   相反的,第二种方法因为将模型添加在全连接层之前,这样就可以像普通的卷积神经网络一样,输入经过数据增强的图像数据,大大降低过拟合几率。由于每次数据输入都要运行卷积计算,那么运算的代价就变得十分的高。下面就分别采用两种方法进行实验。

3.2 无数据增强的快速特征提取

3.2.1 特征的提取

   提取特征数据是在将数据先输入conv_base后得到的结果数据,然后再将其保存为numpy数组。这里获得输出结果用的是predict方法。具体实现如下:

def extract_features(directory,sample_count):
    features = np.zeros(shape = (sample_count,4,4,512))
    labels = np.zeros(shape = (sample_count))
    
    generator = datagen.flow_from_directory(
        directory,
        target_size = (150,150),
        batch_size = batch_size,
        class_mode = 'binary')
    
    i = 0
    for input_batch,labels_batch in generator:
 
  • 3
    点赞
  • 40
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值