LSTM测试,公式复现/LSTM函数/lstmlayer,三种方式对比,结果一致

LSTM测试,公式复现/LSTM函数/lstmlayer,三种方式对比,结果一致

lstmlayer内部是封装的,无法看到里面具体的函数。通过lstmlayer训练得到的参数,我想验证试一下是不是按照公式算的,便于手动转换成C语言。

%% begin 20240323 end: 20240326
%% 对LSTM的源代码、matlab自带LSTM以及matlab的app的测试状态进行对比
%% https://www.mathworks.com/help/releases/R2019b/deeplearning/ref/dlarray.lstm.html#d117e136457
clc;close all;clear;
%% 参数设定
numFeatures = 4;
numObservations = 1;
sequenceLength = 10;
numHiddenUnits = 30;
%% 数据设定
X = randn(numFeatures,numObservations,sequenceLength); %输入样本
%% 已训练的网络参数
H0 = zeros(numHiddenUnits,1); %初始H
C0 = zeros(numHiddenUnits,1); %初始C
Wi = randn(numHiddenUnits,numFeatures); % 输入门
Ri = randn(numHiddenUnits,numHiddenUnits);
bi = randn(numHiddenUnits,1);
Wf = randn(numHiddenUnits,numFeatures); % 遗忘门
Rf = randn(numHiddenUnits,numHiddenUnits);
bf = randn(numHiddenUnits,1);
Wj = randn(numHiddenUnits,numFeatures); % 候选门
Rj = randn(numHiddenUnits,numHiddenUnits);
bj = randn(numHiddenUnits,1);
Wo = randn(numHiddenUnits,numFeatures); % 输出门
Ro = randn(numHiddenUnits,numHiddenUnits);
bo = randn(numHiddenUnits,1);
weights_all_lstm1 = [Wi;Wf;Wj;Wo];
recurrentWeights_all_lstm1 = [Ri;Rf;Rj;Ro];
bias_all_lstm1 = [bi;bf;bj;bo];

%% 方法一 MATLAB手动实现LSTM
lstm_num = 0;

for tt = 1 : sequenceLength
X_temp = X(:, 1, tt);
if tt == 1
C1_t = zeros(numHiddenUnits, 1);
H1_t = zeros(numHiddenUnits, 1);
C1_t_matrix = C1_t; H1_t_matrix = H1_t;
end
in_gate1 = sigmoid_zhaoqy(Wi * X_temp + Ri * H1_t + bi); % equation (1) X_temp为1个样本,是41的向量;ui1为304的矩阵;W为3030的矩阵
forget_gate1 = sigmoid_zhaoqy(Wf * X_temp + Rf * H1_t + bf); % equation (2)
out_gate1 = sigmoid_zhaoqy(Wo * X_temp + Ro * H1_t + bo); % equation (3)
g_gate1 = tanh(Wj * X_temp + Rj * H1_t + bj); % equation (4)
C1_t = C1_t .
forget_gate1 + g_gate1 .* in_gate1; % equation (5)
H1_t = tanh(C1_t) .* out_gate1; % equation (6)
C1_t_matrix = [C1_t_matrix C1_t]; % for visible
H1_t_matrix = [H1_t_matrix H1_t]; % output
end

%% 方法二 MATLAB自带的lstm函数 cellState=C1_t hiddenState=H1_t
dlX = dlarray(X,‘CBT’);
weights = dlarray(weights_all_lstm1,‘CU’);
recurrentWeights = dlarray(recurrentWeights_all_lstm1,‘CU’);
bias = dlarray(bias_all_lstm1,‘C’);
% Perform the LSTM calculation.
[dlY,hiddenState,cellState] = lstm(dlX,H0,C0,weights,recurrentWeights,bias);

%% 方法三 MATLAB自带的lstmLayer实现 中间结果和前两种方法不同
X_Train = mat2cell(reshape(X,[4,10]),4,10);
YTrain = categorical(1);
layers = [
sequenceInputLayer(numFeatures,“Name”,“input”)
lstmLayer(numHiddenUnits,“Name”,“lstm”,“HiddenState”,H0,“CellState”,C0,“OutputMode”,“last”)
fullyConnectedLayer(4,“Name”,“fc”)
softmaxLayer(“Name”,“softmax”)
classificationLayer(“Name”,“classification”)];
options = trainingOptions(‘adam’, …
‘MaxEpochs’,1, …
‘GradientThreshold’,1, …
‘InitialLearnRate’,0.005, …
‘LearnRateSchedule’,‘piecewise’, …
‘LearnRateDropPeriod’,125, …
‘LearnRateDropFactor’,0.2, …
‘Verbose’,0, …
‘MiniBatchSize’,6000, …
‘Plots’,‘training-progress’);

net = trainNetwork(X_Train,YTrain,layers,options);
% 尝试 将 网络中的权重自己设置 net.Layers(2, 1).RecurrentWeights
% net = resetState(net); % 20240328添加
modify_able_net = net.saveobj;%保存为struct
modify_able_net.Layers(2, 1).InputWeights = single(weights_all_lstm1); %要修改的输入权值
modify_able_net.Layers(2,1).RecurrentWeights = single(recurrentWeights_all_lstm1); %要修改的权值
modify_able_net.Layers(2, 1).Bias = single(bias_all_lstm1);
Modified_net = net.loadobj(modify_able_net);%保存为net的格式
% 中间层输出
% features_lstm_layer1 = activations(Modified_net,X_Train,1); %
features_lstm_layer2 = activations(Modified_net,X_Train,2); % H1_t输出

  • 49
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值