系统学习Pytorch笔记三:Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)

Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)

背景

学习知识先有框架(至少先知道有啥东西)然后再通过实战(各个东西具体咋用)来填充这个框架。 而这个系列的目的就是在脑海中先建一个Pytorch的基本框架出来。关于系统学习Pytorch,逻辑上就是按照机器学习的那五大步骤进行的, 步骤为:数据模块 -> 模型模块 -> 损失函数 -> 优化器 -> 迭代训练

基于上次的学习Pytorch的动态图、自动求导及逻辑回归进行整理,这次主要是学习Pytorch的数据读取机制DataLoader和Dataset的运行机制,然后学习图像的预处理模块transforms的原理,最后基于上面的所学玩一个人民币二分类的任务。

Pytorch的数据读取机制

机器模型学习的五大模块,分别是数据,模型,损失函数,优化器,迭代训练

在这里插入图片描述
这里的数据读取机制,很显然是位于数据模块的一个小分支,下面看一下数据模块的详细内容
在这里插入图片描述

数据模块中,又可以大致分为上面不同的子模块, 而今天学习的DataLoaderDataSet也就是数据读取子模块中的核心机制

DataLoader

torch.utils.data.DataLoader(): 构建可迭代的数据装载器, 我们在训练的时候,每一个for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据的。
在这里插入图片描述
DataLoader的参数很多,但我们常用的主要有5个:

  • dataset: Dataset类, 决定数据从哪读取以及如何读取
  • bathsize: 批大小
  • num_works: 是否多进程读取机制
  • shuffle: 每个epoch是否乱序
  • drop_last: 当样本数不能被batchsize整除时, 是否舍弃最后一批数据

Epoch, Iteration和Batchsize的概念

  • Epoch: 所有训练样本都已输入到模型中,称为一个Epoch
  • Iteration: 一批样本输入到模型中,称为一个Iteration
  • Batchsize: 一批样本的大小, 决定一个Epoch有多少个Iteration

举个例子, 假设样本总数80, Batchsize是8, 那么1Epoch=10 Iteration。 假设样本总数是87, Batchsize是8, 如果drop_last=True, 那么1Epoch=10Iteration, 如果等于False, 那么1Epoch=11Iteration, 最后1个Iteration有7个样本。

Dataset

torch.utils.data.Dataset(): Dataset抽象类, 所有自定义的Dataset都需要继承它,并且必须复写__getitem__()这个类方法。
在这里插入图片描述

数据读取机制及Dataset、DataLoader的用法

上面只是介绍了两个数据读取机制用到的两个类,那么具体怎么用呢? 我们就以人民币二分类的任务进行具体查看, 但是查看之前我们要带着关于数据读取的三个问题去看:

1、读哪些数据? 我们每一次迭代要去读取一个batch_size大小的样本,那么读哪些样本呢?
2、从哪读数据? 也就是在硬盘当中该怎么去找数据,在哪设置这个参数。
3、怎么读数据?

人民币分类的任务,数据集是1块的图片100张,100的图片100张,我们的任务就是训练一个模型,来帮助我们对这两类图片进行分类。

#==========================================step 1/5 准备数据===============================

# 数据的路径
split_dir = os.path.join('data', 'rmb_split')
train_dir = os.path.join(split_dir, 'train')
valid_dir = os.path.join(split_dir, 'valid')

## transforms模块,进行数据预处理
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

## 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoader
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# print(train_loader)

首先一开始,是路径部分, 也就是训练集和测试集的位置,这个其实就是我们上面的第二个问题从哪读数据,然后是transforms图像数据的预处理部分, 这个不用管, 后面会介绍transforms这个模块

MyDataset实例还有后面的DataLoader,这个才是我们这次介绍的重点。

我们从 train_data = RMBDataset(data_dir=train_dir, transform=train_transform) 开始, 这一句话里面的核心就是RMBDataset,这个是我们自己写的一个类,继承了上面的抽象类Dataset,并且重写了__getitem__()方法, 这个类的目的就是传入数据的路径,和预处理部分(看参数),然后给我们返回数据.

class RMBDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.label_name = {"1": 0, "100": 1}
        self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.data_info)

    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_img, int(label)))

        return data_info

