介绍
门控循环单元(Gated Recurrent Unit, GRU)是一种循环神经网络(Recurrent Neural Network, RNN)的变体,由Cho等人在2014年提出。相比于传统的RNN,GRU引入了门控机制,使得网络能够更好地捕捉长期依赖性,同时减少了梯度消失的问题。
本文将介绍GRU的方法历史、优点以及与其他方法的不同之处,并给出详细的理论推导过程和计算步骤。最后,我们将用PyTorch给出一个GRU的例子。
方法历史
在RNN中,每个时间步的输出都依赖于前一时刻的状态。然而,由于梯度消失的问题,传统的RNN很难处理长期依赖性。因此,一些方法被提出,如长短时记忆网络(Long Short-Term Memory, LSTM)和门控循环单元(GRU)。
GRU是由Cho等人在2014年提出的,它引入了门控机制,使得网络能够更好地捕捉长期依赖性,同时减少了梯度消失的问题。GRU的设计灵感来自于LSTM,但是GRU只有两个门(重置门和更新门),而LSTM有三个门(输入门、遗忘门和输出门),因此GRU的计算量更小。
方法优点
相比于传统的RNN,GRU有以下优点:
- GRU引入了门控机制,使得网络能够更好地捕捉长期依赖性,同时减少了梯度消失的问题。
- GRU的计算量比LSTM更小,因为它只有两个门(重置门和更新门)。
与其他方法的不同之处
相比于LSTM,GRU只有两个门(重置门和更新门),而LSTM有三个门(输入门、遗忘门和输出门)。因此,GRU的计算量更小,但是LSTM的表现可能更好。
理论推导过程
GRU的公式如下:
r t = σ ( W r x t + U r h t − 1 + b r ) z t = σ ( W z x t + U z h t − 1 + b z ) h ~ t = tanh ( W x t + r t ⊙ U h t − 1 + b ) h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ~ t \begin{aligned} r_t &= \sigma(W_r x_t + U_r h_{t-1} + b_r) \\ z_t &= \sigma(W_z x_t + U_z h_{t-1} + b_z) \\ \tilde{h}_t &= \tanh(W x_t + r_t \odot U h_{t-1} + b) \\ h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \end{aligned} rtzth~tht=σ(Wrxt+Urht−1+br)=σ(Wzxt+Uzht−1+bz)=tanh(Wxt+rt⊙Uht−1+b)=(1−zt)⊙ht−1+zt⊙h~t
其中, x t x_t xt是输入, h t h_t ht是输出, r t r_t rt是重置门, z t z_t zt是更新门, h ~ t \tilde{h}_t h~t是候选隐藏状态, ⊙ \odot ⊙表示逐元素相乘, σ \sigma σ表示sigmoid函数, tanh \tanh tanh表示双曲正切函数, W r , U r , b r , W z , U z , b z , W , U , b W_r, U_r, b_r, W_z, U_z, b_z, W, U, b Wr,Ur,br,Wz,Uz,bz,W,U,b是可学习的参数。
我们可以将上述公式分解为以下步骤:
- 计算重置门 r t r_t rt:
r t = σ ( W r x t + U r h t − 1 + b r ) r_t = \sigma(W_r x_t + U_r h_{t-1} + b_r) rt=σ(Wrxt+Urht−1+br)
其中, W r , U r , b r W_r, U_r, b_r Wr,Ur,br是可学习的参数。
- 计算更新门 z t z_t zt:
z t = σ ( W z x t + U z h t − 1 + b z ) z_t = \sigma(W_z x_t + U_z h_{t-1} + b_z) zt=σ(Wzxt+Uzht−1+bz)
其中, W z , U z , b z W_z, U_z, b_z Wz,Uz,bz是可学习的参数。
- 计算候选隐藏状态 h ~ t \tilde{h}_t h~t:
h ~ t = tanh ( W x t + r t ⊙ U h t − 1 + b ) \tilde{h}_t = \tanh(W x_t + r_t \odot U h_{t-1} + b) h~t=tanh(Wxt+rt⊙Uht−1+b)
其中, W , U , b W, U, b W,U,b是可学习的参数。
- 计算隐藏状态 h t h_t ht:
h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ~ t h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t ht=(1−zt)⊙ht−1+zt⊙h~t
其中, ⊙ \odot ⊙表示逐元素相乘。
计算步骤
下面我们将给出一个GRU的计算步骤:
-
初始化 h 0 h_0 h0为零向量。
-
对于每个时间步 t t t,执行以下操作:
-
计算重置门 r t r_t rt:
r t = σ ( W r x t + U r h t − 1 + b r ) r_t = \sigma(W_r x_t + U_r h_{t-1} + b_r) rt=σ(Wrxt+Urht−1+br)
-
计算更新门 z t z_t zt:
z t = σ ( W z x t + U z h t − 1 + b z ) z_t = \sigma(W_z x_t + U_z h_{t-1} + b_z) zt=σ(Wzxt+Uzht−1+bz)
-
计算候选隐藏状态 h ~ t \tilde{h}_t h~t:
h ~ t = tanh ( W x t + r t ⊙ U h t − 1 + b ) \tilde{h}_t = \tanh(W x_t + r_t \odot U h_{t-1} + b) h~t=tanh(Wxt+rt⊙Uht−1+b)
-
计算隐藏状态 h t h_t ht:
h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ~ t h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t ht=(1−zt)⊙ht−1+zt⊙h~t
-
-
返回所有时间步的隐藏状态 h 1 , h 2 , . . . , h T h_1, h_2, ..., h_T h1,h2,...,hT。
PyTorch实现
下面我们将用PyTorch给出一个GRU的例子:
import torch
import torch.nn as nn
# 定义GRU模型
class GRU(nn.Module):
def __init__(self, input_size, hidden_size):
super(GRU, self).__init__()
self.hidden_size = hidden_size
self.reset_gate = nn.Linear(input_size + hidden_size, hidden_size)
self.update_gate = nn.Linear(input_size + hidden_size, hidden_size)
self.candidate = nn.Linear(input_size + hidden_size, hidden_size)
def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
reset = torch.sigmoid(self.reset_gate(combined))
update = torch.sigmoid(self.update_gate(combined))
combined = torch.cat((input, reset * hidden), 1)
candidate = torch.tanh(self.candidate(combined))
output = update * hidden + (1 - update) * candidate
return output
# 定义输入和隐藏状态
input_size = 3
hidden_size = 2
input = torch.randn(5, 3)
hidden = torch.zeros(1, 2)
# 初始化GRU模型
gru = GRU(input_size, hidden_size)
# 计算输出
output = []
for i in range(input.shape[0]):
hidden = gru(input[i], hidden)
output.append(hidden)
output = torch.cat(output, 0)
print(output)
结构图
下面是GRU的结构图
其中, x t x_t xt是输入, h t h_t ht是输出, r t r_t rt是重置门, z t z_t zt是更新门, h ~ t \tilde{h}_t h~t是候选隐藏状态, W r , U r , b r , W z , U z , b z , W , U , b W_r, U_r, b_r, W_z, U_z, b_z, W, U, b Wr,Ur,br,Wz,Uz,bz,W,U,b是可学习的参数。