基于AI的计算机视觉识别在Java项目中的使用 (五) —— 深度模型的训练调优

换一个方式使用mnist

工作中的具体代码不适合公开发表,这里只使用minst这个常用的手写数字库,和自己业余时间的一些思考总结和代码来作为这篇文章讨论的基础。

要知道阿拉伯数字在不同的国家书写方式都有一些区别的。mnist作为一个手写数字库,它的书写不是以中国的习惯数字书写方式为基础的。所以要作为我们在实际应用中的训练数据会产生一些问题。

另外,每个人的书写习惯不尽相同,有些书写特别不规范,甚至会和其他数字产生混淆。那么这样的样本作为训练样本,在实际使用时就会让模型的识别产生偏差。像下面这些:

当然mnist库里大多数样本还是正常的。像下面这些:

所以不同于一些使用mnist来做深度学习入门介绍的 “Hello World!” 文章。本文会站在实际应用的角度,在使用它来训练深度模型的过程中,会加入一点自己的数据清洗方法。

关于如何数据清洗这些的基本思路,在我上一篇文章中有所提及。下面我会以代码来具体说明。

读取数据并创建模型

首先读取mnist数据

# 加载 mnist 训练数据
from keras.datasets import mnist

# 加载 mnist 训练数据方法
def loadData():
    #加载数据
    (x_train, y_train_label), (x_test, y_test_label) = mnist.load_data()
    
    # 将原始图片数据按记录总数、图片宽度、长度、黑色深度进行维度转换
    img_x, img_y = x_train.shape[1], x_train.shape[2]
    x_train = x_train.reshape(x_train.shape[0], img_x, img_y, 1)
    x_test = x_test.reshape(x_test.shape[0], img_x, img_y, 1)
    
    # 将原图片像素值由 0 - 255 的整形转换为 0 - 1 之间的浮点型数字
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    
    # 将0-9的数字标签转为one-hot编码用于训练
    y_train = np_utils.to_categorical(y_train_label, num_classes=10)
    y_test = np_utils.to_categorical(y_test_label, num_classes=10)
        
    # 简单浏览样本原始图片和对应标签是否正确
    showSample(x_train, y_train_label, 100, 10, 10)
    
    # 训练样本, 训练标签one-hot编码, 训练样本标签, 验证样本, 验证标签one-hot编码,图像宽度, 图像长度
    return x_train, y_train, y_train_label, x_test, y_test, img_x, img_y

x_train, y_train, y_train_label, x_test, y_test, img_x, img_y = loadData()

用keras定义卷积神经网络

from keras.layers import Conv2D, MaxPool2D, SpatialDropout2D
from keras.layers import Dense, Flatten

# 方法定义一个基本卷积神经网络
def createModel():
    model = Sequential()
    model.add(Conv2D(32, kernel_size=(5,5), activation='relu', input_shape=(img_x, img_y, 1)))
    model.add(MaxPool2D(pool_size=(2,2), strides=(2,2)))
    model.add(Conv2D(64, kernel_size=(5,5), activation='relu'))
    model.add(MaxPool2D(pool_size=(2,2), strides=(2,2)))
    model.add(Flatten())
    model.add(Dense(1000, activation='relu'))
    model.add(Dense(10, activation='softmax'))
    return model

通过各种变形对训练数据进行数据增强

对mnist中的训练数据进行转换,生成多样化的训练数据样本

from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()

# 训练数据变形转换
def data_transform(x_train, y_train, y_train_label):
    # 定义对原始样本进行随机小角度旋转、左右位移和缩放方法
    resize_and_rotation = tf.keras.Sequential([
        # 顺时针、逆时针旋转
        tf.keras.layers.RandomRotation(factor=(-0.01, 0.01)), 
        # 左右、上下位移
        tf.keras.layers.RandomTranslation(
        (-0.25, 0.25), (-0.25, 0.25), fill_mode='constant',
            interpolation='bilinear', seed=None, fill_value=0.0
        ),
        # 缩放
        tf.keras.layers.RandomZoom(
            (-0.1, 0.6), fill_mode='constant',
            interpolation='bilinear', seed=None, fill_value=0.0
        ) 
    ])

    # 调用方法对训练集进行变换操作
    x_train_trans = resize_and_rotation(x_train)

    # 将图形变换后的EagerTensor型的训练数据变回到ndarray类型
    x_train_trans = x_train_trans.numpy() 
    
    # 将变换后的数据和变换前原始数据合并到一起进行训练
    x_train_combine = np.vstack((x_train_trans, x_train))
    y_train_combine = np.vstack((y_train, y_train))
    y_train_label_combine = np.hstack((y_train_label, y_train_label))
    
    return x_train_combine, y_train_combine, y_train_label_combine


