首先聊一聊个人对于Pytorch为什么使用dataloder这一机制的理解:
在没有用pytorch之前,我读取数据一般时写一个load_data的函数,在里面导入数据,做一些数据预处理,这一部分就显得很烦索。对于深度学习来说,还得考虑batch的读取、GPU的使用、数据增强、数据乱序读取等等,所以需要有一个模块来集中解决这些事情,所以就有了data_loader的机制
本篇文章主要解决以下三个问题:
-
如何最快地加载torch官方已有的数据集并生成一个dataloader ?
-
如何加载本地的数据集并生成一个dataloader?
-
如何添加一些你想要的预处理的操作?
1.直接加载torch官方的数据集
分三步:
- 生成实例化对象
- 生成dataloader
- 从dataloader里读数据
以CIFAR-10数据集为例,演示如何读取这个数据集
import torchvision
import torch
def load_data():
# 从 torch自带的CIFAR10类 中 生成一个实例化对象trainset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
'''
参数:
trainset:上一步生成的数据集对象
batch_size:决定一次取多少份的数据
shuffle: True则乱序取数据,false则固定顺序取数据
'''
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
return trainloader
def train():
trainloader=load_data()
for epoch in range(max_epoch):
for i, data in enumerate(trainloader, 0):
inputs, labels = data
#这样就能快乐地对data进行操作了
这里具体说说DataLoader里的shuffle这个参数
前面说到shuffle为True则乱序取数据,false则固定顺序读取,且每次读取的起点都相同
我把shuffle设为false,观察每个epoch的前两个batch的图片,发现每个epoch都一样;但是为True的话就不一样
那么它们有什么区别呢?
- shuffle设为true的优点:随机梯度下降,一个batch一个batch地学习,假设前面的几个batch影响更大,那么处于数据集前端的数据先验的有更强的影响,随机地打乱顺序就是避免这种数据偏向性
- shuffle为false 的优点:在特殊的场景下,假设你要统计所有epoch每个样本的loss变化, 那么你希望每个epoch样本的顺序是固定的,这样比较好编程序
2. 制作自己的数据类
我们首先来分析一下第一部分内容的内部机制:
- 首先通过torchvision.datasets.CIFAR10得到数据,里面含有图片序列imgs,和类别标签序列labels。
- 在使用 “for i, data in enumerate(trainloader, 0)”时,其实在做的事情是根据索引获取某一个batch的数据,而这与一个函数相关,就是_getitem_, 返回imgs[i] , 和 labels[i].所以上面程序中的data就含有img和label
但是,你会发现它的使用范围非常有限,当你想使用torch没有的数据集或者你想添加一些其他的功能时,你就要学会编写自己的数据类。
制作自己的数据类分为四个步骤:
- 继承 torch.utils.data.Dataset,含有一些基本方法和属性
- 编写__init__方法:添加一些功能,比如数据切片、数据增强,或者给标签加噪声等等。
- 重载__getitem__方法:根据Index返回数据的相应内容,下面代码会给详细解析
- 重载__len__方法:返回长度
代码示例:
import torch.utils.data.Dataset as DataSet
#写一个自己的数据类
def MyDataSet(DataSet): #继承DataSet类
def __init__(self,参数...):
#以读取本地数据+数据切片为例
data=np.loadtxt("traindata0.csv",delimiter=',',dtype=np.float32)
self.len=data.shape[0]
self.x_data=torch.from_numpy(data[:,0:-1])
self.y_data=torch.from_numpy(data[:,[-1]])
def __getitem__(self,index):
return self.x_data[index],self.y_data[index]
#注意:想返回几个参数自己定,比如我可以return data, self.x_data[index],self.y_data[index]供三个参数,如果你含有颜色,曲率、对应关系等数据,都可以在这里返回
def __len__(self):
return self.len
#接下来同上一段代码
#生成实例化对象
#生成dataloader
#从dataloader里读数据
其实,关于__getitem__方法,可以在里面进行数据预处理操作,这时要用到transform模块。
下面给出一个示例
import torchvision.transforms as transforms
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()]) #将多个操作联合起来,关于这些操作的解释,可以参考https://blog.csdn.net/u013925378/article/details/103363232
#向MyDataSet传入参数transform
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,transform,
shuffle=True, num_workers=2)
def MyDataSet(DataSet):
self.transform=transform
......
def __getitem__(self,index):
img = self.transform(img)
......