cnn神经网络可以用于数据拟合吗_使用Keras搭建卷积神经网络进行手写识别的入门(包含代码解读)...

这篇博客介绍了如何利用Keras训练一个手写方程式识别模型,包括数据预处理、特征提取、构建CNN模型及模型训练。作者提供了一个数据集,包含手写数字和符号的图像,通过轮廓提取技术提取特征,再用这些特征训练一个CNN模型,最终达到98.46%的识别精度。文章还提供了相关代码示例和模型保存方法。
摘要由CSDN通过智能技术生成

本文是发在Medium上的一篇博客:《Handwritten Equation Solver using Convolutional Neural Network》。本文是原文的翻译。这篇文章主要教大家如何使用keras训练手写字符的识别,并保存训练好的模型到本地,以及未来如何调用保存到模型来预测。本文对一些不太明确的地方做了一点小注释。论文使用了cv2的包做预处理,并使用keras搭建了卷积网络。

dcc9a2cbee43596dbe47557234513560.png

主要目录如下

  • 介绍
  • 获取训练数据
  • 1.下载数据集
  • 2.抽取特征
  • 使用卷积神经网络训练数据
  • 1.构建卷积神经网络
  • 2.将模型拟合到数据
  • 测试模型

介绍

随着技术的进步,机器学习和深度学习在当今时代起着至关重要的作用。现在,机器学习和深度学习技术正被用于手写识别,机器人技术,人工智能以及更多领域。开发此类系统需要使用数据训练我们的机器,使其能够学习并进行必要的预测。本文介绍了一种手写方程求解器,通过手写数字和数学符号训练,使用卷积神经网络和一些图像处理技术,实现98.46%的精度。

获取训练数据

1.下载数据集

我们可以从这个链接下载数据集。解压缩zip文件。不同的文件夹包含不同数学符号的图像。为简单起见,我们在本次学习中仅仅使用0-9数字图像,以及“+”、“-”和“×”三个符号图像。在观察我们的数据集时,我们可以看到它对某些数字/符号有偏差,因为某些符号包含了12000个图像,但是其他符号只包含了3000个图像。要消除此偏差,请将每个文件夹中的图像数量减少到约 4000左右。

a81dbb4d57b7541e9e9e5b35a2eb8445.png

2.抽取特征

我们可以使用轮廓提取(contour extraction)技术来获得特征(注意,这里作者并没有明确说明轮廓提取的方法,其实源代码中主要使用cv2的包来获取图片的轮廓)。主要步骤包括:

  1. 反转图像然后将其转换为二进制图像,因为当对象为白色且周围为黑色时,轮廓提取会产生最佳结果。
  2. 要查找轮廓,请使用“findContour”功能。对于特征,我们使用’boundingRect’函数获得轮廓的边界矩形(边界矩形是包围整个轮廓的最小水平矩形)。
  3. 由于数据集中的每个图像只包含一个符号/数字,因此我们只需要最大尺寸的边界矩形。为此,我们计算每个轮廓的边界矩形的面积,并选择具有最大面积的矩形。
  4. 现在,将最大区域边界矩形的大小调整为28乘以28,然后压平,变成784乘以1。因此,现在将有784像素值或特征。现在,给它相应的标签(例如,对于0-9图像与其数字相同的标签,对于 - 指定标签10,对于+指定标签11,“×”指定标签12)。所以现在我们的数据集包含784个特征列和一个标签列。提取特征后,将数据保存为CSV文件。

原文没有任何代码,我从他的Github上找到了相关代码,供大家参考。主要代码包括:

def load_images_from_folder(folder): train_data=[] for filename in os.listdir(folder): img = cv2.imread(os.path.join(folder,filename),cv2.IMREAD_GRAYSCALE) img=~img if img is not None: # 这个方法主要是像素高于阈值时,给像素赋予新值,否则,赋予另外一种颜色,这个操作是为了让轮廓提取效果更好 ret,thresh=cv2.threshold(img,127,255,cv2.THRESH_BINARY)# 这个方法就是提取轮廓ret,ctrs,ret=cv2.findContours(thresh,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_NONE) cnt=sorted(ctrs, key=lambda ctr: cv2.boundingRect(ctr)[0]) w=int(28) h=int(28) maxi=0 for c in cnt: x,y,w,h=cv2.boundingRect(c) maxi=max(w*h,maxi) if maxi==w*h: x_max=x y_max=y w_max=w h_max=h im_crop= thresh[y_max:y_max+h_max+10, x_max:x_max+w_max+10] im_resize = cv2.resize(im_crop,(28,28)) im_resize=np.reshape(im_resize,(784,1)) train_data.append(im_resize) return train_data

使用卷积神经网络训练数据

由于卷积神经网络在二维数据上工作,我们的数据集是785乘1的形式。因此,我们需要重塑它。 首先,将数据集中的标签列分配给变量y_train。 然后从数据集中删除标签列,然后将数据集重新变成28乘28.现在,我们的数据集已准备好用于CNN。

1.构建卷积神经网络

要制作CNN,请导入所有必需的库。

import pandas as pdimport numpy as npimport picklenp.random.seed(1212)import kerasfrom keras.models import Modelfrom keras.layers import *from keras import optimizersfrom keras.layers import Input, Densefrom keras.models import Sequentialfrom keras.layers import Densefrom keras.layers import Dropoutfrom keras.layers import Flattenfrom keras.layers.convolutional import Conv2Dfrom keras.layers.convolutional import MaxPooling2Dfrom keras.utils import np_utilsfrom keras import backend as KK.set_image_dim_ordering('th')from keras.utils.np_utils import to_categoricalfrom keras.models import model_from_json

使用“to_categorical”函数将y_train数据转换为分类数据。创建模型的代码如下:

model = Sequential()model.add(Conv2D(30, (5, 5), input_shape=(1 , 28, 28), activation='relu'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Conv2D(15, (3, 3), activation='relu'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Dropout(0.2))model.add(Flatten())model.add(Dense(128, activation='relu'))model.add(Dense(50, activation='relu'))model.add(Dense(13, activation='softmax'))# Compile modelmodel.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

2.将模型拟合到数据

要使CNN拟合数据,请使用以下代码行。

model.fit(np.array(l), cat, epochs=10, batch_size=200,shuffle=True,verbose=1)

训练我们的模型需要大约三个小时,准确率为98.46%。 经过训练,我们可以将我们的模型保存为json文件以备将来使用,这样我们就不必训练模型并每次等待三个小时。 为了保存我们的模型,我们可以使用以下代码行。

model_json = model.to_json()with open("model_final.json
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值