PyTorch中的 Dataset、DataLoader 和 enumerate()

PyTorch:关于Dataset,DataLoader 和 enumerate()

本博文主要参考了 Pytorch中DataLoader的使用方法详解pytorch:关于enumerate,Dataset和Dataloader 两篇文章进行总结和归纳。

DataLoader 隶属 PyTorch 中 torch.utils.data 下的一个类,任何继承 torch.utils.data.Data 类的子类均需要重载__getitem__()及__len__()两个函数,且子类在__init__()函数产生的数据路径,将作为 DataLoader 参数 DataSets 的实参。该类将自定义的 Dataset 根据 batch size 大小、是否 shuffle 等封装成一个 Batch Size 大小的 Tensor,用于后面的训练。

Dataset 类构建

在构建数据集类时,除了__init__(self),还要有__len__(self)与__getitem__(self,item)两个方法,这三个是必不可少的,至于其它用于数据处理的函数,可以任意定义。这里的 Dateset 可以指整个数据集,也可以是训练集,测试集等。

class Dataset:
    def __init__(self,...):
        ...
    def __len__(self,...):
        return n
    def __getitem__(self,item):
        return data[item]

正常情况下,该数据集是要继承 Pytorch 中 Dataset 类的,但实际操作中,即使不继承,数据集类构建后仍可以用 Dataloader() 加载的。

在dataset类中,len(self)返回数据集中数据的总个数,getitem(self,item)表示每次返回第 item 条(个)数据。
①__init__:传入数据,或者像下面一样直接在函数里加载数据
②__len__:返回这个数据集一共有多少个 item
③__getitem__:返回一条(个)训练样本的数据,并将其转换成 tensor

在 dataset 实例化时一般要传入数据集的路径,一般在__init__() 函数中指定数据集路径等相关信息(可以通过相关路径读取包含图像名称、标签等相关信息的 json 或者 csv 等类型的文件);通过__getitem__(self,item) 得到对应的图像并将进行 transform 转换(缩放、裁剪、转换成 tensor 等操作),最终以 tensor 的形式返回。

DataLoader 使用

在构建 Dataset 类后,即可使用 DataLoader 加载。DataLoader 中常用参数如下:

  1. dataset:需要载入的数据集,如前面构造的 dataset 类。
  2. batch_size:批大小,在神经网络训练时我们很少逐条数据训练,而是几条数据作为一个 batch 进行训练。
  3. shuffle:是否在打乱数据集样本顺序。True 为打乱,False 反之。
  4. num_workers:这个参数决定了有几个进程来处理 data loading。0 意味着所有的数据都会被 load 进主进程。(默认num_workers=0,在 Windows 系统下需要设置为 0
  5. drop_last:是否舍去最后一个batch的数据(很多情况下数据总数 N 与 batch size 不整除,导致最后一个 batch 不为 batch size)。True 为舍去,False 反之。

注意:使用 DataLoader 读取数据时,为了加快效率,所以使用了多个线程,即 num_workers 不为0,在 windows 系统下报如下的错误。
RuntimeError: Couldn’t open shared file mapping: <torch_16716_3565374679>, error code: <1455>

DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 

参照 DataLoader windows平台下 多线程读数据报错 | BrokenPipeError: [Errno 32] Broken pipe | freeze_support() 教程中提到,在 https://github.com/pytorch/pytorch/pull/5585 中给出了一些官方解释,应该是 Windows下的一些线程文件读写的问题。
在 Windows 上,FileMapping 对象应必须在所有相关进程都关闭后,才能释放。启用多线程处理时,子进程将创建 FileMapping,然后主进程将打开它。 之后当子进程将尝试释放它的时候,因为父进程还在引用,所以它的引用计数不为零,无法释放。 但是当前代码没有提供在可能的情况下再次关闭它的机会。这个版本官方说 num_workers=1 是可以用的,更多的线程还在解决,不过现在即便是用 2 个子进程也已经可以了。

加载数据的过程

pytorch 中加载数据的顺序是:

  1. 创建一个 dataset 对象
  2. 创建一个 dataloader 对象
  3. 循环 dataloader 对象,将 data, label 拿到模型中去训练

enumerate() 函数

在对 Dataloader 进行读取时,通常使用 enumerate() 函数,enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标。调用 enumerate(dataloader) 时每次都会读出一个 batch_size 大小的数据。例如,数据集中总共包含 245 张图像,train_loader = dataloader(dataset, batch_size=32, drop_last=True) 被实例化时,经过以下代码后输出的 count 为 224(正好等于32*7),而多出来的 245-224=21 张图像不够一个 batch 因此被 drop 掉了。下面展示了如何从 dataloader 中通过 enumerate() 返回一个batch_size的数据。

for k, images, target in enumerate(dataloader):

其中,k代表下标值,images, target 代表可遍历的数据对象。因为 enumerate(dataloader) 一次会返回一个 batch 的数据,所以返回的 images 为 batch_size 长度的list,target 也为 batch_size 长度的 list。

通常,dataloader 里包含很多个数据对象,那么我们应该怎么保证 batch 就是我们所需要的数据呢?通过 Dataset 的定义可以实现我们需要的数据。Dataset 是用来定义数据从哪里读取,以及如何读取的问题,通过重写 Dataset 抽象类的__getitem__()函数。enumerate(dataloader) 得到的数据就是 __getitem__() 函数返回的数据,只不过 enumerate(dataloader) 一次会得到 batch_size 个不同 item 的数据组成的 list。

def __getitem__(self, item):
	images = self.data[item]
	target = self.label[item]
	return images, target

返回 item 对应的数据,就是 enumerate(dataloader) 得到的数据的一部分。

def __len__(self):
	return len(self.data)

返回 dataset 中总的数据个数,用于控制返回多少个 batch 的数据,enumerate(dataloader) 一次会返回 batch_size 大小的 list。

Reference

Pytorch中DataLoader的使用方法详解
pytorch:关于enumerate,Dataset和Dataloader
DataLoader windows平台下 多线程读数据报错 | BrokenPipeError: [Errno 32] Broken pipe | freeze_support()

好的,我来回答你的问题。 首先,我们需要定义一个数据集类,包括数据读取和预处理。这里我们使用MNIST手写数字数据集作为示例。 ```python import torch from torch.utils.data import Dataset from torchvision import datasets, transforms class MNISTDataset(Dataset): def __init__(self, train=True): self.data = datasets.MNIST('./data', train=train, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])) def __getitem__(self, index): x, y = self.data[index] return x, y def __len__(self): return len(self.data) ``` 接下来,我们需要定义一个数据加载器,用于批量加载数据。 ```python from torch.utils.data import DataLoader batch_size = 16 train_dataset = MNISTDataset(train=True) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) ``` 然后,我们可以搭建一个全连接神经网络,用于对手写数字图片进行分类。 ```python import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(784, 256) self.fc2 = nn.Linear(256, 128) self.fc3 = nn.Linear(128, 10) def forward(self, x): x = x.view(-1, 784) x = nn.functional.relu(self.fc1(x)) x = nn.functional.relu(self.fc2(x)) x = self.fc3(x) return x ``` 最后,我们可以定义损失函数和优化器,并开始训练模型。 ```python net = Net() criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9) for epoch in range(10): running_loss = 0.0 for i, data in enumerate(train_loader, 0): inputs, labels = data optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 2000 == 1999: print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0 ``` 这样,我们就成功地实现了使用PytorchDataloader搭建全连接神经网络进行分类任务的过程,其batch_size=16。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值