【Pytorch】RNN网络中pack和pad操作实践

详解RNN网络中文本的pack和pad操作中曾介绍了packpad的相关机制和相关函数的基本用法,本文在前文基础上,通过实例来进一步演示文本数据张量在整个计算过程中的变化。

1. 不进行pack操作的示例
import torch
import torch.nn as nn

a = torch.tensor([[1,0], [4,5]])    # B*L

torch.manual_seed(0)
ebd = nn.Embedding(10,3,padding_idx=0)  # 显示进行embedding的mask
lstm = nn.LSTM(input_size=3, hidden_size=3,batch_first=True)
embedding = ebd(a)
out, (h, c) = lstm(embedding)

embeding的结果中体现了padding_idx的作用。
在这里插入图片描述
但是由于RNN的递归式结构,其隐藏层结果不光取决于此时的输入,还取决于上一时间步的隐藏层,因此此时padding位置处的token的隐藏层结果仍有值(见下面的输出结果,分别表示outh张量)
在这里插入图片描述
在这里插入图片描述
由上可见,在不进行pack操作时进行RNN计算,不光增大了计算量,而且其输出结果并非可真实采信的。当然,在后续网络中,可通过人工的mask进行数据的屏蔽,但其比较费力,仍建议采用如下的pack操作。

2. 进行pack操作的示例
import torch
import torch.nn as nn

a = torch.tensor([[1,0], [4,5]])    # B*L

torch.manual_seed(0)
ebd = nn.Embedding(10,3)   # 此时完全可以不设置padding_index

lstm = nn.LSTM(input_size=3, hidden_size=3, batch_first=True)
embedding = ebd(a)
embedding_pack = pack_padded_sequence(embedding, lengths=(a!=0).sum(dim=-1), enforce_sorted=False, batch_first=True)

out, (h, c) = lstm(embedding_pack)

out2 = pad_packed_sequence(out2, batch_first=True)

embedding_pack的结果如下::
在这里插入图片描述
注意在最近版本的pytorch中,pack_padded_sequence函数添加了enforce_sorted参数,因此并不一定要对batch内的长度进行降序排序。

out的结果也是个PackedSequence对象:
在这里插入图片描述
我们将其利用pad_packed_sequence展开:
在这里插入图片描述
返回值为一个元组,第一个元素为展开的输出张量,其中padding位置处的out数据变为了0,第二个元素为各序列的实际文本长度。

对照下h,可见其截断到非padding位置处,与我们的需要相符和。
在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值