lsh呵呵的专栏

埋在一座城,关了所有灯

keras分类猫狗数据(下)迁移学习

keras分类猫狗数据(上)数据预处理
keras分类猫狗数据(中)使用CNN分类模型
keras分类猫狗数据(下)迁移学习
keras分类猫狗数据(番外篇)深度学习CNN连接SVM分类

1 .使用keras.applications中的vgg16网络模型进行特征提取,并自定义两个全连接层输出分类。

from keras.applications import VGG16
from keras import models,layers,optimizers
from keras.callbacks import TensorBoard

conv_base=VGG16(weights='imagenet',include_top=False,input_shape=(128,128,3))

model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))

conv_base.trainable=False

model.summary()

model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['acc'])

import catvsdogs.morph as mp

model.fit_generator(
      mp.train_flow,
      steps_per_epoch=32,
      epochs=50,
      validation_data=mp.test_flow,
      validation_steps=32,callbacks=[TensorBoard(log_dir='logs/3')])
model.save_weights('outputs/weights_vgg16_use.h5')

这里写图片描述
这里写图片描述
在30多轮迭代后,测试正确率达到88%。

2 . 微调,使vgg16模型的最后一个卷积层也参与训练,本次使用上文保存的训练权重集weights_vgg16_use.h5加速训练过程,并使用较小的学习率。

from keras.applications import VGG16
from keras import models,layers,optimizers
from keras.callbacks import TensorBoard

conv_base=VGG16(weights='imagenet',include_top=False,input_shape=(128,128,3))

model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))

model.load_weights('outputs/weights_vgg16_use.h5')

conv_base.trainable=True
trainable=False
for layer in conv_base.layers:
    if layer.name=='block5_conv1':
        trainable=True
    layer.trainable=trainable
model.summary()

model.compile(optimizer=optimizers.adam(lr=1e-5),loss='binary_crossentropy',metrics=['acc'])

import catvsdogs.morph as mp

history = model.fit_generator(
      mp.train_flow,
      steps_per_epoch=32,
      epochs=50,
      validation_data=mp.test_flow,
      validation_steps=32,callbacks=[TensorBoard(log_dir='logs/4')])

这里写图片描述
这里写图片描述
上图蓝色为本文过程1的,红色为过程2的,正确率到达90%。本文只使用了2000+1000的数据,迭代次数较少,如果想打算更高的识别率,可以简单修改。

阅读更多
版权声明:本文为博主原创文章,转载请注明出处。 https://blog.csdn.net/nima1994/article/details/79952368
个人分类: python
所属专栏: 机器学习入门与放弃
上一篇keras分类猫狗数据(中)使用CNN分类模型
下一篇keras分类猫狗数据(番外篇)深度学习CNN连接SVM分类
想对作者说点什么? 我来说一句

深度学习解决(Kaggle)猫/狗分类问题

2015年06月03日 14.13MB 下载

猫狗识别数据

2018年03月10日 65B 下载

没有更多推荐了,返回首页

关闭
关闭