torchvision的使用

torchvision

torchvision是一个和torch配合使用的Python包。提供了几个常用数据集,几种已经搭建好的经典网络模型,以及一些图像数据处理方面的工具(主要是预处理阶段用的)。

可以通过pip或者conda下载。


pip install torchvision

其官方文档地址是https://pytorch.org/docs/stable/torchvision/index.html

数据篇

数据

torchvision.datasets提供了几个常用的经典图像数据集。虽然这么说,但是这个模块中本身不包括数据集的文件。它的工作方式是先从网络上把数据集下载到指定目录,然后再用它的加载器把数据集加载到内存中。然后把这个加载后的数据集作为对象返回给用户。

这里我们主要以MNIST为例简单说明一下(其他几个数据集类似,只不过调用的时候把“MNIST”换成其他名字)。

torchvision.datasets.MNIST是一个类。我们直接调用它的构造函数就能返回一个MNIST数据集对象。有一个必填参数root和几个选填参数。

root的值是一个字符串。用于指定数据集保存的路径。注意只需要制定到文件夹一级就可以了,不需要指定具体的数据集文件如何命名。

download表示是否下载数据。默认是False。如果指定为True,那么在指定的路径中没有数据文件时,就会自动启动下载进程,从网上下载数据。如果指定的路径中已经存在指定的数据集文件。那么就不会下载,而是直接使用已经存在的数据集。因此就算这项一直被设定为True也不会在每次启动的时候都把数据集重新下载一次。

为了节约时间,你也可以把下载好的文件拷贝一份。以后创建新项目的时候直接把这些数据文件复制到指定路径。这样就不需要每创建一个新项目就下载一次。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-B1EBkdSW-1662021310081)(torchvision.assets/1542962295195.png)]

图1:下载数据集

为了让加载的数据集具有统一的格式,下载好的数据集往往还要进行分隔或者格式重整。这个过程会在指定的路径下产生一些临时文件。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-CIzJ1ZyL-1662021310083)(torchvision.assets/1542962719582.png)]

图2:下载好的数据集和一些临时文件

你可能注意到图1中我还使用了train参数。这和MNIST数据集的组织方法有关系。这个数据集在提供给使用者的时候就已经分成了训练集和测试集两部分。如果trainTrue,那么就加载训练集,否则加载测试集。其他的数据集不一定在发布时就做了这样的分割,所以加载其他数据集时train不一定是个有效参数。具体要参考官方文档。

介绍了这些参数,我们自然而然就能得到一个最简单的加载数据集的语句(这里加载的是MNIST的测试集):


testset = torchvision.datasets.MNIST(root='./data', train=False, download=True)

那么我们得到的是一个什么东西呢?用type查看一下testset的类型,可以得到torchvision.datasets.mnist.MNIST。而这个类以及其他datasets中的类都是torch.utils.data.Dataset的子类。

要想查看其中的内容,我们不妨把testset转换成列表:


L1=list(testset)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ctFk6AUO-1662021310084)(torchvision.assets/1542963672580.png)]

图3:转换成列表的数据集对象

从图3中可以看出。转换后数据集对象变成了一个元组列表。每个元组有两个成员,第一个成员是图像,第二个成员是标签。

MNIST数据集是手写数字数据集。包含一些裁剪好的单个手写数字的图片,和这些图片对应的数字标签。

这里图像成员是以PIL.Image.Image的形式存在的。这种对象在Jupyter中是可以直接显示出来的。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tohM5i6s-1662021310085)(torchvision.assets/1542963912094.png)]

图4:显示一条数据

除了几个能自动下载的相对较小的数据集。torchvision中还提供了一个类torchvision.datasets.DatasetFolder用于加载其他数据集。只要数据集按照它约定的方式保存,它就能像加载MNIST那样加载它们。具体用法见官方手册。

关于图像加载

torchvision默认使用的图像加载器是PIL。也可以通过torchvision.set_image_backend设置为其他加载器。

torchvision.get_image_backend可以获取当前使用的加载器名称。

PIL是一个Python的图像处理库。

特别注意,加载器是在加载数据集的时候torchvision自己调用的。不需要用户写一些额外的代码来调用。

图像转换工具

之前我们直接加载出来的数据集显然还不能用于神经网络的训练。因为图像都是PIL中的对象,而不是torch的Tensor。这时候我们就需要torchvision.transforms中的转换器。

另外,我们加载图像后,往往要做一些简单的变换。比如裁切边框,调整图像比例和大小,调整图片亮度,灰度重整,正则化等等。这些操作也可以用torchvision.transforms中的工具完成。

这些工具有很多使用方式,这里我仅仅提一下函数式的用法。

transforms.ToTensor是一个函数生成器。它生成的函数能把PIL图像转换成torch的Tensor。

