背景
GRU是循环神经网络中一个非常具有性价比的工具,学习了解是非常有必要的。在本博客中,就将简要介绍一下GRU的原理及其使用Pytorch实现的具体代码。
参考资料
https://pytorch.org/docs/stable/generated/torch.nn.GRUCell.html#torch.nn.GRUCell
https://github.com/YoYo000/MVSNet/blob/master/mvsnet/convgru.py
https://zhuanlan.zhihu.com/p/32481747
1、GRU原理
网上讲解GRU原理的博客实在是太多了,看的人都有点头晕目眩。对于一个合格的程序员来说,其实这些更多看起来花里胡哨的,反而导致搞不清楚啥情况。
1.1 GRU公式
所以在本博客中,倾向首先用PyTorch文档中简单扼要的公式来展示,如下图所示,其中x
和h
是输入,h'
是输出,其他的W
和b
是各种权重,另外的就是一些中间结果了。
到了这一步,对于程序员来说,实现就没啥问题了,后边的代码基本上就是对此公式的一个具体展示。
额外提醒
:上述公式的最后一行,有时候也可以写成h'=z*n+(1-z)*h
,不过显然这两个是等价的。
1.2 GRU公式背后
GRU的原理其实是非常推荐知乎上的人人都能看懂的GRU,用一个清楚明了的图展示了GRU的整个过程,这里借用一下其中的图片,并展示GRU的各个流程。
1.2.1 计算reset,用于选择合适的信息
对于GRU而言,第一件事情就是计算reset参数(对应公式1),用于从历史状态 h t − 1 h^{t-1} ht−1和当前输入 x t x^t xt中选择合适的信息进行长期记忆(对应公式3)。
1.2.2 计算update,用于保留需要的信息
第二件事情是计算update参数(对应公式2),将选择出的长期记忆进行保留,并遗忘一定部分的历史记忆(对应公式4)。
2、GRUCell实现
GRUCell是GRU的基本结构,本节简单介绍一下,并做代码实现。
1、基于nn.Linear
的GRU实现
# https://github.com/emadRad/lstm-gru-pytorch/blob/master/lstm_gru.ipynb
class GRUCell(nn.Module):
"""
An implementation of GRUCell.
"""
def __init__(self, input_size, hidden_size, bias=True):
super(GRUCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.x2h = nn.Linear(input_size, 3 * hidden_size, bias=bias)
self.h2h = nn.Linear(hidden_size, 3 * hidden_size, bias=bias)
self.reset_parameters()
def reset_parameters(self):
std = 1.0 / math.sqrt(self.hidden_size)
for w in self.parameters():
w.data.uniform_(-std, std)
def forward(self, x, hidden):
x = x.view(-1, x.size(1))
gate_x = self.x2h(x)
gate_h = self.h2h(hidden)
gate_x = gate_x.squeeze()
gate_h = gate_h.squeeze()
i_r, i_i, i_n = gate_x.chunk(3, 1)
h_r, h_i, h_n = gate_h.chunk(3, 1)
# 公式1
resetgate = F.sigmoid(i_r + h_r)
# 公式2
inputgate = F.sigmoid(i_i + h_i)
# 公式3
newgate = F.tanh(i_n + (resetgate * h_n))
# 公式4,不过稍微调整了一下公式形式
hy = newgate + inputgate * (hidden - newgate)
return hy
2、基于nn.Conv2d
的GRU实现
class GRUConvCell(nn.Module):
def __init__(self, input_channel, output_channel):
super(GRUConvCell, self).__init__()
# filters used for gates
gru_input_channel = input_channel + output_channel
self.output_channel = output_channel
self.gate_conv = nn.Conv2d(gru_input_channel, output_channel * 2, kernel_size=3, padding=1)
self.reset_gate_norm = nn.GroupNorm(1, output_channel, 1e-6, True)
self.update_gate_norm = nn.GroupNorm(1, output_channel, 1e-6, True)
# filters used for outputs
self.output_conv = nn.Conv2d(gru_input_channel, output_channel, kernel_size=3, padding=1)
self.output_norm = nn.GroupNorm(1, output_channel, 1e-6, True)
self.activation = nn.Tanh()
# 公式1,2
def gates(self, x, h):
# x = N x C x H x W
# h = N x C x H x W
# c = N x C*2 x H x W
c = torch.cat((x, h), dim=1)
f = self.gate_conv(c)
# r = reset gate, u = update gate
# both are N x O x H x W
C = f.shape[1]
r, u = torch.split(f, C // 2, 1)
rn = self.reset_gate_norm(r)
un = self.update_gate_norm(u)
rns = torch.sigmoid(rn)
uns = torch.sigmoid(un)
return rns, uns
# 公式3
def output(self, x, h, r, u):
f = torch.cat((x, r * h), dim=1)
o = self.output_conv(f)
on = self.output_norm(o)
return on
def forward(self, x, h = None):
N, C, H, W = x.shape
HC = self.output_channel
if(h is None):
h = torch.zeros((N, HC, H, W), dtype=torch.float, device=x.device)
r, u = self.gates(x, h)
o = self.output(x, h, r, u)
y = self.activation(o)
# 公式4
return u * h + (1 - u) * y
额外提醒
:上述两个实现都没有输出,当前主要是GRU并没有输出门,通常的做法是再添加了一个计算单元(例如nn.Linear
)将公式4结果转换成输出。
3、GRU网络示例
本节简单介绍一个GRU网络
class GRUNet(nn.Module):
def __init__(self, hidden_size=64):
super(GRUNet,self).__init__()
self.gru_1 = GRUConvCell(input_channel=4, output_channel=hidden_size)
self.gru_2 = GRUConvCell(input_channel=hidden_size,output_channel=hidden_size)
self.gru_3 = GRUConvCell(input_channel=hidden_size,output_channel=hidden_size)
self.fc = nn.Conv2d(in_channels=hidden_size,out_channels=1,kernel_size=3,padding=1)
def forward(self, x, h):
if h is None:
h = [None,None,None]
h1 = self.gru_1( x,h[0])
h2 = self.gru_2(h1,h[1])
h3 = self.gru_3(h2,h[2])
o = self.fc(h3)
return o,[h1,h2,h3]
if __name__ == '__main__':
from utils import *
device = 'cuda'
x = torch.rand(1,1,10,20).to(device)
grunet=GRUNet()
grunet=grunet.to(device)
grunet.eval()
h = None
o,h_n = grunet(x,h)
4、结论
上述简单总结了一下GRU的基本结构。