Torchvision:对数据进行操作

本文介绍了Torchvision在PyTorch中的作用,包括数据集的读取和数据处理。通过Dataset类和DataLoader类,阐述了如何自定义数据集并进行批量加载。同时提到了Torchvision提供的数据集接口和图像处理工具,用于数据预处理和增强,如ToTensor()等。
摘要由CSDN通过智能技术生成

Torchvision:数据读取,训练开始的第一步

如果将模型看作一辆汽车,那么它的开发过程就可以看作是一套完整的生产流程,环环相扣、缺一不可。这些环节包括数据的读取、网络的设计、优化方法与损失函数的选择以及一些辅助的工具等。未来你将尝试构建自己的豪华汽车,或者站在巨人的肩膀上对前人的作品进行优化。

试想一下,如果你对这些基础环节所使用的方法都不清楚,你还能很好地进行下去吗?所以通过这个模块,我们的目标是先把基础打好。通过这模块的学习,对于 PyTorch 都为我们提供了哪些丰富的 API,你就会了然于胸了。

Torchvision 是一个和 PyTorch 配合使用的 Python 包,包含很多图像处理的工具。我们先从数据处理入手,开始 PyTorch 的学习的第一步。我会先介绍 Torchvision 的常用数据集及其读取方法,在后面的文章里,我再带你了解常用的图像处理方法与Torchvision 其它有趣的功能。

PyTorch 中的数据读取

训练开始的第一步,首先就是数据读取。PyTorch 为我们提供了一种十分方便的数据读取机制,即使用 Dataset 类与 DataLoader 类的组合,来得到数据迭代器。在训练或预测时,数据迭代器能够输出每一批次所需的数据,并且对数据进行相应的预处理与数据增强操作。下面我们分别来看下 Dataset 类与 DataLoader 类。

Dataset 类

PyTorch 中的 Dataset 类是一个抽象类,它可以用来表示数据集。我们通过继承 Dataset类来自定义数据集的格式、大小和其它属性,后面就可以供 DataLoader 类直接使用。

其实这就表示,无论使用自定义的数据集,还是官方为我们封装好的数据集,其本质都是继承了 Dataset 类。而在继承 Dataset 类时,至少需要重写以下几个方法:

__init__():构造函数,可自定义数据读取方法以及进行数据预处理;
__len__():返回数据集大小;
__getitem__():索引数据集中的某一个数据。

光看原理不容易理解,下面我们来编写一个简单的例子,看下如何使用 Dataset 类定义一个Tensor 类型的数据集。

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    # 构造函数
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor

    # 返回数据集大小
    def __len__(self):
        return self.data_tensor.size(0)

    # 返回索引的数据与标签
    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]

结合代码可以看到,我们定义了一个名字为 MyDataset 的数据集,在构造函数中,传入Tensor 类型的数据与标签;在 __len__ 函数中,直接返回 Tensor 的大小;在__getitem__ 函数中返回索引的数据与标签。

下面,我们来看一下如何调用刚才定义的数据集。首先随机生成一个 10*3 维的数据Tensor,然后生成 10 维的标签 Tensor,与数据 Tensor 相对应。利用这两个 Tensor,生成一个 MyDataset 的对象。查看数据集的大小可以直接用 len() 函数,索引调用数据可以直接使用下标。

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    # 构造函数
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor

    # 返回数据集大小
    def __len__(self):
        return self.data_tensor.size(0)

    # 返回索引的数据与标签
    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]

# 生成数据
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,)) # 标签是0或1

# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)

# 查看数据集大小
print('Dataset size:', len(my_dataset))
'''
输出:
Dataset size: 10
'''
# 使用索引调用数据
print('tensor_data[0]: ', my_dataset[0])

'''
输出:
tensor_data[0]:
(tensor([ 0.4931, -0.0697,
0.4171]), tensor(0))
'''

DataLoader 类

在实际项目中,如果数据量很大,考虑到内存有限、I/O 速度等问题,在训练过程中不可能一次性的将所有数据全部加载到内存中,也不能只用一个进程去加载,所以就需要多进程、迭代加载,而 DataLoader 就是基于这些需要被设计出来的。

DataLoader 是一个迭代器,最基本的使用方法就是传入一个 Dataset 对象,它会根据参数 batch_size 的值生成一个 batch 的数据,节省内存的同时,它还可以实现多进程、数据打乱等处理。

DataLoader 类的调用方式如下:

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    # 构造函数
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor

    # 返回数据集大小
    def __len__(self):
        return self.data_tensor.size(0)

    # 返回索引的数据与标签
    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]

# 生成数据
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,)) # 标签是0或1

# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)

