networkstream读取时获取传入的长度_pytorch的dataloader如何读取变长数据

v2-3dc948753cfe09eb77102625b2561573_1440w.jpg?source=172ae18b

最近在做一个新的声学模型,其中遇到一个点就是每个sentence的长度不一样的花,直接用dataloader的读取是有问题的。查了下中文资料,大家大多数这个问题都是趋于用torch.nn.utils.rnn.PackedSequence来打包的,这个在dataloader里面其实就不太适用,pytorch论坛上提到用dataloader的collate_fn来处理的,所以想写个资料总结下 。

pytorch里面dataset的工作逻辑:

pytorch的数据载入主要是这么几个逻辑,从底层一步步来讲,我用h5矩阵,图片和音频三个方面来举例,首先是逻辑层次是,首先把data装进用torch.utils.data.Dataset装进一个dataset的对象里面,然后在把dataset这个对象传递给一个torch.utils.data.DataLoader

dataset的工作逻辑

数据集的切分一般在dataset这个对象上做处理,支持随机切分等,详见torch.utils.data - PyTorch master documentation,一般来讲,我都是写一个torch.utils.data.Dataset的子类,里面就三个成员函数,初始化,长度和读取,一般在读取你自己定义的读取方法,我习惯的是h5矩阵的话,就读一段(子矩阵),图片就是一张图,或者一段音频。

这里面有个很关键的点,就是dataset的逻辑是一次读一个item,最好不要在dataset层面一次slice一段,slice这个层面的事情交给dataloader来做,原因我一会说

记住dataset的逻辑在于装和item读取,预处理,其他都不要做。

dataloader的工作逻辑

dataloader层面主要就是slice读取数据,shuffle也是在这个层面来做。

dataloader有几个关键点,很多地方都零零碎碎的提到过,我总结下,

  1. 是稀松平常的batch_size, sampler, shuffle这几个稀松平常的不提,shuffle是在dataset的item层面做混洗,
  2. 注意,num_workers是一个多线程的读取,当batchsize>1的时候,多线程读取item, 然后各个item调用一个collate_fn合并成新的tensor,其中h5依然是个坑,anaconda安装的h5是不支持多线程的,请参考并行 HDF5 和 h5py安装并行h5,至于num_worker以及pin_memoru的具体使用,参考云梦:Pytorch 提速指南,不重复造轮子。
  3. 关于这个collatefn是重点,当开启多线程了一个,多线程先后读取了dataset里面batch_size个item以后,生成了一个list,里面每个元素就是batchsize个item,然后用collatefn合并,如果没有指定的collatefns的话,就直接合并成一个高一维的tensor。

collatefns的工作逻辑

coolatefns的输入是个list,长度为batchsize,其中各个元素是各个item,函数的目的就是合并。

当各个item变长时,不指定collatefns合并就会报错,懒人方法就是把在dataset里面的读取函数把tensor加到最长,就可以直接merge。

当使用collatefns时,pytorch论坛上有人写了一个函数,我贴过来,大家配合注释看看:

def pad_tensor(vec, pad, dim):
    """
    args:
        vec - tensor to pad
        pad - the size to pad to
        dim - dimension to pad

    return:
        a new tensor padded to 'pad' in dimension 'dim'
    """
    pad_size = list(vec.shape)
    pad_size[dim] = pad - vec.size(dim)
    return torch.cat([vec, torch.zeros(*pad_size)], dim=dim)


class PadCollate:
    """
    a variant of callate_fn that pads according to the longest sequence in
    a batch of sequences
    """

    def __init__(self, dim=0):
        """
        args:
            dim - the dimension to be padded (dimension of time in sequences)
        """
        self.dim = dim

    def pad_collate(self, batch):
        """
        args:
            batch - list of (tensor, label)

        reutrn:
            xs - a tensor of all examples in 'batch' after padding
            ys - a LongTensor of all labels in batch
        """
        # find longest sequence
        max_len = max(map(lambda x: x[0].shape[self.dim], batch))
        # pad according to max_len
        batch = map(lambda (x, y):
                    (pad_tensor(x, pad=max_len, dim=self.dim), y), batch)
        # stack all
        xs = torch.stack(map(lambda x: x[0], batch), dim=0)
        ys = torch.LongTensor(map(lambda x: x[1], batch))
        return xs, ys

    def __call__(self, batch):
        return self.pad_collate(batch)

调用使用:

train_loader = DataLoader(ds, ..., collate_fn=PadCollate(dim=0))

来源:DataLoader for various length of data

对于读取了以后的数据,在rnn中的工作逻辑,pytorch的文档也提到过

total_length is useful to implement the packsequence->recurrentnetwork->unpacksequence pattern in a Module wrapped in DataParallel. See this FAQ section for details.

来源:torch.nn - PyTorch master documentation

关于读取到了的padding的变长数据,如何pack,请参考 @尹相楠的:

尹相楠:PyTorch 训练 RNN 时,序列长度不固定怎么办?​zhuanlan.zhihu.com

我就多说一句,pack_padded_sequence返回的packsequence虽说rnn都可以用,但我是cnn啊。。。。这个有人在pytorch论坛里提过,但是至今没有消息。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值