Pytorch学习(三)定义自己的数据集及加载训练

对于如何定义自己的Datasets我讲从以下几个方面进行解说
**

1.什么是Datasets?
2.为什么要定义Datasets?
3.如何定义Datasets?

定义Datasets分为以下几个板块:

1)Datasets的源代码及解说

2)Datasets的整体框架及解说

3)自己的Datasets框架及解说

4)DataLoader的使用

5)如何生成txt文件

什么是Datasets?

Datasets是我们用的数据集的库,我们知道pytorch自带多种数据集列如Cifar10数据集就是在pytorch的Datasets的库中的。

为什么要定义Datasets?

Pytorch中有工具函数torch.utils.Data.DataLoader,通过这个函数我们在准备加载数据集使用mini-batch的时候可以使用多线程并行处理,这样可以加快我们准备数据集的速度。Datasets就是构建这个工具函数的实例参数之一。

如何定义Datasets?

Dataset类
Dataset类是Pytorch中图像数据集中最为重要的一个类,也是Pytorch中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数必须被重载,否则将会触发错误提示:

def getitem(self, index):
def len(self):

其中__len__应该返回数据集的大小,而__getitem__应该编写支持数据集索引的函数
这里重点看 getitem函数,getitem接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。

然而,如何制作这个list呢,通常的方法是将图片的路径和标签信息存储在一个txt中,然后从该txt中读取。
那么读取自己数据的基本流程就是:

制作存储了图片的路径和标签信息的txt
将这些信息转化为list,该list每一个元素对应一个样本
通过getitem函数,读取数据和标签,并返回数据和标签

定义自己的数据集类

1)Datasets的源代码及解说

All datasets are subclasses of torch.utils.data.Dataset i.e,
 they have __getitem__ and __len__ methods implemented. 
 Hence, they can all be passed to a torch.
 utils.data.DataLoader which can load multiple samples parallelly using torch.multiprocessing workers. 

[源代码地址(https://pytorch.org/docs/stable/torchvision/datasets.html)
从源代码我们可以看出继承Datasets必须继承__init_()和__getitim__()
首先继承上面的dataset类。然后在__init__()方法中得到图像的路径,然后将图像路径组成一个数组,这样在__getitim__()中就可以直接读取.

2)Datasets的整体框架及解说

class FirstDataset(data.Dataset):#需要继承data.Dataset
    def __init__(self):
        # TODO
        # 1. 初始化文件路径或文件名列表。
        #也就是在这个模块里,我们所做的工作就是初始化该类的一些基本参数。
        pass
    def __getitem__(self, index):
        # TODO

        #1。从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open)。
         #2。预处理数据(例如torchvision.Trans
评论 68
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值