重点就是__getitem__()这个方法的实现了,我们说过从这里面,我们要拿到我们的训练样本.我们只要给定了index, 通过data_info[index] 代码进行获取样本的。
主要调用静态方法,get_img_info(data_dir)。函数的输入为数据在的路径,输出为是一个list, 而list的每个元素是元组,格式就是[(样本1_loc, label_1), (样本2_loc, label_2), …(样本n_loc, label_n)]。这个其实就是data_info拿到的一个list。 有了这个list,然后又给了data_info一个index,那么取数据不就很容易了吗? data_info[index] 不就取出了某个(样本i_loc, label_i)。

Pytorch的图像预处理transforms

transforms是常用的图像预处理方法, 这个在torchvision计算机视觉工具包中,我们在安装Pytorch的时候顺便安装了这个torchvision。 在torchvision中,有三个主要的模块:

  • torchvision.transforms: 常用的图像预处理方法, 比如标准化,中心化缩放,裁剪,旋转,翻转,填充,噪声添加,灰度变换,线性变换,仿射变换,亮度、饱和度及对比度变换等操作
  • trochvision.datasets: 常用的数据集的dataset实现, MNIST, CIFAR-10, ImageNet等
  • torchvision.models: 常用的模型预训练, AlexNet, VGG, ResNet, GoogLeNet等。

二分类任务中用到的transforms的方法

人民币二分类任务中用到的图像预处理的方法

在这里插入图片描述

  • transforms.Compose方法是将一系列的transforms方法进行有序的组合包装,具体实现的时候,依次的用包装的方法对图像进行操作。
  • transforms.Resize方法改变图像大小
  • transforms.RandomCrop方法对图像进行裁剪(这个在训练集里面用,验证集就用不到了)
  • transforms.ToTensor方法是将图像转换成张量,同时会进行归一化的一个操作,将张量的值从0-255转到0-1
  • transforms.Normalize方法是将数据进行标准化

了解了图像处理的transforms机制,我们下面学习一个比较常用的数据预处理机制,叫做数据标准化:
transforms.Normalize: 逐channel的对图像进行标准化。 output = (input - mean) / std。Normalize的处理作用就是有利于加快模型的收敛速度

transforms的其他图像增强方法

1、数据增强
数据增强又称为数据增广, 数据扩增,是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力

2、图像裁剪

  • transforms.CenterCrop(size): 图像中心裁剪图片, size是所需裁剪的图片尺寸,如果比原始图像大了, 会默认填充0。
  • transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant): 从图片中位置随机裁剪出尺寸为size的图片, size是尺寸大小,padding设置填充大小(当为a, 上下左右均填充a个像素, 当为(a,b), 上下填充b个,左右填充a个,当为(a,b,c,d), 左,上,右,下分别填充a,b,c,d个), pad_if_need: 若图像小于设定的size, 则填充。 padding_mode表示填充模型, 有4种,constant像素值由fill设定, edge像素值由图像边缘像素设定,reflect镜像填充, symmetric也是镜像填充, 这俩镜像是怎么做的看官方文档吧。镜像操作就类似于复制图片的一部分进行填充。
  • transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(3/4, 4/3), interpolation): 随机大小,长宽比裁剪图片。 scale表示随机裁剪面积比例,ratio随机长宽比, interpolation表示插值方法。
  • FiveCrop, TenCrop: 在图像的上下左右及中心裁剪出尺寸为size的5张图片,后者还在这5张图片的基础上再水平或者垂直镜像得到10张图片,具体使用这里就不整理了。

3、图像的翻转和旋转

  • RandomHorizontalFlip(p=0.5), RandomVerticalFlip(p=0.5): 依概率水平或者垂直翻转图片, p表示翻转概率
  • RandomRotation(degrees, resample=False, expand=False, center=None):随机旋转图片, degrees表示旋转角度 , resample表示重采样方法, expand表示是否扩大图片,以保持原图信息。