x_train_combine, y_train_combine, y_train_label_combine = data_transform(x_train, y_train, y_train_label)

对比查看样本中随机变形的数字图片和原始数字图片

import random
randomnumber = random.randint(1,60000)
images = []
labels = []

for i in range(25):
    images.append(x_train_combine[i+randomnumber])
    labels.append(y_train_label_combine[i+randomnumber])
    images.append(x_train_combine[i+randomnumber+60000])
    labels.append(y_train_label_combine[i+randomnumber+60000])
    images.append(x_train_combine[i+randomnumber+120000])
    labels.append(y_train_label_combine[i+randomnumber+120000])

showSample(images, labels, 75, 5, 15)

在训练中清洗数据

使用训练过的预测模型,从给定的样本中寻找易混淆的样本所对应的索引

def findMisleadingDataIndexFromDataset(X, Y, probability_model):
    misleadingIdxes =  []
    batchSize = 10000
    count = X.shape[0]
    idx = 0
    
    # 分批次从样本中找到混淆项
    while idx * batchSize < count:
        startIdx = idx * batchSize
        endIdx = (idx + 1) * batchSize
        if endIdx > count:
            endIdx = count

        images = X[startIdx:endIdx]
        labels = Y[startIdx:endIdx]

        predictions = probability_model.predict(images)

        i = 0
        for result in predictions:
            if(np.argmax(result) != labels[i]):
                misleadingIdxes.append(startIdx + i)
            i = i+1

        idx = idx + 1            
            
        prograss = idx * batchSize * 100 /count
        prograss = min([int(prograss), 100])
        print("\r", end="")
        print("Finding Misleading Samples:{}% - found:{} ".format(prograss, len(misleadingIdxes)), "=" * (prograss // 2), end="")
        
    return misleadingIdxes

训练并在过程中逐渐去除样本中引起混淆的项, 使用更纯净的样本追加训练

def removeMisleadingSamplesAndTraining(model, x_train, y_train, y_label, x_test, y_test, threshold, training_batch):
    mnist_epochs = 1
    if(model == None):
        model = createModel()
        model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
        mnist_epochs = 2
    model.fit(x_train, y_train, batch_size=128, epochs=mnist_epochs)
    #评估模型
    score = model.evaluate(x_test, y_test)[1]
    print('acc', score)
    #使用训练后的模型来寻找存在的易混淆项
    probability_model = tf.keras.Sequential([model, tf.keras.layers.Softmax(name='result')])
    misleadingIdxes = findMisleadingDataIndexFromDataset(x_train, y_label, probability_model)
    # 如果找到的混淆项数量大于给定阈值,则继续删除混淆项并训练
    if(len(misleadingIdxes) > threshold or score < 0.99):
        showIndexSample(misleadingIdxes, x_train, y_label)
        #删除引起混淆的样本
        x_train = np.delete(x_train, misleadingIdxes, axis=0)
        y_train = np.delete(y_train, misleadingIdxes, axis=0)
        y_label = np.delete(y_label, misleadingIdxes, axis=None)
        print(x_train.shape)
        print(y_train.shape)
        print(y_label.shape)
        return removeMisleadingSamplesAndTraining(model, x_train, y_train, y_label, x_test, y_test, threshold, training_batch + 1)
    # 如果找到的混淆项数量小于等于给定阈值,则停止并返回
    else:
        return model, x_train, y_train, y_label

model, x_train, y_train, y_label = removeMisleadingSamplesAndTraining(None, x_train_combine, y_train_combine, y_train_label_combine, x_test, y_test, 100, 0)

在训练的过程中,我会设定一个阈值,当从当前剩余样本中找到的混淆项少于这个值时,停止过程并返回当前模型。每次我删除一部分混淆项后,会使用更纯净的数据来追加训练。让模型参数不断向着更正确的方向移动。实际测试中,可以看到经过9轮训练和淘汰,达到目标并退出训练。被淘汰的数量也从最开始的3418 逐步降低到94。

从训练被淘汰的样本中抽取的一部分样本中可以看到,很多都是容易造成混淆的项,另外也有一部分是比较正常的项被剔除。但因为我们增强后的样本比较多,所以即使一部分正常项被淘汰也不会有什么影响。

写在最后

总结一下,这里的数据清洗方式是基于这样一个思路:初步训练的模型对数据有最基本的认识,他所形成的判断模式是基于所有训练数据的最大公约数之上的。虽然不能很好地拟合各种特别情况,但正是这样的属性,让哪些有悖于公约数特性的样本被找出来,而这些样本也往往是异常样本。通过用更纯净的样本来追加训练,可以让模型的判断更倾向于正确的样本模式,逐步冲淡混淆样本早期对模型的负面影响。

在实际应用中,总体样本中的某些异常样本参与训练,会对实际应用中的判断造成干扰。而本文使用的剔除干扰项的方法在其他分类模型的应用中也可以使用。当然,我这里提供的代码只是一些初步的代码。实际使用中还有很多优化空间,比如可以用更优质的数据去引导上文所说的这个排除过程等等。

本期到此为止。《基于AI的计算机视觉识别在Java项目中的使用》专题将按下列章节展开,欢迎关注我的个人公众号( TuXiang )和CSDN( TuXiang++ )。

一、《基于AI的计算机视觉识别在Java项目中的使用 —— 背景》

二、《基于AI的计算机视觉识别在Java项目中的使用 —— OpenCV的使用》

三、《基于AI的计算机视觉识别在Java项目中的使用 —— 搭建基于Docker的深度学习训练环境》

四、《基于AI的计算机视觉识别在Java项目中的使用 —— 准备深度学习训练数据》

五、《基于AI的计算机视觉识别在Java项目中的使用 —— 深度模型的训练调优》

六、《基于AI的计算机视觉识别在Java项目中的使用 —— 深度模型在Java环境中的部署》

行数:257

字数:2367

主题:橙心

ImageComparerUI——基于Java语言实现的相似图像识别,基于直方图比较算法。 import java.awt.BorderLayout; import java.awt.Color; import java.awt.Dimension; import java.awt.FlowLayout; import java.awt.Font; import java.awt.Graphics; import java.awt.Graphics2D; import java.awt.Image; import java.awt.MediaTracker; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.image.BufferedImage; import java.io.File; import java.io.IOException; import javax.imageio.ImageIO; import javax.swing.JButton; import javax.swing.JComponent; import javax.swing.JFileChooser; import javax.swing.JFrame; import javax.swing.JPanel; public class ImageComparerUI extends JComponent implements ActionListener { /** * */ private static final long serialVersionUID = 1L; private JButton browseBtn; private JButton histogramBtn; private JButton compareBtn; private Dimension mySize; // image operator private MediaTracker tracker; private BufferedImage sourceImage; private BufferedImage candidateImage; private double simility; // command constants public final static String BROWSE_CMD = "Browse..."; public final static String HISTOGRAM_CMD = "Histogram Bins"; public final static String COMPARE_CMD = "Compare Result"; public ImageComparerUI() { JPanel btnPanel = new JPanel(); btnPanel.setLayout(new FlowLayout(FlowLayout.LEFT)); browseBtn = new JButton("Browse..."); histogramBtn = new JButton("Histogram Bins"); compareBtn = new JButton("Compare Result"); // buttons btnPanel.add(browseBtn); btnPanel.add(histogramBtn); btnPanel.add(compareBtn); // setup listener... browseBtn.addActionListener(this); histogramBtn.addActionListener(this); compareBtn.addActionListener(this); mySize = new Dimension(620, 500); JFrame demoUI = new JFrame("Similiar Image Finder"); demoUI.getContentPane().setLayout(new BorderLayout()); demoUI.getContentPane().add(this, BorderLayout.CENTER); demoUI.getContentPane().add(btnPanel, BorderLayout.SOUTH); demoUI.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); demoUI.pack(); demoUI.setVisible(true); } public void paint(Graphics g) { Graphics2D g2 = (Graphics2D) g; if(sourceImage != null) { Image scaledImage = sourceImage.getScaledInstance(300, 300, Image.SCALE_FAST); g2.drawImage(scaledImage, 0, 0, 300, 300, null); } if(candidateImage != null) { Image scaledImage = candidateImage.getScaledInstance(300, 330, Image.SCALE_FAST); g2.drawImage(scaledImage, 310, 0, 300, 300, null); } // display compare result info here Font myFont = new Font("Serif", Font.BOLD, 16); g2.setFont(myFont); g2.setPaint(Color.RED); g2.drawString("The degree of similarity : " + simility, 50, 350); } public void actionPerformed(ActionEvent e) { if(BROWSE_CMD.equals(e.getActionCommand())) { JFileChooser chooser = new JFileChooser(); chooser.showOpenDialog(null); File f = chooser.getSelectedFile(); BufferedImage bImage = null; if(f == null) return; try { bImage = ImageIO.read(f); } catch (IOException e1) { e1.printStackTrace(); } tracker = new MediaTracker(this); tracker.addImage(bImage, 1); // blocked 10 seconds to load the image data try { if (!tracker.waitForID(1, 10000)) { System.out.println("Load error."); System.exit(1); }// end if } catch (InterruptedException ine) { ine.printStackTrace(); System.exit(1); } // end catch if(sourceImage == null) { sourceImage = bImage; }else if(candidateImage == null) { candidateImage = bImage; } else { sourceImage = null; candidateImage = null; }
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值