目录
- 模型结构详解
- 数学原理与推导
- 代表性变体及改进
- 应用场景与优缺点
- PyTorch代码示例
1. 模型结构详解
1.1 核心结构
RNN通过时间步循环处理序列数据,核心组件包括:
输入序列 → 隐藏状态传递 → 输出序列
- 输入:序列数据 X=[x1,x2,...,xT],每个时间步输入 xt∈Rd
- 隐藏状态:ht∈Rh,存储历史信息
- 输出:yt∈Rk,每个时间步可输出(如序列标注)或仅最后一步输出(如文本分类)
结构示意图
1.2 激活函数
- 隐藏层激活:通常使用 tanh(梯度稳定)
- 输出层激活:根据任务选择(如Softmax用于分类)
1.3 输入输出形式
任务类型 | 输入输出形式 | 示例 |
---|---|---|
序列到序列 | 每个时间步均有输入和输出 | 机器翻译(中→英) |
序列到单值 | 仅最后时间步输出 | 文本情感分类 |
单值到序列 | 仅第一时间步输入,后续自生成输出 | 文本生成 |
2. 数学原理与推导
2.1 前向传播公式
- 隐藏状态更新:
- 输出计算:
2.2 反向传播(BPTT算法)
损失函数 L 对参数 Whh 的梯度计算:
其中 依赖所有历史时间步,导致梯度消失/爆炸。
3. 代表性变体及改进
3.1 LSTM(长短期记忆网络)
核心结构
- 门控机制:
- 输入门 it:控制新信息流入
- 遗忘门 ft:控制历史信息遗忘
- 输出门 ot:控制当前状态输出
- 细胞状态 Ct:长期记忆存储
公式源码
- 输入门:
- 遗忘门:
- 细胞状态更新:
- 输出门:
3.2 GRU(门控循环单元)
核心简化
- 合并门控:更新门 zt 替代输入门和遗忘门
- 重置门 rt:控制历史信息保留比例
公式源码
- 更新门:
- 重置门:
- 候选状态:
- 隐藏状态:
3.3 双向RNN(Bi-RNN)
- 结构:前向RNN + 后向RNN → 拼接输出
- 公式:
4. 应用场景与优缺点
4.1 应用场景
- 自然语言处理:机器翻译、文本生成
- 时间序列预测:股票价格、气象数据
- 语音识别:音频序列转文本
4.2 优缺点对比
优点 | 缺点 |
---|---|
处理变长序列 | 梯度消失/爆炸问题(基础RNN) |
捕捉时序依赖 | 计算效率低(无法并行) |
灵活结构设计 | 长序列记忆能力有限(需LSTM/GRU) |
5. PyTorch代码示例
import torch
import torch.nn as nn
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.hidden_size = hidden_size
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(1, x.size(0), self.hidden_size) # 初始隐藏状态
out, _ = self.rnn(x, h0)
out = self.fc(out[:, -1, :]) # 取最后时间步输出
return out
# 示例:序列分类任务
model = SimpleRNN(input_size=10, hidden_size=20, output_size=2)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 输入数据格式:[batch_size, seq_length, input_size]
inputs = torch.randn(3, 5, 10) # 3个样本,序列长度5,特征维度10
labels = torch.tensor([1, 0, 1]) # 分类标签
# 训练循环
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
核心总结
- 结构本质:通过时间步循环传递隐藏状态,建模序列依赖
- 核心缺陷:基础RNN存在梯度消失/爆炸,需LSTM/GRU优化
- 工程价值:语音、文本等时序任务的基础架构