系列博客是博主学习神经网络中相关的笔记和一些个人理解,仅为作者记录笔记之用,不免有很多细节不对之处。
MNIST识别率能否再提高
我的这一阶段目标是在学习完浅层BP神经网络的相关知识时,可以将手写字的识别率达到98%。在前面的几次实验中,MNIST手写字的识别率始终未达到98%以上,这样的结果是有些让人沮丧的。
今天进过艰苦奋斗,多次尝试之后终于将MNIST手写字的识别率提高到了98%以上,最高识别率达到98.39%。这次的实验是利用Matlab代码进行的,每次有放回的从训练数据中取 mini_batch_size 样本,下面是三组实验结果:
第一组:网络结构 [784,80,80,10], mini_batch_size = 100,max_iteration = 50000, eta = 1, lambda = 5,最高识别率达到98.34%,下面是测试结果曲线
第二组:网络结构 [784,100,100,10], mini_batch_size = 100,max_iteration = 50000, eta = 1, lambda = 5,最高识别率达到98.39%,下面是测试结果曲线
第三组:网络结构 [784,120,120,10], mini_batch_size = 100,max_iteration = 50000, eta = 1, lambda = 5,最高识别率达到98.37%,下面是测试结果曲线
上面挑选的是每组参数中最高识别率的结果,每次运行的结果稍有差异,但基本在30000次训练(大约相当于30000*100/50000 = 60个epoch)之后,能稳定达到98%的识别率。这个结果对我来说是十分满意的,完成了自己的学习目标。本节代码可以在这里下载到(没有积分的同学可以私信我)
以第二组参数,我对不同的隐层激活函数(Sigmoid函数和ReLU激活函数)、不同的输出层函数(Sigmoid输出函数和Softmax输出函数)和不同的代价函数(二次代价函数和交叉熵代价函数)的各种组合进行了一组测试,测试结果如下:
隐层激活函数 | 输出层函数 | 代价函数 | 校验数据最高识别率 | 测试数据识别率 |
---|---|---|---|---|
Sigmoid | Sigmoid | Quadratic | 97.30% | 96.90% |
Sigmoid | Sigmoid | Cross-entropy | 97.40% | 96.93% |
Sigmoid | Softmax | Cross-entropy | 97.23% | 97.02% |
ReLU | Sigmoid | Quadratic | 98.30% | 98.29% |
ReLU | Sigmoid | Cross-entropy | 98.38% | 98.09% |
ReLU | Softmax | Cross-entropy | 溢出 | 溢出 |
注:校验数据最高识别率,在模型训练时同时用 validation_data 进行的测试,其中最高的识别率;测试数据识别率,指利用 test_data 对最后训练的模型的测试的识别率。
总结
下面是对BP神经网络学习的一个知识点总结:
四个基本方程
δL=∇aC⊙σ′(zL) δ L = ∇ a C ⊙ σ ′ ( z L )δl=((wl+1)Tδl+1)⊙σ′(zl) δ l = ( ( w l + 1 ) T δ l + 1 ) ⊙ σ ′ ( z l )∂C∂blj=δlj ∂ C ∂ b j l = δ j l∂C∂wljk=al−1kδlj ∂ C ∂ w j k l = a k l − 1 δ j l这里 ⊙ ⊙ 表示对应元素相乘,实施算法如下:
随机梯度下降法基本更新公式
wk→w′k=wk−η∂C∂wk w k → w k ′ = w k − η ∂ C ∂ w kbl→b′l=bl−η∂C∂bl b l → b l ′ = b l − η ∂ C ∂ b l- 在实践中,通常将反向传播算法和诸如随机梯度下降这样的学习算法进行组合使用。特别地,给定一个大小为
m
m
的小批量数据,下面的算法在这个小批量数据的基础上应用一步梯度下降学习算法:
- ReLU激活函数:可以缓解梯度弥散的问题,提高学习速率
- 交叉熵代价函数:
C=−1n∑x[ylna+(1−y)ln(1−a)] C = − 1 n ∑ x [ y ln a + ( 1 − y ) ln ( 1 − a ) ]针对一个训练样本 x x 的输出误差 为δL=aL−y δ L = a L − y
- L2正则化
C=C0+λ2n∑ww2 C = C 0 + λ 2 n ∑ w w 2对于小批量数据 m m ,更新方程变为偏置的更新方程不发生变化
- L1正则化
C=C0+λn∑w|w| C = C 0 + λ n ∑ w | w |对于小批量数据 m m ,更新方程变为: