基于MATLAB的神经网络进行手写体数字识别(含鼠绘GUI / 数据集:MNIST)

基本介绍

  • 软件:Matlab R2018b
  • 数据集:MNIST手写体数字数据集
  • 网络:自建简单网络

数据准备

MNIST数据集还挺有名的,这里就不过多介绍了。数据集本身读取格式官网有给,怎么转换成图片格式网上也有很多,这里不再赘述。
官网:http://yann.lecun.com/exdb/mnist/
训练集包含60000个示例,测试集包含10000个示例。
测试集的前5000个示例来自原始的NIST训练集。 最后的5000个来自原始的NIST测试集。 前5000个比后5000个更干净点,识别起来更容易。
当然为了方便使用MATLAB,这里给出程序缺省的数据集:
链接:https://pan.baidu.com/s/1VItI8MdUa-oBhWjKUUB72w
提取码:tgv9
CSDN地址:https://download.csdn.net/download/garker/12413315
每一个数字都包含1000张图片,每张图片大小均为28×28×1,1代表单通道,即灰度图。
在这里插入图片描述

神经网络组建

因为数据集本身特征并不多,所以不需要动用常用的神经网络,这里给出一个官方的结构形式。一共有15层。
在这里插入图片描述
在这里插入图片描述
这里可以看出,三层卷积,三层归一化,是相当简单的CNN网络结构了,可以当作CNN结构的入门学习好好钻研学习。
在MATLAB中的建构代码如下:

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

这其中,各层的参数如下:
convolution2dLayer

参数含义
FilterSize3,3卷积核尺寸
NumFilter8卷积核数量
Padding‘same’new_height = new_width = W / S (结果向上取整)
(W×W的输入矩阵,F×F的卷积核,步长为S=1)

BatchNormalizationLayer
归一化层采用默认数据

maxPooling2dLayer

参数含义
PoolSize2,2池化尺寸
Stride2,2步长

fullyConnectedLayer
全连接层输出为10(0-9共10个数字)

训练神经网络

imds = imageDatastore('train_dataset', ...
    'IncludeSubfolders',true,'LabelSource','foldernames');
%导入数据

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.8,'randomize');
%分割数据集与测试集

options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ...
    'MaxEpochs',5, ...
    'Shuffle','every-epoch', ...
    'ValidationData',imdsValidation, ...
    'ValidationFrequency',30, ...
    'Verbose',false, ...
    'Plots','training-progress');
%设置训练参数

net = trainNetwork(imdsTrain,layers,options);
%训练神经网络

在这里插入图片描述
这里可以看出来基本上第三个世代就已经训练差不多了,最后的accuracy也能达到99.80%

测试数据集

YPred = classify(net,imdsValidation);
YValidation = imdsValidation.Labels;

accuracy = sum(YPred == YValidation)/numel(YValidation);

figure;
perm = randperm(10000,20);
for i = 1:20
    subplot(4,5,i);
    s = classify(net,imread(imds.Files{perm(i)}));
    imshow(imds.Files{perm(i)});title(string(s));
end

随机挑出来20个看看效果,没什么大问题:
在这里插入图片描述

鼠绘输入识别的GUI

GUI的代码编写不算难,直接回调函数里面编写也比较方便。这里着重讲一下鼠绘的问题,网上查了很多资料也踩了不少坑,这里按处理顺序把比较坑的细节都放一下:

鼠绘区域

在这里插入图片描述
红色区域里面只有axes1是有实际作用的,为了美观我把X、Y轴颜色改成了背景的灰色以达到隐藏的效果。此外,还需要把X、Y轴XLimMode、YLimMode设置为manual,其主要作用是锁住它们,不然在鼠绘的时候每一笔都会飘。
在这里插入图片描述
此外,对该区域的鼠绘效果显示代码如下:

figure1_WindowButtonDownFcn

