torch中的DataLoader主要是用来将给定数据集中的样本打包成一个一个batch的,那么它具体是怎么工作的呢?对于给定的数据集又有什么要求呢?
1.流程讲解,实例1
from torch.utils.data import DataLoader
class show_how_dataloader_work():
def __init__(self,x):
self.x = x
def __len__(self): #必须要有!
return len(self.x)
def __getitem__(self,index): #必须要有!
print('index是{},即dataloader取出了第{}个元素'.format(index,index+1))
return self.x[index]
a = show_how_dataloader_work(['wyb','xz','zql','wx','hjy'])
a_batch = DataLoader(a,batch_size=2,shuffle=True)
#a就是相当于给定的数据集。
list(a_batch)
index是3,即dataloader取出了第4个元素
index是1,即dataloader取出了第2个元素
index是0,即dataloader取出了第1个元素
index是4,即dataloader取出了第5个元素
index是2,即dataloader取出了第3个元素
[['wx', 'xz'], ['wyb', 'hjy'], ['zql']]
具体来说,dataloader是如何工作的呢?
- 首先得到length,length是__len(self)__的返回值,即一共有多少个样本
- 如果shuffle=True,那么就随机从range(length)中取出batch_size个数,依次作为index,放到__getitem__(self,index)中,然后得到batch_size个__getitem__(self,index)的返回值,并打包在一起
- 如果shuffle=False,那么就依次从range(length)中取出batch_size个数,作为index,放到__getitem__(self,index)中,然后得到batch_size个__getitem__(self,index)的返回值,并打包在一起
因此对于给定数据集的最基本要求就是要有__len__(self)函数和__getitem__(self,index)函数
带到上述例子中具体分析:
我们的self.x是['wyb','xz','zql','wx','hjy'],__len__(self)的返回值是len(self.x),依次length=5
shuffle=True,batch_size是2
第一次,取两个,index分别是3,1.那么调用__getitem__(self,index)函数两次,得到返回值self.x[3]即‘wx’和self.x[1]即‘xz’,并将它们打包在一起
第二次,取两个,index分别是0,4.那么调用__getitem__(self,index)函数两次,得到返回值self.x[0]即‘wyb’和self.x[4]即‘hjy’,并将它们打包在一起
PS:**打包方式** 下文继续详解
最后一次,只剩一个了,index是2,那么调用__getitem__(self,index)函数一次,得到返回值self.x[2]即‘zql’
至此,length个返回值已经全部取完
2.流程讲解,实例2
再来举一个例子:
class

本文详细介绍了PyTorch的DataLoader的工作流程,包括如何根据`shuffle`参数进行样本选取,以及在处理numpy数组、tensor及非numpy/tensor数据时的打包方式。通过实例分析了`__len__`和`__getitem__`函数的重要性,强调在处理非numpy/tensor数据时需谨慎使用DataLoader。
最低0.47元/天 解锁文章
394

被折叠的 条评论
为什么被折叠?



