门控循环单元(GRU)
门控循环单元(Gated Recurrent Unit,GRU)是长短期记忆(LSTM)的简化版本。GRU通过减少门控机制的数量,提高了计算效率,同时在很多任务上性能与LSTM相近。GRU由两个主要的门组成:重置门和更新门。这些门帮助GRU决定如何在每个时间步更新和传递信息。
GRU 结构
GRU 的结构相比 LSTM 更简单,没有独立的记忆细胞状态。它通过两个门(重置门和更新门)来控制信息的流动和状态的更新。
1. 重置门(Reset Gate)
重置门控制前一时刻的隐藏状态在当前时刻的信息重置程度。如果重置门的输出接近于0,意味着忘记前一时刻的状态信息;如果输出接近于1,则保留前一时刻的状态信息。
重置门的公式如下:
r t = σ ( W r ⋅ [ h t − 1 , x t ] + b r ) r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) rt=σ(Wr⋅[ht−1,xt]+br)
其中:
-
r
t
r_t
rt是重置门的输出。
-
σ
\sigma
σ是 sigmoid 激活函数。
-
W
r
W_r
Wr是权重矩阵,
b
r
b_r
br是偏置项。
-
h
t
−
1
h_{t-1}
ht−1是前一时刻的隐藏状态,
x
t
x_t
xt是当前时刻的输入。
2. 更新门(Update Gate)
更新门决定前一时刻的隐藏状态和当前时刻的新候选隐藏状态的权重比例。更新门的输出用于在新旧信息之间进行加权平均。
更新门的公式如下:
z t = σ ( W z ⋅ [ h t − 1 , x t ] + b z ) z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) zt=σ(Wz⋅[ht−1,xt]+bz)
其中:
-
z
t
z_t
zt是更新门的输出。
-
W
z
W_z
Wz是权重矩阵,
b
z
b_z
bz是偏置项。
3. 候选隐藏状态
候选隐藏状态结合了当前输入和前一时刻的隐藏状态(经过重置门调节),用于更新当前的隐藏状态。
候选隐藏状态的公式如下:
h ~ t = tanh ( W ⋅ [ r t ∗ h t − 1 , x t ] + b ) \tilde{h}_t = \tanh(W \cdot [r_t * h_{t-1}, x_t] + b) h~t=tanh(W⋅[rt∗ht−1,xt]+b)
其中:
-
h
~
t
\tilde{h}_t
h~t是候选隐藏状态。
-
r
t
∗
h
t
−
1
r_t * h_{t-1}
rt∗ht−1表示重置门控制下的前一时刻隐藏状态。
4. 更新隐藏状态
最终的隐藏状态通过更新门的输出进行加权平均:
h t = ( 1 − z t ) ∗ h t − 1 + z t ∗ h ~ t h_t = (1 - z_t) * h_{t-1} + z_t * \tilde{h}_t ht=(1−zt)∗ht−1+zt∗h~t
具体数值举例
假设我们有一个简单的输入序列: x 1 = 0.5 x_1 = 0.5 x1=0.5和 x 2 = 0.8 x_2 = 0.8 x2=0.8。我们通过 GRU 单元来计算输出。
-
初始状态:
- 初始隐藏状态 h 0 = 0 h_0 = 0 h0=0
-
权重和偏置(假设为已知值):
- W r = 0.1 , b r = 0.1 W_r = 0.1, b_r = 0.1 Wr=0.1,br=0.1
- W z = 0.2 , b z = 0.2 W_z = 0.2, b_z = 0.2 Wz=0.2,bz=0.2
- W = 0.3 , b = 0.3 W = 0.3, b = 0.3 W=0.3,b=0.3 -
第一个时间步 x 1 = 0.5 x_1 = 0.5 x1=0.5:
-
重置门:
r 1 = σ ( 0.1 ⋅ [ 0 , 0.5 ] + 0.1 ) = σ ( 0.05 + 0.1 ) = σ ( 0.15 ) ≈ 0.537 r_1 = \sigma(0.1 \cdot [0, 0.5] + 0.1) = \sigma(0.05 + 0.1) = \sigma(0.15) \approx 0.537 r1=σ(0.1⋅[0,0.5]+0.1)=σ(0.05+0.1)=σ(0.15)≈0.537 -
更新门:
z 1 = σ ( 0.2 ⋅ [ 0 , 0.5 ] + 0.2 ) = σ ( 0.1 + 0.2 ) = σ ( 0.3 ) ≈ 0.574 z_1 = \sigma(0.2 \cdot [0, 0.5] + 0.2) = \sigma(0.1 + 0.2) = \sigma(0.3) \approx 0.574 z1=σ(0.2⋅[0,0.5]+0.2)=σ(0.1+0.2)=σ(0.3)≈0.574 -
候选隐藏状态:
h ~ 1 = tanh ( 0.3 ⋅ [ 0.537 ⋅ 0 , 0.5 ] + 0.3 ) = tanh ( 0.15 ) ≈ 0.149 \tilde{h}_1 = \tanh(0.3 \cdot [0.537 \cdot 0, 0.5] + 0.3) = \tanh(0.15) \approx 0.149 h~1=tanh(0.3⋅[0.537⋅0,0.5]+0.3)=tanh(0.15)≈0.149 -
更新隐藏状态:
h 1 = ( 1 − 0.574 ) ⋅ 0 + 0.574 ⋅ 0.149 ≈ 0.086 h_1 = (1 - 0.574) \cdot 0 + 0.574 \cdot 0.149 \approx 0.086 h1=(1−0.574)⋅0+0.574⋅0.149≈0.086
- 第二个时间步 x 2 = 0.8 x_2 = 0.8 x2=0.8:
-
重置门:
r 2 = σ ( 0.1 ⋅ [ 0.086 , 0.8 ] + 0.1 ) = σ ( 0.1 ⋅ 0.886 + 0.1 ) = σ ( 0.1886 ) ≈ 0.547 r_2 = \sigma(0.1 \cdot [0.086, 0.8] + 0.1) = \sigma(0.1 \cdot 0.886 + 0.1) = \sigma(0.1886) \approx 0.547 r2=σ(0.1⋅[0.086,0.8]+0.1)=σ(0.1⋅0.886+0.1)=σ(0.1886)≈0.547 -
更新门:
z 2 = σ ( 0.2 ⋅ [ 0.086 , 0.8 ] + 0.2 ) = σ ( 0.2 ⋅ 0.886 + 0.2 ) = σ ( 0.3772 ) ≈ 0.593 z_2 = \sigma(0.2 \cdot [0.086, 0.8] + 0.2) = \sigma(0.2 \cdot 0.886 + 0.2) = \sigma(0.3772) \approx 0.593 z2=σ(0.2⋅[0.086,0.8]+0.2)=σ(0.2⋅0.886+0.2)=σ(0.3772)≈0.593 -
候选隐藏状态:
h ~ 2 = tanh ( 0.3 ⋅ [ 0.547 ⋅ 0.086 , 0.8 ] + 0.3 ) = tanh ( 0.3211 ) ≈ 0.310 \tilde{h}_2 = \tanh(0.3 \cdot [0.547 \cdot 0.086, 0.8] + 0.3) = \tanh(0.3211) \approx 0.310 h~2=tanh(0.3⋅[0.547⋅0.086,0.8]+0.3)=tanh(0.3211)≈0.310 -
更新隐藏状态:
h 2 = ( 1 − 0.593 ) ⋅ 0.086 + 0.593 ⋅ 0.310 ≈ 0.229 h_2 = (1 - 0.593) \cdot 0.086 + 0.593 \cdot 0.310 \approx 0.229 h2=(1−0.593)⋅0.086+0.593⋅0.310≈0.229
总结
通过这个具体的数值例子,我们可以看到 GRU 如何通过重置门和更新门来控制信息的流动,从而在序列建模中捕捉长时间范围内的依赖关系。相比于 LSTM,GRU 结构更简单,计算效率更高,同时在很多任务上性能与 LSTM 相近。这使得 GRU 在处理序列数据时成为一种有效的选择。