PyTorch BiLSTM参数及packed形式的输出的组成

先回忆一下LSTM,直接上pytorch官网的公式截图:

它可以处理变长序列,每个rnn cell参数是一样的,共享,就是上面列出的那些W..。这里需要说明的是,PyTorch里面将W_{i*}统一放到了'weight_ih_l0'变量,将W_{h*}统一放到了'weight_hh_l0'变量。

 

BiLSTM包含一个从左到右和一个从右到左的并列的的序列计算。需要注意的是,两个方向使用的模型参数是不一样的,PyTorch里面这样表示:['weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0', 'weight_ih_l0_reverse', 'weight_hh_l0_reverse', 'bias_ih_l0_reverse', 'bias_hh_l0_reverse']。

还需要注意一个点:BiLSTM的输出output是最后一层两个方向结果的concat形式,可以使用output.view(seq_len, batch,num_directions, hidden_size)还原两个方向分别的结果,num_directions=0或者1 分别表示前向结果和反向结果。

并且,序列最左边一个结果y_1是由正向第一个结果(最左边)和反向最后一个结果(最左边)concat而成,以此类推,如下图所示。但是,参数h_n表示的是正向最后一个结果(最右边)和反向最后一个结果(最左边)(没有concat),也就是说反向的顺序是按照计算顺序排列的,跟实际位置顺序相反!

h_n包含的是全部的正向信息和反向信息,体现的是计算顺序的信息;但是output[:, -1, :]只包含了正向全部信息和反向第一个的信息,不全面,体现的更多的是序列位置信息。

Image result for BiLSTM

这是很容易引起混淆的问题,需搞清楚。

相关验证代码如下:

In [1]: import torch

In [2]: import torch.nn as nn

In [3]: net=nn.LSTM(3,4,bidirectional=True,batch_first=True)    # 隐层尺度为4

In [4]: x=torch.rand(1,5,3)    # 序列长度为5,输入尺度为3

In [5]: net(x)
Out[5]:
(tensor([[[-0.0788,  0.0382, -0.2163, -0.1578, -0.0292,  0.1069, -0.0143,
            0.1737],
          [-0.1368,  0.0638, -0.2676, -0.2137, -0.0240,  0.0565, -0.0175,
            0.1829],
          [-0.1758,  0.0794, -0.2617, -0.2254, -0.0237,  0.0470, -0.0173,
            0.1671],
          [-0.1961,  0.0550, -0.3815, -0.2830, -0.0395,  0.0477,  0.0075,
            0.0992],
          [-0.2055,  0.0889, -0.4298, -0.3324, -0.0534,  0.0210,  0.0222,
           -0.0028]]], grad_fn=<TransposeBackward0>),
 (tensor([[[-0.2055,  0.0889, -0.4298, -0.3324]],

          [[-0.0292,  0.1069, -0.0143,  0.1737]]], grad_fn=<StackBackward>),
  tensor([[[-0.3894,  0.2452, -0.8997, -0.6065]],

          [[-0.1179,  0.2379, -0.0273,  0.3738]]], grad_fn=<StackBackward>)))

In [6]: x1=x[:,:4,:]    # 去掉一些右边信息,但左起信息保留

In [7]: net(x1)    # 从结果比较,可以看出concat顺序是先正向再反向,h_n正向部分是正向最后计算结果
Out[7]:
(tensor([[[-0.0788,  0.0382, -0.2163, -0.1578, -0.0271,  0.1041, -0.0193,
            0.1708],
          [-0.1368,  0.0638, -0.2676, -0.2137, -0.0191,  0.0523, -0.0238,
            0.1790],
          [-0.1758,  0.0794, -0.2617, -0.2254, -0.0145,  0.0400, -0.0265,
            0.1614],
          [-0.1961,  0.0550, -0.3815, -0.2830, -0.0188,  0.0366, -0.0101,
            0.0908]]], grad_fn=<TransposeBackward0>),
 (tensor([[[-0.1961,  0.0550, -0.3815, -0.2830]],

          [[-0.0271,  0.1041, -0.0193,  0.1708]]], grad_fn=<StackBackward>),
  tensor([[[-0.4375,  0.1530, -0.6968, -0.5101]],

          [[-0.1089,  0.2317, -0.0369,  0.3664]]], grad_fn=<StackBackward>)))

In [8]: x2=x[:,2:,:]    # 去掉一些左边信息,但右起信息保留

In [9]: net(x2)    # 通过结果比较,发现concat是按照实际序列位置进行的,且h_n的反向部分是反向最后计算结果
Out[9]:
(tensor([[[-0.0444,  0.0530, -0.1410, -0.1179, -0.0237,  0.0470, -0.0173,
            0.1671],
          [-0.1135,  0.0405, -0.3151, -0.2288, -0.0395,  0.0477,  0.0075,
            0.0992],
          [-0.1452,  0.0820, -0.4029, -0.3136, -0.0534,  0.0210,  0.0222,
           -0.0028]]], grad_fn=<TransposeBackward0>),
 (tensor([[[-0.1452,  0.0820, -0.4029, -0.3136]],

          [[-0.0237,  0.0470, -0.0173,  0.1671]]], grad_fn=<StackBackward>),
  tensor([[[-0.2708,  0.2182, -0.7877, -0.5484]],

          [[-0.0945,  0.1083, -0.0384,  0.3695]]], grad_fn=<StackBackward>)))

 

  • 6
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值