最近在学神经网络,顺便做了一个mnist数据集的识别程序,采用cnn实现。许多内容采用的是matlab例程中的程序。
mnist数据集需要下载,解压后和下面的程序放在同一目录下,网址为:http://yann.lecun.com/exdb/mnist/.
多次测试基本都能达到99%的识别率,不排除小于99%的识别率。
下面是程序:
% 实现MNIST数据集的识别,在matlab2018运行正常。需要先下载数据集,
% 网址为:http://yann.lecun.com/exdb/mnist/。
clear all;
close all;
N_sample = 60000;
N_test=10000;
XTrain = zeros(28,28,1,N_sample);
YTrain=zeros(N_sample,1);
% Please dowload the MNIST data set from http://yann.lecun.com/exdb/mnist/
% and unzip.
fidimg1=fopen('train-images.idx3-ubyte','rb');
fidimg2=fopen('train-labels.idx1-ubyte','rb');
[img,count]=fread(fidimg1,16); % table head
[imgInd,count1]=fread(fidimg2,8); %table head
for k=1:N_sample
[im,~]=fread(fidimg1,[28,28]);
ind=fread(fidimg2,1);
XTrain(:,:,1,k)=im';
YTrain(k)=ind;
end
fclose(fidimg1);
fclose(fidimg2);
YTrain=categorical(YTrain);
XTest = zeros(28,28,1,N_test);
YTest=zeros(N_test,1);
fidimg1=fopen('t10k-images.idx3-ubyte','rb');
fidimg2=fopen('t10k-labels.idx1-ubyte','rb');
[img,count]=fread(fidimg1,16);
[imgInd,count1]=fread(fidimg2,8);
for k=1:N_test
[im,~]=fread(fidimg1,[28,28]);
ind=fread(fidimg2,1);
XTest(:,:,1,k)=im';
YTest(k)=ind;
end
fclose(fidimg1);
fclose(fidimg2);
YTest=categorical(YTest);
imageAugmenter = imageDataAugmenter( ...
'RandRotation',[-20,20], ...
'RandXTranslation',[-3 3], ...
'RandYTranslation',[-3 3])
imageSize = [28 28 1];
augimds = augmentedImageDatastore(imageSize,XTrain,YTrain,'DataAugmentation',imageAugmenter);
layers=[imageInputLayer([28 28 1],'Name','input')
convolution2dLayer(3,6,'Padding','same')
reluLayer
batchNormalizationLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16)
reluLayer
batchNormalizationLayer
maxPooling2dLayer(2,'Stride',2)
fullyConnectedLayer(120,'name','f1')
reluLayer
fullyConnectedLayer(84,'name','f2')
reluLayer
fullyConnectedLayer(10,'name','f3')
softmaxLayer
classificationLayer];
options = trainingOptions('adam','MaxEpochs',25,'LearnRateSchedule' ,'piecewise','LearnRateDropPeriod',15,'LearnRateDropFactor' ,0.1);
tic;
net = trainNetwork(augimds,layers,options);
toc;
YPred = classify(net,XTest);
accuracy = sum(YTest==YPred)/numel(YTest)
在mathworks网站也刚上传了一个,两者基本一致。mathworks
下面是一个截图。
附注:
这几天用matlab的help里面DAGNetwork的例程中的net试了一下,20个epoch达到99.4%以上的准确率也是很容易的。网络结构如下: