【pytorch】nn.GRU的使用

官方文档在这里

GRU具体不做介绍了,本篇只做pytorch的API使用介绍.


torch.nn.GRU(*args, **kwargs)

公式

公式
下面公式忽略bias,由于输入向量的长度和隐藏层特征值长度不一致,所以每个公式的W都按x和h分开。这跟理论公式部分有一些具体的实践上区别。

  • reset gate, 重置门
    r t = σ ( W i r x t + W h r h t − 1 ) r_t = \sigma(W_{ir}x_t+W_{hr}h_{t-1}) rt=σ(Wirxt+Whrht1) GRU里的参数是 W i r W_{ir} Wir W i r W_{ir} Wir
  • update gate,更新门
    z t = σ ( W i z x t + W h z h t − 1 ) z_t = \sigma(W_{iz}x_t+W_{hz}h_{t-1}) zt=σ(Wizxt+Whzht1) GRU里的参数是 W i z W_{iz} Wiz W h z W_{hz} Whz
  • 更新状态阈值
    n t = t a n h ( W i n x t + r t ( W h n h t − 1 ) ) n_t = tanh (W_{in}x_t+r_t(W_{hn} h_{t-1})) nt=tanh(Winxt+rt(Whnht1)) GRU里的参数是 W i n W_{in} Win W h n W_{hn} Whn
    这里同LSTM里的 g ( t ) g(t) g(t)函数,只是多了重置门对 h t − 1 h_{t-1} ht1的影响
  • 更新 h t h_t ht
    h t = ( 1 − z t ) n t + z t h t − 1 h_t = (1-z_t)n_t + z_t h_{t-1} ht=(1zt)nt+ztht1

GRU Cell图片

所以从输入张量和隐藏层张量来说,一共有两组参数(忽略bias参数)

  1. input 组 { W i r W_{ir} Wir W i z W_{iz} Wiz W i n W_{in} Win}
  2. hidden组 { W i r W_{ir} Wir W h z W_{hz} Whz W h n W_{hn} Whn }

官网参数
因为hidden size为隐藏层特征输出长度,所以每个参数第一维度都是hidden size;然后每一组是把3个张量按照第一维度拼接,所以要乘以3

举例代码

from torch import nn

gru = nn.GRU(input_size=3, hidden_size=5, num_layers=1, bias=False)

print('weight_ih_l0.shape = ', gru.weight_ih_l0.shape, ', weight_hh_l0.shape = ' , gru.weight_hh_l0.shape)

样例代码

双向GRU

如果要实现双向的GRU,只需要增加参数bidirectional=True

但是参数并没有增加。

from torch import nn

gru = nn.GRU(input_size=3, hidden_size=5, num_layers=1, bidirectional=True, bias=False)

print('weight_ih_l0.shape = ', gru.weight_ih_l0.shape, ', weight_ih_l0_reverse.shape = ', gru.weight_ih_l0_reverse.shape,
      '\nweight_hh_l0.shape = ' , gru.weight_hh_l0.shape, ', weight_hh_l0_reverse.shape = ', gru.weight_hh_l0_reverse.shape)

双向GRU

多层的概念

可以参考这里 https://blog.csdn.net/mimiduck/article/details/119975080

  • 8
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值