先提供一个官网解读
https://pytorch.org/docs/1.0.1/nn.html#torch.nn.utils.rnn.pack_padded_sequence
- 在使用深度学习特别是LSTM进行文本分析时,经常会遇到文本长度不一样的情况,此时就需要对同一个batch中的不同文本使用padding的方式进行文本长度对齐,方便将训练数据输入到LSTM模型进行训练,同时为了保证模型训练的精度,应该同时告诉LSTM相关padding的情况,此时,pytorch中的pack_padded_sequence就有了用武之地。
- 通常pading的位置向量都是0,我们需要使用pack_padded_sequence() 把数据压紧,即去掉pading的部分,减少冗余。然后再输入网络中,如lstm等。通过网络后的结果也是压紧的,需要通过pad_packed_sequence()还原。
torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False)
input (Tensor) – padded batch of variable length sequences.
lengths (Tensor) – list of sequences lengths of each batch element.
batch_first (bool, optional) – if True, the input is expected in B x T x * format.
- pad_packed_sequence 解压
代码示例
import torch
import torch.nn as nn
import torch.nn.utils as utils
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# 定义一个双向lstm网络层
lstm = nn.LSTM(4, 100, num_layers=1, batch_first=True, bidirectional=True)
# 定义一个有padding的序列数据,也就是有冗余的0
x = torch.tensor([[[1,2,3,4],
[2,3,4,5],
[2,5,6,0]],
[[1,2,1,1],
[1,6,7,9],
[0,0,0,0]],
[[1,2,3,4],
[1,1,1,1],
[0,0,0,0]],
[[1,2,3,4],
[0,0,0,0],
[0,0,0,0]],
])
x = x.float()
# 压紧数据,去掉冗余
packed = pack_padded_sequence(x, torch.tensor([3, 2, 2,1]), batch_first=True) # 打包,压缩
# 通过lstm进行计算,得到的结果也是压紧的
output, hidden = lstm(packed)
# 解压
encoder_outputs, lenghts = pad_packed_sequence(output, batch_first=True)