基于AI的计算机视觉识别在Java项目中的使用 (五) —— 深度模型的训练调优(续)

看起来容易

不是说手写数字识别这种简单任务相当于深度学习入门的“Hello World”,用最简单深度模型都可以做出比较高的识别率吗,还需要花时间去做优化?

确实,一个基本的全连接模型可以做到97%的准确率,相对于抛硬币已经算非常好了。但是要知道97%的识别率是在国外的MNIST测试集上的结果,而不是在实际生产中得到的结果。这样的模型用到国内的应用场景中表现又会怎样?

另外一个问题,即使能达到97%,客户能满意吗?比如有一个手写数字表单,每个表单十个手写数字,十张表单就有一百个数字。每识别10张表单,平均就有1-3张表单的结果存在问题!这样的结果有人愿意掏钱吗?所以97%的准确率作为学术性结论也许还行,就算99%的准确率在生产应用中也会被骂,客户要的是无限接近100%。

就是这么一个看似简单的任务,要想做到无限接近100%的准确率,也许并不十分容易。另外你还需要自己从表单上把数字都抠出来,既要保持完整,又不能夹带干扰元素。很多看起来“简单”的事情,也许我们自己不实际面对根本想不到其中的难处,但无论如何这是我们追求的目标!

影响训练效果的因素

如果一个模型的实际应用效果不及预期,原因可能是多方面的。从我自己的遇到的问题来看主要有以下几个方面:

1、训练的样本不够,达不到深度学习训练的需要

2、训练样本不少,但类型比较单一,无法覆盖生产中的主要情况

3、训练数据存在干扰,比如数字4的样本里有9,或像9的样本。如果数量少问题也许不太大,如果比较多就会影响模型的判断

4、实际应用中输入的数据干扰多,引起误判

5、深度模型设计不好,容易过耦合。过拟合的结果就是,只对我们训练中给的情况识别效果好,实际应用中情况稍微变化一点就不行了,也就是泛化能力不够

前4个问题都是数据问题。针对第3个问题在我上一篇文章里已经给出了我自己的解决办法。

针对问题1、2,我们可以首先在在网上寻找类似问题的训练数据集,没准运气好能找到。实在找不到就只能自己造数据了。

自己造数据是一个费时费力的过程,这个过程中尽量用程序来做一些重复的工作;另外我们要尽可能利用好来自不易的数据,通过适当变化组合来形成新的训练数据(数据增强);还有就是不断收集用户测试中产生的数据来丰富训练样本,不断让模型去适应新的现实情况。

针对问题4,一方面需要想办法尽可能排除干扰,另一方面可以在训练样本中模拟实际情况增加一些干扰,增强模型判断的鲁棒性。

针对模型过拟合的问题,主要是模型设计的问题,需要对模型架构进行调整。

他山之石可以攻玉

针对模型的设计问题,像我们这样的以应用而非学术研究为导向的情况,最好先寻找已经研究过的模型设计论文。很多研究论文有过程、有数据、有结论,在不涉及知识产权纠纷的前提下要充分利用。毕竟我们的精力是有限的,做模型探究实验、比较、评估需要耗费大量的准备和训练时间。

这里推荐一个网站:https://www.kaggle.com

这是一个为开发商和数据科学家提供举办机器学习竞赛、托管数据库、编写和分享代码的平台。这里有不少关于各种实际问题的代码、模型、数据资源可供参考。

比如,从这里找到一篇关于MNIST的识别模型的研究文章

《How to choose CNN Architecture MNIST》

文章分别从卷积和池化层对的数量、各层卷积核的数量、Dense层的大小、Dropout层的丢弃比例、批量归一化、数据增强等角度做了测试和探讨,内容详实比较值得借鉴。 文章给的建议模型设计如下:

model = Sequential()

