CNN实战(卷积神经网络实战)二

CNN实战(卷积神经网络实战)二

本文详细介绍 CIFAR-10数据集的图像分类的代码

一、划分训练集和测试集,验证集

from keras.utils import np_utils

# one-hot encode the labels
num_classes = len(np.unique(y_train))
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

# break training set into training and validation sets
(x_train, x_valid) = x_train[5000:], x_train[:5000]
(y_train, y_valid) = y_train[5000:], y_train[:5000]

# print shape of training set
print('x_train shape:', x_train.shape)

# print number of training, validation, and test images
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print(x_valid.shape[0], 'validation samples')

二、定义网络架构

如何创建网络架构,你可以多学习一些流行的架构(AlexNet,ResNet…)

本文使用较小的AlexNet.

from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout

model = Sequential()
model.add(Conv2D(filters=16, kernel_size=2, padding='same', activation='relu', 
                        input_shape=(32, 32, 3)))
model.add(MaxPooling2D(pool_size=2))
model.add(Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Conv2D(filters=64, kernel_size=2, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Dropout(0.3))
model.add(Flatten())
model.add(Dense(500, activation='relu'))
model.add(Dropout(0.4))
model.add(Dense(10, activation='softmax'))

model.summary()

代码和结果不解读了,上一章已经讲了。

直接到训练模型

三、训练模型

# compile the model
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
from keras.callbacks import ModelCheckpoint   

# train the model
checkpointer = ModelCheckpoint(filepath='model.weights.best.hdf5', verbose=1, save_best_only=True)

hist = model.fit(x_train, y_train, batch_size=32, epochs=100,
          validation_data=(x_valid, y_valid), callbacks=[checkpointer], 
          verbose=2, shuffle=True)

结果:

Train on 45000 samples, validate on 5000 samples
Epoch 1/100
 - 19s - loss: 1.6015 - accuracy: 0.4201 - val_loss: 1.4069 - val_accuracy: 0.4974

Epoch 00001: val_loss improved from inf to 1.40692, saving model to model.weights.best.hdf5
Epoch 2/100
 - 26s - loss: 1.2940 - accuracy: 0.5354 - val_loss: 1.1676 - val_accuracy: 0.5960

Epoch 00002: val_loss improved from 1.40692 to 1.16764, saving model to model.weights.best.hdf5
Epoch 3/100
 - 27s - loss: 1.1701 - accuracy: 0.5862 - val_loss: 1.1453 - val_accuracy: 0.5996

Epoch 00003: val_loss improved from 1.16764 to 1.14532, saving model to model.weights.best.hdf5
Epoch 4/100
 - 29s - loss: 1.0884 - accuracy: 0.6146 - val_loss: 1.0601 - val_accuracy: 0.6232

Epoch 00004: val_loss improved from 1.14532 to 1.06011, saving model to model.weights.best.hdf5
Epoch 5/100
 - 28s - loss: 1.0522 - accuracy: 0.6304 - val_loss: 1.0739 - val_accuracy: 0.6268

Epoch 00005: val_loss did not improve from 1.06011
Epoch 6/100
 - 29s - loss: 1.0232 - accuracy: 0.6437 - val_loss: 0.9668 - val_accuracy: 0.6626

Epoch 00006: val_loss improved from 1.06011 to 0.96683, saving model to model.weights.best.hdf5
Epoch 7/100
 - 30s - loss: 1.0141 - accuracy: 0.6500 - val_loss: 1.1182 - val_accuracy: 0.6114

Epoch 00007: val_loss did not improve from 0.96683
Epoch 8/100
 - 29s - loss: 1.0096 - accuracy: 0.6526 - val_loss: 0.9595 - val_accuracy: 0.6762

Epoch 00008: val_loss improved from 0.96683 to 0.95946, saving model to model.weights.best.hdf5
Epoch 9/100
 - 31s - loss: 1.0088 - accuracy: 0.6552 - val_loss: 0.9655 - val_accuracy: 0.6728

Epoch 00009: val_loss did not improve from 0.95946
Epoch 10/100
 - 31s - loss: 1.0156 - accuracy: 0.6583 - val_loss: 1.1407 - val_accuracy: 0.6328

理想情况下,val_loss和val_acc会不断上升,但是 其实不是,在第7轮中val_loss升到1.1182,说明网络没有改善。

因此,我们对结果要有一个调优的方向。

1、如果val_loss不断震荡,需要降低学习率的超参数

2、val_loss没有下降,意味着模型过于简单,需要增加隐藏层。

3、如果 loss下降且val_loss停止下降,意味着过拟合,需要引入dropout等技术避免过拟合。

参考《深度学习计算机视觉》第三章,本人正在学习深度学习,有错误希望各位指正,大家一起进步。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值