LSTM(长短期记忆神经网络)

一、什么是LSTM

长短时记忆网络(Long Short-Term Memory,LSTM)是一种循环神经网络(RNN)的变体,旨在解决传统RNN在处理长序列时的梯度消失和梯度爆炸问题。LSTM引入了一种特殊的存储单元和门控机制,以更有效地捕捉和处理序列数据中的长期依赖关系。

二、为什么更倾向于使用LSTM?

 最基础版本的RNN,我们可以看到,每一时刻的隐藏状态都不仅由该时刻的输入决定,还取决于上一时刻的隐藏层的值,如果一个句子很长,到句子末尾时,它将记不住这个句子的开头的内容详细内容

个例子来说明这个问题:

假设我们有一个简单的RNN用于生成文本,下面是一个包含长句子的例子:

"The cat sat on the mat. It was fluffy and had a long tail. In the corner of the room, a little mouse peeked out."

在训练RNN时,模型会逐个时间步处理每个单词。当模型处理到句子的末尾时,它的隐藏状态包含了整个句子的信息。然而,由于梯度消失的问题,模型可能在处理过程中忘记了句子开头的重要信息。

例如,在处理"peeked out"这个短语时,如果模型已经遗忘了"The cat sat on the mat."的信息,那么它可能无法正确理解"peeked out"的上下文,因为它缺乏前文的语境。

这就是为什么在处理长序列时,特别是在自然语言处理等任务中,更先进的模型如长短时记忆网络(LSTM)和门控循环单元(GRU)等被提出来,以解决RNN的梯度问题,更好地捕捉长期依赖关系。这些模型通过引入门控机制来控制信息的流动,从而提高了对长序列的建模能力。

三、LSTM的结构与普通RNN有何不同

LSTM是RNN的一种变体,更高级的RNN,那么它的本质还是一样的,还记得RNN的特点吗,可以有效的处理序列数据,当然LSTM也可以,还记得RNN是如何处理有效数据的吗,是不是每个时刻都会把隐藏层的值存下来,到下一时刻的时候再拿出来用,这样就保证了,每一时刻含有上一时刻的信息,如图,我们把存每一时刻信息的地方叫做Memory Cell,中文就是记忆细胞,可以这么理解。

LSTM和普通RNN的区别在于RNN什么信息它都存下来,因为它没有挑选的能力,而LSTM不一样,它会选择性的存储信息,因为它能力强,它有门控装置,它可以尽情的选择。如下图,普通RNN只有中间的Memory Cell用来存所有的信息,而从下图我们可以看到,LSTM多了三个Gate,也就是三个门,什么意思呢?在现实生活中,门就是用来控制进出的,门关上了,你就进不去房子了,门打开你就能进去,同理,这里的门是用来控制每一时刻信息记忆与遗忘的。

那么这三个门分别都是用来干什么的?

  1. Input Gate:中文是输入门,在每一时刻从输入层输入的信息会首先经过输入门,输入门的开关会决定这一时刻是否会有信息输入到Memory Cell。
  2. Output Gate:中文是输出门,每一时刻是否有信息从Memory Cell输出取决于这一道门。
  3. Forget Gate:中文是遗忘门,每一时刻Memory Cell里的值都会经历一个是否被遗忘的过程,就是由该门控制的,如果打卡,那么将会把Memory Cell里的值清除,也就是遗忘掉。

那么我们就可以总结出这个过程:先经过输入门,看是否有信息输入,再判断遗忘门是否选择遗忘Memory Cell里的信息,最后再经过输出门,判断是否将这一时刻的信息进行输出。

下面来看一下LSTM的内部结构把!!!

先来看一下这个符号:代表一个激活函数,LSTM里常用的激活函数有两个,一个是tanh,一个是sigmoid。

然后通过下面的图来说明四个输入分别是什么

                                    

首先解释一下,经过这个sigmod激活函数后,得到的 都是在0到1之间的数值,1表示该门完全打开,0表示该门完全关闭,

其中 是最为普通的输入,可以从上图中看到,是通过该时刻的输入 和上一时刻存在memory cell里的隐藏层信息 向量拼接,再与权重参数向量 点积,得到的值经过激活函数tanh最终会得到一个数值,也就是  ,注意只有 的激活函数是tanh,因为 是真正作为输入的,其他三个都是门控装置。

再来看  ,input gate的缩写i,所以也就是输入门的门控装置, 同样也是通过该时刻的输入  和上一时刻隐藏状态,也就是上一时刻存下来的信息 向量拼接,在与权重参数向量 点积(注意每个门的权重向量都不一样,这里的下标i代表input的意思,也就是输入门)。得到的值经过激活函数sigmoid的最终会得到一个0-1之间的一个数值,用来作为输入门的控制信号。

以此类推,就不详细讲解  了,分别是缩写forget和output的门控装置,原理与上述输入门的门控装置类似。

四、LSTM的训练方法

LSTM网络模型的训练方法与传统的RNN相似,都是采用反向传播算法。在反向传播算法中,我们需要计算损失函数对网络参数的梯度。但是由于LSTM网络模型中存在门控单元,导致梯度的计算比较复杂。为了解决这个问题,我们可以采用一种称之为“反向传播加权”的方法。

