Pytorch的ConcatDataset学习

训练神经网络非常重要的一个步骤就是准备数据,甚至有时候比写训练模型的代码还要重要。然而网上的一些例程多半是使用的经典的数据集如MNIST、CIFAR-100等作为例子。这些数据集都是被人家处理好了的,加载进来放到dataloader里面就可以用。而在我们自己的实际任务中,可能数据集很大,不可能一次性把所有数据都加载到内存中,所以就需要对整个数据集划分成许多个子数据集,分别存储、加载。但是这时会有一个问题就是一个数据集中包含多少个样本呢?如何协调这个值与batch size的关系?也就是说我们有时候要对这些子数据集进行处理,如合并操作等等。那么你可以自己写一个队列用于不断地给模型提供所需数量的数据,今天我们讨论一种pytorch自带的一种合并子数据集的方式,ConcatDataset类,学习一下它的源码。

首先,它继承自Dataset类。

其次,它的构造函数要求一个包含若干个子数据集的列表L作为输入,并且这些数据集不能是可迭代的数据集(Iterable Dataset)。构造函数会计算一个cumulative size,即“把L中的第n个子数据集算上后,现在我一共有多少个样本”,这样也得到一个列表。

然后,它重写了__len__方法,返回cumulative_size[-1];重写了__getitem__,就是假如dataloader想通过索引来它这里取数据的时候,它应该返回什么.

def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return self.datasets[dataset_idx][sample_idx]

因为cumulative_size是递增的,所以它用了一个二分查找的包:bisect,找到第一个大于idx的索引dataset_idx。注意,idx是一个全局索引,意思是“我要拿到被Concat之后的这个大的数据集里面的第idx个样本”,所以具体实现的时候我们就需要知道要到第几个子数据集中去找第几个样本。而dataset_idx就是第几个子数据集,这也是后面几行代码的意思了。

 

这就是ConcatDataset的主要内容,它在连接一些小数据集的时候很有用。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值