pytorch基础语法学习:数据预处理transforms模块

来源:投稿 作者:阿克西
编辑:学姐

建议搭配视频食用

视频链接:https://ai.deepshare.net/detail/p_5df0ad9a09d37_qYqVmt85/6

系列其他文章传送门:

pytorch基础语法学习:数据读取机制Dataloader与Dataset

pytorch基础语法(一)

pytorch基础语法(二)

1.transforms运行机制

torchvision是pytorch的计算机视觉工具包,主要有以下三个模块:

  • torchvision.transforms:提供了常用的一系列图像预处理方法,例如数据的标准化,中心化,旋转,翻转等。

  • torchvision.datasets:定义了一系列常用的公开数据集的datasets,比如MNIST,CIFAR-10,ImageNet等。

  • torchvision.model:提供了常用的预训练模型,例如AlexNet,VGG,ResNet,GoogLeNet等。

torchvision.transforms:常用的图像预处理方法

  • 数据中心化,数据标准化

  • 缩放,裁剪,旋转,翻转,填充

  • 噪声添加,灰度变换,线性变换,仿射变换

  • 亮度、饱和度及对比度变换

深度学习是由数据驱动的,数据的数量以及分布对模型的优劣起到决定性作用,所以需要对数据进行一定的预处理以及数据增强,用来提升模型的泛化能力。

上图是1张原始图片经过数据增强之后生成的一系列数据,一共有64张图片。对图片进行数据增强可以丰富训练数据,提高模型的泛化能力。因为如果数据增强生成了与测试样本很相似的图片,那么模型的泛化能力自然可以得到提高。

使用上一节中介绍的人民币二分类实验的代码的数据预处理部分:

2.断点调试