# 查看数据集大小
print('Dataset size:', len(my_dataset))
'''
输出:
Dataset size: 10
'''
# 使用索引调用数据
print('tensor_data[0]: ', my_dataset[0])

'''
输出:
tensor_data[0]:
(tensor([ 0.4931, -0.0697,
0.4171]), tensor(0))
'''

from torch.utils.data import DataLoader
tensor_dataloader = DataLoader(dataset=my_dataset, # 传入的数据集, 必须参数
                               batch_size=2, # 输出的batch大小
                               shuffle=True, # 数据是否打乱
                               num_workers=0) # 进程数, 0表示只有主进程

# 以循环形式输出
for data, target in tensor_dataloader:
    print(data, target)

# 输出一个batch
print('One batch tensor data: ', iter(tensor_dataloader).next())

结合代码,我们梳理一下 DataLoader 中的几个参数,它们分别表示:

dataset:Dataset 类型,输入的数据集,必须参数;

batch_size:int 类型,每个 batch 有多少个样本;

shuffle:bool 类型,在每个 epoch 开始的时候,是否对数据进行重新打乱;

num_workers:int 类型,加载数据的进程数,0 意味着所有的数据都会被加载进主进
程,默认为 0。

什么是 Torchvision

PyTroch 官方为我们提供了一些常用的图片数据集,如果你需要读取这些数据集,那么无需自己实现,只需要利用 Torchvision 就可以搞定。

Torchvision 是一个和 PyTorch 配合使用的 Python 包。它不只提供了一些常用数据集,还提供了几个已经搭建好的经典网络模型,以及集成了一些图像数据处理方面的工具,主要供数据预处理阶段使用。简单地说,Torchvision 库就是常用数据集 + 常见网络模型 +常用图像处理方法。

Torchvision 的安装方式同样非常简单,可以使用 conda 安装,命令如下:

conda install torchvision -c pytorch

或使用 pip 进行安装,命令如下:

pip install torchvision

Torchvision 中默认使用的图像加载器是 PIL,因此为了确保 Torchvision 正常运行,我们还需要安装一个 Python 的第三方图像处理库——Pillow 库。Pillow 提供了广泛的文件格式支持,强大的图像处理能力,主要包括图像储存、图像显示、格式转换以及基本的图像处理操作等。

使用 conda 安装 Pillow 的命令如下:

conda install pillow

使用 pip 安装 Pillow 的命令如下:

pip install pillow

安装好 Torchvision 之后,我们再来接着看看。Torchvision 库为我们读取数据提供了哪些支持。

Torchvision 库中的torchvision.datasets包中提供了丰富的图像数据集的接口。常用的图像数据集,例如 MNIST、COCO 等,这个模块都为我们做了相应的封装。

下表中列出了torchvision.datasets包所有支持的数据集。各个数据集的说明与接口,详见链接:Datasets — Torchvision 0.15 documentation

 

这里我想提醒你注意,torchvision.datasets这个包本身并不包含数据集的文件本身,它的工作方式是先从网络上把数据集下载到用户指定目录,然后再用它的加载器把数据集加载到内存中。最后,把这个加载后的数据集作为对象返回给用户。

Torchvision:数据增强,让数据更加多样性

上面,我们一同迈出了训练开始的第一步——数据读取,初步认识了 Torchvision,学习了如何利用 Torchvision 读取数据。不过仅仅将数据集中的图片读取出来是不够的,在训练的过程中,神经网络模型接收的数据类型是 Tensor,而不是 PIL 对象,因此我们还需要对数据进行预处理操作,比如图像格式的转换。

与此同时,加载后的图像数据可能还需要进行一系列图像变换与增强操作,例如裁切边框、调整图像比例和大小、标准化等,以便模型能够更好地学习到数据的特征。这些操作都可以使用torchvision.transforms工具完成。

今天我们就来学习一下,利用 Torchvision 如何进行数据预处理操作,如何进行图像变换与增强。

图像处理工具之 torchvision.transforms

Torchvision 库中的torchvision.transforms包中提供了常用的图像操作,包括对Tensor 及 PIL Image 对象的操作,例如随机切割、旋转、数据类型转换等等。


按照torchvision.transforms 的功能,大致分为以下几类:数据类型转换、对PIL.Image 和 Tensor 进行变化和变换的组合。下面我们依次来学习这些类别中的操作。

数据类型转换

上面,我们学习了读取数据集中的图片,读取到的数据是 PIL.Image 的对象。而在模型训练阶段,需要传入 Tensor 类型的数据,神经网络才能进行运算。


那么如何将 PIL.Image 或 Numpy.ndarray 格式的数据转化为 Tensor 格式呢?这需要用到transforms.ToTensor() 类。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

repinkply

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值