LSTM是一种特殊的循环神经网络,用于处理序列数据中长期依赖关系。通过使用门控机制和记忆单元,缓解传统RNN梯度消失的问题,有效的捕捉时间序列中的长期模式。
一 传统RNN局限性
1.1梯度消失问题
长序列训练过程中,通过时间反向传播时,梯度可能会爆炸或者消失,原因之前讲残差网络的解释过,100层的每层梯度0.2,反向传播梯度相乘0.2的100次方梯度就消失了,要是梯度是2,就梯度爆炸了。
LSTM解决办法:通过细胞状态的线性传播和门控机制的自适应调节,为梯度提供了更稳定的传播路径。
1.2短期记忆瓶颈
传统RNN的隐藏状态,就是包含历史信息的,主要依赖的是最近的输入,而并非长期的输入。不能记忆长期的信息可能就会造成训练的长期可靠性就不高。
LSTM解决办法:LSTM通过细胞状态(长期记忆库)和门控机制(动态信息筛选),实现了对关键信息的长期保留与无关信息的主动遗忘。
二 LSTM的核心组成
LSTM通过三个门和一个记忆单元控制信息流动
符号 | 含义 |
---|---|
当前时间步(例如,序列中的第 t 个时刻) | |
时间步 ttt 的输入向量(维度例如为 d) | |
上一时间步的隐藏状态(隐藏层输出,维度例如为 h) | |
上一时间步的记忆单元(长期状态,维度同 h) |
2.1遗忘门(Forget Gate)
功能:决定从上一时刻的记忆单元 中丢弃多少信息。
公式:
符号 | 含义 |
---|---|
遗忘门输出,值在0到1之间( | |
遗忘门的权重矩阵(将 | |
遗忘门的偏置项 | |
Sigmoid激活函数,将输入压缩到0到1之间 |
输出:0到1之间的数值,0表示完全遗忘,1表示完全保留。
2.2输入门(Input Gate)
功能:确定将当前输入 的哪些信息存入记忆单元。
候选记忆生成:
输入门控:
符号 | 含义 |
---|---|
输入门输出,控制当前候选记忆的权重 | |
候选记忆,基于当前输入和隐藏状态生成的新信息(通过 | |
输入门和候选记忆的权重矩阵 | |
输入门和候选记忆的偏置项 | |
双曲正切激活函数,输出范围在-1到1之间 |
输出:更新后的记忆单元部分 ,其中
为逐元素乘法。
2.3更新记忆单元(Cell State)
综合遗忘与输入:
符号 | 含义 |
---|---|
更新后的记忆单元,结合遗忘门和输入门的结果 | |
逐元素乘法(即对应位置相乘) |
作用:长期记忆在此更新,保留关键历史信息。
2.4输出门(Output Gate)
功能:控制记忆单元 对当前输出的影响。
门控信号:
符号 | 含义 |
---|---|
输出门信号,决定记忆单元对当前输出的影响 | |
当前隐藏状态,作为LSTM的输出传递到下一时间步 | |
输出门的权重矩阵 | |
输出门的偏置项 |
最终输出:隐藏状态 将传递到下一时间步。
示意图:
输入: ──┬──────────────┬──────────┐
隐藏状态: ─┘ | |
↓ ↓ ↓
[,
]拼接 [
,
]拼接 [
,
]拼接
↓ ↓ ↓
↓ ↓ ↓
σ (Sigmoid) → σ (Sigmoid) →
tanh →
↓ ↓ ↓
├─────────┐ │ │
↓ ↓ ↓ ↓
更新记忆单元:
↓
↑__________│
↓
↓
→
三 LSTM工作流程
时间步处理:每个时间步依次计算遗忘门、输入门、记忆单元更新、输出门。
信息流动:
存储长期状态,通过门控机制选择性地保留/遗忘信息。
作为短期记忆,用于当前预测并传递至下一时间步。
梯度保护机制:
记忆单元 的累加形式使得梯度可直接传递,减少梯度消失(如
时,历史信息长期保留)。
灵活的信息控制:门控机制动态调节信息流,适应不同序列模式。
四 LSTM变体
双向LSTM(Bi-LSTM):结合前向和后向LSTM,捕捉上下文依赖。
深度LSTM:堆叠多层LSTM,增强模型表达能力。
窥视孔连接(Peephole):允许门控单元查看记忆单元 的状态。
LSTM vs GRU(GRU是LSTM的简版)
特性 | LSTM | GRU |
---|---|---|
参数数量 | 较多(3个门 + 记忆单元) | 较少(2个门) |
记忆单元 | 显式存储 CtC_tCt | 隐式合并 hth_tht |
训练速度 | 较慢 | 较快 |
应用场景 | 需精细控制长期依赖(如文本生成) | 资源受限任务(如实时预测) |
import torch
import torch.nn as nn
# 定义LSTM模型
class LSTMModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1):
super(LSTMModel, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
# 初始化隐藏状态和记忆单元
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).requires_grad_()
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).requires_grad_()
# 前向传播
out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
# 取最后一个时间步输出
out = self.fc(out[:, -1, :])
return out
# 参数设置
input_dim = 10 # 输入特征维度
hidden_dim = 20 # 隐藏层维度
output_dim = 1 # 输出维度(如回归任务)
model = LSTMModel(input_dim, hidden_dim, output_dim)
五 LSTM应用场景
时间序列预测:股票价格、天气预测。
自然语言处理:机器翻译、文本生成。
语音识别:音频序列到文本的转换。
异常检测:监测时序数据中的异常模式。
LSTM通过门控机制精细调控信息流,有效解决了长期依赖的学习难题。尽管在Transformer崛起后面临竞争,其在处理中等长度序列、资源受限场景中仍具实用价值。理解LSTM的门控逻辑及其与记忆单元的交互,是掌握现代序列建模技术的重要基础。