注意:这里的batch指的是mini-batch
两种实现序列(文本、日志)批处理的方法
- 固定长度的序列(uniform length sequences in batches)
所有batch内序列的长度一样。比如seqs = [[1,2,3,3,4,5,6,7], [1,2,3], [2,4,1,2,3], [1,2,4,1]]
batch_size = 2
那么最大序列长度取8,如果不足8用0填充到该长度
batch1 = [[1, 2, 3, 3, 4, 5, 6, 7], [1, 2, 3, 0, 0, 0, 0, 0]],
batch2 = [[2, 4, 1, 2, 3, 0, 0, 0], [1, 2, 4, 1, 0, 0, 0, 0]]
- 变长的序列(variable length sequences in batches)
每个batch的序列长度一致,不同batch之间的序列的长度可能不同。比如上面的例子,如果是变长的,那么先对序列长度排序,再按照每个batch内序列最大长度padding
batch1 = [[1, 2, 3, 0], [1, 2, 4, 1]]。 # len = 4
batch2 = [[2, 4, 1, 2, 3, 0, 0, 0], [1, 2, 3, 3, 4, 5, 6, 7]] #len = 8
为什么要采用变长的序列呢?
如果训练数据中有非常短的序列,那么用一个统一的长度padding,会造成数据过于稀疏。有可能影响训练时间,以及模型的预测效果。
pytorch中如何实现变长的序列?
答:dynamical padding(动态填充)
根据前面的思路,实现动态填充主要两步
- 先根据序列长度排序
- 在每个batch里,选择序列最大长度,或者四分之三分为点的长度(防止极端情况,最大值非常的大)作为该batch的固定长度。
collate_fn 参数
collate_fn是DataLoader的一个属性,用来处理批次数据,官网介绍
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
代码实现
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
class MyDataset(Dataset):
def __init__(self, seq, label):
self.seq = seq
self.label = label
def __len__(self):
return len(self.label)
def __getitem__(self, index):
return self