pytorch数据预处理——4. Pytorch DataLoader类

本专题主要是解决Pytorch框架下项目的数据预处理工作
Table of Contents:
     1. HDF5文件简介
     2. Python中的_, __, __xx__区别
     3. Dataset类
     4. DataLoader类

DataLoader类是 torch.utils.data 库下的一个类,这是不用用户自定义的,直接调用即可,现在来简略谈谈这个类的功能,与一些关键代码。

先看看如何调用这个类的:

    dataset_train_h5 = H5Dataset("./data_train", mode='train')
    # dataset_h5 is what ?
    trainloader = utilsdata.DataLoader(dataset_h5, batch_size=5, shuffle=True)
    # trainloader is what ?
    dataset_val_h5 = H5Dataset("./data_val", mode='val')
    valloader = utilsdata.DataLoader(dataset_val_h5, batch_size=1, shuffle=False)

1. DataLoader类功能

由前面几节内容可知,直接调用 dataset[i] 不就可以返回训练样本了,为什么还要使用 DataLoader类呢?因为 dataset[i] 功能比较单一,或者说功能有限,而 DataLoader类可以将 dataset 装饰成迭代器,并且可以返回一个 batch 的数据。因此就可以用 enumerate 得到一个 batch 的数据data,由(images, labels)组成。


迭代器是访问集合元素的一种方式。迭代器对象从集合的第一个元素开始访问,知道所有的元素被访问完结束。
迭代器有两个基本方法:
__next__方法:返回迭代器的下一个(批次)元素
__iter__方法:返回迭代器对象本身(用来将一个可迭代对象转换为迭代器,“迭代器”指的是 iter 所返回的一个支持 next() 的对象)


2. DataLoader类关键代码

2.1 如何实现返回一个 batch 的数据呢?

在这里插入图片描述
前面说过,DataLoader类可以将 dataset 装饰成迭代器,然后再用 next() 或者 for 遍历数据。
过程:

  1. container = iter(list)
  2. container.next() # for

上面函数可以看出,next() 或 for 遍历数据时就是调用私有函数 def __next__():的。

其实,这些细节你不想深究也行,只要知道接口处传一个实参 batch_size 就返回一个迭代器,并且这个迭代器在遍历的时候一次会返回一个 batch_size 的样本数据

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
PyTorch中,数据预处理通常涉及以下几个步骤: 1. 加载数据集:使用PyTorch数据加载器(如`torchvision.datasets`)加载数据集。可以是常见的图像数据集(如MNIST、CIFAR10)或自定义数据集。 2. 转换数据:使用`torchvision.transforms`模块中的转换函数对数据进行预处理。常见的转换包括缩放、裁剪、旋转、归一化等。可以根据需求组合多个转换操作。 3. 创建数据加载器:将转换后的数据集传递给`torch.utils.data.DataLoader`来创建一个数据加载器。数据加载器可以指定批处理大小、并发加载等参数。 下面是一个简单的示例,演示如何使用PyTorch进行数据预处理: ```python import torch import torchvision import torchvision.transforms as transforms # 1. 加载数据集 train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True) # 2. 转换数据 transform = transforms.Compose([ transforms.ToTensor(), # 将图像转换为Tensor transforms.Normalize((0.5,), (0.5,)) # 归一化到[-1, 1]范围 ]) train_dataset = train_dataset.transform(transform) # 3. 创建数据加载器 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) ``` 在这个示例中,我们加载了MNIST数据集,并将图像转换为Tensor,并进行了归一化处理。然后使用`DataLoader`创建了一个批处理大小为64的数据加载器,同时打乱了数据的顺序。 这只是一个简单的例子,根据具体需求,你可能需要进行更复杂的数据预处理操作。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值