transforms.Normalize也是一个函数生成器。它接受meanstd两个参数。这里会按照这两个参数来正则化数据。(按照正则化的公式,这里的mean应当是数据平均值,std应当是标准差。正则化结果就是每一个数据点减去平均值的差再除以标准差。也就是 x − m e a n s t d \frac{x-mean} {std} stdxmean。至于为什么要这么正则化是数学上讨论的,今天就不讨论这么多了。)

注意,这里meanstd都是元组。因为要对应于图像的多个通道。例如我要把RGB图像按每个通道均值均为 0.5 0.5 0.5,偏差均为 0.2 0.2 0.2的形式整理。就要用transforms.Normalize((0.5,0.5,0.5),(0.2,0.2,0.2))。当然也可能出现三个通道均值或者标准差不一样的情况。比如三个通道的均值依次是(0.2,0.5,0.1)。这时候第一个通道按均值 0.2 0.2 0.2算,第二个通道按 0.5 0.5 0.5算,第三个按 0.1 0.1 0.1算。

而对于灰度图像,由于只有一个通道所以用transforms.Normalize((0.5,),(0.2,))就可以了。注意0.5,中那个逗号不能省略。如果省略了,传入参数就变成了浮点数。而要求必须是元组(哪怕元组中只有一个成员)。

不过经过我的测试,你拿transforms.Normalize((0.5,0.5,0.5),(0.2,0.2,0.2))去处理MNIST也是没有问题的(MNIST是灰度图)。只不过,转换的时候只会用meanstd的第一个成员。

另外,按道理来说正则化时用的meanstd应当是根据数据集直接求平均值和标准差算出来。我为了让后边的例程看起来简单点,直接以常数方式规定了meanstd

这里简单列举了两类转换器。官网手册上能查到更多。如果官网上列举的现成的转换器不够用,你还可以自己定制。比如通过torchvision.transforms.Lambda来订制自己的转换器。这个函数的作用类似于线性表处理中的map函数。

现在我们把转换用的函数都准备好了。但是要完成一张图的转换需要依次调用所有的转换函数非常麻烦。接下来我们就用transforms.Compose把它们打包成一个函数,方便使用。transforms.Compose接受一个函数的列表作为参数,把列表中的函数按顺序打包成一个函数然后返回这个函数。所以我们需要写:


