Pytorch(1.4.0+):GRU原理及实现

背景

GRU是循环神经网络中一个非常具有性价比的工具,学习了解是非常有必要的。在本博客中,就将简要介绍一下GRU的原理及其使用Pytorch实现的具体代码。


参考资料

https://pytorch.org/docs/stable/generated/torch.nn.GRUCell.html#torch.nn.GRUCell
https://github.com/YoYo000/MVSNet/blob/master/mvsnet/convgru.py
https://zhuanlan.zhihu.com/p/32481747


1、GRU原理

网上讲解GRU原理的博客实在是太多了,看的人都有点头晕目眩。对于一个合格的程序员来说,其实这些更多看起来花里胡哨的,反而导致搞不清楚啥情况。

1.1 GRU公式

所以在本博客中,倾向首先用PyTorch文档中简单扼要的公式来展示,如下图所示,其中xh是输入,h'是输出,其他的Wb是各种权重,另外的就是一些中间结果了。
到了这一步,对于程序员来说,实现就没啥问题了,后边的代码基本上就是对此公式的一个具体展示。
在这里插入图片描述
额外提醒:上述公式的最后一行,有时候也可以写成h'=z*n+(1-z)*h,不过显然这两个是等价的。

1.2 GRU公式背后

GRU的原理其实是非常推荐知乎上的人人都能看懂的GRU,用一个清楚明了的图展示了GRU的整个过程,这里借用一下其中的图片,并展示GRU的各个流程。
在这里插入图片描述

1.2.1 计算reset,用于选择合适的信息

对于GRU而言,第一件事情就是计算reset参数(对应公式1),用于从历史状态 h t − 1 h^{t-1} ht1和当前输入 x t x^t xt中选择合适的信息进行长期记忆(对应公式3)。

1.2.2 计算update,用于保留需要的信息

第二件事情是计算update参数(对应公式2),将选择出的长期记忆进行保留,并遗忘一定部分的历史记忆(对应公式4)。


2、GRUCell实现

GRUCell是GRU的基本结构,本节简单介绍一下,并做代码实现。
1、基于nn.Linear的GRU实现

# https://github.com/emadRad/lstm-gru-pytorch/blob/master/lstm_gru.ipynb
class GRUCell(nn.Module):

    """
    An implementation of GRUCell.

    """

    def __init__(self, input_size, hidden_size, bias=True):
        super(GRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.x2h = nn.Linear(input_size, 3 * hidden_size, bias=bias)
        self.h2h = nn.Linear(hidden_size, 3 * hidden_size, bias=bias)
        self.reset_parameters()



    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)
    
    def forward(self, x, hidden):
        
        x = x.view(-1, x.size(1))
        
        gate_x = self.x2h(x) 
        gate_h = self.h2h(hidden)
        
        gate_x = gate_x.squeeze()
        gate_h = gate_h.squeeze()
        
        i_r, i_i, i_n = gate_x.chunk(3, 1)
        h_r, h_i, h_n = gate_h.chunk(3, 1)
        
        # 公式1
        resetgate = F.sigmoid(i_r + h_r)
        # 公式2
        inputgate = F.sigmoid(i_i + h_i)
        # 公式3
        newgate = F.tanh(i_n + (resetgate * h_n))
        # 公式4,不过稍微调整了一下公式形式
        hy = newgate + inputgate * (hidden - newgate)
        
        
        return hy

2、基于nn.Conv2d的GRU实现

class GRUConvCell(nn.Module):

    def __init__(self, input_channel, output_channel):

        super(GRUConvCell, self).__init__()

        # filters used for gates
        gru_input_channel = input_channel + output_channel
        self.output_channel = output_channel

        self.gate_conv = nn.Conv2d(gru_input_channel, output_channel * 2, kernel_size=3, padding=1)
        self.reset_gate_norm = nn.GroupNorm(1, output_channel, 1e-6, True)
        self.update_gate_norm = nn.GroupNorm(1, output_channel, 1e-6, True)

        # filters used for outputs
        self.output_conv = nn.Conv2d(gru_input_channel, output_channel, kernel_size=3, padding=1)
        self.output_norm = nn.GroupNorm(1, output_channel, 1e-6, True)

        self.activation = nn.Tanh()

	# 公式1,2
    def gates(self, x, h):

        # x = N x C x H x W
        # h = N x C x H x W

        # c = N x C*2 x H x W
        c = torch.cat((x, h), dim=1)
        f = self.gate_conv(c)

        # r = reset gate, u = update gate
        # both are N x O x H x W
        C = f.shape[1]
        r, u = torch.split(f, C // 2, 1)

        rn = self.reset_gate_norm(r)
        un = self.update_gate_norm(u)
        rns = torch.sigmoid(rn)
        uns = torch.sigmoid(un)
        return rns, uns

    # 公式3
    def output(self, x, h, r, u):

        f = torch.cat((x, r * h), dim=1)
        o = self.output_conv(f)
        on = self.output_norm(o)
        return on

    def forward(self, x, h = None):

        N, C, H, W = x.shape
        HC = self.output_channel
        if(h is None):
            h = torch.zeros((N, HC, H, W), dtype=torch.float, device=x.device)
        r, u = self.gates(x, h)
        o = self.output(x, h, r, u)
        y = self.activation(o)
	    
	    # 公式4
        return u * h + (1 - u) * y

额外提醒:上述两个实现都没有输出,当前主要是GRU并没有输出门,通常的做法是再添加了一个计算单元(例如nn.Linear)将公式4结果转换成输出。


3、GRU网络示例

本节简单介绍一个GRU网络

class GRUNet(nn.Module):

    def __init__(self, hidden_size=64):

        super(GRUNet,self).__init__()

        self.gru_1 = GRUConvCell(input_channel=4,          output_channel=hidden_size)
        self.gru_2 = GRUConvCell(input_channel=hidden_size,output_channel=hidden_size)
        self.gru_3 = GRUConvCell(input_channel=hidden_size,output_channel=hidden_size)

        self.fc = nn.Conv2d(in_channels=hidden_size,out_channels=1,kernel_size=3,padding=1)

    def forward(self, x, h):

        if h is None:
            h = [None,None,None]

        h1 = self.gru_1( x,h[0])
        h2 = self.gru_2(h1,h[1])
        h3 = self.gru_3(h2,h[2])

        o = self.fc(h3)

        return o,[h1,h2,h3]
 

if __name__ == '__main__':

    from utils import *
    
    device = 'cuda'

    x = torch.rand(1,1,10,20).to(device)

    grunet=GRUNet()
    grunet=grunet.to(device)
    grunet.eval()

    h = None
    o,h_n = grunet(x,h)


4、结论

上述简单总结了一下GRU的基本结构。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值