门控循环单元(GRU)

37 篇文章 1 订阅
27 篇文章 1 订阅

介绍

门控循环单元(Gated Recurrent Unit, GRU)是一种循环神经网络(Recurrent Neural Network, RNN)的变体,由Cho等人在2014年提出。相比于传统的RNN,GRU引入了门控机制,使得网络能够更好地捕捉长期依赖性,同时减少了梯度消失的问题。

本文将介绍GRU的方法历史、优点以及与其他方法的不同之处,并给出详细的理论推导过程和计算步骤。最后,我们将用PyTorch给出一个GRU的例子。

方法历史

在RNN中,每个时间步的输出都依赖于前一时刻的状态。然而,由于梯度消失的问题,传统的RNN很难处理长期依赖性。因此,一些方法被提出,如长短时记忆网络(Long Short-Term Memory, LSTM)和门控循环单元(GRU)。

GRU是由Cho等人在2014年提出的,它引入了门控机制,使得网络能够更好地捕捉长期依赖性,同时减少了梯度消失的问题。GRU的设计灵感来自于LSTM,但是GRU只有两个门(重置门和更新门),而LSTM有三个门(输入门、遗忘门和输出门),因此GRU的计算量更小。

方法优点

相比于传统的RNN,GRU有以下优点:

  • GRU引入了门控机制,使得网络能够更好地捕捉长期依赖性,同时减少了梯度消失的问题。
  • GRU的计算量比LSTM更小,因为它只有两个门(重置门和更新门)。

与其他方法的不同之处

相比于LSTM,GRU只有两个门(重置门和更新门),而LSTM有三个门(输入门、遗忘门和输出门)。因此,GRU的计算量更小,但是LSTM的表现可能更好。

理论推导过程

GRU的公式如下:

r t = σ ( W r x t + U r h t − 1 + b r ) z t = σ ( W z x t + U z h t − 1 + b z ) h ~ t = tanh ⁡ ( W x t + r t ⊙ U h t − 1 + b ) h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ~ t \begin{aligned} r_t &= \sigma(W_r x_t + U_r h_{t-1} + b_r) \\ z_t &= \sigma(W_z x_t + U_z h_{t-1} + b_z) \\ \tilde{h}_t &= \tanh(W x_t + r_t \odot U h_{t-1} + b) \\ h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \end{aligned} rtzth~tht=σ(Wrxt+Urht1+br)=σ(Wzxt+Uzht1+bz)=tanh(Wxt+rtUht1+b)=(1zt)ht1+zth~t

其中, x t x_t xt是输入, h t h_t ht是输出, r t r_t rt是重置门, z t z_t zt是更新门, h ~ t \tilde{h}_t h~t是候选隐藏状态, ⊙ \odot 表示逐元素相乘, σ \sigma σ表示sigmoid函数, tanh ⁡ \tanh tanh表示双曲正切函数, W r , U r , b r , W z , U z , b z , W , U , b W_r, U_r, b_r, W_z, U_z, b_z, W, U, b Wr,Ur,br,Wz,Uz,bz,W,U,b是可学习的参数。

我们可以将上述公式分解为以下步骤:

  1. 计算重置门 r t r_t rt

r t = σ ( W r x t + U r h t − 1 + b r ) r_t = \sigma(W_r x_t + U_r h_{t-1} + b_r) rt=σ(Wrxt+Urht1+br)

其中, W r , U r , b r W_r, U_r, b_r Wr,Ur,br是可学习的参数。

  1. 计算更新门 z t z_t zt

z t = σ ( W z x t + U z h t − 1 + b z ) z_t = \sigma(W_z x_t + U_z h_{t-1} + b_z) zt=σ(Wzxt+Uzht1+bz)

其中, W z , U z , b z W_z, U_z, b_z Wz,Uz,bz是可学习的参数。

  1. 计算候选隐藏状态 h ~ t \tilde{h}_t h~t

h ~ t = tanh ⁡ ( W x t + r t ⊙ U h t − 1 + b ) \tilde{h}_t = \tanh(W x_t + r_t \odot U h_{t-1} + b) h~t=tanh(Wxt+rtUht1+b)

其中, W , U , b W, U, b W,U,b是可学习的参数。

  1. 计算隐藏状态 h t h_t ht

h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ~ t h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t ht=(1zt)ht1+zth~t

其中, ⊙ \odot 表示逐元素相乘。

计算步骤

下面我们将给出一个GRU的计算步骤:

  1. 初始化 h 0 h_0 h0为零向量。

  2. 对于每个时间步 t t t,执行以下操作:

    1. 计算重置门 r t r_t rt

      r t = σ ( W r x t + U r h t − 1 + b r ) r_t = \sigma(W_r x_t + U_r h_{t-1} + b_r) rt=σ(Wrxt+Urht1+br)

    2. 计算更新门 z t z_t zt

      z t = σ ( W z x t + U z h t − 1 + b z ) z_t = \sigma(W_z x_t + U_z h_{t-1} + b_z) zt=σ(Wzxt+Uzht1+bz)

    3. 计算候选隐藏状态 h ~ t \tilde{h}_t h~t

      h ~ t = tanh ⁡ ( W x t + r t ⊙ U h t − 1 + b ) \tilde{h}_t = \tanh(W x_t + r_t \odot U h_{t-1} + b) h~t=tanh(Wxt+rtUht1+b)

    4. 计算隐藏状态 h t h_t ht

      h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ~ t h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t ht=(1zt)ht1+zth~t

  3. 返回所有时间步的隐藏状态 h 1 , h 2 , . . . , h T h_1, h_2, ..., h_T h1,h2,...,hT

PyTorch实现

下面我们将用PyTorch给出一个GRU的例子:

import torch
import torch.nn as nn

# 定义GRU模型
class GRU(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(GRU, self).__init__()
        self.hidden_size = hidden_size
        self.reset_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.update_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.candidate = nn.Linear(input_size + hidden_size, hidden_size)

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        reset = torch.sigmoid(self.reset_gate(combined))
        update = torch.sigmoid(self.update_gate(combined))
        combined = torch.cat((input, reset * hidden), 1)
        candidate = torch.tanh(self.candidate(combined))
        output = update * hidden + (1 - update) * candidate
        return output

# 定义输入和隐藏状态
input_size = 3
hidden_size = 2
input = torch.randn(5, 3)
hidden = torch.zeros(1, 2)

# 初始化GRU模型
gru = GRU(input_size, hidden_size)

# 计算输出
output = []
for i in range(input.shape[0]):
    hidden = gru(input[i], hidden)
    output.append(hidden)
output = torch.cat(output, 0)

print(output)

结构图

下面是GRU的结构图

x_t
W_r
h_t-1
U_r
r_t
b_r
W_z
U_z
z_t
b_z
W
U
tilde_h_t
h_t

其中, x t x_t xt是输入, h t h_t ht是输出, r t r_t rt是重置门, z t z_t zt是更新门, h ~ t \tilde{h}_t h~t是候选隐藏状态, W r , U r , b r , W z , U z , b z , W , U , b W_r, U_r, b_r, W_z, U_z, b_z, W, U, b Wr,Ur,br,Wz,Uz,bz,W,U,b是可学习的参数。

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值