Pytorch(3)-数据载入接口:Dataloader、datasets

1.数据载入概况

数据是机器学习算法的驱动力, Pytorch提供了方便的数据载入和处理接口. 数据载入流程为:

step1: 指定要使用的数据集dataset
step2: 使用Dataloader载入数据

dataloader实质是一个可迭代对象,不能使用next()访问。但如果使用iter()封装,返回一个迭代器,可以使用.next()操作。

Dataloader 是啥

来自官网document的描述:

Dataloader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
The DataLoader supports both map-style and iterable-style datasets with single- or multi-process loading, 
customizing loading order and optional automatic batching (collation) and memory pinning.
See torch.utils.data documentation page for more details.

大概就是说:用来对数据集进行(小批次)迭代 载入的接口,所能够载入的数据集要么支持map-style操作,要么支持 iterable-style操作。

(这两种操作只有在编写用户数据类时才需要考虑,使用内置公开数据集和.ImageFolder不需要管这两者是啥东西,开发者已经帮你写好了)

2.支持的三类数据集

1.torchvision.datasets–内置了许多常见的公开数据集

2.torchvision.datasets.ImageFolder–用户定制数据集1(只要自己的数据集满足ImageFolder要求的格式,提供数据集所在的地址即可)

3.定制数据集–需要编写自己的dataset 类

2.1 torchvision.datasets.xxx

一些常用的公开数据集合,可以在torchvision.datasets接口中找到。

例如–MNIST、Fashion-MNIST、KMNIST、EMNIST、FakeData、COCO、Captions、Detection、LSUN、ImageFolder、DatasetFolder、ImageNet、CIFAR、STL10、SVHN、PhotoTour、SBU、Flickr、VOC、Cityscapes、SBD等常用数据集合。

torchvision.datasets在使用一个新的数据集合前,需要保证本地拥有该数据集合(符合pytorch内部编码格式)。最简单额方式是第一次使用时,将download=True将默认将该数据集下载到指定的root 目录中。

CIFAR10数据集使用的例子

transform = transforms.Compose( [transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)
默认值:train=True

step1 数据集选择与图片处理方式选择

trainset = torchvision.datasets.CIFAR10(root=’./data’, train=True,download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root=’./data’, train=False,download=True, transform=transform)

参数解释:
1.root=’./data’
数据集的保存目录,各种数据集有自己的文件格式,其中MNIST是以training.pt和test.pt的保存图像数据信息(具体看一下文件应该怎么存,读入之后的列表和迭代器各是什么内容

2.train =True
处理MNIST时从training.pt读取训练数据,=False 从test.pt读取测试数据。仔细观察,上面两句话只有在train这个选项处不同.

3.download =True
会从网上下载对应的数据集文件,MNIST对应.pt文件,如果存在 .pt 文件,这个参数可以设置为False

4.transform
设置一组对图像进行处理的操作,这一组操作由Compose组成,这一组compose 的顺序还很重要按如下顺序编写:
transforms.Resize()
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

step2 数据载入接口

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

参数解释
1.将刚刚生成的trainset列表传入 torch.utils.data.DataLoader()

2.batch_size=4 设定图像数据批次大小

3.shuffle=True 每一个epoch过程中会打乱数据顺序,重新随机选择

4.导入数据时的线程数目,默认为0,主线程导入数据

2.2 torchvision.datasets.ImageFolder

当数据集超出1中所提供数据集的范围时,Pytorch还提供了ImageFolder数据集导入方式。只要将数据按照一定的要求存放,就能如方式1一样方便取用。

数据集合格式要求:同类别的图像放在一个文件夹下,用类别名称/标号来命名文件夹。要自己手工设计训练集合、测试集合

x=torch.datasets.ImageFolder(root="图像集合中文件夹路径”)

x是一个ImageFolder格式的数据:
在这里插入图片描述
其中重要主要成员为:
class_to_idx ={dict} 是字典数据,以“文件夹名字:分配类别序号”作为键值的字典
classes ={list} 包含所有文件夹名字的一个序列
imgs={list} 列表元素为–(图像路径,对应文件夹名)

使用torch.utils.data.DataLoader载入数据:

trainloader = torch.utils.data.DataLoader(x, batch_size=4, shuffle=True, num_workers=4)

参考网址:
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

2.3 写自己的数据类,读入定制化数据

当用户数据个格式不能用以上两种方式读取时,可以尝试写自己的数据类

所有的datasets都是torch.utils.data.Dataset的子类,方法1中使用的是torchvision.datasets.数据集,方法 2中使用的是torchvision.datasets.ImageFolder。当我们在编写自己的数据类时,也需要继承Dataset类。

2.3.1 数据类的编写

在介绍Dataloader 使提到过,其载入的数据类需要满足两者操作中的一个(map-style操作/iterable-style操作)

map-style范式

Map-style 操作范式数据类的核心:实现了 getitem() 和 len()方法,通过data[index]获取数据样本和相应的标签。

猜测:DataLoader 在导入minibatch数据时,随机采样一批index(通过len确认index 的采样范围), 然后在经过getitem获取相应的数据

class MyDataset:
    def __init__(self, gentor: object, batchSize: int, imgSize: int):
		# 从源数据中读取数据列表,或者能操作数据的名称列表
    def __len__(self):
    	# 返回数据集样本的数量
        return sample_map_num
    def __getitem__(self, idx:int):
   		# 通过idx获取数据data
   		data = get(idx)   // get 依据不同的数据集定制
   		// 进行一些tansform操作在返回
        return data

官方实践demo:https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

iterable-style 范式

Iterable-style 操作范式数据类 是 IterableDataset的子类,实现了__iter__()方法。当随机读取非常耗时/无法实现时。(数据流,实时记录的数据)
有机会实践一下

2.3.2 DataLoader 导入数据类

编写好了自己的数据类之后,同其他两种数据类一样使用DataLoader导入数据即可。

    train_set = MyDataset()
    data = train_set[0]       # idx 读取某一个数据
    trainloader = DataLoader(train_set, batch_size=64, shuffle=True)     # 封装成dataloader的形式
    print(len(trainloader))
    for _, data in enumerate(trainloader):
		....

下面提供一些可供参考的博文:
https://www.jianshu.com/p/220357ca3342
https://www.cnblogs.com/devilmaycry812839668/p/10122148.html
https://ptorch.com/news/215.html

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值