看起来容易
不是说手写数字识别这种简单任务相当于深度学习入门的“Hello World”,用最简单深度模型都可以做出比较高的识别率吗,还需要花时间去做优化?
确实,一个基本的全连接模型可以做到97%的准确率,相对于抛硬币已经算非常好了。但是要知道97%的识别率是在国外的MNIST测试集上的结果,而不是在实际生产中得到的结果。这样的模型用到国内的应用场景中表现又会怎样?
另外一个问题,即使能达到97%,客户能满意吗?比如有一个手写数字表单,每个表单十个手写数字,十张表单就有一百个数字。每识别10张表单,平均就有1-3张表单的结果存在问题!这样的结果有人愿意掏钱吗?所以97%的准确率作为学术性结论也许还行,就算99%的准确率在生产应用中也会被骂,客户要的是无限接近100%。
就是这么一个看似简单的任务,要想做到无限接近100%的准确率,也许并不十分容易。另外你还需要自己从表单上把数字都抠出来,既要保持完整,又不能夹带干扰元素。很多看起来“简单”的事情,也许我们自己不实际面对根本想不到其中的难处,但无论如何这是我们追求的目标!
影响训练效果的因素
如果一个模型的实际应用效果不及预期,原因可能是多方面的。从我自己的遇到的问题来看主要有以下几个方面:
1、训练的样本不够,达不到深度学习训练的需要
2、训练样本不少,但类型比较单一,无法覆盖生产中的主要情况
3、训练数据存在干扰,比如数字4的样本里有9,或像9的样本。如果数量少问题也许不太大,如果比较多就会影响模型的判断
4、实际应用中输入的数据干扰多,引起误判
5、深度模型设计不好,容易过耦合。过拟合的结果就是,只对我们训练中给的情况识别效果好,实际应用中情况稍微变化一点就不行了,也就是泛化能力不够
前4个问题都是数据问题。针对第3个问题在我上一篇文章里已经给出了我自己的解决办法。
针对问题1、2,我们可以首先在在网上寻找类似问题的训练数据集,没准运气好能找到。实在找不到就只能自己造数据了。
自己造数据是一个费时费力的过程,这个过程中尽量用程序来做一些重复的工作;另外我们要尽可能利用好来自不易的数据,通过适当变化组合来形成新的训练数据(数据增强);还有就是不断收集用户测试中产生的数据来丰富训练样本,不断让模型去适应新的现实情况。
针对问题4,一方面需要想办法尽可能排除干扰,另一方面可以在训练样本中模拟实际情况增加一些干扰,增强模型判断的鲁棒性。
针对模型过拟合的问题,主要是模型设计的问题,需要对模型架构进行调整。
他山之石可以攻玉
针对模型的设计问题,像我们这样的以应用而非学术研究为导向的情况,最好先寻找已经研究过的模型设计论文。很多研究论文有过程、有数据、有结论,在不涉及知识产权纠纷的前提下要充分利用。毕竟我们的精力是有限的,做模型探究实验、比较、评估需要耗费大量的准备和训练时间。
这里推荐一个网站:https://www.kaggle.com
这是一个为开发商和数据科学家提供举办机器学习竞赛、托管数据库、编写和分享代码的平台。这里有不少关于各种实际问题的代码、模型、数据资源可供参考。
比如,从这里找到一篇关于MNIST的识别模型的研究文章
《How to choose CNN Architecture MNIST》
文章分别从卷积和池化层对的数量、各层卷积核的数量、Dense层的大小、Dropout层的丢弃比例、批量归一化、数据增强等角度做了测试和探讨,内容详实比较值得借鉴。 文章给的建议模型设计如下:
model = Sequential()
model.add(Conv2D(32, kernel_size = 3, activation='relu', input_shape = (28, 28, 1)))
model.add(BatchNormalization())
model.add(Conv2D(32, kernel_size = 3, activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(32, kernel_size = 5, strides=2, padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.4))
model.add(Conv2D(64, kernel_size = 3, activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(64, kernel_size = 3, activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(64, kernel_size = 5, strides=2, padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.4))
model.add(Conv2D(128, kernel_size = 4, activation='relu'))
model.add(BatchNormalization())
model.add(Flatten())
model.add(Dropout(0.4))
model.add(Dense(10, activation='softmax'))
# COMPILE WITH ADAM OPTIMIZER AND CROSS ENTROPY COST
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
这个识别模型我们可以直接使用,也可以根据我们的实际需求在这个基础上再优化迭代。包括模型中层的调整、参数调整、训练次数的调整。调试迭代过程中应该对数据、模型、训练结果做好记录备份。选取其中最好的作为生产用的模型。
本期到此为止。《基于AI的计算机视觉识别在Java项目中的使用》专题将按下列章节展开,欢迎关注我的个人公众号( TuXiang )和CSDN( TuXiang++ )。
一、《基于AI的计算机视觉识别在Java项目中的使用 —— 背景》
二、《基于AI的计算机视觉识别在Java项目中的使用 —— OpenCV的使用》
三、《基于AI的计算机视觉识别在Java项目中的使用 —— 搭建基于Docker的深度学习训练环境》
四、《基于AI的计算机视觉识别在Java项目中的使用 —— 准备深度学习训练数据》
五、《基于AI的计算机视觉识别在Java项目中的使用 —— 深度模型的训练调优》
六、《基于AI的计算机视觉识别在Java项目中的使用 —— 深度模型在Java环境中的部署》