PaddlePaddle飞桨(学习笔记四——Transforms 数据预处理)

在模型训练过程中有时会遇到过拟合的问题,其中一个解决方法就是对训练数据做数据增强处理。通过对数据进行特定的处理,如图像的裁剪、翻转、调整亮度等处理,以增加样本的多样性,从而增强模型的泛化能力。

1、paddle.vision.transforms 介绍

飞桨框架在 paddle.vision.transforms 下内置了数十种图像数据处理方法,可以通过以下代码查看

paddle.vision.transforms.__all__

 图像数据处理方法: ['BaseTransform', 'Compose', 'Resize', 'RandomResizedCrop', 'CenterCrop', 'RandomHorizontalFlip', 'RandomVerticalFlip', 'Transpose', 'Normalize', 'BrightnessTransform', 'SaturationTransform', 'ContrastTransform', 'HueTransform', 'ColorJitter', 'RandomCrop', 'Pad', 'RandomRotation', 'Grayscale', 'ToTensor', 'to_tensor', 'hflip', 'vflip', 'resize', 'pad', 'rotate', 'to_grayscale', 'crop', 'center_crop', 'adjust_brightness', 'adjust_contrast', 'adjust_hue', 'normalize']

  • 单个使用

from paddle.vision.transforms import Resize

# 定义一个待使用的数据处理方法,这里定义了一个调整图像大小的方法
transform = Resize(size=28)
  • 多个组合使用

from paddle.vision.transforms import Compose, RandomRotation

# 定义待使用的数据处理方法,这里包括随机旋转、改变图片大小两个组合处理
transform = Compose([RandomRotation(10), Resize(size=32)])

2、在数据集中应用数据预处理操作

  • 在框架内置数据集中应用

train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
  • 在自定义的数据集中应用

对于自定义的数据集,可以在数据集中将定义好的数据处理方法传入 __init__ 函数,将其定义为自定义数据集类的一个属性,然后在 __getitem__ 中将其应用到图像上,如下述代码所示:

import os
import cv2
import numpy as np
from paddle.io import Dataset

class MyDataset(Dataset):
    """
    步骤一:继承 paddle.io.Dataset 类
    """
    def __init__(self, data_dir, label_path, transform=None):
        """
        步骤二:实现 __init__ 函数,初始化数据集,将样本和标签映射到列表中
        """
        super().__init__()
        self.data_list = []
        with open(label_path,encoding='utf-8') as f:
            for line in f.readlines():
                image_path, label = line.strip().split('\t')
                image_path = os.path.join(data_dir, image_path)
                self.data_list.append([image_path, label])
        ###############################################
        ###############################################
        # 2. 传入定义好的数据处理方法,作为自定义数据集类的一个属性
        self.transform = transform
        ###############################################
        ###############################################

    def __getitem__(self, index):
        """
        步骤三:实现 __getitem__ 函数,定义指定 index 时如何获取数据,并返回单条数据(样本数据、对应的标签)
        """
        image_path, label = self.data_list[index]
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        image = image.astype('float32')
        # 3. 应用数据处理方法到图像上
        ###############################################
        ###############################################
        if self.transform is not None:
            image = self.transform(image)
        ###############################################
        ###############################################
        label = int(label)
        return image, label

    def __len__(self):
        """
        步骤四:实现 __len__ 函数,返回数据集的样本总数
        """
        return len(self.data_list)
###############################################
###############################################
# 1. 定义随机旋转和改变图片大小的数据处理方法
transform = Compose([RandomRotation(10), Resize(size=32)])

custom_dataset = MyDataset('mnist/train','mnist/train/label.txt', transform)
###############################################
###############################################

  • 9
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一级piaopiao虎

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值