已经有人解释的比较详细了,参考了一些我写在下面:
总结就是一句话:以batch为一个单位,对补齐后的batch进行压缩计算最后解压。减少这个batch中大量pad对输出的影响。
参考:pytorch中如何处理RNN输入变长序列padding - 知乎
https://www.cnblogs.com/lindaxin/p/8052043.html
如果已经看过类似解释的同学,可以直接跳到最后的代码展示。
当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练样例长度不同的情况,这样我们就会很自然的进行padding,将短句子padding为跟最长的句子一样。
比如向下图这样:
但是这会有一个问题,什么问题呢?比如上图,句子“Yes”只有一个单词,但是padding了5的pad符号,这样会导致LSTM对它的表示通过了非常多无用的字符,这样得到的句子表示就会有误差,更直观的如下图:
那么我们正确的做法应该是怎么样呢?
这就引出pytorch中RNN需要处理变长输入的需求了。在上面这个例子,我们想要得到的表示仅仅是LSTM过完单词"Yes"之后的表示,而不是通过了多个无用的“Pad”得到的表示。
主要是用函数torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这两个函数的用法。
torch.nn.utils.rnn.pack_padded_sequence():
这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)
输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。
Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后(特别注意需要进行排序)。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。
参数说明:
input (Variable) – 变长序列 被填充后的 batch
lengths (list[int]) – Variable 中 每个序列的长度。(知道了每个序列的长度,才能知道每个序列处理到多长停止)
batch_first (bool, optional) – 如果是True,input的形状应该是B*T*size。
返回值:
一个PackedSequence 对象。
这里的[]4, 3, 2] 是序列长度排序。
out_pad = torch.nn.utils.rnn.pack_padded_sequence(out, torch.tensor([4, 3, 2]), batch_first=True)
返回的output是PackedSequence类型的,可以使用:
encoder_outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
下面直接从代码层面解释这两个怎么使用:
假设有三段文本,我们已经把他补齐,作为一个bath传递进去,里面有大量的pad补齐字符。
import torch
input_tensor = torch.tensor([[1, 3, 5, 6, 2, 0, 0],
[1, 3, 5, 0, 0, 0, 0],
[1, 3, 0, 0, 0, 0, 0]])
接下来是embedding层。
embe = torch.nn.Embedding(10, 6)
out = embe(input_tensor)
print(out.shape)
# 结果
torch.Size([3, 7, 6])
正常的GRU层(可以与下面对比一下结果):
gru = torch.nn.GRU(6, 8, batch_first=True)
hidden_normal = torch.zeros(1, 3, 8)
output_normal, _ = gru(out, hidden_normal)
print(output_normal.shape)
print(output_normal)
# 结果
torch.Size([3, 7, 8])
tensor([[[-0.3121, 0.0188, -0.1041, -0.1437, -0.4423, 0.2555, 0.3690,
0.2136],
[ 0.1832, -0.2063, -0.0339, -0.3196, -0.6962, 0.2769, 0.3495,
0.0115],
[ 0.3326, -0.3881, 0.0615, -0.2771, -0.4755, 0.2857, 0.3597,
-0.4412],
[ 0.1384, -0.0065, 0.2262, -0.4853, -0.6944, -0.0467, 0.5761,
-0.3320],
[ 0.2038, 0.0938, -0.1772, -0.4974, -0.5730, -0.3191, 0.6605,
-0.3210],
[ 0.2620, 0.1287, -0.4169, -0.4849, -0.5390, -0.4803, 0.6889,
-0.2553],
[ 0.3085, 0.1449, -0.5499, -0.4641, -0.5323, -0.5623, 0.6914,
-0.1874]],
[[-0.3121, 0.0188, -0.1041, -0.1437, -0.4423, 0.2555, 0.3690,
0.2136],
[ 0.1832, -0.2063, -0.0339, -0.3196, -0.6962, 0.2769, 0.3495,
0.0115],
[ 0.3326, -0.3881, 0.0615, -0.2771, -0.4755, 0.2857, 0.3597,
-0.4412],
[ 0.3520, -0.0209, -0.2198, -0.3820, -0.4272, -0.1440, 0.5585,
-0.3779],
[ 0.3638, 0.0925, -0.4146, -0.4223, -0.4525, -0.3903, 0.6477,
-0.2921],
[ 0.3740, 0.1334, -0.5358, -0.4313, -0.4838, -0.5156, 0.6777,
-0.2139],
[ 0.3830, 0.1506, -0.6060, -0.4272, -0.5031, -0.5743, 0.6832,
-0.1545]],
[[-0.3121, 0.0188, -0.1041, -0.1437, -0.4423, 0.2555, 0.3690,
0.2136],
[ 0.1832, -0.2063, -0.0339, -0.3196, -0.6962, 0.2769, 0.3495,
0.0115],
[ 0.2611, 0.0369, -0.2880, -0.3972, -0.4885, -0.1832, 0.5467,
-0.1531],
[ 0.3099, 0.1091, -0.4588, -0.4242, -0.4651, -0.4192, 0.6352,
-0.1683],
[ 0.3431, 0.1368, -0.5622, -0.4278, -0.4822, -0.5314, 0.6664,
-0.1414],
[ 0.3657, 0.1500, -0.6208, -0.4224, -0.4980, -0.5815, 0.6743,
-0.1111],
[ 0.3810, 0.1572, -0.6524, -0.4151, -0.5059, -0.6024, 0.6746,
-0.0875]]], grad_fn=<TransposeBackward1>)
加入pack_padded_sequence,这个函数得到的输出是一个特殊的类,相当于压缩了的序列,pytorch中可以直接传递给下一个模块,我们要使用可以利用pad_packed_sequence将其解压出来。
这个结果可以看到,以最大长度(非补齐输入序列)补齐输出长度。这样减少了大量《pad》字符对输出的影响,直接将输出部分置0,只计算真正有输入的单词。
hidden = torch.zeros(1, 3, 8)
out_pad = torch.nn.utils.rnn.pack_padded_sequence(out, torch.tensor([4, 3, 2]), batch_first=True)
output, _ = gru(out_pad, hidden)
encoder_outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
print(encoder_outputs.shape)
print(encoder_outputs)
# 结果
torch.Size([3, 4, 8])
tensor([[[-0.3147, 0.2937, 0.3170, 0.0374, 0.0856, 0.1972, 0.1793,
-0.1815],
[-0.1413, -0.2737, 0.4023, -0.0043, -0.1145, 0.0961, 0.0909,
-0.1149],
[-0.2327, 0.0745, 0.5349, 0.0076, 0.1540, 0.1582, 0.2454,
-0.2582],
[-0.1467, -0.2010, 0.4935, 0.0996, -0.3427, 0.2260, 0.0455,
0.0056]],
[[-0.3147, 0.2937, 0.3170, 0.0374, 0.0856, 0.1972, 0.1793,
-0.1815],
[-0.1413, -0.2737, 0.4023, -0.0043, -0.1145, 0.0961, 0.0909,
-0.1149],
[-0.2327, 0.0745, 0.5349, 0.0076, 0.1540, 0.1582, 0.2454,
-0.2582],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000]],
[[-0.3147, 0.2937, 0.3170, 0.0374, 0.0856, 0.1972, 0.1793,
-0.1815],
[-0.1413, -0.2737, 0.4023, -0.0043, -0.1145, 0.0961, 0.0909,
-0.1149],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000]]], grad_fn=<TransposeBackward0>)