前言
借助MATLAB的deepNetworkDesigner搭一个简单的LSTM,数据集使用mnist手写数字识别数据集。
一、数据集
mnist数据集包括60000组训练数据和对应的标签,10000组测试数据和对应标签。每个数据都是一个28x28的矩阵,可以将其看做28x28像素的灰度图像(黑底白字)。而LSTM的输入应当是一个序列,我们可以把矩阵的每一行当做一帧,把图像分为28帧输入到LSTM。
数据集可以在我上传的资源里找到。
数据的格式是这样的:
XTrain,即训练图像,是一个60000x1的cell,cell的每一个元素是一个28x28的矩阵。矩阵的每一列为一帧。直接将矩阵以图片显示是这样的:
imshow(cell2mat(XTrain(8)))
这不是某希腊字母,而是手写数字3。我们希望按行输入,而MATLAB按列读取,因此我做了个转置。再转置一下就能看到正常的图像:
imshow(cell2mat(XTrain(8))')
标签的格式为:
可以直接通过categorical函数实现数值到categorical的转换,比如:
输入训练数据的方式不唯一,我用的只是其中一种,详情见MathWorks官网:trainNetwork
二、网络结构
使用一层128个隐藏节点的LSTM,一层全连接,输出使用softmax。网络的输入是一个序列,输出是标签,在MATLAB中,此网络可以这样描述:
layers = [ ...
sequenceInputLayer(inputSize) %sequence输入
lstmLayer(numHiddenUnits,'OutputMode','last') %lstm
fullyConnectedLayer(numClasses) %全连接
softmaxLayer %softmax
classificationLayer]; %label输出
三、测试程序
完整的测试程序如下:
clear
clc
%加载数据
load('.\mnist_data_mat\XTrain.mat')
load('.\mnist_data_mat\YTrain.mat')
load('.\mnist_data_mat\XTest.mat')
load('.\mnist_data_mat\YTest.mat')
%设置参数
inputSize = 28; %28个输入节点
numHiddenUnits = 128; %128个隐藏节点
numClasses = 10; %10种分类结果
layers = [ ...
sequenceInputLayer(inputSize) %sequence输入
lstmLayer(numHiddenUnits,'OutputMode','last') %lstm
fullyConnectedLayer(numClasses) %全连接
softmaxLayer %softmax
classificationLayer]; %label输出
options = trainingOptions('adam', ...
'ExecutionEnvironment','cpu', ...
'MaxEpochs',5, ...
'MiniBatchSize',60, ...
'GradientThreshold',1, ...
'Verbose',false, ...
'Plots','training-progress');
net=trainNetwork(XTrain,YTrain,layers, options); %训练
Y_pred = classify(net, XTest); %测试
accy = sum(Y_pred == YTest) / length(YTest); %计算准确度
准确度为97.73%
options里的参数可以修改一下,我用同样结构的网络不同的参数做出了98.74%的准确度,仍有提升空间。这里为了节省训练时间牺牲了一些精度。
训练好的网络也上传到了资源里。