model.add(Conv2D(32, kernel_size = 3, activation='relu', input_shape = (28, 28, 1)))
model.add(BatchNormalization())
model.add(Conv2D(32, kernel_size = 3, activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(32, kernel_size = 5, strides=2, padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.4))

model.add(Conv2D(64, kernel_size = 3, activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(64, kernel_size = 3, activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(64, kernel_size = 5, strides=2, padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.4))

model.add(Conv2D(128, kernel_size = 4, activation='relu'))
model.add(BatchNormalization())
model.add(Flatten())
model.add(Dropout(0.4))
model.add(Dense(10, activation='softmax'))

# COMPILE WITH ADAM OPTIMIZER AND CROSS ENTROPY COST
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

这个识别模型我们可以直接使用,也可以根据我们的实际需求在这个基础上再优化迭代。包括模型中层的调整、参数调整、训练次数的调整。调试迭代过程中应该对数据、模型、训练结果做好记录备份。选取其中最好的作为生产用的模型。

本期到此为止。《基于AI的计算机视觉识别在Java项目中的使用》专题将按下列章节展开,欢迎关注我的个人公众号( TuXiang )和CSDN( TuXiang++ )。

一、《基于AI的计算机视觉识别在Java项目中的使用 —— 背景》

二、《基于AI的计算机视觉识别在Java项目中的使用 —— OpenCV的使用》

三、《基于AI的计算机视觉识别在Java项目中的使用 —— 搭建基于Docker的深度学习训练环境》

四、《基于AI的计算机视觉识别在Java项目中的使用 —— 准备深度学习训练数据》

五、《基于AI的计算机视觉识别在Java项目中的使用 —— 深度模型的训练调优》

六、《基于AI的计算机视觉识别在Java项目中的使用 —— 深度模型在Java环境中的部署》

 

ImageComparerUI——基于Java语言实现的相似图像识别,基于直方图比较算法。 import java.awt.BorderLayout; import java.awt.Color; import java.awt.Dimension; import java.awt.FlowLayout; import java.awt.Font; import java.awt.Graphics; import java.awt.Graphics2D; import java.awt.Image; import java.awt.MediaTracker; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.image.BufferedImage; import java.io.File; import java.io.IOException; import javax.imageio.ImageIO; import javax.swing.JButton; import javax.swing.JComponent; import javax.swing.JFileChooser; import javax.swing.JFrame; import javax.swing.JPanel; public class ImageComparerUI extends JComponent implements ActionListener { /** * */ private static final long serialVersionUID = 1L; private JButton browseBtn; private JButton histogramBtn; private JButton compareBtn; private Dimension mySize; // image operator private MediaTracker tracker; private BufferedImage sourceImage; private BufferedImage candidateImage; private double simility; // command constants public final static String BROWSE_CMD = "Browse..."; public final static String HISTOGRAM_CMD = "Histogram Bins"; public final static String COMPARE_CMD = "Compare Result"; public ImageComparerUI() { JPanel btnPanel = new JPanel(); btnPanel.setLayout(new FlowLayout(FlowLayout.LEFT)); browseBtn = new JButton("Browse..."); histogramBtn = new JButton("Histogram Bins"); compareBtn = new JButton("Compare Result"); // buttons btnPanel.add(browseBtn); btnPanel.add(histogramBtn); btnPanel.add(compareBtn); // setup listener... browseBtn.addActionListener(this); histogramBtn.addActionListener(this); compareBtn.addActionListener(this); mySize = new Dimension(620, 500); JFrame demoUI = new JFrame("Similiar Image Finder"); demoUI.getContentPane().setLayout(new BorderLayout()); demoUI.getContentPane().add(this, BorderLayout.CENTER); demoUI.getContentPane().add(btnPanel, BorderLayout.SOUTH); demoUI.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); demoUI.pack(); demoUI.setVisible(true); } public void paint(Graphics g) { Graphics2D g2 = (Graphics2D) g; if(sourceImage != null) { Image scaledImage = sourceImage.getScaledInstance(300, 300, Image.SCALE_FAST); g2.drawImage(scaledImage, 0, 0, 300, 300, null); } if(candidateImage != null) { Image scaledImage = candidateImage.getScaledInstance(300, 330, Image.SCALE_FAST); g2.drawImage(scaledImage, 310, 0, 300, 300, null); } // display compare result info here Font myFont = new Font("Serif", Font.BOLD, 16); g2.setFont(myFont); g2.setPaint(Color.RED); g2.drawString("The degree of similarity : " + simility, 50, 350); } public void actionPerformed(ActionEvent e) { if(BROWSE_CMD.equals(e.getActionCommand())) { JFileChooser chooser = new JFileChooser(); chooser.showOpenDialog(null); File f = chooser.getSelectedFile(); BufferedImage bImage = null; if(f == null) return; try { bImage = ImageIO.read(f); } catch (IOException e1) { e1.printStackTrace(); } tracker = new MediaTracker(this); tracker.addImage(bImage, 1); // blocked 10 seconds to load the image data try { if (!tracker.waitForID(1, 10000)) { System.out.println("Load error."); System.exit(1); }// end if } catch (InterruptedException ine) { ine.printStackTrace(); System.exit(1); } // end catch if(sourceImage == null) { sourceImage = bImage; }else if(candidateImage == null) { candidateImage = bImage; } else { sourceImage = null; candidateImage = null; }
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值