pytorch collate_fn函数实现变长序列 - dynamical padding

本文介绍了在PyTorch中处理变长序列批处理的方法,特别是通过使用`collate_fn`进行动态填充。文章探讨了为何采用变长序列的原因,避免数据稀疏性影响训练效果,并提供了代码实现和相关参考资料,包括如何利用`pack_padded_sequence`和`pad_packed_sequence`处理RNN模型中的填充序列。
摘要由CSDN通过智能技术生成

注意:这里的batch指的是mini-batch

两种实现序列(文本、日志)批处理的方法

  1. 固定长度的序列(uniform length sequences in batches)
    所有batch内序列的长度一样。比如seqs = [[1,2,3,3,4,5,6,7], [1,2,3], [2,4,1,2,3], [1,2,4,1]]
    batch_size = 2
    那么最大序列长度取8,如果不足8用0填充到该长度
batch1 = [[1, 2, 3, 3, 4, 5, 6, 7], [1, 2, 3, 0, 0, 0, 0, 0]], 
batch2 = [[2, 4, 1, 2, 3, 0, 0, 0], [1, 2, 4, 1, 0, 0, 0, 0]]
  1. 变长的序列(variable length sequences in batches)
    每个batch的序列长度一致,不同batch之间的序列的长度可能不同。比如上面的例子,如果是变长的,那么先对序列长度排序,再按照每个batch内序列最大长度padding
batch1 = [[1, 2, 3, 0], [1, 2, 4, 1]]。 # len = 4
batch2 =  [[2, 4, 1, 2, 3, 0, 0, 0], [1, 2, 3, 3, 4, 5, 6, 7]] #len = 8

为什么要采用变长的序列呢?

如果训练数据中有非常短的序列,那么用一个统一的长度padding,会造成数据过于稀疏。有可能影响训练时间,以及模型的预测效果。

pytorch中如何实现变长的序列?

答:dynamical padding(动态填充)
根据前面的思路,实现动态填充主要两步

  1. 先根据序列长度排序
  2. 在每个batch里,选择序列最大长度,或者四分之三分为点的长度(防止极端情况,最大值非常的大)作为该batch的固定长度。
collate_fn 参数

collate_fn是DataLoader的一个属性,用来处理批次数据,官网介绍

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)
代码实现
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np

class MyDataset(Dataset):
    def __init__(self, seq, label):
        self.seq = seq
        self.label = label

    def __len__(self):
        return len(self.label)

    def __getitem__(self, index):
        return self
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值