Keras实现LeNet识别mnist

转载自:https://www.jianshu.com/p/d39e9e75a410

Keras是一款特别友好的基于Python的深度学习库,甚至比Tensorflow还友好。关于Keras的介绍和配置,可以看我之前的文章Keras的介绍与配置,也可以直接查看官网中文文档

接下来我们要做被誉为机器学习届的Hello World的手写数字识别。真的掌握了这个,就已经把Keras掌握得七七八八了。剩下的就是算法方面的问题了。

我们知道,机器学习的工作,比起别的编程工作,有两个特别大的痛点。

一个是数据集的问题。因为收集数据集是一个很繁琐的活,尤其是某些更加复杂,无法自己生成的数据集。例如我要做一个汽车的分类问题。算法其实可以直接在很多多分类问题的算法基础上改,但是最大的问题是从哪里获得那么多汽车的图片。而且数据集这种东西,毋庸置疑是越全面越广泛越好的。数据量太少了就容易过拟合,或者说碰到一个新的情况就不知道怎么办了。

另一个则是训练的问题。这点最近也搞得我很头疼。因为模型复杂起来,训练就需要特别久。最可气的是,在模型一开始的时候,往往收敛比较缓慢,甚至出现震荡。这个时候你很难知道模型到底是不收敛还只是震荡。尤其是对苦逼的学生党,一跑就是几天,可能loss一直保持不变。

作为HelloWord,最关键在于简单。简单到很快就可以看到效果,给新手一种我也写出了东西的感觉。手写数字识别就是这样的东西。

手写数字识别,可以说是卷积神经网络(CNN)第一次发挥出作用,可以说是CNN的开端。当时诞生了经典的LeNet模型,可以对手写数字识别做到非常好的效果。并且,已经五脏俱全,现代CNN网络的基本组件都已经出现了。但是因为当时的计算机能力有限,加上像SVM等机器学习的方法在手写数字识别上面,也可以取得比较好的效果。CNN的锋芒就被掩盖了。直到AlexNet出现的时候,计算机运算速度大大增加,出现了多GPU运算和大数据,使得CNN开始大放异彩。近十几年又开始火了起来。因此,手写数字识别算是CNN解决的第一个,也是最简单的一个问题。

前面两个痛点在这个问题上面完全不存在。手写数字识别有一个非常经典的库,叫做mnist。已经成为了这个问题的非常经典的库了。里面有60000个28*28像素的灰度图作为训练集,还有10000个测试集。由于LeNet是最早的CNN,完全不用担心数据跑很久的问题,基本是跑一两个epoch就会收敛了。

mnist中的样例

下面介绍LeNet。LeNet的模型也很简单,具体结构如下图。两层卷积+池化层,两个连接层就没有了。至于什么是卷积层,什么是连接层,请看我关于卷积神经网络的介绍的文章(其实还没有写)。

LeNet模型

说了这么多,我们试着写一下吧。

首先引入我们将用到的数据包:

import numpy as np
import keras
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense, Activation, Conv2D, MaxPooling2D, Flatten
from keras.optimizers import Adam

在Keras中,像mnist这样经典的数据集,都是直接给了。我们只需要一行代码:

(X_train, y_train), (X_test, y_test) = mnist.load_data()

接着我们要对数据进行预处理。首先把X换成的数据换成图片对应的形式,接着对y进行独热编码。

X_train = X_train.reshape(-1, 28, 28, 1)  # normalize
X_test = X_test.reshape(-1, 28, 28, 1)      # normalize
X_train = X_train / 255
X_test = X_test / 255
y_train = np_utils.to_categorical(y_train, num_classes=10)
y_test = np_utils.to_categorical(y_test, num_classes=10)

reshape函数代表了把数据变成什么样的矩阵(张量)。就比如我们的数据集,就是60000*28*28*1的张量。前面的-1代表了自动推导。我们告诉它,这些数据是28*28的灰度图,它就可以自动推导出这里面有多少数据了。需要再次强调的是,如果是Theano作为后端,则变成了1*28*28

对数据进行归一化,即把数据变成(0,1)区间的小数。原来的数据是0~255的,这样不方便处理,收敛的效果也相对不好。而把数据归一化,则更加方便计算机处理,可以加快收敛速率。

对数字进行独热编码,原因就比较复杂了。首先,我们的数据很可能不像这里,是0-9,而有可能是,猫狗鼠。为了让计算机知道我们在做什么,我们就需要给他们附上数字。例如0代表猫,1代表狗,2代表鼠。但是,即使转化为数字表示后,上述数据也不能直接拿来用,因为分类器往往默认数据数据是连续的,并且是有序的。因此我们还需要对数据进行独热编码即 One-Hot 编码。把001代表0,010代表1,100代表2。这样,我们的分类器总算不会认为他们有什么关系了。同时也扩充了特征。

接着我们就开始建立我们的模型了。在Keras里面,模型分为Sequential和Functional,即贯序模型和函数式模型。贯序模型,顾名思义比较适合一条道走到黑的模型,而函数式模型则比较厉害,可以胜任任何更加复杂的模型。鉴于我们的LeNet就是典型的一条道走到黑的模型,所以我们就使用Sequential就行了。

接下来就是代码了。代码很容易,看着前面的LeNet的图,一个一个数据比对就知道代码是什么意思了(当然可能还需要一点英文水平)。

model = Sequential()
model.add(Conv2D(input_shape=(28, 28, 1), kernel_size=(5, 5), filters=20, activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2), strides=2, padding='same'))

model.add(Conv2D(kernel_size=(5, 5), filters=50,  activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2,2), strides=2, padding='same'))

model.add(Flatten())
model.add(Dense(500, activation='relu'))
model.add(Dense(10, activation='softmax'))
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

这就是为什么说Keras简单易用了。我们完全是按照上面的图一层一层地接起来的,就好像接水管一样。可以把心思都放在模型的搭建上面,而不管代码的其他问题。

这里的padding表示补0策略,有“valid”和“same” 。“valid”代表只进行有效的卷积,即对边界数据不处理。“same”代表保留边界处的卷积结果,通常会导致输出shape与输入shape相同。默认是valid,但是考虑到几层卷积下来,可能卷积连接有问题,一般会用same。

至于激活函数,当然是用经典的relu。优化器用的rmsprop,学习率是默认的学习率。优化器等我以后有机会再讲。学习率则是梯度下降的速率。通常来说,如果学习率比较高,则有可能产生震荡,而学习率比较低可能收敛比较慢。通常我们采用的是除三法,即从0.01开始,尝试0.003,0.001……直到找到收敛比较快的点。

我们的水管接好了,接下来就是通水了。

print('Training')
model.fit(X_train, y_train, epochs=2, batch_size=32)

print('\nTesting')
loss, accuracy = model.evaluate(X_test, y_test)

print('\ntest loss: ', loss)
print('\ntest accuracy: ', accuracy)

这里的batch_size是32。因为我们一次遍历所有数据算一次损失函数,开销是比较大的。因此我们把数据分成若干份,按批次更新参数,然后跑完一个epoch在合起来做一次。batch_size也是很玄学的东西,大了开销大,小了不准确。

我们只跑了2个epoch。训练的效果就可以达到98%左右了。当然我们可以多跑几个epoch,效果肯定更好。

至此,我们的CNN就已经入门了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值