Matlab深度学习实践之手写体识别(含详细注释)

Matlab这几年在人工智能这块儿也越做越好了,最近为了熟悉matlab如何搭建神经网络,自己做了一个手写体识别实验,记录一下。

实验任务非常简单,网络搭的也非常随意,不合理的地方也懒得改,旨在走通matlab搭建神经网络的流程。

首先,数据集为MNIST数据集
在这里插入图片描述
我已经把数据按类别分好,分为train和test,底下又都有十个子文件夹存放手写体图像。

网络训练代码如下:

clear;close all;clc;
%% 数据读取、增强
   %读取训练集
path_train = 'D:\work\过期文件\手写体识别\MNIST\train';  %训练集路径
folders_train = fullfile(path_train,{'0' '1' '2' '3' '4' '5' '6' '7' '8' '9'}); %读取子目录
imds_train = imageDatastore(folders_train,'FileExtensions','.jpg',...
                            'LabelSource','foldernames');                 %读取所有图像路径
[imdsTrain,imdsValidation] = splitEachLabel(imds_train,0.9,0.1);          %拆分出验证集
   %读取测试集
path_test = 'D:\work\过期文件\手写体识别\MNIST\test';
folders_test = fullfile(path_test,{'0' '1' '2' '3' '4' '5' '6' '7' '8' '9'});
imds_test = imageDatastore(folders_test,'FileExtensions','.jpg',...
                           'LabelSource','foldernames');
   %图像增强
pixelRange = [-2 2];   %平移范围
scaleRange = [0.9 1.1];  %缩放范围
imageAugmenter = imageDataAugmenter( ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange, ...
    'RandXScale',scaleRange, ...
    'RandYScale',scaleRange);     %定义图像增强器
augimdsTrain = augmentedImageDatastore([28,28],imds_train, ...
    'DataAugmentation',imageAugmenter);  %图像增强
%% 设计(或者读取)网络
layers = [
    imageInputLayer([28 28 1],"Name","imageinput")
    convolution2dLayer([5 5],32,"Name","conv_1","Padding","same","Stride",[2 2])
    reluLayer("Name","relu_1")
    batchNormalizationLayer("Name","batchnorm_1")
    convolution2dLayer([3 3],32,"Name","conv_2","Padding","same")
    reluLayer("Name","relu_2")
    fullyConnectedLayer(512,"Name","fc_1")
    batchNormalizationLayer("Name","batchnorm_2")
    reluLayer("Name","relu_3")
    fullyConnectedLayer(10,"Name","fc_2")
    softmaxLayer("Name","softmax")
    classificationLayer("Name","classoutput")];
% analyzeNetwork(layers)  %分析网络

%% 训练网络
options = trainingOptions('sgdm', ...
    'MiniBatchSize',512, ...
    'MaxEpochs',1, ...
    'InitialLearnRate',1e-2, ...
    'Shuffle','every-epoch', ...
    'ValidationData',imdsValidation, ...
    'ValidationFrequency',3, ...
    'Verbose',1, ...
    'Plots','training-progress');   %设置训练策略
trainedNet = trainNetwork(augimdsTrain,layers,options);  %训练

%% 测试模型
[YPred,probs] = classify(trainedNet,imds_test); 
accuracy = mean(YPred == imds_test.Labels)

这里面,用到了一些函数,一些重要的用法我都写在其他博客里了,这儿只大致说一下有什么用

  • fullfile:读取文件夹下的所有子文件夹
  • imageDatastore:读取数据集,这个函数比较重要,后边很多函数都在调用它
  • splitEachLabel:拆分imageDatastore读取的数据
  • imageDataAugmenter:图像增强器,定义如何增强图像
  • augmentedImageDatastore:进行图像增强
  • 关于如何搭建网络,我写在了这里,看完后发现一些简单任务用不着敲一行代码
  • analyzeNetwork:分析网络
  • trainingOptions:定义训练策略,比如学习率,优化器之类的
  • trainNetwork:训练网络
  • classify:将网络用于分类

训练结果:
在这里插入图片描述
如果需要处理好的数据集,可以留下邮箱~

听取评论区建议,直接放上网盘,应该不会被吞贴吧2333
链接:https://pan.baidu.com/s/1htpPayQ2m0B3C5xOk3Quqg
提取码:pdb6

最后在说明一下,网络是随便搭的,不要用!!只是学习MATLAB用的

以上这些希望会对你有所帮助

评论 52
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值