使用MATLAB的trainNetwork设计一个简单的LSTM神经网络


前言

借助MATLAB的deepNetworkDesigner搭一个简单的LSTM,数据集使用mnist手写数字识别数据集。


一、数据集

mnist数据集包括60000组训练数据和对应的标签,10000组测试数据和对应标签。每个数据都是一个28x28的矩阵,可以将其看做28x28像素的灰度图像(黑底白字)。而LSTM的输入应当是一个序列,我们可以把矩阵的每一行当做一帧,把图像分为28帧输入到LSTM。
数据集可以在我上传的资源里找到。

数据的格式是这样的:

在这里插入图片描述
XTrain,即训练图像,是一个60000x1的cell,cell的每一个元素是一个28x28的矩阵。矩阵的每一列为一帧。直接将矩阵以图片显示是这样的:

 imshow(cell2mat(XTrain(8)))

在这里插入图片描述
这不是某希腊字母,而是手写数字3。我们希望按行输入,而MATLAB按列读取,因此我做了个转置。再转置一下就能看到正常的图像:

 imshow(cell2mat(XTrain(8))')

在这里插入图片描述
标签的格式为:

在这里插入图片描述
可以直接通过categorical函数实现数值到categorical的转换,比如:

在这里插入图片描述

输入训练数据的方式不唯一,我用的只是其中一种,详情见MathWorks官网:trainNetwork

二、网络结构

使用一层128个隐藏节点的LSTM,一层全连接,输出使用softmax。网络的输入是一个序列,输出是标签,在MATLAB中,此网络可以这样描述:

layers = [ ...
    sequenceInputLayer(inputSize)                   %sequence输入
    lstmLayer(numHiddenUnits,'OutputMode','last')   %lstm
    fullyConnectedLayer(numClasses)                 %全连接
    softmaxLayer                                    %softmax
    classificationLayer];                           %label输出

三、测试程序

完整的测试程序如下:

clear
clc
%加载数据
load('.\mnist_data_mat\XTrain.mat')
load('.\mnist_data_mat\YTrain.mat')
load('.\mnist_data_mat\XTest.mat')
load('.\mnist_data_mat\YTest.mat')

%设置参数
inputSize = 28;         %28个输入节点
numHiddenUnits = 128;   %128个隐藏节点
numClasses = 10;        %10种分类结果

layers = [ ...
    sequenceInputLayer(inputSize)                   %sequence输入
    lstmLayer(numHiddenUnits,'OutputMode','last')   %lstm
    fullyConnectedLayer(numClasses)                 %全连接
    softmaxLayer                                    %softmax
    classificationLayer];                           %label输出

options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'MaxEpochs',5, ...
    'MiniBatchSize',60, ...
    'GradientThreshold',1, ...
    'Verbose',false, ...
    'Plots','training-progress');

net=trainNetwork(XTrain,YTrain,layers, options);    %训练

Y_pred = classify(net, XTest);                      %测试
accy = sum(Y_pred == YTest) / length(YTest);        %计算准确度

准确度为97.73%
options里的参数可以修改一下,我用同样结构的网络不同的参数做出了98.74%的准确度,仍有提升空间。这里为了节省训练时间牺牲了一些精度。
训练好的网络也上传到了资源里。

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值