Matlab这几年在人工智能这块儿也越做越好了,最近为了熟悉matlab如何搭建神经网络,自己做了一个手写体识别实验,记录一下。
实验任务非常简单,网络搭的也非常随意,不合理的地方也懒得改,旨在走通matlab搭建神经网络的流程。
首先,数据集为MNIST数据集
我已经把数据按类别分好,分为train和test,底下又都有十个子文件夹存放手写体图像。
网络训练代码如下:
clear;close all;clc;
%% 数据读取、增强
%读取训练集
path_train = 'D:\work\过期文件\手写体识别\MNIST\train'; %训练集路径
folders_train = fullfile(path_train,{'0' '1' '2' '3' '4' '5' '6' '7' '8' '9'}); %读取子目录
imds_train = imageDatastore(folders_train,'FileExtensions','.jpg',...
'LabelSource','foldernames'); %读取所有图像路径
[imdsTrain,imdsValidation] = splitEachLabel(imds_train,0.9,0.1); %拆分出验证集
%读取测试集
path_test = 'D:\work\过期文件\手写体识别\MNIST\test';
folders_test = fullfile(path_test,{'0' '1' '2' '3' '4' '5' '6' '7' '8' '9'});
imds_test = imageDatastore(folders_test,'FileExtensions','.jpg',...
'LabelSource','foldernames');
%图像增强
pixelRange = [-2 2]; %平移范围
scaleRange = [0.9 1.1]; %缩放范围
imageAugmenter = imageDataAugmenter( ...
'RandXTranslation',pixelRange, ...
'RandYTranslation',pixelRange, ...
'RandXScale',scaleRange, ...
'RandYScale',scaleRange); %定义图像增强器
augimdsTrain = augmentedImageDatastore([28,28],imds_train, ...
'DataAugmentation',imageAugmenter); %图像增强
%% 设计(或者读取)网络
layers = [
imageInputLayer([28 28 1],"Name","imageinput")
convolution2dLayer([5 5],32,"Name","conv_1","Padding","same","Stride",[2 2])
reluLayer("Name","relu_1")
batchNormalizationLayer("Name","batchnorm_1")
convolution2dLayer([3 3],32,"Name","conv_2","Padding","same")
reluLayer("Name","relu_2")
fullyConnectedLayer(512,"Name","fc_1")
batchNormalizationLayer("Name","batchnorm_2")
reluLayer("Name","relu_3")
fullyConnectedLayer(10,"Name","fc_2")
softmaxLayer("Name","softmax")
classificationLayer("Name","classoutput")];
% analyzeNetwork(layers) %分析网络
%% 训练网络
options = trainingOptions('sgdm', ...
'MiniBatchSize',512, ...
'MaxEpochs',1, ...
'InitialLearnRate',1e-2, ...
'Shuffle','every-epoch', ...
'ValidationData',imdsValidation, ...
'ValidationFrequency',3, ...
'Verbose',1, ...
'Plots','training-progress'); %设置训练策略
trainedNet = trainNetwork(augimdsTrain,layers,options); %训练
%% 测试模型
[YPred,probs] = classify(trainedNet,imds_test);
accuracy = mean(YPred == imds_test.Labels)
这里面,用到了一些函数,一些重要的用法我都写在其他博客里了,这儿只大致说一下有什么用
- fullfile:读取文件夹下的所有子文件夹
- imageDatastore:读取数据集,这个函数比较重要,后边很多函数都在调用它
- splitEachLabel:拆分imageDatastore读取的数据
- imageDataAugmenter:图像增强器,定义如何增强图像
- augmentedImageDatastore:进行图像增强
- 关于如何搭建网络,我写在了这里,看完后发现一些简单任务用不着敲一行代码
- analyzeNetwork:分析网络
- trainingOptions:定义训练策略,比如学习率,优化器之类的
- trainNetwork:训练网络
- classify:将网络用于分类
训练结果:
如果需要处理好的数据集,可以留下邮箱~
听取评论区建议,直接放上网盘,应该不会被吞贴吧2333
链接:https://pan.baidu.com/s/1htpPayQ2m0B3C5xOk3Quqg
提取码:pdb6
最后在说明一下,网络是随便搭的,不要用!!只是学习MATLAB用的
以上这些希望会对你有所帮助