【Pytorch基础知识】DataLoader与DataSet基本使用方法

DataLoader官方文档: https://pytorch.org/docs/stable/data.html

DataLoader与DataSet是干嘛的

DataLoader是Pytorch用来加载数据的一个类,其实就是一个迭代器,而迭代的数据从哪来?就需要用到DataSet了。
DataSet就是用来封装数据的类,主要用来对数据进行相关的自定义操作(比如图片的裁剪、标签的定义等),通过__getitem__函数返回所需要的数据。

DataSet类介绍

一般来说,需要重新定义一个新的类来继承DataSet类,然后再通过DataLoader来加载器数据。
继承DataSet类一般需要重写其中的__getitem__函数,该函数用于返回第index个数据。其中常常也会重写__len__函数,用于返回整个数据集的大小。
举个栗子:

class MyData(Dataset):
    def __init__(self,imag_path):
        self.imag_path = imag_path
        self.imag_path_list = os.listdir(imag_path)

    def __getitem__(self, item):
        imag_name = self.imag_path_list[item]
        imag_item_path = os.path.join(self.imag_path,imag_name)
        imag = Image.open(imag_item_path)
        label = imag_name
        return imag,label   # 返回的第item项的图片以及对应的标签

    def __len__(self):
        return len(self.imag_path_list)

DataLoader类介绍

DataLoader一般通过torch.utils.data.DataLoader直接调用即可。DataLoader就是对DataSet中的数据进行迭代,通过__getiem__函数来获取DataSet对应数据集中的第item项数据,然后组合成batch,给程序进行训练。

DataLoader不需要继承,直接拿来用就行,DataLoader类如下:
在这里插入图片描述
其参数介绍如下:
在这里插入图片描述
这里主要对常用的几个参数进行介绍。

  • dataset: 所需要加载数据的数据集
  • batch_size:batch的大小。默认为1,None代表禁用批处理
  • shuffle:是否随机抽取样本。一般用于训练数据集
  • num_workers:为整数,代表多线程加载数据。默认为单线程加载数据。
  • drop_last:代表是否删除最后一个不完整batch。

举个栗子:

# 加载数据集
train_dataset = DataLoader(train_folder)

# 初始化DataLoader
train_batch = torch.utils.data.DataLoader(train_dataset, batch_size = 5,
                                  shuffle=True, num_workers=4, drop_last=True)

# 使用DataLoader
for k,(img,label) in enumerate(train_batch):
    print(k,img,label)

注意

  • 在windows中使用多线程加载数据时,需要加上以下代码:
    if __name__ == '__main__':

2020-1024=?,ಥ_ಥ

  • 4
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值