【PyTorch修炼】二、带你详细了解并使用Dataset以及DataLoader

一、前言

最近开始重新记载我学习的pytorch笔记。

今天讲的是加载数据的模块,为什么要写这个模块呢?

因为我最近自己生成了个全新的数据集,需要加载,所以顺便把这个部分复习整理一下,列出了我觉得需要知道的一些技术点。

提醒:这篇文章可能会出现大量的代码。

二、初时DataSet

研究事情咱们还是要归于本身,所以我们直接先看这个类的源码部分。

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

注释里说得清清楚楚,全部的其他数据集类都要继承这个类,里面的__len____getitem__是必须要重写的,不重写直接报错,也是可以想想看,你做个数据集还不带整体长度,也不带按照index得到你的每条信息,那要你这个dataset干啥呢,也就没不存在bachsize,遍历等等啦。

这里有个直接给出的TensorDataset我们可以先简单用下,这也算是官方给的一个案例了。

class TensorDataset(Dataset):
    """Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """

    def __init__(self, *tensors):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

这里我采用一个我经常拿来做时序demo的数据集,北京pm2.5数据集,这里为了显示方便只取了128个数据,正好两个batch.

因为代码很easy,这里我也不做很多的注释了,会发现我在这里做了很多的类型的print,无非就是想告诉大家利用TensorDataset的时候传入的应该是tensor类型,如果是df需要先转换成numpy.array在转换成tensor,输出的也是tensor,事情其实可以分为以下三步:

  1. 加载数据,提取出feature和label,并转换成tensor

  2. 传入TensorDataset中,实例化TensorDataset为datsset

  3. 再将dataset传入到Dataloader中,最后通过enumerate输出我们想要的经过shuffle的bachsize大小的feature和label数据

ps:这里DataLoader的源码就先不介绍了,比较多,如果读者们想一起读这个源码的可以留言,我闲下来再写一篇一起学习交流,共同进步。

三、自定义DataSet

我个人觉得结构化的数据和图片维度的数据dataset写法稍有不同,这里做个简单的分类,并给出我写的demo

3.1 结构化数据

这里的话 主要在__init__做的事情差不多就是在前面所写的输入到TensorDataset之前做的一些事情,不过多了一个数据的len,之后在__getitem__中返回通过index得到的数据(X, Y)

这里补充几句,结构化的相对于图片的要好写很多,并且其实写这个类很自由的,如果你是对应于特定的数据集的话你就可以仿效我写的这个demo这样,直接在里面就表明你的路径了,当然你也可以把path写成传参的,如果你不写成特定的,那就在类外面进行一波操作,转换成tensor之后传入到MyDataset的这个类中,其实还是看具体你的需求了。

3.2 图片类数据

这里我挑了个简单一些的图片分类数据集,自己稍微删了删文件夹中的数据,为了方便演示,我只留了很少的数目。

代码demo如下:但是会报这个错误这个错误的主要原因:

这种错误有两种可能:(来自于 https://www.cnblogs.com/zxj9487/p/11531888.html)

1.你输入的图像数据的维度不完全是一样的,比如是训练的数据有100组,其中99组是256×256,但有一组是384×384,这样会导致Pytorch的检查程序报错

2.比较隐晦的batchsize的问题,Pytorch中检查你训练维度正确是按照每个batchsize的维度来检查的,比如你有1000组数据(假设每组数据为三通道256px×256px的图像),batchsize为4,那么每次训练则提取(4,3,256,256)维度的张量来训练,刚好250个epoch解决(250×4=1000)。但是如果你有999组数据,你继续使用batchsize为4的话,这样999和4并不能整除,你在训练前249组时的张量维度都为(4,3,256,256)但是最后一个批次的维度为(3,3,256,256),Pytorch检查到(4,3,256,256) != (3,3,256,256),维度不匹配,自然就会报错了,这可以称为一个小bug。

这里出的错误仔细一读,65和85,应该是图片的大小不一致造成的,因为我一共才留了,20多张图片。我也随手打开我的图片文件夹,发现

两个图片肉眼可见,像素应该不一样,看一下图片信息详情确实是这样。

解决方案这里我加了一个resize就好了。

输出,看来现在已经不存在上面所说的第二种问题了,不是整除也可以一样输出,问题不大。

我们把shuffle=True看下结果

除了上面那个错误之后还有个常见的错误

错误在于你在MyDataset中通过index得到的数据不是tensor类型,这里进行转换一下就好了,说白了就是dataloader输出的每个bachsize大小的数据最好是张量,我也建议大家这么写,这样也方便直接拿出来去模型训练了。

四、总结

  1. 首先我们要去构建自己继承Dataset的MyDataSet

  2. 传入到Dataloader中,最后进行enumerate遍历每个batchsize

  3. Dataset通过index输出的最好是tensor

  4. 整体的Dataset和Dataloader中,基本上的工作过程是Dataloader每次给你返回一个shuffle过的index(如果shuffle=True,如果是false则返回的index就是按照顺序依次得到batchsize大小的index集合),当然这个index是在你的__len__之内的范围中的,之后以这些index遍历数据集,通过 getitem(self, index)返回一组你要的(feature, label)

公众号后台回复:dataloader    返回这篇文章所有的代码以及数据连接

更多精彩内容(请点击图片进行阅读)

公众号:AI蜗牛车

保持谦逊、保持自律、保持进步

个人微信

备注:昵称+学校/公司+方向

如果没有备注不拉群!

拉你进AI蜗牛车交流群

点个在看,么么哒!
  • 5
    点赞
  • 29
    收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
©️2022 CSDN 皮肤主题:Age of Ai 设计师:meimeiellie 返回首页
评论

打赏作者

AI蜗牛车

你的鼓励将是我创作的最大动力

¥2 ¥4 ¥6 ¥10 ¥20
输入1-500的整数
余额支付 (余额:-- )
扫码支付
扫码支付:¥2
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值