PyTorch学习笔记(17)--torchvision.transforms用法介绍

PyTorch学习笔记(17)–torchvision.transforms用法介绍

    本博文是PyTorch的学习笔记,第17次内容记录,主要记录了torchvision.transforms的使用方法。

1.问题来源

    在读ResNet的应用代码时,遇到下面这一小段代码,这段代码出现在读取图片信息之前,这段代码的具体功能是什么呢?对于初学者来说很有必要弄清楚这段代码的具体含义

    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

2.torchvision.transforms具体用法

    PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:torchvision.datasets、torchvision.models、torchvision.transforms。而上面这段代码就用到了torchvision.transforms这个包。

    这里用到的 torchvision 工具库是 pytorch 框架下常用的图像处理包,可以用来生成图片和视频数据集(torchvision.datasets),做一些图像预处理(torchvision.transforms),导入预训练模型(torchvision.models),以及生成图和保存图像(torchvision.utils)。
    其中,transforms函数对图像做预处理可以是:归一化(normalize)尺寸剪裁(resize)翻转(flip) 等。
    上面的这些步骤实际操作起来往往是一系列的,此时可以用compose将这些图像预处理操作连起来。
    如上面的代码,这里做的操作是:
    transforms.ToTensor() ,将一个PIL图像转换为tensor。即, ( H ∗ W ∗ C ) (H\ast W\ast C) (HWC)范围在[0,255]的PIL图像 转换为 ( C ∗ H ∗ W ) (C\ast H\ast W) (CHW)范围在[0,1]的torch.tensor。
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ,用均值[0.485, 0.456, 0.406]和标准差[0.229, 0.224, 0.225]对图像做归一化处理。

3.torchvision.transforms其他的用法

    transforms函数另外的功能还包括:

    Resize:把给定的图片resize到给定的尺寸。

    ToPILImage: 将torch.tensor 转换为PIL图像。

    CenterCrop:以输入图的中心点为中心做指定size的裁剪操作。

    RandomCrop:以输入图的随机位置为中心做指定size的裁剪操作。

    RandomHorizontalFlip:以0.5概率水平翻转给定的PIL图像。

    RandomVerticalFlip:以0.5概率竖直翻转给定的PIL图像。

    RandomResizedCrop:将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为制定的大小(有一个参数n)。

    Grayscale:将给定图像转换为灰度图像。

    RandomGrayscale:将图像以指定的概率转换为灰度图像。

    FiveCrop: 从一张输入图像中裁剪出5张指定size的图像,包括4个角的图像和一个中心。

    TenCrop:剪出10张指定size的图像。做法是在FiveCrop的基础上,再将输入图像进行水平或竖直翻转,然后进行FiveCrop操作,这样一张图像可得到10张crop图像。

    Pad:对给定图像的所有边用的“padding”个像素用“fill”值填充。

    ColorJitter:修改图像的亮度,对比度,饱和度和色度。

    Lambda:做其参数指定的变换。

    上述四个包及其具体函数的详细介绍参考Pytorch的中文文档

    代码实现可以参考github的代码实现

4.补充torchvision模块的其他功能

    torchvision 是独立于 PyTorch 的关于图像操作的一个工具库,目前包括六个模块:

    1)torchvision.datasets:几个常用视觉数据集,可以下载和加载,以及如何编写自己的 Dataset。

     2)torchvision.models:经典模型,例如 AlexNet、VGG、ResNet 等,以及训练好的参数。

     3)torchvision.transforms:常用的图像操作,例随机切割、旋转、数据类型转换、tensor 与 numpy 和 PIL Image 的互换等。

     4)torchvision.ops:提供 CV 中常用的一些操作,比如 NMS、ROI_Align、ROI_Pool 等。

     5)torchvision.io:提供输入输出的一些操作,目前针对的是视频的写入写出。

     6)torchvision.utils:其他工具,比如产生一个图像网格等。

5.运行错误解决

    问题1:数据集为彩色图像,通道数为3,但是模型中输入通道数为1,也就是接收灰色图像,这时,训练模型时会报错,具体错误为:

RuntimeError: Given groups=1, weight of size 32 3 3 3, expected input[1, 4, 416, 416] to have 3 channels

    解决输入通道数的问题,也就是要将3通道的彩色图像修改成1通道的灰色图像,这时的修改方式为:

修改前:
train_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=True,
                                          transform=torchvision.transforms.torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=False,
                                         transform=torchvision.transforms.ToTensor(),
                                         download=True)
修改后:
train_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=True,
                                          transform=torchvision.transforms.Compose([
                                              torchvision.transforms.Grayscale(),
                                              torchvision.transforms.ToTensor()]),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=False,
                                         transform=torchvision.transforms.Compose([
                                             torchvision.transforms.Grayscale(),
                                             torchvision.transforms.ToTensor()]),
                                         download=True)

    也就是增加一个torchvision.transforms.Grayscale()的操作。
    问题1:

  • 3
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
torchvision.transforms.v2是一个Python库,它提供了一系列的数据预处理操作,可以用于对图像数据进行处理和转换。其中一些常见的预处理操作包括: 1. transforms.CenterCrop(size):将给定的图像进行中心切割,得到给定的size大小的图像。size可以是一个tuple,表示目标图像的高度和宽度;也可以是一个整数,表示切出来的图像是正方形。 2. transforms.RandomCrop(size, padding=0):随机选取图像的中心点位置进行切割。size可以是一个tuple,也可以是一个整数。 3. transforms.RandomHorizontalFlip():随机水平翻转给定的图像,概率为0.5,即有50%的概率进行翻转。 4. transforms.RandomSizedCrop(size, interpolation=2):先随机切割图像(尺寸不定),然后再将切割后的图像resize成给定的size大小。 5. transforms.Pad(padding, fill=0):将给定的图像的所有边用给定的填充值进行填充。padding表示要填充多少像素,fill表示用什么值进行填充。 6. transforms.Normalize(mean, std):使用给定的均值和标准差进行归一化操作。归一化公式为:channel = (channel - mean) / std。其中mean和std分别表示每个通道的均值和标准差。 以上是torchvision.transforms.v2库中的一些常见的预处理操作,可以根据需要选择合适的操作来对图像数据进行处理。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [torchvision.transforms](https://blog.csdn.net/qq_33254870/article/details/103364028)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *2* [pytorch-nyuv2:PyTorch NYUv2数据集类](https://download.csdn.net/download/weixin_42100188/18378138)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值