myTransform = transforms.Compose(

    [transforms.ToTensor(),

     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

这样调用myTransform就相当于顺序调用了transforms.ToTensor()transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

实际上我们也不需要亲自调用myTransform,直接把它当作参数,在加载数据的时候传递进去就好了(对应的形参名应该是transform)。所以其实我们定义好转换器后这样加载数据,就能在加载的同时做好转换工作:


testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=myTransform)

实际上加载数据这件事是一个惰性的程序。只要你不真正的读数据,数据就不会被加载,也不会被转换。假如你写错了一个转换函数(比如转换器和数据的类型不匹配)。在调用MNIST返回一个数据集对象的时候不一定会报错。而当你把这个数据集转化成List或者输入神经网络求结果的时候就会报错。

图像预览

在用Jupyter做神经网络的实验的时候。我们经常想看看一个图像被处理成什么样子了。直接用一个代码单元返回一个PIL图像对象的做法是一种办法,但是当要显示的图像比较多的时候还是挺麻烦的。而且这种方法离开Jupyter就用不了。

一般来说我们使用matplotlib中的绘图函数来显示图像。但是matplotlib的绘图函数一次只接收一个numpy矩阵。如果要排版的话会比较麻烦。我们可以调用torchvision.utils中的工具来快速解决图像预览的问题。

这里我们假设testset是刚刚那个带转换器的MNIST构造函数返回的对象。为了后边方便使用,我们先把它转换成列表:


L1 = list(testset)

print(L1[0])

print(L1[0][0].size())

L1中每个成员都是一个元组。元组中第一个成员是一个torch的3阶张量,表示一张图像。第二个成员是一个torch的0阶张量(其实就是一个标量,但是torch中所有变量都是张量),表示一个数字标签。

之所以要用3阶张量表示图像是因为图像除了高和宽,还有一个色彩纬度。RGB彩图中用三个通道表示三原色。有的格式中会使用更多的通道,比如加一个通道表示黄色,或者加一个α通道表示透明度。灰度图色彩维度上只有一个取值——也就是只有一个灰度通道。

其实灰度图完全可以压缩成2阶张量来存储。但是为了方便和其他类型的图进行转换,所有灰度图也用3阶张量(哪怕在色彩维度上只有一个可选的索引值)。

我们选择其中的几个数据,然后把图像拿出来(不要标签):


temp = list(map(lambda x: x[0],L1[:5])) # 现在的temp是一个成员为3阶张量的List。

imgs = torch.stack(temp,0) # 之后的操作需要用张量,而不能用List。所以这里整理为一个4阶张量。

然后做一些额外的工作,比如解正则化。之前我们正则化了图片。但是正则化的结果不一定适合输出显示。所以我们可以把除掉的东西乘回来,把减去的东西加回来。之前我们正则化的时候是 x − 0.5 0.5 \frac{x-0.5}{0.5} 0.5x0.5。现在就要反过来算:


imgs=imgs*0.5+0.5

一会要用matplotlib进行显示。而matplotlib希望灰度值(或者色彩值)在[0,1]区间内上。我们这里解正则化就是为了把灰度值还原到这个区间上。不然显示效果会有问题。

使用torchvision.utils.make_grid把一个图片组(4阶张量)转化为一张联结起来的大图片(3阶张量)。(注意,这个函数在转化的时候,即使输入灰度图也会自动输出三通道彩图。但是由于三个通道的值一致,所以显示效果还是黑白的。)


bigImg = torchvision.utils.make_grid(imgs)

print(bigImg.size())

接下来我们就调用matplotlib显示这张大图。


import matplotlib.pyplot as plt

import numpy as np

plt.imshow(np.transpose(bigImg.numpy(), (1, 2, 0)))

这里需要说明几点:

  • matplotlib接收numpy矩阵但是不接受torch张量,所以要先转换成numpy矩阵。

  • 我们的bigImg中第一个维度就是颜色通道。后两个维度是高和宽(“颜色通道,高,宽”)。相当于认为一张彩图是三张灰度图(分别对应三原色强度)合成的。而matplotlib中认为一张图片是由RGB三元组组成的矩阵,所以维度排序是“高,宽,颜色通道”。因此我们调用了np.transpose来调整维度的顺序。

函数torchvision.utils.save_imagetorchvision.utils.make_grid作用类似。只不过它不是让图片用来显示,而是把打包好的大图直接保存。

模型篇

torchvision中除了常用数据集。还有一些经典模型的类。直接实例化出来就可以进行训练或者使用。

首先来说说不带训练参数的。如果我们直接把一个模型的类实例化,就会得到一个网络模型。这个网络模型和你自己定义class搭建,然后实例化的网络模型是类似的。只不过它们的结构是按照经典模型的设计弄出来的。你自己按照经典模型的论文写一个类实例化一下,和从torchvision中直接实例化一个效果是一样的。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-aGZa5bWT-1662021310085)(torchvision.assets/1542976763951.png)]

图5:实例化一个VGG16网络

有了这个网络之后,接下来按照torch的一般用法训练就行了。

如果在实例化网络的时候把参数pretrained设为True。那么这个网络加载好之后就是经过预训练的。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-S3EUBNsa-1662021310086)(torchvision.assets/1542977095634.png)]

图6:下载中的网络参数

如果之前没有加载过带预训练参数的网络。需要在加载网络前下载(和下载数据集一样是自动下载的)。下载一次后不需要反复下载(只不过网络参数默认保存在torch的文件路径中,而不是由用户指定路径)。

你也可以自己从网站上下载好,然后复制到指定路径中。

那么带预训练参数的网络有什么用呢?可以用来做“fine-tune”。也就是微调网络。简单来说就是先有人在一个比较普遍,宽泛的数据集上进行了大量训练得出了一套参数。然后你把这套参数(连带网络模型)抄过来,再在你要处理的问题的数据集上训练。经过预训练的网络参数拿来训练,要比完全随机化一些参数然后训练更好。好在哪里呢?训练更块,网络更容易收敛,小一点的训练集也能取得比较好效果。

那么为啥用了预训练参数就能有这些好处呢?因为我们相信两个处理类似问题的网络中,网络的参数也具有某种相似性。因此我们相信,把一个已经训练得很好的网络中的参数移植到另一个网络上,就能起到八九不离十的作用。即使两个网络的工作不完全一样,在这套预训练参数的基础上,再经过微调(fine-tune)性质的训练(而不需要大动干戈地从头训练),也能得到不错的效果。更详细的解释参看这篇知乎答案:

迁移学习与fine-tuning有什么区别? - 蒋竺波的回答 - 知乎

预训练网络也是训练过的。训练就要数据集。而torchvision中自动下载的数据集都是在ImageNet数据集上训练的。

预训练参数的加载方式就是这样。加载了带预训练参数的网络之后,还要根据你的具体问题调整输出层输出数量,或者调整输入图像的大小。但是这些操作就不再这里详述了,按照一般的修改网络,修改图像大小的流程走就行了。 作者:知则 https://www.bilibili.com/read/cv1723691/ 出处:bilibili

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值