Pytorch中的pack_padded_sequence与pad_packed_sequence

这两个函数主要是用在RNN中处理变长序列的

我们来看下面的例子:

import torch
inputs=torch.LongTensor([[1,2,0,0],[7,5,4,0],[9,0,0,0],[2,5,8,7]])
inputs_length=torch.LongTensor([2,3,1,4])

其中inputs是我们假设的输入数据,值是0的位置代表该位置是pad的。

inputs_length是输入数据的实际长度。相当于说我们输入了4个句子,第一个句子长度为2,第二个句子长度为3,第三个句子长度为1,第四个句子长度是4

由于句子中有pad的位置,这些位置是不应该计算的,所以需要pack_padded_sequence,将pad位置去掉,它的原理是将输入数据排序,句子长的在前面

接下来我们把输入句子排序

print(inputs)
sorted_inputs_length,sorted_sequence_ids=torch.sort(inputs_length,descending=True)
print(sorted_inputs_length,sorted_sequence_ids)
sorted_inputs=inputs.index_select(0,sorted_sequence_ids)
print(sorted_inputs)

在这里插入图片描述

其中 sorted_sequence_ids是指每一个句子在整个batch_size中的索引,[3,1,0,2]是指

  • sorted_inputs中的第一个句子是原批次句子中的第四个句子
  • sorted_inputs中的第二个句子是原批次句子中的第二个句子
  • sorted_inputs中的第三个句子是原批次句子中的第一个句子
  • sorted_inputs中的第四个句子是原批次句子中的第三个句子
    我们从图中也可以看到sorted_inputs和inputs之间的关系

接下来我们把排好序的sorted_inputs传给pack_padded_sequence

packed_inputs=torch.nn.utils.rnn.pack_padded_sequence(input=sorted_inputs,lengths=sorted_inputs_length,batch_first=True,enforce_sorted=True)

在打印结果之前我们要知道pack_padded_sequence具体是怎么做的。
我们可以看到sorted_inputs里面的每一个值

  • [2,5,8,7]代表的是第一个句子,有4个时间步
  • [7,5,4,0]代表的是第二个句子,有3个时间步
  • [1,2,0,0]代表的是第三个句子,有2个时间步
  • [9,0,0,0]代表最后一个句子,只有一个时间步

pack_padded_sequence统计每一个时间步对应有多少个batch,也就是说以每一列为单位,然后展开

具体的就是

  • 第一个时间步有四个batch,即[2,7,1,9]
  • 第二个时间步有三个batch,即[5,5,2]
  • 第三个时间步有两个batch,即[8,4]
  • 第四个时间步有一个batch,即[7]
    所以最后的结果为[2,7,1,9,5,5,2,8,4,7]

在这里插入图片描述
结果正是这样,其中的batch_sizes代表的就是每一个时间步有多少的batch

接下来就是把pack_inputs输入给RNN。
由于这里只是展示两个函数的作用,就不列出RNN这一步了
接下来我们把packed_inputs还原回补全的样子

padded_inputs,padded_inputs_length=torch.nn.utils.rnn.pad_packed_sequence(packed_inputs,batch_first=True)
print(padded_inputs)
print(padded_inputs_length)

在这里插入图片描述
我们可以看到padded_inputs就是sorted_inputs,padded_inputs_length就是sorted_inputs_length

最后我们把padded_inputs也就是sorted_inputs还原回原来的顺序

print(sorted_sequence_ids)
_,original_sequence_ids=torch.sort(sorted_sequence_ids)
print(original_sequence_ids)
original_inputs=sorted_inputs.index_select(0,original_sequence_ids)

print(original_inputs)

在这里插入图片描述

转了一圈回来了

  • 5
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值