循环神经网络(RNN)深度解析

目录

  1. ​模型结构详解​
  2. ​数学原理与推导​
  3. ​代表性变体及改进​
  4. ​应用场景与优缺点​
  5. ​PyTorch代码示例​

1. 模型结构详解

1.1 核心结构

RNN通过​​时间步循环​​处理序列数据,核心组件包括:

输入序列 → 隐藏状态传递 → 输出序列

  • ​输入​​:序列数据 X=[x1​,x2​,...,xT​],每个时间步输入 xt​∈Rd
  • ​隐藏状态​​:ht​∈Rh,存储历史信息
  • ​输出​​:yt​∈Rk,每个时间步可输出(如序列标注)或仅最后一步输出(如文本分类)
结构示意图

1.2 激活函数

  • ​隐藏层激活​​:通常使用 ​​tanh​​(梯度稳定)

    \tanh(z) = \frac{e^z - e^{-z}}{e^z + e^{-z}}
  • ​输出层激活​​:根据任务选择(如Softmax用于分类)

1.3 输入输出形式

任务类型输入输出形式示例
序列到序列每个时间步均有输入和输出机器翻译(中→英)
序列到单值仅最后时间步输出文本情感分类
单值到序列仅第一时间步输入,后续自生成输出文本生成

2. 数学原理与推导

2.1 前向传播公式

  • ​隐藏状态更新​​:

    h_t = \tanh(W_{xh} x_t + W_{hh} h_{t-1} + b_h)
  • ​输出计算​​:
    y_t = \text{Softmax}(W_{hy} h_t + b_y)

2.2 反向传播(BPTT算法)

损失函数 L 对参数 Whh​ 的梯度计算:

\frac{\partial \mathcal{L}}{\partial W_{hh}} = \sum_{t=1}^T \frac{\partial \mathcal{L}}{\partial h_t} \cdot \frac{\partial h_t}{\partial W_{hh}}
其中 依赖所有历史时间步,导致​​梯度消失/爆炸​​。


3. 代表性变体及改进

3.1 LSTM(长短期记忆网络)

核心结构
  • ​门控机制​​:
    • 输入门 it​:控制新信息流入
    • 遗忘门 ft​:控制历史信息遗忘
    • 输出门 ot​:控制当前状态输出
  • ​细胞状态 Ct​​​:长期记忆存储
公式源码
  • 输入门:
    i_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i)
  • 遗忘门:
    f_t = \sigma(W_{xf} x_t + W_{hf} h_{t-1} + b_f)
  • 细胞状态更新:
    \tilde{C}t = \tanh(W{xc} x_t + W_{hc} h_{t-1} + b_c) C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t
  • 输出门:
    o_t = \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o) h_t = o_t \odot \tanh(C_t)

3.2 GRU(门控循环单元)

核心简化
  • ​合并门控​​:更新门 zt​ 替代输入门和遗忘门
  • ​重置门 rt​​​:控制历史信息保留比例
公式源码
  • 更新门:
    z_t = \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z)
  • 重置门:
    r_t = \sigma(W_{xr} x_t + W_{hr} h_{t-1} + b_r)
  • 候选状态:
    \tilde{h}t = \tanh(W{xh} x_t + W_{hh} (r_t \odot h_{t-1}) + b_h)
  • 隐藏状态:
    h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t

3.3 双向RNN(Bi-RNN)

  • ​结构​​:前向RNN + 后向RNN → 拼接输出
  • ​公式​​:
    h_t^{\rightarrow} = \text{RNN}(x_t, h_{t-1}^{\rightarrow})
    h_t^{\leftarrow} = \text{RNN}(x_t, h_{t+1}^{\leftarrow})
    y_t = W_y [h_t^{\rightarrow}; h_t^{\leftarrow}] + b_y

4. 应用场景与优缺点

4.1 应用场景

  • ​自然语言处理​​:机器翻译、文本生成
  • ​时间序列预测​​:股票价格、气象数据
  • ​语音识别​​:音频序列转文本

4.2 优缺点对比

优点缺点
处理变长序列梯度消失/爆炸问题(基础RNN)
捕捉时序依赖计算效率低(无法并行)
灵活结构设计长序列记忆能力有限(需LSTM/GRU)

5. PyTorch代码示例

import torch
import torch.nn as nn

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), self.hidden_size)  # 初始隐藏状态
        out, _ = self.rnn(x, h0)
        out = self.fc(out[:, -1, :])  # 取最后时间步输出
        return out

# 示例:序列分类任务
model = SimpleRNN(input_size=10, hidden_size=20, output_size=2)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 输入数据格式:[batch_size, seq_length, input_size]
inputs = torch.randn(3, 5, 10)  # 3个样本,序列长度5,特征维度10
labels = torch.tensor([1, 0, 1])  # 分类标签

# 训练循环
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()

核心总结

  • ​结构本质​​:通过时间步循环传递隐藏状态,建模序列依赖
  • ​核心缺陷​​:基础RNN存在梯度消失/爆炸,需LSTM/GRU优化
  • ​工程价值​​:语音、文本等时序任务的基础架构
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值