这两个函数主要是用在RNN中处理变长序列的
我们来看下面的例子:
import torch
inputs=torch.LongTensor([[1,2,0,0],[7,5,4,0],[9,0,0,0],[2,5,8,7]])
inputs_length=torch.LongTensor([2,3,1,4])
其中inputs是我们假设的输入数据,值是0的位置代表该位置是pad的。
inputs_length是输入数据的实际长度。相当于说我们输入了4个句子,第一个句子长度为2,第二个句子长度为3,第三个句子长度为1,第四个句子长度是4
由于句子中有pad的位置,这些位置是不应该计算的,所以需要pack_padded_sequence,将pad位置去掉,它的原理是将输入数据排序,句子长的在前面
接下来我们把输入句子排序
print(inputs)
sorted_inputs_length,sorted_sequence_ids=torch.sort(inputs_length,descending=True)
print(sorted_inputs_length,sorted_sequence_ids)
sorted_inputs=inputs.index_select(0,sorted_sequence_ids)
print(sorted_inputs)
其中 sorted_sequence_ids是指每一个句子在整个batch_size中的索引,[3,1,0,2]是指
- sorted_inputs中的第一个句子是原批次句子中的第四个句子
- sorted_inputs中的第二个句子是原批次句子中的第二个句子
- sorted_inputs中的第三个句子是原批次句子中的第一个句子
- sorted_inputs中的第四个句子是原批次句子中的第三个句子
我们从图中也可以看到sorted_inputs和inputs之间的关系
接下来我们把排好序的sorted_inputs传给pack_padded_sequence
packed_inputs=torch.nn.utils.rnn.pack_padded_sequence(input=sorted_inputs,lengths=sorted_inputs_length,batch_first=True,enforce_sorted=True)
在打印结果之前我们要知道pack_padded_sequence具体是怎么做的。
我们可以看到sorted_inputs里面的每一个值
- [2,5,8,7]代表的是第一个句子,有4个时间步
- [7,5,4,0]代表的是第二个句子,有3个时间步
- [1,2,0,0]代表的是第三个句子,有2个时间步
- [9,0,0,0]代表最后一个句子,只有一个时间步
pack_padded_sequence统计每一个时间步对应有多少个batch,也就是说以每一列为单位,然后展开
具体的就是
- 第一个时间步有四个batch,即[2,7,1,9]
- 第二个时间步有三个batch,即[5,5,2]
- 第三个时间步有两个batch,即[8,4]
- 第四个时间步有一个batch,即[7]
所以最后的结果为[2,7,1,9,5,5,2,8,4,7]
结果正是这样,其中的batch_sizes代表的就是每一个时间步有多少的batch
接下来就是把pack_inputs输入给RNN。
由于这里只是展示两个函数的作用,就不列出RNN这一步了
接下来我们把packed_inputs还原回补全的样子
padded_inputs,padded_inputs_length=torch.nn.utils.rnn.pad_packed_sequence(packed_inputs,batch_first=True)
print(padded_inputs)
print(padded_inputs_length)
我们可以看到padded_inputs就是sorted_inputs,padded_inputs_length就是sorted_inputs_length
最后我们把padded_inputs也就是sorted_inputs还原回原来的顺序
print(sorted_sequence_ids)
_,original_sequence_ids=torch.sort(sorted_sequence_ids)
print(original_sequence_ids)
original_inputs=sorted_inputs.index_select(0,original_sequence_ids)
print(original_inputs)