PyTorch|Transforms运行机制

01 | Transforms运行机制

机器学习本质上属于数据驱动的智能学习,因而数据的分布与属性对于模型训练质量至关重要。一般Anaconda包中会提供诸如数据标准化、归一化等多种数据预处理方式,以便于在模型训练时取得更好的训练效果。

无独有偶,PyTorch中也提供了计算机视觉工具包“torchvision”,其中常用的有三个模块:

  • torchvision.transforms:提供常用图像处理方法,如缩放、裁剪、翻转等
  • torchvision.datasets:提供常用数据集的dataset实现,如MINST、CIFAR-10、ImageNet等数据集的实现工具
  • torchvision.model:提供常用模型预训练,如AlexNet、VGG、ResNet、GoogleNet等

对于“torchvision.transforms”而言,其提供了常用的图像处理方法,典型如:数据中心化、数据标准化、缩放、裁剪、翻转、跳转、填充、噪声添加、灰度变换、线性变换、仿射变换以及亮度、饱和度和对比度变换等。

数据变换(Transform)的一大作用在于数据增强以提升泛化能力。如同我们曾经做过的“5年高考3年模拟”,历届高考题类似于标准训练集,而模拟题则是在标准高考题上的演化和改变,若高考时遇到类似题目,则可以较好地的作答,从而实现了“泛化”能力。

从学习观的角度,模拟题的意义在于改变原始高考题的考查形式,从而对比凸显出不变的考点和解题方法。

下侧图片为同一张图片各种变换后的结果,虽然表面样式多变,但是不变的是图像主体,从而可以帮助AI模型更好地去除不必要变量,发现不变量因子。


从代码上看,“torchvision.transforms”的使用方法大致如下:其中借助“compose”方法将诸多图形处理方法按顺序集合到一起,从而需要时提取图片按顺序执行排列的预处理操作。下述代码中依次进行了图片缩放、裁剪、转换为张量以及标准化处理。

train_transform = transforms.compose([
  	transforms.Resize((32, 32)), # 图片缩放
  	transforms.RandomCrop(32, padding=4), # 图片随机裁剪
  	transforms.ToTensor(), # 图片转变为张量
  	transforms.Normalize(norm_mean, norm_std) # 图片标准化
])

从运行机制上看,transforms模块调用在getitem之后,即当按要求挑选出一张图片后,便按顺序调用compose中的处理工序。整体机制如下图:

在这里插入图片描述

02 | 数据标准化之transforms.normalize

PyTorch提供了“transforms.normalize”方法对图片进行逐通道标准化,其方法即为:

o u t p u t = i n p u t − m e a n s t d output = \frac{input - mean}{std} output=stdinputmean

其中主要参数为:

  • mean:各通道的均值
  • std:各通道的标准差
  • inplace:是否原位操作,默认为False

为何要对数据进行标准化呢?

因为许多模型初始假设均值为0,因而在实际训练时,若数据分布也在(mean=0)附近,则模型可以相对更快迭代到数据分界线,表现为更快的收敛训练速度;反之,则可能需要更长时间迭代。

以之前实现的逻辑回归为例,默认情况下,初始模型假设与数据标准化后,其均值都为0,表现为初始模型与数据分布均在(0,0)附近。因此大约在迭代400次时,即可达到99.50%的准确率。

标准化后迭代非标准化迭代起始状态
非标准化迭代终止状态

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值