unction figure1_WindowButtonDownFcn(hObject, eventdata, handles)
% hObject    handle to figure1 (see GCBO)
% eventdata  reserved - to be defined in a future version of MATLAB
% handles    structure with handles and user data (see GUIDATA)
global draw_enable;
global x;
global y;
draw_enable=1;
if draw_enable
    position=get(gca,'currentpoint');
    x(1)=position(1);
    y(1)=position(3);
end

figure1_WindowButtonMotionFcn

function figure1_WindowButtonMotionFcn(hObject, eventdata, handles)
% hObject    handle to figure1 (see GCBO)
% eventdata  reserved - to be defined in a future version of MATLAB
% handles    structure with handles and user data (see GUIDATA)
global draw_enable;
global x;
global y;
if draw_enable
    position=get(gca,'currentpoint');
    x(2)=position(1);
    y(2)=position(3);
    h1 = line(x,y,'EraseMode','xor','LineWidth',5,'color','black');
    x(1)=x(2);
    y(1)=y(2);
end

figure1_WindowButtonUpFcn

function figure1_WindowButtonUpFcn(hObject, eventdata, handles)
% hObject    handle to figure1 (see GCBO)
% eventdata  reserved - to be defined in a future version of MATLAB
% handles    structure with handles and user data (see GUIDATA)
global draw_enable
draw_enable=0;

特别特别需要注意的是,这三个回调函数都是在整个GUI默认的整体面板上来的,也就是figure1。具体找到这个回调函数的如下图所示:
在这里插入图片描述
没错,就是点击GUI编辑面板空白区域!

识别

识别按钮的回调函数很简单这里就不赘述了,需要特别提醒的是:
从绘制区域直接得到的并不是可直接使用图像数据,这里直接保存到默认目录一份正好也做备份用;
再者,保存好的图像的手写数据部分是深色的,背景部分是浅色的,这与我们之前的训练数据是不符的,直接用来识别肯定不会出现正确的答案,所以把这个数据读取之后再取反色,部分代码如下:

h=getframe(handles.axes1);
imwrite(h.cdata,'output.jpg','jpg');
img = imread('output.jpg');
img = imresize(img,[28,28]);
img = rgb2gray(img);
img = 255 - img; %取反色

在这里插入图片描述

  • 62
    点赞
  • 434
    收藏
    觉得还不错? 一键收藏
  • 42
    评论
实现OMP算法和稀疏字典学习需要使用MATLAB中的一些工具箱,比如Sparse Coding Toolbox和Image Processing Toolbox等。下面是一个基本的实现步骤: 1. 加载MNIST手写数字数据集,将每个数字图像转换为向量形式。 2. 随机生成一个字典,并将其归一化。 3. 对每个图像向量进行稀疏编码,使用OMP算法求解。 4. 利用稀疏编码结果和字典重构图像。 5. 对重构的图像进行分类,比较分类结果和真实标签的差异。 具体实现代码如下: ```matlab % 加载MNIST手写数字数据集 load mnist.mat % 将每个数字图像转换为向量形式 X = double(reshape(mnist.train_X, [28*28, 60000])); Y = mnist.train_labels; % 随机生成一个字典 dict_size = 100; D = randn(28*28, dict_size); D = normc(D); % 设置OMP算法参数 k = 10; tol = 1e-6; % 对每个图像向量进行稀疏编码,使用OMP算法求解 A = zeros(dict_size, size(X,2)); for i = 1:size(X,2) [a, ~, ~] = OMP(D, X(:,i), k, tol); A(:,i) = a; end % 利用稀疏编码结果和字典重构图像 X_recon = D * A; % 对重构的图像进行分类,比较分类结果和真实标签的差异 Y_recon = knnsearch(X_recon', X', 'K', 1); % 计算分类准确率 acc = sum(Y_recon' == Y) / length(Y); disp(['Classification accuracy: ' num2str(acc)]); ``` 其中OMP算法的实现可以使用MATLAB自带的`OMP`函数,也可以自己实现。稀疏编码结果的重构可以使用线性组合的方法,即$X_{recon} = DA$。分类使用了KNN算法,可以使用`knnsearch`函数实现。
评论 42
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值