反向传播加权的核心思想是将门控单元的梯度乘以一个权重,从而使其对梯度的贡献更大。具体来说,我们可以将门控单元的输出与门控单元的输入相乘,从而得到一个权重,将其乘以门控单元的梯度即可。

下面进行举例: 

以LSTM中的遗忘门(forget gate)为例,它的输出(记为)是由sigmoid激活函数处理的,其输入包括当前时间步的输入(​)和前一时间步的隐藏状态(​)。

遗忘门的输出计算如下:

其中, 是遗忘门的权重和偏置参数,表示将隐藏状态和输入拼接在一起。

在反向传播时,我们需要计算相对于损失函数的遗忘门输出的梯度。如果我们直接采用传统的反向传播算法,计算遗忘门的梯度时会遇到梯度消失的问题。为了解决这个问题,可以采用反向传播加权的方法,即通过引入一个权重矩阵(记为 )来调整梯度。

其中,gradientgradient 表示损失函数对遗忘门输出的梯度。这个权重矩阵 可以根据具体的问题和实验来调整。

这个思想可以类似地应用到其他门控单元,比如输入门和输出门。总体来说,反向传播加权的方法是通过在门控单元的梯度计算中引入额外的权重来缓解梯度问题。

  • 33
    点赞
  • 46
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
LSTM长短期记忆神经网络是一种适用于序列数据的深度学习模型,常用于时间序列预测任务。下面介绍如何使用Matlab实现LSTM长短期记忆神经网络多变量时间序列预测。 1. 准备数据 首先,需要准备多变量时间序列数据,即多个变量随时间变化的数据。例如,可以使用Matlab自带的airline数据集作为示例数据。将数据集导入Matlab,然后将其转换为时间序列对象。 ```matlab data = readtable('airline.csv'); data = table2timetable(data); ``` 2. 数据预处理 接下来,需要对数据进行预处理,以便用于模型训练。首先,将数据集分为训练集和验证集。 ```matlab train_data = data(1:120,:); val_data = data(121:end,:); ``` 然后,对每个变量进行归一化处理,以使其值在0到1之间。 ```matlab data_normalized = normalize(data,'zscore'); ``` 最后,将数据序列转换为输入和输出序列。对于每个时间步,将前面的几个时间步作为输入,预测下一个时间步的输出。这里将前10个时间步作为输入,预测下一个时间步的输出。 ```matlab XTrain = []; YTrain = []; for i=1:110 XTrain(:,:,i) = data_normalized{i:i+9,:}; YTrain(i,:) = data_normalized{i+10,:}; end ``` 同样地,对验证集进行相同的操作。 ```matlab XVal = []; YVal = []; for i=1:14 XVal(:,:,i) = data_normalized{110+i:119+i,:}; YVal(i,:) = data_normalized{129+i,:}; end ``` 3. 构建LSTM模型 接下来,需要构建LSTM模型。这里使用Matlab自带的LSTM层和FullyConnected层构建模型。输入序列的长度为10,输出序列的长度为1。模型中包含两个LSTM层和两个FullyConnected层,每个LSTM层和FullyConnected层的节点数为64。 ```matlab numFeatures = size(XTrain,2); numResponses = size(YTrain,2); numHiddenUnits = 64; layers = [ ... sequenceInputLayer(numFeatures) lstmLayer(numHiddenUnits,'OutputMode','sequence') lstmLayer(numHiddenUnits,'OutputMode','last') fullyConnectedLayer(64) dropoutLayer(0.5) fullyConnectedLayer(numResponses) regressionLayer]; options = trainingOptions('adam', ... 'MaxEpochs',100, ... 'GradientThreshold',1, ... 'InitialLearnRate',0.005, ... 'LearnRateSchedule','piecewise', ... 'LearnRateDropFactor',0.2, ... 'LearnRateDropPeriod',20, ... 'ValidationData',{XVal,YVal}, ... 'ValidationFrequency',5, ... 'Plots','training-progress', ... 'Verbose',false); net = trainNetwork(XTrain,YTrain,layers,options); ``` 4. 模型预测 训练完成后,可以使用模型对测试集进行预测。首先将测试集数据归一化处理,然后将其转换为输入序列。 ```matlab data_test_normalized = normalize(data(121:end,:),'zscore'); XTest = []; for i=1:14 XTest(:,:,i) = data_test_normalized{i:i+9,:}; end ``` 最后,使用模型对测试集进行预测,并将预测结果反归一化处理。 ```matlab YPred = predict(net,XTest); YPred = YPred .* std(data{121:end,:}) + mean(data{121:end,:}); ``` 5. 结果可视化 最后,将模型预测结果与测试集真实值进行比较,以评估模型的预测性能。 ```matlab figure plot(data{121:end,:}) hold on plot(YPred,'.-') hold off legend(["Observed" "Predicted"]) ylabel("Passengers") title("Forecast") ``` 通过可视化结果,可以评估模型的预测性能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值