当每个训练数据为 sequence 的时候,我们第一反应是采用 RNN 以及其各种变体。这时新手们(我也是刚弄明白)往往会遇到这样的问题:训练数据 sequence 长度是变化的,难以采用 mini-batch 训练,这时应该怎么办,难道只能一个 sequence 一个 sequence 地训练吗?针对这一问题,本文记录 PyTorch 给出的解决方案。
需要用到的函数如下:
torch.nn.utils.rnn.pad_sequence()
torch.nn.utils.rnn.pack_padded_sequence()
torch.nn.utils.rnn.pad_packed_sequence()
pad_sequence
我们构造如下的训练数据,其中每条训练数据长度都不同。
import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
train_x = [torch.tensor([1, 1, 1, 1, 1, 1, 1]),
torch.tensor([2, 2, 2, 2, 2, 2]),
torch.tensor([3, 3, 3, 3, 3]),
torch.tensor([4, 4, 4, 4]),
torch.tensor([5, 5, 5]),
torch.tensor([6, 6]),
torch.tensor([7])]
x = rnn_utils.pad_sequence(train_x, batch_first=True)
x 将变成:
tensor([[1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 0],
[3, 3, 3, 3, 3, 0, 0],
[4, 4, 4, 4, 0, 0, 0],
[5, 5, 5, 0, 0, 0, 0],
[6, 6, 0, 0, 0, 0, 0],
[7, 0, 0, 0, 0, 0, 0]])
我们发现,这个函数会把长度小于最大长度的 sequences 用 0 填充,并且把 list 中所有的元素拼成一个 tensor。这样做的主要目的是为了让 DataLoader 可以返回 batch,因为 batch 是一个高维的 tensor,其中每个元素的数据必须长度相同。
为了证明这一点,我们完整地写一个数据类,用 dataloader 按 batch 的形式读取数据,代码如下:
import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data import DataLoader
import torch.utils.data as data
train_x = [torch.tensor([1, 1, 1, 1, 1, 1, 1]),
torch.tensor([2, 2, 2, 2, 2, 2]),
torch.tensor([3, 3, 3, 3, 3]),
torch.tensor([4, 4, 4, 4]),
torch.tensor([5, 5, 5]),
torch.tensor([6, 6]),
torch.tensor([7])]
x = rnn_utils.pad_sequence(train_x, batch_first=True)
class MyData(data.Dataset):
def __init__(self, data_seq):
self.data_seq = data_seq
def __len__(self):
return len(self.data_seq)
def __getitem__(self, idx):
return self.data_seq[idx]
if __name__=='__main__':
data = MyData(train_x)
data_loader = DataLoader(data, batch_size=2, shuffle=True)
batch_x = iter(data_loader).next()
print('END')
我们将会收到如下报错:
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension
0. Got 3 and 7 in dimension 1 at
/pytorch/aten/src/TH/generic/THTensorMoreMath.cpp:1333
报错的原因是,不同的数据长度不同,无法组成一个 batch tensor。
DataLoader
中有个参数 collate_fn
,专门用来把 Dataset 类的返回值拼接成 tensor,我们不