# ============================ step 1/5 数据 ============================
# 这部分设置数据的路径
split_dir = os.path.join("C:/Users/10530/Desktop/pytorch/rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

# 设置数据标准化的均值和标准差
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

# transforms.Compose的功能是将一系列的transforms方法进行有序的组合包装,
# 在具体实现的时候,会依次按顺序对图像进行操作
train_transform = transforms.Compose([
    transforms.Resize((32, 32)),  # 将图像缩放到32*32的大小
    transforms.RandomCrop(32, padding=4),  # 对数据进行随机的裁剪
    # 将图片转成张量的形式同时会进行归一化操作,把像素值的区间从0-255归一化到0-1
    transforms.ToTensor(),
    # 标准化操作,将数据的均值变为0,标准差变为1
    transforms.Normalize(norm_mean, norm_std),
])

# 验证集的预处理的方法,对比训练集,少了RandomCrop这一部分,
# 因为在验证集中是不需要对数据进行数据增强的
valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

同样,在模型训练样本读取位置设置断点,进行debug:

点击step into按键,在跳转后的代码中进行一个是否采用多进程的判断:

点击step over,选择单进程的运行机制,再点击step into按键,进入dataloader.py界面:

光标设置在index = self._next_index() # may raise StopIteration这一行,点击Run to Cursor,程序就会运行到光标所在的行。这一步的作用是获取Index,也就是要读取哪些数据。点击step over,得到Index就可以进入dataset_fetcher.fetch(index),根据索引去获取数据。点击step into进入到fetch函数:

在fetch函数中,代码data = [self.dataset[idx] for idx in possibly_batched_index]使用了列表生成式,调用了dataset,接着点击step over与step into进入dataset所在的代码位置,dataset代码位于类RMBDataset(Dataset)中的__getitem__()函数:

在getitem()中根据索引去获取图片的路径以及标签,然后采用代码img = Image.open(path_img).convert('RGB') # 0~255打开图片,读取进来的图片是一个PIL的数据类型,然后在getitem中调用transform()进行图像预处理操作,在代码处img = self.transform(img)通过step into进入transforms.py中的def 「call」()函数

「call」()函数是一个for循环,也就是依次有序地从compose中去调用预处理方法,第一个预处理方法是t(img),其功能是是Resize缩放;第二个功能是裁剪,第三个功能是进行张量操作,第四个功能是进行归一化;对compose的四个功能循环结束之后,就会返回代码处img = self.transform(img)。

transform是在__getitem__()中调用,并且在__getitem__()中实现数据预处理,然后通过__getitem__返回一个样本。

执行step out操作返回fetch()函数,接着就是不断地循环index获取一个batch_size大小的数据,最后在return的时候调用collate_fn()函数,将数据整理成一个batch_data的形式。

然后执行step out操作返回到dataloader.py中的__next__()函数中,然后再执行执行step out操作回到训练代码中,接着数据就读取进来了。这就是pytorch数据读取和transforms的运行机制。

回顾上面的数据读取流程图,transforms是在getitem中使用的,在getitem中读取一张图片,然后对这一张图片进行一系列预处理,返回图片以及标签。

了解了transforms的机制,现在学习一个比较常用的预处理方法,数据的标准化transforms.Normalize。

3.数据标准化transforms.normalize

3.1 定义

功能:逐channel的对图像进行标准化,即数据的均值变为0,标准差变为1。

计算公式:output =\frac{(input - mean)}{std}

  • mean:各通道的均值

  • std:各通道的标准差

  • inplace:是否原位操作

transform.Normalize(mean,
                    std,
                    inplace=False)

3.2 断点调试

回到代码中看一下normalize的具体实现方法,transform是在dataset的getitem中实现的,所以可以直接去dataset的getitem函数中设置断点:

进行debug操作,点击step into进入详细代码环境,进入了transforms.py中的call()函数中,在call函数中循环transforms。

点击step over执行多次,到normalize实现

接着点击step into查看normalize的实现,来到了normalize()类中的__call__()函数中,代码只有一行,实际上这行代码是调用了pytorch中的function中normalize方法。pytorch的function提供了很多常用的函数。

接着使用step into查看normalize中的具体实现。

def normalize(tensor, mean, std, inplace=False):
    """Normalize a tensor image with mean and standard deviation.

    .. note::
        This transform acts out of place by default, i.e., it does not mutates the input tensor.

    See :class:`~torchvision.transforms.Normalize` for more details.

    Args:
        tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
        inplace(bool,optional): Bool to make this operation inplace.

    Returns:
        Tensor: Normalized Tensor image.
    """
    if not _is_tensor_image(tensor):  # 输入的合法性判断
        raise TypeError('tensor is not a torch image.')

    if not inplace:       # 判断是否需要原地操作
        tensor = tensor.clone()

    dtype = tensor.dtype
    # 获取均值与标准差,将list形式转变为张量形式
    mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
    std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
    tensor.sub_(mean[:, None, None]).div_(std[:, None, None]) # 归一化公式
    return tensor

首先是输入的合法性判断,输入的是tensor,也就是原始的图像,接着判断是否要原地操作,如果不是inplace就需要将张量复制一份到新的内存空间中。下面的代码就是获取数据的均值和标准差,并将数据转换为张量。注意在sub_和div_后面有下划线,意思是进行原位操作,这样就完成了数据标准化的操作。

3.3 标准化作用

对数据进行标准化之后可以加快模型的收敛。

之前的逻辑回归代码bias=1,发现迭代次数360次即可得到99%的准确率,损失loss=0.05。

当修改bias=5时,发现需要迭代960次模型才能收敛,loss=0.14,得到99%的准确率。

原因:模型初始化一般有0均值,需要逐渐靠近最优分类平面。

bias=5的初始化距离分类平面较远

可以看出,如果训练数据有良好的分布或者权重有良好的初始化,可以加速模型的训练。

点击下方卡片《学姐带你玩AI》🚀🚀🚀

关注回复“500”领取300+经典论文合集&讲解视频

码字不易,欢迎大家点赞评论收藏!

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值