pytorch:数据读取操作

工欲善其事,必先利其器。事实上,一遍深度学习从头到脚都需要数据的支持。因此,数据集的读取是第一步。而在Pytorch中,官方给我们封装好了一个提取训练集、测试集的一个虚类。所谓虚类,即需要我们继承。在下面,我们将着重介绍如何通过继承Dataset虚类来完成对数据的读取。

Dataset虚类

训练集有了,苦于不知如何将其转化为代码?先问自己:这些数据集哪里来的?

通过torchvision官方自带的dataset

如果是想通过torchvision的接口(以CIFAR10为例):

import torchvision

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=None)
  1. root:表示cifar10数据的加载的相对目录
  2. train:true时加载数据库的训练集,false时加载测试集
  3. download:表示是否自动下载cifar数据集
  4. transform:表示是否需要对数据进行预处理,none为不进行预处理

由于美帝路途遥远,靠命令台进程下载100多M的数据速度很慢,所以我们可以自己去到cifar10的官网上把CIFAR-10 python version下载下来,然后解压为cifar-10-batches-py文件夹,并复制到相对目录./data下。(若设置download=True,则程序会自动从网上下载cifar10数据到相对目录./data下, 这样小伙伴们可能要等一个世纪了),并对训练集进行加载(train=True)[1]。

利用ImageFolder

还有一个比较讨巧的方法,就是使用torchvision.datasets.ImageFolder 。这个ImageFolder可以自动地读取自己路径之下的数据集,它是一个通用的数据加载器,它要求我们以下面这种格式来组织数据集的训练、验证或者测试图片:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

对于上面的root,假设data文件夹在.py文件的同级目录中,那么root一般都是如下这种形式:./data/train./data/valid[2]。
在这里插入图片描述

参数如下:

dataset=torchvision.datasets.ImageFolder(
                       root, transform=None, 
                       target_transform=None, 
                       loader=<function default_loader>, 
                       is_valid_file=None)

参数列表:

  • root:图片存储的根目录,即各类别文件夹所在目录的上一级目录。
  • transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。
  • target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
  • loader:表示数据集加载方式,通常默认加载方式即可。
  • is_valid_file:获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)

返回的dataset都有以下三种属性:

  • self.classes:用一个 list 保存类别名称
  • self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
  • self.imgs:保存(img-path, class) tuple的 list

自己写一个dataset

这样子下来,代码中的trainset 已经是一个Dataset虚类的实例了,已经无须再实现。但如若数据是在自己本地的,便需要自己写一个接口。重写一个MyDataset类:

class DealDataset(Dataset):

		def __init__(self):
        # 你的操作
				# 比如用文件流、读取数据集等。需要给你的数据集一个属性
				
    
    def __getitem__(self, index):
				# 

    def __len__(self):
        # 
				return self.len

这里需要说明Dataset的__getitem__原理:更高层的方法通过__getitem__ 魔法方法来读取一个元素的。因此,其参数需要有一个index,好索引其对应的元素。并且,__getitem__ 的返回可以是标签、也可以是特征、抑或是数据本身。

Dataloader

在使用 PyTorch 训练模型的过程中,需要将原始的数据转换为张量格式,为了方便的进行数据的批量处理,PyTorch 定义了一系列工具来对这个过程进行包装。PyTorch 的数据载入一般使用的是 torch.utils.data.DataLoader 类[3]。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

主要参数

  • dataset

就是 torch.utils.data.Dataset 类的实例。也就是说为了使用 DataLoader 类,需要先定义一个 torch.utils.data.Dataset 类的实例(即上面我们所定义的)。

  • batch_size

每一个批次需要加载的训练样本个数。

  • shuffle

如果设置为 True 表示训练样本数据会被随机打乱,默认值为 False。一般会设置为 True 。

  • sampler

自定义从数据集中取样本的策略,如果指定这个参数,那么 shuffle 必须为 False 。如果指定了该参数,同时 shuffle 设定为 True, DataLoader 的 __init__ 函数就会抛出一个异常 。

  • batch_sampler

与 sampler 类似,但是一次只返回一个 batch 的 indices(索引),需要注意的是,一旦指定了这个参数,那么 batch_size, shuffle, sampler, drop_last 就不能再指定了。源码中同样做了限制。

  • num_workers

表示会使用多少个线程来加载训练数据;默认值为 0,表示数据加载直接在主线程中进行。

  • collate_fn

对每一个 batch 的数据做一些你想要的操作,比如 padding 之类的。

  • pin_memory

把数据转移到和 GPU 相关联的 CPU 内存,加速 GPU 载入数据的速度。

  • drop_last

这个是对最后的少于 batch_size 的数据来说的。比如你的batch_size设置为 32,而一个 epoch 只有 100 个样本;如果设置为 True,那么训练的时候后面的 4 个就被扔掉了。如果为 False(默认),那么会继续正常执行,只是最后的 batch_size 会小一点。

  • timeout

加载一个 batch 数据的超时时间。

  • worker_init_fn

指定每个数据加载线程的入口函数。

Dataloader 作为迭代器,最基本的使用就是传入一个 Dataset 对象,通过 Dataset 类里面的 __getitem__ 函数获取单个的数据,根据参数 batch_size 的值组合成一个 batch 的数据;然后传递给 collate_fn 所指定的函数对这个 batch 做一些操作,比如 padding 之类的。

使用

如果是需要输出一份数据看看,则使用enumerate 来遍历。

for i, data in enumerate(dataLoader):

参考文献

[1] https://www.jianshu.com/p/8da9b24b2fb6

[2] https://blog.csdn.net/qq_39507748/article/details/105394808

[3] https://zhuanlan.zhihu.com/p/339675188

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
PyTorch提供了许多方便的工具和类来进行数据读取和预处理。下面是一个基本的数据读取和预处理流程的示例: 1. 导入必要的库: ```python import torch from torchvision import transforms from torch.utils.data import DataLoader ``` 2. 定义数据集类: ```python class CustomDataset(torch.utils.data.Dataset): def __init__(self, data, targets, transform=None): self.data = data self.targets = targets self.transform = transform def __getitem__(self, index): x = self.data[index] y = self.targets[index] if self.transform: x = self.transform(x) return x, y def __len__(self): return len(self.data) ``` 在上面的代码中,`CustomDataset` 是一个自定义的数据集类,其中 `data` 是输入数据,`targets` 是对应的标签。`transform` 是一个可选的数据预处理函数。 3. 数据预处理: 可以使用 `torchvision.transforms` 中的函数来对数据进行常见的预处理操作,例如缩放、裁剪、标准化等。下面是一个示例: ```python transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) ``` 在上面的代码中,我们使用 `transforms.Compose` 将多个预处理操作连接在一起。示例中使用了 `ToTensor` 将数据转换为张量,然后使用 `Normalize` 进行标准化。 4. 创建数据集实例: 使用定义的数据集类和预处理操作,创建数据集实例: ```python dataset = CustomDataset(data, targets, transform=transform) ``` 其中 `data` 和 `targets` 是输入数据和标签,`transform` 是之前定义的数据预处理操作。 5. 创建数据加载器: 使用 `torch.utils.data.DataLoader` 创建数据加载器,可以指定批次大小、是否打乱数据等参数。 ```python dataloader = DataLoader(dataset, batch_size=32, shuffle=True) ``` 6. 迭代数据集: 现在可以使用数据加载器来迭代数据集,并进行模型训练或评估。 ```python for batch_data, batch_targets in dataloader: # 在这里执行模型训练或评估的操作 pass ``` 以上是一个基本的PyTorch数据读取和预处理的示例。你可以根据实际需求进行调整和扩展。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值