LSTM模型
LSTM(Long Short-Term Memory)也称长短时记忆结构, 它是传统RNN的变体, 与经典RNN相比能够有效捕捉长序列之间的语义关联, 缓解梯度消失或爆炸现象
LSTM核心结构
-
遗忘门
-
输入门
-
细胞状态
-
输出门
LSTM的内部结构图
- 结构解释图:
遗忘门
遗忘门部分结构图与计算公式
遗忘门结构分析
与传统RNN的内部结构计算非常相似, 首先将当前时间步输入x(t)与上一个时间步隐含状态h(t-1)拼接, 得到[x(t), h(t-1)], 然后通过一个全连接层做变换, 最后通过sigmoid函数进行激活得到f(t), 我们可以将f(t)看作是门值, 好比一扇门开合的大小程度, 门值都将作用在通过该扇门的张量。
遗忘门门值将作用的上一层的细胞状态上, 代表遗忘过去的多少信息, 又因为遗忘门门值是由x(t), h(t-1)计算得来的, 因此整个公式意味着根据当前时间步输入和上一个时间步隐含状态h(t-1)来决定遗忘多少上一层的细胞状态所携带的过往信息.
遗忘门内部结构过程演示
激活函数sigmiod的作用
- 用于帮助调节流经网络的值, sigmoid函数将值压缩在0和1之间.
输入门
输入门部分结构图与计算公式
输入门结构分析
我们看到输入门的计算公式有两个, 第一个就是产生输入门门值的公式, 它和遗忘门公式几乎相同, 区别只是在于它们之后要作用的目标上. 这个公式意味着输入信息有多少需要进行过滤. 输入门的第二个公式是与传统RNN的内部结构计算相同. 对于LSTM来讲, 它得到的是当前的细胞状态, 而不是像经典RNN一样得到的是隐含状态.
输入门内部结构过程演示
细胞更新状态
细胞状态更新图与计算公式
细胞状态更新分析
细胞更新的结构与计算公式非常容易理解, 这里没有全连接层, 只是将刚刚得到的遗忘门门值与上一个时间步得到的C(t-1)相乘, 再加上输入门门值与当前时间步得到的未更新C(t)相乘的结果. 最终得到更新后的C(t)作为下一个时间步输入的一部分. 整个细胞状态更新过程就是对遗忘门和输入门的应用.
细胞状态更新过程演示
输出门
输出门部分结构图与计算公式
输出门结构分析
输出门部分的公式也是两个, 第一个即是计算输出门的门值, 它和遗忘门,输入门计算方式相同. 第二个即是使用这个门值产生隐含状态h(t), 他将作用在更新后的细胞状态C(t)上, 并做tanh激活, 最终得到h(t)作为下一时间步输入的一部分. 整个输出门的过程, 就是为了产生隐含状态h(t).
输出门内部结构过程演示
举例
'''
Description: lstm举例
Autor: 365JHWZGo
Date: 2021-12-09 19:20:23
LastEditors: 365JHWZGo
LastEditTime: 2021-12-09 19:28:08
'''
import torch
import torch.nn as nn
torch.manual_seed(1)
TIME_STEP = 1
INPUT_SIZE = 5
HIDDEN_LAYER = 2
HIDDEN_SIZE = 6
BATCH_SIZE = 3
input_data = torch.randn(TIME_STEP,BATCH_SIZE,INPUT_SIZE)
h0 = torch.randn(HIDDEN_LAYER,BATCH_SIZE,HIDDEN_SIZE)
c0 = torch.randn(HIDDEN_LAYER,BATCH_SIZE,HIDDEN_SIZE)
lstm = nn.LSTM(INPUT_SIZE,HIDDEN_SIZE,HIDDEN_LAYER)
output,(h_,c_) = lstm(input_data,(h0,c0))
print(f'output:\t{output}')
print(f'h_:\t{h_}')
print(f'c_:\t{c_}')