DataLoader官方文档: https://pytorch.org/docs/stable/data.html
DataLoader与DataSet是干嘛的
DataLoader是Pytorch用来加载数据的一个类,其实就是一个迭代器,而迭代的数据从哪来?就需要用到DataSet了。
DataSet就是用来封装数据的类,主要用来对数据进行相关的自定义操作(比如图片的裁剪、标签的定义等),通过__getitem__
函数返回所需要的数据。
DataSet类介绍
一般来说,需要重新定义一个新的类来继承DataSet类,然后再通过DataLoader来加载器数据。
继承DataSet类一般需要重写其中的__getitem__
函数,该函数用于返回第index个数据。其中常常也会重写__len__
函数,用于返回整个数据集的大小。
举个栗子:
class MyData(Dataset):
def __init__(self,imag_path):
self.imag_path = imag_path
self.imag_path_list = os.listdir(imag_path)
def __getitem__(self, item):
imag_name = self.imag_path_list[item]
imag_item_path = os.path.join(self.imag_path,imag_name)
imag = Image.open(imag_item_path)
label = imag_name
return imag,label # 返回的第item项的图片以及对应的标签
def __len__(self):
return len(self.imag_path_list)
DataLoader类介绍
DataLoader一般通过torch.utils.data.DataLoader直接调用即可。DataLoader就是对DataSet中的数据进行迭代,通过__getiem__
函数来获取DataSet对应数据集中的第item项数据,然后组合成batch,给程序进行训练。
DataLoader不需要继承,直接拿来用就行,DataLoader类如下:
其参数介绍如下:
这里主要对常用的几个参数进行介绍。
- dataset: 所需要加载数据的数据集
- batch_size:batch的大小。默认为1,None代表禁用批处理
- shuffle:是否随机抽取样本。一般用于训练数据集
- num_workers:为整数,代表多线程加载数据。默认为单线程加载数据。
- drop_last:代表是否删除最后一个不完整batch。
举个栗子:
# 加载数据集
train_dataset = DataLoader(train_folder)
# 初始化DataLoader
train_batch = torch.utils.data.DataLoader(train_dataset, batch_size = 5,
shuffle=True, num_workers=4, drop_last=True)
# 使用DataLoader
for k,(img,label) in enumerate(train_batch):
print(k,img,label)
注意
- 在windows中使用多线程加载数据时,需要加上以下代码:
if __name__ == '__main__':
2020-1024=?,ಥ_ಥ