引言
长短期记忆网络(Long Short-Term Memory,简称 LSTM)是一种递归神经网络(Recurrent Neural Network,简称 RNN)的变体,它在处理序列数据时表现出色,并且被广泛应用于时间序列分析、自然语言处理和语音识别等领域。
数据集res为3505的数值矩阵,有三个特征值,二个输出
LSTM 简介
LSTM 是一种特殊的 RNN,它通过引入门控机制来解决传统 RNN 的长期依赖问题。LSTM 的核心思想是在每个时间步上维护一个细胞状态(cell state),并通过输入门、遗忘门和输出门等门控单元来控制细胞状态的信息流动。
- 输入门(input gate):决定当前时间步的输入信息对细胞状态的影响程度。
- 遗忘门(forgetgate):决定上一个时间步的细胞状态对当前时间步的影响程度。
- 细胞状态(cell state):用于在不同时间步之间传递和存储信息。
- 输出门(output gate):决定细胞状态对当前时间步的输出影响程度。
- 隐藏状态(hiddenstate):当前时间步的输出,也是下一个时间步的输入。
LSTM 通过这些门控单元的组合和调整,能够有效地处理长序列数据,并捕捉序列中的长期依赖关系。具体门的数学公式网上有,此处只讲应用。
LSTM多输入多输出matlab实现
%% 划分训练集和测试集
temp = randperm(350);
P_train = res(temp(1: 200), 1: 3)';
T_train = res(temp(1: 200), 4: 5)';
M = size(P_train, 2);
P_test = res(temp(201: end), 1: 3)';
T_test = res(temp(201: end), 4: 5)';
N = size(P_test, 2);
%% 数据归一化
[P_train, ps_input] = mapminmax(P_train, 0, 1);
P_test = mapminmax('apply', P_test, ps_input);
[t_train, ps_output] = mapminmax(T_train, 0, 1);
t_test = mapminmax('apply', T_test, ps_output);
%% 数据平铺
P_train = double(reshape(P_train, 3, 1, 1, M));
P_test = double(reshape(P_test , 3, 1, 1, N));
t_train = t_train';
t_test = t_test' ;
%% 数据格式转换
for i = 1 : M
p_train{i, 1} = P_train(:, :, 1, i);
end
for i = 1 : N
p_test{i, 1} = P_test( :, :, 1, i);
end
%% 创建模型
layers = [
sequenceInputLayer(3) % 建立输入层
lstmLayer(64, 'OutputMode', 'last') % LSTM层
reluLayer % Relu激活层
dropoutLayer(0.2) % 添加 Dropout 层,丢弃概率为 0.2
lstmLayer(64, 'OutputMode', 'last') % LSTM层
reluLayer % Relu激活层
dropoutLayer(0.2) % 添加 Dropout 层,丢弃概率为 0.2
fullyConnectedLayer(2) % 全连接层
regressionLayer]; % 回归层
%% 参数设置
options = trainingOptions('adam', ... % Adam 梯度下降算法
'MiniBatchSize', 30, ... % 批大小
'MaxEpochs', 1200, ... % 最大迭代次数
'InitialLearnRate', 1e-2, ... % 初始学习率为
'LearnRateSchedule', 'piecewise', ... % 学习率下降
'LearnRateDropFactor', 0.5, ... % 学习率下降因子
'LearnRateDropPeriod', 800, ... % 经过 800 次训练后 学习率为 0.01 * 0.5
'Shuffle', 'every-epoch', ... % 每次训练打乱数据集
'Plots', 'training-progress', ... % 画出曲线
'Verbose', false);
%% 训练模型
net = trainNetwork(p_train, t_train, layers, options);
%% 仿真预测
t_sim1 = predict(net, p_train);
t_sim2 = predict(net, p_test );
%% 查看网络结构
analyzeNetwork(net)
%% 数据反归一化
T_sim1 = mapminmax('reverse', t_sim1', ps_output)';
T_sim2 = mapminmax('reverse', t_sim2', ps_output)';
for i = 1: 2
%% 均方根误差
error1(i, :) = sqrt(sum((T_sim1(:,i)' - T_train(i,:)).^2) ./ M);
error2(i, :) = sqrt(sum((T_sim2(:,i)' - T_test(i,:) ).^2) ./ N);
%% 绘图
figure
subplot(2, 1, 1)
plot(1: M, T_train(i,:), 'r-*', 1: M, T_sim1(:,i), 'b-o', 'LineWidth', 1)
legend('True value', 'Predicted value')
xlabel('Predict the sample')
ylabel('Predict the outcome')
string = {'Comparison of training set prediction results'; ['RMSE=' num2str(error1(i,:))]};
title(string)
xlim([1, M])
grid
subplot(2, 1, 2)
plot(1: N, T_test(i,:), 'r-*', 1: N, T_sim2(:,i), 'b-o', 'LineWidth', 1)
legend('True value', 'Predicted value')
xlabel('Predict the sample')
ylabel('Predict the outcome')
string = {'Comparison of test set prediction results'; ['RMSE=' num2str(error2(i,:))]};
title(string)
xlim([1, N])
grid
%% 分割线
disp('**************************')
disp(['下列是输出', num2str(i)])
disp('**************************')
%% 相关指标计算
% R2
R1(i,:) = 1 - norm(T_train(i,:) - T_sim1(:,i)')^2 / norm(T_train(i,:) - mean(T_train(i,:)))^2;
R2(i,:) = 1 - norm(T_test(i,:) - T_sim2(:,i)')^2 / norm(T_test(i,:) - mean(T_test(i,:) ))^2;
disp(['The R2 of the training set data is:', num2str(R1(i,:))])
disp(['The R2 of the test set data is:', num2str(R2(i,:))])
% MAE
mae1(i,:) = sum(abs(T_sim1(:,i)' - T_train(i,:))) ./ M ;
mae2(i,:)= sum(abs(T_sim2(:,i)' - T_test(i,:) )) ./ N ;
disp(['The MAE of the training set data is:', num2str(mae1(i,:))])
disp(['The MAE of the test set data is:', num2str(mae2(i,:))])
% MBE
mbe1(i,:) = sum(T_sim1(:,i)' - T_train(i,:)) ./ M ;
mbe2(i,:) = sum(T_sim2(:,i)' - T_test(i,:)) ./ N ;
disp(['The MBE of the training set data is:', num2str(mbe1(i,:))])
disp(['The MBE of the test set data is:', num2str(mbe2(i,:))])
end


文章介绍了LSTM网络的基本原理和其在处理序列数据的优势,特别是在时间序列分析中的应用。然后,通过Matlab代码展示了如何对具有三个特征和两个输出的3505样本数据集进行LSTM模型的训练、测试,包括数据预处理、模型构建、训练参数设置、预测及性能评估(如RMSE、R2、MAE和MBE)。
1209

被折叠的 条评论
为什么被折叠?



