在模型训练过程中有时会遇到过拟合的问题,其中一个解决方法就是对训练数据做数据增强处理。通过对数据进行特定的处理,如图像的裁剪、翻转、调整亮度等处理,以增加样本的多样性,从而增强模型的泛化能力。
1、paddle.vision.transforms 介绍
飞桨框架在 paddle.vision.transforms 下内置了数十种图像数据处理方法,可以通过以下代码查看
paddle.vision.transforms.__all__
图像数据处理方法: ['BaseTransform', 'Compose', 'Resize', 'RandomResizedCrop', 'CenterCrop', 'RandomHorizontalFlip', 'RandomVerticalFlip', 'Transpose', 'Normalize', 'BrightnessTransform', 'SaturationTransform', 'ContrastTransform', 'HueTransform', 'ColorJitter', 'RandomCrop', 'Pad', 'RandomRotation', 'Grayscale', 'ToTensor', 'to_tensor', 'hflip', 'vflip', 'resize', 'pad', 'rotate', 'to_grayscale', 'crop', 'center_crop', 'adjust_brightness', 'adjust_contrast', 'adjust_hue', 'normalize']
-
单个使用
from paddle.vision.transforms import Resize
# 定义一个待使用的数据处理方法,这里定义了一个调整图像大小的方法
transform = Resize(size=28)
-
多个组合使用
from paddle.vision.transforms import Compose, RandomRotation
# 定义待使用的数据处理方法,这里包括随机旋转、改变图片大小两个组合处理
transform = Compose([RandomRotation(10), Resize(size=32)])
2、在数据集中应用数据预处理操作
-
在框架内置数据集中应用
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
-
在自定义的数据集中应用
对于自定义的数据集,可以在数据集中将定义好的数据处理方法传入 __init__
函数,将其定义为自定义数据集类的一个属性,然后在 __getitem__
中将其应用到图像上,如下述代码所示:
import os
import cv2
import numpy as np
from paddle.io import Dataset
class MyDataset(Dataset):
"""
步骤一:继承 paddle.io.Dataset 类
"""
def __init__(self, data_dir, label_path, transform=None):
"""
步骤二:实现 __init__ 函数,初始化数据集,将样本和标签映射到列表中
"""
super().__init__()
self.data_list = []
with open(label_path,encoding='utf-8') as f:
for line in f.readlines():
image_path, label = line.strip().split('\t')
image_path = os.path.join(data_dir, image_path)
self.data_list.append([image_path, label])
###############################################
###############################################
# 2. 传入定义好的数据处理方法,作为自定义数据集类的一个属性
self.transform = transform
###############################################
###############################################
def __getitem__(self, index):
"""
步骤三:实现 __getitem__ 函数,定义指定 index 时如何获取数据,并返回单条数据(样本数据、对应的标签)
"""
image_path, label = self.data_list[index]
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
image = image.astype('float32')
# 3. 应用数据处理方法到图像上
###############################################
###############################################
if self.transform is not None:
image = self.transform(image)
###############################################
###############################################
label = int(label)
return image, label
def __len__(self):
"""
步骤四:实现 __len__ 函数,返回数据集的样本总数
"""
return len(self.data_list)
###############################################
###############################################
# 1. 定义随机旋转和改变图片大小的数据处理方法
transform = Compose([RandomRotation(10), Resize(size=32)])
custom_dataset = MyDataset('mnist/train','mnist/train/label.txt', transform)
###############################################
###############################################