4、图像变换

  • transforms.Pad(padding, fill=0, padding_mode=‘constant’): 对图片边缘进行填充
  • transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0):调整亮度、对比度、饱和度和色相, 这个是比较实用的方法, brightness是亮度调节因子, contrast对比度参数, saturation饱和度参数, hue是色相因子。
  • transfor.RandomGrayscale(num_output_channels, p=0.1): 依概率将图片转换为灰度图, 第一个参数是通道数, 只能1或3, p是概率值,转换为灰度图像的概率
  • transforms.RandomAffine(degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0): 对图像进行仿射变换, 反射变换是二维的线性变换, 由五中基本原子变换构成,分别是旋转,平移,缩放,错切和翻转。 degrees表示旋转角度, translate表示平移区间设置,scale表示缩放比例,fill_color填充颜色设置, shear表示错切
  • transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): 这个也比较实用, 对图像进行随机遮挡, p概率值,scale遮挡区域的面积, ratio遮挡区域长宽比。 随机遮挡有利于模型识别被遮挡的图片。value遮挡像素。 这个是对张量进行操作,所以需要先转成张量才能做
  • transforms.Lambda(lambd): 用户自定义的lambda方法, lambd是一个匿名函数。lambda [arg1 [, arg2…argn]]: expression

transforms的选择操作

对几个transforms的操作进行选择,使得图像预处理更加的灵活。

  • transforms.RandomChoice([transforms1, transforms2, transforms3]): 从一系列transforms方法中随机选一个
  • transforms.RandomApply([transforms1, transforms2, transforms3], p=0.5): 依据概率执行一组transforms操作
  • transforms.RandomOrder([transforms1, transforms2, transforms3]): 对一组transforms操作打乱顺序

到这里,关于Pytorch的transforms操作基本上就搞定, 上面只是整理了一些常用的函数,如果真的需要,具体细节还得去看官方文档。 虽然Pytorch提供了很多的transforms方法, 但是在实际工作中,可能需要自己的项目去自定义一些transforms方法

自定义transforms

在Compose这个类里面调用了一系列的transforms方法。对Compose里面的这些transforms方法执行一个for循环,每次挑取一个方法进行执行。 也就是transforms方法仅接收一个参数,返回一个参数,然后就是for循环中,上一个transforms的输出正好是下一个transforms的输入,所以数据类型要注意匹配。 这就是自定义transforms的两个要素。

所以,这里注意一下, 对于图像增强来说,我们给定一张图片,然后通过一系列图片增强,什么裁剪,旋转,变化啥的,最终得到的是一张图片,也就是上面的一系列图片增强技术是串联执行的.
之前不是说图像增强可以丰富数据集,增加图片数量,并且可以增强模型的鲁棒性,我一张图片经过一系列猛操作之后,就只得到了一张图片,也没增强啥东西呀?
如果如果想进行这种丰富数据集,增加图片数量的这种图像增强,我的建议有两个:
1、原图片事先多复制几张。 比如一张相同的图片,复制5张放到同一个目录中,这样构建dataloader,读取图片的时候,相当于这5张都有机会做增强,才会出来一张图片有不同形态的效果,实现丰富数据集,增强模型鲁棒的目的。
2、数据增强过程和构建dataloader的过程耦合开, 先通过transform技术做一些数据增强的图片,比如对于一张图片, 我经过裁剪,旋转, 变换等得到多张图片,直接保存到数据目录中去。 然后再读取。

下面给出一个自定义transforms的结构:

在这里插入图片描述
上面就是整个transforms的图像增强处理的技术了.但是我们如何去选择图像增强的策略呢?

数据增强策略原则: 让训练集与测试集更接近

  • 空间位置上: 可以选择平移
  • 色彩上: 灰度图,色彩抖动
  • 形状: 仿射变换
  • 上下文场景: 遮挡,填充

总结

首先是整理了Pytorch的数据读取机制, 学习到了两个数据读取的关键DataLoader和Dataset,并通过一个人民币二分类的例子具体看了下这两个是如何使用的。

然后又学习了Pytorch的图像处理模块transforms, 这一模块主要是整理了各种图像处理的方法,transforms的选择操作,并且从战术的角度看了一下这些方法到底什么时候用。 至于这些方法的细节,具体用到的时候查看官方文档。

关于Pytorch的数据模块,到这里就基本结束, 我们的逻辑就是按照机器学习的那五大步骤进行的查看, 数据模块 -> 模型模块 -> 损失函数 -> 优化器 -> 训练等。

参考:
[1]: https://blog.csdn.net/wuzhongqiang/article/details/105499476
[2]: https://pytorch.org/docs/stable/torch.html?
[3]: https://mermaidjs.github.io/
[4]: http://adrai.github.io/flowchart.js/

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值