长短期记忆(Long Short-Term Memory, LSTM)网络是一种特殊的循环神经网络(RNN),用于解决标准RNN在处理长序列时的梯度消失和梯度爆炸问题。LSTM通过引入门机制来控制信息流动,从而能够捕捉更长时间范围内的依赖关系。
LSTM 单元结构
LSTM 单元由三个主要的门组成:遗忘门、输入门和输出门。每个门都包含一个 sigmoid 激活函数,用于确定信息的通过量。LSTM 的记忆细胞状态 C t C_t Ct可以看作一个传输带,它可以直接流向下一时间步,只有少量线性交互,从而确保梯度可以很好地传播。下面是详细的结构和公式:
1. 遗忘门(Forget Gate)
遗忘门决定丢弃多少以前的记忆细胞状态 C t − 1 C_{t-1} Ct−1:
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf⋅[ht−1,xt]+bf)
其中:
-
σ
\sigma
σ是 sigmoid 激活函数。
-
W
f
W_f
Wf是权重矩阵,
b
f
b_f
bf是偏置项。
-
h
t
−
1
h_{t-1}
ht−1是前一个时间步的隐藏状态,
x
t
x_t
xt是当前时间步的输入。
2. 输入门(Input Gate)
输入门决定更新多少当前的记忆细胞状态:
i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi⋅[ht−1,xt]+bi)
候选的记忆细胞状态:
C ~ t = tanh ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC⋅[ht−1,xt]+bC)
其中:
-
W
i
W_i
Wi和
W
C
W_C
WC是权重矩阵,
b
i
b_i
bi和
b
C
b_C
bC是偏置项。
-
tanh
\tanh
tanh是双曲正切激活函数。
3. 更新记忆细胞状态
结合遗忘门和输入门,更新记忆细胞状态:
C t = f t ⋅ C t − 1 + i t ⋅ C ~ t C_t = f_t \cdot C_{t-1} + i_t \cdot \tilde{C}_t Ct=ft⋅Ct−1+it⋅C~t
4. 输出门(Output Gate)
输出门决定当前隐藏状态 h t h_t ht的输出:
o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo⋅[ht−1,xt]+bo)
最终的隐藏状态:
h t = o t ⋅ tanh ( C t ) h_t = o_t \cdot \tanh(C_t) ht=ot⋅tanh(Ct)
LSTM 的具体数值例子
假设我们有一个简单的输入序列: x 1 , x 2 x_1, x_2 x1,x2,它们的值分别是: x 1 = 0.5 x_1 = 0.5 x1=0.5和 x 2 = 0.8 x_2 = 0.8 x2=0.8。我们通过 LSTM 单元来计算输出。
-
初始状态:
- 初始隐藏状态 h 0 = 0 h_0 = 0 h0=0
- 初始记忆细胞状态 C 0 = 0 C_0 = 0 C0=0
-
权重和偏置(假设为已知值):
- W f = 0.1 , b f = 0.1 W_f = 0.1, b_f = 0.1 Wf=0.1,bf=0.1
- W i = 0.2 , b i = 0.2 W_i = 0.2, b_i = 0.2 Wi=0.2,bi=0.2
- W C = 0.3 , b C = 0.3 W_C = 0.3, b_C = 0.3 WC=0.3,bC=0.3
- W o = 0.4 , b o = 0.4 W_o = 0.4, b_o = 0.4 Wo=0.4,bo=0.4 -
第一个时间步 x 1 = 0.5 x_1 = 0.5 x1=0.5:
-
遗忘门:
f 1 = σ ( 0.1 ⋅ [ 0 , 0.5 ] + 0.1 ) = σ ( 0.05 + 0.1 ) = σ ( 0.15 ) ≈ 0.537 f_1 = \sigma(0.1 \cdot [0, 0.5] + 0.1) = \sigma(0.05 + 0.1) = \sigma(0.15) \approx 0.537 f1=σ(0.1⋅[0,0.5]+0.1)=σ(0.05+0.1)=σ(0.15)≈0.537 -
输入门:
i 1 = σ ( 0.2 ⋅ [ 0 , 0.5 ] + 0.2 ) = σ ( 0.1 + 0.2 ) = σ ( 0.3 ) ≈ 0.574 i_1 = \sigma(0.2 \cdot [0, 0.5] + 0.2) = \sigma(0.1 + 0.2) = \sigma(0.3) \approx 0.574 i1=σ(0.2⋅[0,0.5]+0.2)=σ(0.1+0.2)=σ(0.3)≈0.574 -
候选记忆细胞状态:
C ~ 1 = tanh ( 0.3 ⋅ [ 0 , 0.5 ] + 0.3 ) = tanh ( 0.15 + 0.3 ) = tanh ( 0.45 ) ≈ 0.422 \tilde{C}_1 = \tanh(0.3 \cdot [0, 0.5] + 0.3) = \tanh(0.15 + 0.3) = \tanh(0.45) \approx 0.422 C~1=tanh(0.3⋅[0,0.5]+0.3)=tanh(0.15+0.3)=tanh(0.45)≈0.422 -
更新记忆细胞状态:
C 1 = 0.537 ⋅ 0 + 0.574 ⋅ 0.422 ≈ 0.242 C_1 = 0.537 \cdot 0 + 0.574 \cdot 0.422 \approx 0.242 C1=0.537⋅0+0.574⋅0.422≈0.242 -
输出门:
o 1 = σ ( 0.4 ⋅ [ 0 , 0.5 ] + 0.4 ) = σ ( 0.2 + 0.4 ) = σ ( 0.6 ) ≈ 0.645 o_1 = \sigma(0.4 \cdot [0, 0.5] + 0.4) = \sigma(0.2 + 0.4) = \sigma(0.6) \approx 0.645 o1=σ(0.4⋅[0,0.5]+0.4)=σ(0.2+0.4)=σ(0.6)≈0.645 -
隐藏状态:
h 1 = 0.645 ⋅ tanh ( 0.242 ) ≈ 0.645 ⋅ 0.237 ≈ 0.153 h_1 = 0.645 \cdot \tanh(0.242) \approx 0.645 \cdot 0.237 \approx 0.153 h1=0.645⋅tanh(0.242)≈0.645⋅0.237≈0.153
- 第二个时间步 x 2 = 0.8 x_2 = 0.8 x2=0.8:
-
遗忘门:
f 2 = σ ( 0.1 ⋅ [ 0.153 , 0.8 ] + 0.1 ) = σ ( 0.1 ⋅ 0.953 + 0.1 ) = σ ( 0.195 ) ≈ 0.548 f_2 = \sigma(0.1 \cdot [0.153, 0.8] + 0.1) = \sigma(0.1 \cdot 0.953 + 0.1) = \sigma(0.195) \approx 0.548 f2=σ(0.1⋅[0.153,0.8]+0.1)=σ(0.1⋅0.953+0.1)=σ(0.195)≈0.548 -
输入门:
i 2 = σ ( 0.2 ⋅ [ 0.153 , 0.8 ] + 0.2 ) = σ ( 0.2 ⋅ 0.953 + 0.2 ) = σ ( 0.391 ) ≈ 0.596 i_2 = \sigma(0.2 \cdot [0.153, 0.8] + 0.2) = \sigma(0.2 \cdot 0.953 + 0.2) = \sigma(0.391) \approx 0.596 i2=σ(0.2⋅[0.153,0.8]+0.2)=σ(0.2⋅0.953+0.2)=σ(0.391)≈0.596 -
候选记忆细胞状态:
C ~ 2 = tanh ( 0.3 ⋅ [ 0.153 , 0.8 ] + 0.3 ) = tanh ( 0.3 ⋅ 0.953 + 0.3 ) = tanh ( 0.586 ) ≈ 0.528 \tilde{C}_2 = \tanh(0.3 \cdot [0.153, 0.8] + 0.3) = \tanh(0.3 \cdot 0.953 + 0.3) = \tanh(0.586) \approx 0.528 C~2=tanh(0.3⋅[0.153,0.8]+0.3)=tanh(0.3⋅0.953+0.3)=tanh(0.586)≈0.528 -
更新记忆细胞状态:
C 2 = 0.548 ⋅ 0.242 + 0.596 ⋅ 0.528 ≈ 0.133 + 0.315 ≈ 0.448 C_2 = 0.548 \cdot 0.242 + 0.596 \cdot 0.528 \approx 0.133 + 0.315 \approx 0.448 C2=0.548⋅0.242+0.596⋅0.528≈0.133+0.315≈0.448 -
输出门:
o 2 = σ ( 0.4 ⋅ [ 0.153 , 0.8 ] + 0.4 ) = σ ( 0.4 ⋅ 0.953 + 0.4 ) = σ ( 0.781 ) ≈ 0.686 o_2 = \sigma(0.4 \cdot [0.153, 0.8] + 0.4) = \sigma(0.4 \cdot 0.953 + 0.4) = \sigma(0.781) \approx 0.686 o2=σ(0.4⋅[0.153,0.8]+0.4)=σ(0.4⋅0.953+0.4)=σ(0.781)≈0.686 -
隐藏状态:
h 2 = 0.686 ⋅ tanh ( 0.448 ) ≈ 0.686 ⋅ 0.42 ≈ 0.288 h_2 = 0.686 \cdot \tanh(0.448) \approx 0.686 \cdot 0.42 \approx 0.288 h2=0.686⋅tanh(0.448)≈0.686⋅0.42≈0.288
总结
通过这个具体的数值例子,我们可以看到 LSTM 如何通过遗忘门、输入门和输出门来更新隐藏状态和记忆细胞状态,从而在序列建模中捕捉长时间范围内的依赖关系。这种结构有效解决了标准 RNN 的梯度消失问题,使其在处理语言建模和机器翻译等任务时表现优异。