当每个训练数据为 sequence 的时候,我们第一反应是采用 RNN 以及其各种变体。这时新手们(我也是刚弄明白)往往会遇到这样的问题:训练数据 sequence 长度是变化的,难以采用 mini-batch 训练,这时应该怎么办,难道只能一个 sequence 一个 sequence 地训练吗?针对这一问题,本文记录 PyTorch 给出的解决方案。
需要用到的函数如下:
torch.nn.utils.rnn.pad_sequence()
torch.nn.utils.rnn.pack_padded_sequence()
torch.nn.utils.rnn.pad_packed_sequence()
pad_sequence
我们构造如下的训练数据,其中每条训练数据长度都不同。
import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
train_x = [torch.tensor([1, 1, 1, 1, 1, 1, 1]),
torch.tensor([2, 2, 2, 2, 2, 2]),
torch.tensor([3, 3, 3, 3, 3]),
torch.tensor([4, 4, 4, 4]),
torch.tensor([5, 5, 5]),
torch.tensor([6, 6]),
torch.tensor([7])]
x = rnn_utils.pad_sequence(train_x, batch_first=True)
x 将变成:
tensor([[1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 0],
[3, 3, 3, 3, 3, 0, 0],
[4, 4, 4, 4, 0, 0, 0],
[5, 5, 5, 0, 0, 0, 0],
[6, 6, 0, 0, 0, 0, 0],
[7, 0, 0, 0, 0, 0, 0]])
我们发现,这个函数会把长度小于最大长度的 sequences 用 0 填充,并且把 list 中所有的元素拼成一个 tensor。这样做的主要目的是为了让 DataLoader 可以返回 batch,因为 batch 是一个高维的 tensor,其中每个元素的数据必须长度相同。
为了证明这一点,我们完整地写一个数据类,用 dataloader 按 batch 的形式读取数据,代码如下:
import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data import DataLoader
import torch.utils.data as data
train_x = [torch.tensor([1, 1, 1, 1, 1, 1, 1]),
torch.tensor([2, 2, 2, 2, 2, 2]),
torch.tensor([3, 3, 3, 3, 3]),
torch.tensor([4, 4, 4, 4]),
torch.tensor([5, 5, 5]),
torch.tensor([6, 6]),
torch.tensor([7])]
x = rnn_utils.pad_sequence(train_x, batch_first=True)
class MyData(data.Dataset):
def __init__(self, data_seq):
self.data_seq = data_seq
def __len__(self):
return len(self.data_seq)
def __getitem__(self, idx):
return self.data_seq[idx]
if __name__=='__main__':
data = MyData(train_x)
data_loader = DataLoader(data, batch_size=2, shuffle=True)
batch_x = iter(data_loader).next()
print('END')
我们将会收到如下报错:
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension
0. Got 3 and 7 in dimension 1 at
/pytorch/aten/src/TH/generic/THTensorMoreMath.cpp:1333
报错的原因是,不同的数据长度不同,无法组成一个 batch tensor。
DataLoader
中有个参数 collate_fn
,专门用来把 Dataset 类的返回值拼接成 tensor,我们不设置的时候,会调用 default 的函数,这次我们的训练数据长度不一,default 函数就 hold 不住了,因此我们要自定义一个 collate_fn
,并在 DataLoader 中设置这个参数,再运行就不会报错了(注意代码中对 data 先按照长度降序排列了一下,后面会讲到原因)。
def collate_fn(data):
data.sort(key=lambda x: len(x), reverse=True)
data = rnn_utils.pad_sequence(data, batch_first=True, padding_value=0)
return data
if __name__=='__main__':
data = MyData(train_x)
data_loader = DataLoader(data, batch_size=3, shuffle=True,
collate_fn=collate_fn)
batch_x = iter(data_loader).next()
print('END')
运行结果如下:
batch_x
Out[2]:
tensor([[1, 1, 1, 1, 1, 1, 1],
[3, 3, 3, 3, 3, 0, 0],
[6, 6, 0, 0, 0, 0, 0]])
正是我们想要的。
pack_padded_sequence
我们通过 pad_sequence
得到了 padded_sequence
,那么直接扔进 RNN 训练不就完了吗?为啥还要用 pack_padded_sequence
?这个 pack 又是什么意思呢?
我们回忆一下 RNN 是如何训练的,首先考虑单个训练数据,也就是batch_size
=1 的情况:每次网络吃进一个 time step 的数据+该数据对应的 hidden state,然后输出,再继续吃进去第二个 time step 的数据 + hidden state,再输出,以此类推;如果换成 mini-batch 的训练模式则是:网络每次吃进去一组同样 time step 的数据,也就是mini-batch 中所有 sequence 中相同下标的数据,加上它们对应的 hidden state,获得一个 mini-batch 的输出,然后再移到下一个 time step,再读入 mini-batch 中所有该 time step 的数据,再输出……
因此,以上面 pad_sequence
的输出为例,数据将会按照如图所示的方式读取:
![0caa994e643ad72980c39fd2827d9432.png](https://i-blog.csdnimg.cn/blog_migrate/16fdaa340485537f85e84bb568d14de4.jpeg)
网络读取数据的顺序是:[1, 3, 6],[1, 3, 6],[1, 3, 0],[1, 3, 0],[1, 3, 0],[1, 0, 0],[1, 0, 0]。而该 mini-batch 中的 0 是没有意义的 padding,只是为了用来让它和最长的数据对齐而已,显然这种做法浪费了大量计算资源。因此,我们将用到 pack_padded_sequence
。即,不光要 padd,还要 pack。
pack_padded_sequence
有三个参数:input, lengths, batch_first
。input
是上一步加过 padding 的数据,lengths
是各个 sequence 的实际长度,batch_first
是数据各个 dimension 按照 [batch_size, sequence_length, data_dim]
顺序排列。
上面例子中,batch_x 为:
batch_x
Out[2]:
tensor([[1, 1, 1, 1, 1, 1, 1],
[3, 3, 3, 3, 3, 0, 0],
[6, 6, 0, 0, 0, 0, 0]])
因此应该设置 lengths=[7, 5, 2]
rnn_utils.pack_padded_sequence(batch_x, [7,5,2], batch_first=True)
Out[3]: PackedSequence(
data=tensor([1., 3., 6., 1., 3., 6., 1., 3., 1., 3., 1., 3., 1., 1.]),
batch_sizes=tensor([3, 3, 2, 2, 2, 1, 1]))
我们发现,它的输出有两部分,分别是 data
和 batch_sizes
,第一部分为原来的数据按照 time step 重新排列,而 padding 的部分,直接空过了。batch_sizes
则是每次实际读入的数据量,也就是说,RNN 把一个 mini-batch sequence 又重新划分为了很多小的 batch,每个小 batch 为所有 sequence 在当前 time step 对应的值,如果某 sequence 在当前 time step 已经没有值了,那么,就不再读入填充的 0,而是降低 batch_size
。batch_size
相当于是对训练数据的重新划分。这也是为什么前面在 collate_fn
中我们要对 mini-batch 中的 sequence 按照长度降序排列,是为了方便我们取每个 time step 的batch,防止中间夹杂着 padding。 而每个 mini-batch 中 sequence 的真实 length 又如何获得呢?这就要重新修改 collate_fn
了,我们在其中加入data_length=[len(sq) for sq in data]
修改后的代码如下:
def collate_fn(data):
data.sort(key=lambda x: len(x), reverse=True)
data_length = [len(sq) for sq in data]
data = rnn_utils.pad_sequence(data, batch_first=True, padding_value=0)
return data, data_length
if __name__=='__main__':
data = MyData(train_x)
data_loader = DataLoader(data, batch_size=3, shuffle=True,
collate_fn=collate_fn)
batch_x, batch_x_len = iter(data_loader).next()
batch_x_pack = rnn_utils.pack_padded_sequence(batch_x,
batch_x_len, batch_first=True)
pad_packed_sequence
一看名字就知道,这个函数和前面的函数是一对。有点像西游记里的奔波儿灞和灞波儿奔。
上文的例子中,我们为了直观,没有考虑到 RNN 对数据维度的要求,因此在这里我们要重新改写 collate_fn
使其返回的数据符合 [batch, sequence_len, input_size]
的格式(我们设置网络为 batch_first
的模式,更符合习惯)。在例子中,每个 sequence 的元素维度都是1,因此只需要在 tensor 末尾加一维就好了,即对返回的数据 unsqueeze(-1) 一下(也可以在数据库的类中,对 _getitem_
的返回值 unsqueeze)。
def collate_fn(data):
data.sort(key=lambda x: len(x), reverse=True)
data_length = [len(sq) for sq in data]
data = rnn_utils.pad_sequence(data, batch_first=True, padding_value=0)
return data.unsqueeze(-1), data_length
修改后,batch_x
和batch_x_pack
分别为:
batch_x
Out[2]:
tensor([[[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.]],
[[3.],
[3.],
[3.],
[3.],
[3.],
[0.],
[0.]],
[[6.],
[6.],
[0.],
[0.],
[0.],
[0.],
[0.]]])
batch_x_pack
Out[3]:
PackedSequence(data=tensor([
[1.],
[3.],
[6.],
[1.],
[3.],
[6.],
[1.],
[3.],
[1.],
[3.],
[1.],
[3.],
[1.],
[1.]]), batch_sizes=tensor([3, 3, 2, 2, 2, 1, 1]))
符合我们的预期。
接下来,我们随机初始化 hidden state 和 cell state (维度为:num_layers * num_directions, batch, hidden_size
), 和batch_x_pack
一起送入LSTM中。
if __name__=='__main__':
data = MyData(train_x)
data_loader = DataLoader(data, batch_size=3, shuffle=True,
collate_fn=collate_fn)
batch_x, batch_x_len = iter(data_loader).next()
batch_x_pack = rnn_utils.pack_padded_sequence(batch_x,
batch_x_len, batch_first=True)
net = nn.LSTM(1, 10, 2, batch_first=True)
h0 = torch.rand(2, 3, 10)
c0 = torch.rand(2, 3, 10)
out, (h1, c1) = net(batch_x_pack, (h0, c0))
print('END')
其中 LSTM 输入为 1 维,hidden size 为 10 ,总共两层。经过一次前向传播,我们得到 out
。out
和 batch_x_pack
一样,分为两部分: data
和 batch_sizes
。观察一下它这两部分:
out.data.shape
Out[5]: torch.Size([14, 10])
batch_x_pack.data.shape
Out[6]: torch.Size([14, 1])
out.batch_sizes
Out[7]: tensor([3, 3, 2, 2, 2, 1, 1])
batch_x_pack.batch_sizes
Out[8]: tensor([3, 3, 2, 2, 2, 1, 1])
输入的 mini-batch 中,统计所有 time step 共有 14 个非零的数据,而 LSTM 的 hidden unit 有10维,故 out.data.shape
为 torch.Size([14, 10])
。而out.batch_sizes
则和 batch_x_pack.batch_sizes
相同,都是 tensor([3, 3, 2, 2, 2, 1, 1])
。
pad_packed_sequence
执行的是 pack_padded_sequence
的逆操作,执行下面的代码,观察输出。
out, (h1, c1) = net(batch_x_pack, (h0, c0))
out_pad, out_len = rnn_utils.pad_packed_sequence(out, batch_first=True)
out_pad.shape
Out[2]: torch.Size([3, 7, 10])
out.data.shape
Out[3]: torch.Size([14, 10])
out_len
Out[4]: tensor([7, 5, 2])
我们发现,经过这样的操作后out_pad
形状变成了[3, 7, 10]
,仿佛我们直接输入加了padding 的 mini-batch ,mini-batch 中有 3 个 sequence,每个 sequence 有 7 个 time step,每个 time step 数据从输入的 1 维,映射成 LSTM 的 10 维,此外它还输出了 out_len
,为 [7, 5, 2]
,即每个 sequence 的真实长度。 为了放心,我们再看一下out_pad[1]
是什么:
out_pad[1].shape
Out[11]: torch.Size([7, 10])
out_pad[1]
Out[12]:
tensor([[ 0.0027, -0.0135, 0.1366, -0.0420, 0.3269, 0.0726, -0.0872, -0.0409,
0.1267, 0.2546],
[-0.0365, -0.0574, 0.0436, -0.0346, 0.2652, -0.0088, -0.0881, -0.0700,
0.1753, 0.2102],
[-0.0557, -0.0865, -0.0048, -0.0317, 0.1738, -0.0366, -0.0858, -0.0805,
0.1873, 0.1898],
[-0.0684, -0.1015, -0.0261, -0.0301, 0.1045, -0.0485, -0.0828, -0.0843,
0.1961, 0.1764],
[-0.0769, -0.1085, -0.0354, -0.0294, 0.0567, -0.0542, -0.0807, -0.0857,
0.2019, 0.1671],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000]], grad_fn=<SelectBackward>)
下标为 1 的 sequence 真实长度是 5 ,第 6 、7 个 time step 是填充的 0,因此它对应的输出第 6 、7 行都是 0,符合我们的预期。
总结
torch.nn.utils.rnn.pad_sequence()
torch.nn.utils.rnn.pack_padded_sequence()
torch.nn.utils.rnn.pad_packed_sequence()
上面三个函数相互配合,可以在 sequence 长度变化时,成批读入数据,训练 RNN。第一个函数用于给 mini-batch 中的数据加 padding,让 mini-batch 中所有 sequence 的长度等于该 mini-batch 中最长的那个 sequence 的长度。
第二、三个函数,用于提高效率,避免 LSTM 前向传播时,把加入在训练数据中的 padding 考虑进去。因此第二、三个函数理论上可以不用,但为了提高效率最好还是用。
除此之外,本文还介绍了 DataLoader
的collate_fn
参数,用于把 Dataset
类的 __getitem__
方法的返回的 batchsize 个值拼接成一个 tensor。
全部代码如下:
import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data import DataLoader
import torch.utils.data as data
train_x = [torch.Tensor([1, 1, 1, 1, 1, 1, 1]),
torch.Tensor([2, 2, 2, 2, 2, 2]),
torch.Tensor([3, 3, 3, 3, 3]),
torch.Tensor([4, 4, 4, 4]),
torch.Tensor([5, 5, 5]),
torch.Tensor([6, 6]),
torch.Tensor([7])
]
x = rnn_utils.pad_sequence(train_x, batch_first=True)
class MyData(data.Dataset):
def __init__(self, data_seq):
self.data_seq = data_seq
def __len__(self):
return len(self.data_seq)
def __getitem__(self, idx):
return self.data_seq[idx]
def collate_fn(data):
data.sort(key=lambda x: len(x), reverse=True)
data_length = [len(sq) for sq in data]
data = rnn_utils.pad_sequence(data, batch_first=True, padding_value=0)
return data.unsqueeze(-1), data_length
if __name__=='__main__':
data = MyData(train_x)
data_loader = DataLoader(data, batch_size=3, shuffle=True,
collate_fn=collate_fn)
batch_x, batch_x_len = iter(data_loader).next()
batch_x_pack = rnn_utils.pack_padded_sequence(batch_x,
batch_x_len, batch_first=True)
net = nn.LSTM(1, 10, 2, batch_first=True)
h0 = torch.rand(2, 3, 10)
c0 = torch.rand(2, 3, 10)
out, (h1, c1) = net(batch_x_pack, (h0, c0))
out_pad, out_len = rnn_utils.pad_packed_sequence(out, batch_first=True)
print('END')