PyTorch中的Dataset

注:本文源码基于PyTorch1.0,目前在PyTorch0.4下没有发现错误。

PyTorch中的Dataset是一个抽象类,我们可以通过继承Dataset来将数据集的源文件、规模和其他非必要的功能打包,从而供DataLoader使用。无论是官方给出的数据集如torchvision.datasets.MNIST等,还是我们在做实验时需要使用自己的数据集,都要继承Dataset类,在继承过程中,须重载的函数包括:

  1. __init__():构造函数,略过不说。
  2. __getitem__():_DataLoaderIter()类中有调用:
# https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html
batch = self.collate_fn([self.dataset[i] for i in indices])
  1. __len__():sampler(如SequentialSampler()类)中有调用len()函数:
# https://pytorch.org/docs/stable/_modules/torch/utils/data/sampler.html
class SequentialSampler(Sampler):
    """Samples elements sequentially, always in the same order.

    Arguments:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

这三个类是继承时必须重载的函数,我们也可以加入self.loader和self.transform等变量以方便后续处理。需要注意的是,Dataset类只相当于一个打包工具,包含了数据的地址。真正把数据读入内存的过程是由Dataloader进行批迭代输入的时候进行的。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值