![59d5023cb9725101cc30868437c27469.png](https://i-blog.csdnimg.cn/blog_migrate/13972191beffa86a50a00547bf890586.jpeg)
![dadeaacbc1a19d9b6519ae855e76455a.png](https://i-blog.csdnimg.cn/blog_migrate/8870c6a109c1982b9d65a0d0726f3fc2.jpeg)
在开始学习之前推荐大家可以多在FlyAI竞赛服务平台多参加训练和竞赛,以此来提升自己的能力。FlyAI是为AI开发者提供数据竞赛并支持GPU离线训练的一站式服务平台。每周免费提供项目开源算法样例,支持算法能力变现以及快速的迭代算法模型。
目录
- 1 Dataset基类
- 2 构建Dataset子类
- 2.1 Init
- 2.2 getitem
- 3 dataloader
1 Dataset基类
PyTorch 读取其他的数据,主要是通过 Dataset 类,所以先简单了解一下 Dataset 类。在看很多PyTorch的代码的时候,也会经常看到dataset这个东西的存在。Dataset类作为所有的 datasets 的基类存在,所有的 datasets 都需要继承它。
先看一下源码:
![24aa77fd8cf1e3f0abfb64ef5582be93.png](https://i-blog.csdnimg.cn/blog_migrate/73047a2a1d266194837c2ef68e315083.jpeg)
这里有一个__getitem__
函数,__getitem__
函数接收一个index,然后返回图片数据和标签,这个index通常是指一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。之后会举例子来讲解这个逻辑。
其实说着了些都没用,因为在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,这是触发去读取图片这些操作的是DataLoader里的__iter__(self)
(后面再讲)。
2 构建Dataset子类
下面我们构建一下Dataset的子类,叫他MyDataset类:
import torch
from torch.utils.data import Dataset,DataLoader
class MyDataset(Dataset):
def __init__(self):
self.data = torch.tensor([[1,2,3],[2,3,4],[3,4,5],[4,5,6]])
self.label = torch.LongTensor([1,1,0,0])
def __getitem__(self,index):
return self.data[index],self.label[index]
def __len__(self):
return len(self.data)
2.1 Init
- 初始化中,一般是把数据直接保存在这个类的属性中。像是
self.data,self.label
2.2 getitem
- index是一个索引,这个索引的取值范围是要根据
__len__
这个返回值确定的,在上面的例子中,__len__
的返回值是4,所以这个index会在0,1,2,3这个范围内。
3 dataloader
从上文中,我们知道了MyDataset这个类中的__getitem__
的返回值,应该是某一个样本的数据和标签(如果是测试集的dataset,那么就只返回数据),在梯度下降的过程中,一般是需要将多个数据组成batch,这个需要我们自己来组合吗?不需要的,所以PyTorch中存在DataLoader这个迭代器(这个名词用的准不准确有待考究)。
继续上面的代码,我们接着写代码:
mydataloader = DataLoader(dataset=mydataset,
batch_size=1)
我们现在创建了一个DataLoader的实例,并且把之前实例化的mydataset作为参数输入进去,并且还输入了batch_size这个参数,现在我们使用的batch_size是1.下面来用for循环来遍历这个dataloader:
for i,(data,label) in enumerate(mydataloader):
print(data,label)
输出结果是:
![744777825eaaea63041e7744c377cdec.png](https://i-blog.csdnimg.cn/blog_migrate/5908f7acae67f770e2a58914ec7523bc.png)
意料之中的结果,总共输出了4个batch,每个batch都是只有1个样本(数据+标签),值得注意的是,这个输出过程是顺序的。
我们稍微修改一下上面的DataLoader的参数:
mydataloader = DataLoader(dataset=mydataset,
batch_size=2,
shuffle=True)
for i,(data,label) in enumerate(mydataloader):
print(data,label)
结果是:
![952d199d49b8942036d7fdc8675ddf6a.png](https://i-blog.csdnimg.cn/blog_migrate/a77a57c4cd44ae1c5d50db3e8ed6ea21.png)
可以看到每一个batch内出现了2个样本。假如我们再运行一遍上面的代码,得到:
![45ff969a67c7ca92d0efe9550cd0d1d0.png](https://i-blog.csdnimg.cn/blog_migrate/6fd91bd9e1016064e6c508ee9c663da6.png)
两次结果不同,这是因为shuffle=True
,dataset中的index不再是按照顺序从0到3了,而是乱序,可能是[0,1,2,3],也可能是[2,3,1,0]。
【个人感想】
Dataloader和Dataset两个类是非常方便的,因为这个可以快速的做出来batch数据,修改batch_size和乱序都非常地方便。有下面两个希望注意的地方:
- 一般标签值应该是Long整数的,所以标签的tensor可以用
torch.LongTensor(数据)
或者用.long()
来转化成Long整数的形式。 - 如果要使用PyTorch的GPU训练的话,一般是先判断cuda是否可用,然后把数据标签都用
to()
放到GPU显存上进行GPU加速。
device = 'cuda' if torch.cuda.is_available() else 'cpu'
for i,(data,label) in enumerate(mydataloader):
data = data.to(device)
label = label.to(device)
print(data,label)
看一下输出:
![38b6d4ddf51e2b02be77922bd13dcf22.png](https://i-blog.csdnimg.cn/blog_migrate/e71e2f6402007e075d3afd0e179e483a.png)
![2d94c1455e83021d9eb19cf6841d045e.png](https://i-blog.csdnimg.cn/blog_migrate/246eb7134b11c0aeea0b06bed7a92e1c.jpeg)
更多关于“Pytorch”的竞赛项目,大家可移步官网进行查看和参赛!
更多精彩内容请访问FlyAI-AI竞赛服务平台;为AI开发者提供数据竞赛并支持GPU离线训练的一站式服务平台;每周免费提供项目开源算法样例,支持算法能力变现以及快速的迭代算法模型。
挑战者,都在FlyAI!!!