Pytorch学习记录(3) Dataloader的使用
这篇博客记录了本人Pytorch的过程,主要目的是为了今后的学习以及复习,仅用于个人学习,不为其他目的。
本次学习主要研究了Dataloader的使用。
Dataloader与Datasets的区别:
torch.utils.data.Dataset是代表了这一数据的抽象类,可以通过继承与重写这个抽象类来实现自己的数据类,只需要定义__len__和__getitem__这两个函数即可。
例如如下代码展示示范:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
class MyDataset(Dataset):
""" my dataset."""
# Initialize your data, download, etc.
def __init__(self):
# 读取csv文件中的数据
xy = np.loadtxt('data-diabetes.csv', delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
# 除去最后一列为数据位,存在x_data中
self.x_data = torch.from_numpy(xy[:, 0:-1])
# 最后一列为标签为,存在y_data中
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
# 根据索引返回数据和对应的标签
return self.x_data[index], self.y_data[index]
def __len__(self):
# 返回文件数据的数目
return self.len
通过覆写__getitem__与__len__两个函数,使得MyDataset类继承了Dataset类,并且可以根据需要获得对应的类参数。
Dataloader是Pytorch中用来处理模型输入数据的一个工具类,组合了数据集(dataset)+采样器(sampler),并在数据集上提供单线程或多线程(num_workers)的可迭代对象。
函数原型如下所示:
torch.utils.data.DataLoader(dataset, batch_size=1,
shuffle=False, sampler=None,
batch_sampler=None, num_workers=0,
collate_fn=None, pin_memory=False,
drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None)
其中常用的几个参数介绍如下:
1)dataset:决定数据从哪里读取
2)batch_size:批大小,决定每批传入多少个样本到模型中,因此batch_size决定了每个epoch中有多少个iteration,即:迭代次数(iteration次数)= 样本总数(epoch传入总数)/ batch_size。
3)shuffle:每一个epoch中是否为乱序
4)num_workers:读取数据时,是否采用多进程进行读取
5)drop_last:当样本数不能被batchsize整除时,最后一批数据是否舍弃
6)pin_memory:如果为True会将数据放置到GPU上去
更多关于Dataset与Dataloader的介绍具体可以参考该博文:https://blog.csdn.net/He3he3he/article/details/105441083
Dataloader的使用:
2.1 创建Dataset和Dataloader:
首先引入torchvision,并读取本地的CIFAR10文件,并创建对应的datasets对象,并转换成Tensor类型:
![](https://img-blog.csdnimg.cn/img_convert/e1ade112212e61c9ef5a5b19421e7bd5.png)
![](https://img-blog.csdnimg.cn/img_convert/28a071b7bc660ff78931b76bee505165.png)
接下来引入Dataloader类,通过torch.utils.data引入,并同样创建一个Dataloader对象,其中将batch_size设置为6,shuffle设置为True,drop_last设置为False:
![](https://img-blog.csdnimg.cn/img_convert/b2f2170321ceb223fb214fd98125ade3.png)
![](https://img-blog.csdnimg.cn/img_convert/f86a5fbf5152543c132fae10de3ae168.png)
可以首先查看Dataset的每一行的格式,由于转化为了tensor类型,因此可以观察到其shape和target为:
![](https://img-blog.csdnimg.cn/img_convert/5c2fd4a58f2905776e93652b9d5c8b77.png)
接下来查看Dataloader的格式,由于batch_size为6,因此所有的图像被分为了每6个一组的形式,通过for循环可以查看每个batch的shape和target:
![](https://img-blog.csdnimg.cn/img_convert/a873d7f4fcc8ef6b4940126ea8c17289.png)
其中最后一个batch由于凑不齐6个图像,并且在Dataloader的参数中drop_last设置为False,该组batch只有4个图像。
接下来通过Tensorboard对于每一组batch的图像进行查看,为了便于观察,将每一组batch大小设置为64,剩余参数不做修改:
![](https://img-blog.csdnimg.cn/img_convert/5b88a641cb557b564438276b6039148c.png)
打开Tensorboard,可以查看除最后一组batch外,每一组batch中都含有64个图像:
![](https://img-blog.csdnimg.cn/img_convert/26973c65792f29456128d64b22052a73.png)
最后一组batch仅有16张图像:
![](https://img-blog.csdnimg.cn/img_convert/dfb86e52aa6e8a5a1d4f55de6bb45b9f.png)
2.2 drop_last参数的使用:
接下来对于drop_last参数进行调整,将其调整至True,即舍弃最后不足以构成64张图的剩下所有图:
![](https://img-blog.csdnimg.cn/img_convert/cca247a50da21fd7f16a3e22f1054bb4.png)
通过Tensorboard进行查看:
![](https://img-blog.csdnimg.cn/img_convert/c902ca29fc1568ae2c982beeda728b00.png)
可以清晰地观察到,相较于上述的test_data,test_data_droplast少了一组,并且其最后一组的图像数量为64,而并非16,因此可以确定修改了参数之后,最后16个图像被舍弃了。
2.3 shuffle参数的使用:
接下来对shuffle参数进行修改,将其调整为False,因此每次在使用该Dataloader进行读取数据时,每个epoch并不会对所有的图像进行打乱,因此每个epoch得到的结果理论上应该是相同的:
![](https://img-blog.csdnimg.cn/img_convert/ff8780e1248185a7dfbc38a15a73afd1.png)
通过Tensorboard进行查看:
![](https://img-blog.csdnimg.cn/img_convert/ab1d80ba7048994993565f5ae796317e.png)
可以观察到epoch0与epoch1对应的每一个batch的图像都是相同的,也印证了我们上述的结果。
3. 总结
在本文中总结了DataLoader的使用方法,并通过读取CIFAR10中的数据,借助Tensorboard的展示各种参数的功能,能为后续神经网络的训练奠定基础,同时也能更好的理解pytorch。