pytorch collate_fn函数实现变长序列 - dynamical padding

注意:这里的batch指的是mini-batch

两种实现序列(文本、日志)批处理的方法

  1. 固定长度的序列(uniform length sequences in batches)
    所有batch内序列的长度一样。比如seqs = [[1,2,3,3,4,5,6,7], [1,2,3], [2,4,1,2,3], [1,2,4,1]]
    batch_size = 2
    那么最大序列长度取8,如果不足8用0填充到该长度
batch1 = [[1, 2, 3, 3, 4, 5, 6, 7], [1, 2, 3, 0, 0, 0, 0, 0]], 
batch2 = [[2, 4, 1, 2, 3, 0, 0, 0], [1, 2, 4, 1, 0, 0, 0, 0]]
  1. 变长的序列(variable length sequences in batches)
    每个batch的序列长度一致,不同batch之间的序列的长度可能不同。比如上面的例子,如果是变长的,那么先对序列长度排序,再按照每个batch内序列最大长度padding
batch1 = [[1, 2, 3, 0], [1, 2, 4, 1]]。 # len = 4
batch2 =  [[2, 4, 1, 2, 3, 0, 0, 0], [1, 2, 3, 3, 4, 5, 6, 7]] #len = 8

为什么要采用变长的序列呢?

如果训练数据中有非常短的序列,那么用一个统一的长度padding,会造成数据过于稀疏。有可能影响训练时间,以及模型的预测效果。

pytorch中如何实现变长的序列?

答:dynamical padding(动态填充)
根据前面的思路,实现动态填充主要两步

  1. 先根据序列长度排序
  2. 在每个batch里,选择序列最大长度,或者四分之三分为点的长度(防止极端情况,最大值非常的大)作为该batch的固定长度。

collate_fn 参数

collate_fn是DataLoader的一个属性,用来处理批次数据,官网介绍

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

代码实现

from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np

class MyDataset(Dataset):
    def __init__(self, seq, label):
        self.seq = seq
        self.label = label

    def __len__(self):
        return len(self.label)

    def __getitem__(self, index):
        return self.seq[index], self.label[index]


def collate_fn(batch):
    """
    args:
        batch: [[input_vector, label_vector] for seq in batch]

    return:
        [[output_vector]] * batch_size, [[label]]*batch_szie
    """


    percentile = 100
    dynamical_pad = True
    max_len = 50
    pad_index = 0

    lens = [len(dat[0]) for dat in batch]

    # find the max len in each batch
    if dynamical_pad:
        # dynamical padding
        seq_len = min(int(np.percentile(lens, percentile)), max_len)
        # or seq_len = max(lens)
    else:
        # fixed length padding
        seq_len = max_len
    print("collate_fn seq_len", seq_len)

    output = []
    out_label = []
    for dat in batch:
        seq = dat[0][:seq_len]
        label = dat[1][:seq_len]

        padding = [pad_index for _ in range(seq_len - len(seq))]
        seq.extend(padding)
        label.extend(padding)

        output.append(seq)
        out_label.append(label)

    output = torch.tensor(output, dtype=torch.long)
    out_label = torch.tensor(out_label, dtype=torch.long)

    return output, out_label


batch_size = 2
seqs = np.array([[1,0,3,3,4,5,6,0], [1,0,3], [2,4,0,2,3], [1,2,0,1]])
label = np.array([[0,2,0,0,0,0,0,7], [0,2,0], [0,0,1,0,0], [0,0,4,0]])

lens = np.array(list(map(len, seqs)))
len_index = np.argsort(-1 * lens)
seqs = seqs[len_index]
label = label[len_index]

mydataset = MyDataset(seqs, label)
dl = DataLoader(mydataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
for batch in dl:
    print(batch)


代码参考[1]

如果你用到了rnn-based的模型,那么可以了解一下pack_padded_sequence和 pad_packed_sequence这两个函数,作用是不将padding的值传入到模型训练。
pack_padded_sequence将填充过的模型压紧,然后将数据传入到模型训练。模型的结果用pad_packed_sequence恢复原来的维度

简单实现

#pack_padded_sequence so that padded items in the sequence won't be shown to the LSTM
X = torch.nn.utils.rnn.pack_padded_sequence(x, X_lengths, batch_first=True)

# now run through LSTM
X, self.hidden = self.lstm(X, self.hidden)

# undo the packing operation
X, _ = torch.nn.utils.rnn.pad_packed_sequence(X, batch_first=True)

中文参考见[2]
英文参考见[3],里面还提到用了padding怎么计算loss,推荐一看。pytorch自带的一些loss 函数比如NLLLoss,可以指定忽略padding 值。

参考资料

[1]https://www.kaggle.com/evilpsycho42/pytorch-batch-dynamic-padding-sort-pack
[2]https://blog.csdn.net/u011550545/article/details/89529977?utm_medium=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param
[3]https://towardsdatascience.com/taming-lstms-variable-sized-mini-batches-and-why-pytorch-is-good-for-your-health-61d35642972e

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值