Pytorch图片数据集预处理与增强

27 篇文章 8 订阅


一、前言

  1. 图片增强
    提高图片泛化度,包括 旋转、翻转、拉伸、色彩抖动等处理,需要根据具体图片类型来决定,比如,我做猫狗二分类,那么旋转、拉伸、翻转、抖动都可以,但是我如果做的是比较严谨的分类比如医学相关的,那么翻转、拉伸、色彩抖动就别整了或者参数调小点
  2. 归一化与标准化
    <1>图片像素值统一除以255,归一化到 [0,1] 之间
    <2>再将归一化的结果减去0.5,除以0.5,标准化到 [-1, 1] 之间
  3. 训练数据可以进行增强和归一化,但是预测数据只进行归一化即可,训练数据增强是为了让模型适应更多不确定性的环境,但是预测的时候就不要把图片转来转去拉来拉去给自己找麻烦了(狗头),这样也能保证输出结果稳定、唯一

二、预处理与增强

1.针对训练数据

  1. 单个图片的增强与加载
from PIL import Image
from torchvision import transforms


def get_transform_for_predict():
    '''
    图片数据转换
    :return:
    '''
    return transforms.Compose([
        transforms.Resize(size=(224, 224)),                                       # 图片拉成 224*224
        transforms.RandomHorizontalFlip(p=0.3),                                   # 将三成图片水平翻转
        transforms.RandomVerticalFlip(p=0.3),                                     # 将三成图片垂直翻转
        transforms.RandomPerspective(distortion_scale=0.3, p=0.3),                # 将三成图片不规则拉伸,拉伸力度0.3
        transforms.RandomRotation(degrees=(0, 180)),                              # 图片随机旋转,0-180度
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),                               # 图片随机抖动,四个对应的值分别是 亮度、对比度、饱和度、色调
        transforms.ToTensor(),                                                    # 1.归一化处理,所有像素值除以255,归一化到[0,1];2.通道维度提前 HWC->CHW
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),          # 标准化处理,归一化后的三个通道维度上的值分别减去 mean 除以 std,将所有像素值标准化到[-1,1]
    ])


if __name__ == '__main__':

    # 加载图片
    img_path = 'dataset/train/dog/dog.0.jpg'
    img = Image.open(img_path).convert('RGB')

    # 初始化数据增强方法
    tsf = get_transform_for_predict()
    # 图片增强
    X = tsf(img)

    # # show
    # plt.imshow(X.permute(1, 2, 0).numpy())
    # plt.show()
  1. 批量增强与加载
    创建如下目录结构
datasets						# 根路径
	|
	|----train					# 训练集
		   |
		   |----cat				# 样本目录1
		   |     |
		   | 	 |----01.jpg	# 图片数据1
		   |   	 |----02.jpg	# 图片数据2
		   |   	 |----...		
		   |
		   |
		   |----dog				# 样本目录2
		   	     |
			     |----01.dog	# 图片数据1
			     |----02.dog	# 图片数据2
			     |----...

# 加载后 cat、dog...等样本目录 会依次变成 0、1... 等标签,分别对应样本目录下的图片数据
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt


def get_transform_for_train():
    '''
    图片数据转换
    :return:
    '''
    return transforms.Compose([
        transforms.Resize(size=(224, 224)),                                       # 图片拉成 224*224
        transforms.RandomHorizontalFlip(p=0.3),                                   # 将三成图片水平翻转
        transforms.RandomVerticalFlip(p=0.3),                                     # 将三成图片垂直翻转
        transforms.RandomPerspective(distortion_scale=0.3, p=0.3),                # 将三成图片不规则拉伸,拉伸力度0.3
        transforms.RandomRotation(degrees=(0, 180)),                              # 图片随机旋转,0-180度
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),                               # 图片随机抖动,四个抖动因子分别是 亮度、对比度、饱和度、色调
        transforms.ToTensor(),                                                    # 1.归一化处理,所有像素值除以255,归一化到[0,1];2.通道维度提前 HWC->CHW
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),          # 标准化处理,归一化后的三个通道维度上的值分别减去 mean 除以 std,将所有像素值标准化到[-1,1]
    ])

def split_data_to_train_and_valid(datasets, validation_split=0.1):
    '''
    将数据集分为训练集和验证集
    :param datasets:
    :param size:
    :return:
    '''
    train_size = int((1-validation_split) * len(datasets))                                      # 训练集size
    validation_size = len(datasets) - train_size                                                # 验证集size
    train_dataset, validation_dataset = random_split(datasets, [train_size, validation_size])   # 拆分
    return train_dataset, validation_dataset

def get_data_iter(data_path, batch_size, validation_split):
    '''
    获取训练张量
    :return:
    '''
    datasets = ImageFolder(
        root=data_path,                         # 数据集路径
        transform=get_transform_for_train()     # 图片增强
    )

    # 拆分数据集和验证集
    train_dataset, validation_dataset = split_data_to_train_and_valid(datasets, validation_split)

    # 分别加入迭代器
    train_iter = DataLoader(
        train_dataset,              # 训练集
        batch_size=batch_size,      # 批量大小
        shuffle=True,               # 是否乱序
        num_workers=4               # 加载时使用的进程并发数
    )
    validation_iter = DataLoader(
        validation_dataset,         # 训练集
        batch_size=batch_size,      # 批量大小
        shuffle=True,               # 是否乱序
        num_workers=4               # 加载时使用的进程并发数
    )

    return train_iter, validation_iter


if __name__ == '__main__':

    data_path = r'E:\数据集\猫狗数据集\kaggle_Dog&Cat\train'

    # 加载数据并拆分为训练集和验证集
    train_iter, validation_iter = get_data_iter(
        data_path=data_path,            # 数据集位置
        batch_size=64,                  # 批量大小
        validation_split=0.3            # 三成数据作为验证集
    )
	
	# 到此为止就可以送入网络训练了,下面的是打印检查数据 ------------
	
    # 打印每个batch的训练集和验证集长度
    print(len(train_iter), len(validation_iter))
    # 274 118
    # 我这里准备了25000张图片
    # 训练集每批数据量 = 25000 / 64 * (1 - 0.3) = 274
    # 验证集每批数据量 = 25000 / 64 * 0.3 = 118

    # 打开训练集的第一张图片康康
    for index, batch_data in enumerate(train_iter):
        # 打印索引、训练集尺寸、标签尺寸
        print(index, batch_data[0].shape, batch_data[1].shape)
        # 0 torch.Size([64, 3, 224, 224]) torch.Size([64])

        # 打印第一批数据的第一张图的标签和样本
        print(batch_data[0][0], batch_data[1][0])
        # 样本
        # tensor([[[-0.9765, -0.9765, -0.9765,  ..., -0.9765, -0.9765, -0.9765],
        #          [-0.9765, -0.9765, -0.9765,  ..., -0.9765, -0.9765, -0.9765],
        #          [-0.9765, -0.9765, -0.9765,  ..., -0.9765, -0.9765, -0.9765],
        #          ...,
        #          [-0.9765, -0.9765, -0.9765,  ..., -0.9765, -0.9765, -0.9765],
        #          [-0.9765, -0.9765, -0.9765,  ..., -0.9765, -0.9765, -0.9765],
        #          [-0.9765, -0.9765, -0.9765,  ..., -0.9765, -0.9765, -0.9765]]])
        # 标签
        # tensor(1)

        for data in batch_data[0]:
            data = data.permute(1, 2, 0)        # 通道维度放到最后
            plt.imshow(data.numpy())            # 转成numpy展示图片
            plt.show()
            break
        break

    # 打开验证集的第一张图片康康
    for index, batch_data in enumerate(validation_iter):
        # 打印索引、训练集尺寸、标签尺寸
        print(index, batch_data[0].shape, batch_data[1].shape)
        # 0 torch.Size([64, 3, 224, 224]) torch.Size([64])

        # 打印第一批数据的第一张图的标签和样本
        print(batch_data[0][0], batch_data[1][0])
        # 样本
        # tensor([[[-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
        #          [-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
        #          [-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
        #          ...,
        #          [-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
        #          [-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
        #          [-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922]]])
        # 标签
        # tensor(0)

        for data in batch_data[0]:
            data = data.permute(1, 2, 0)        # 通道维度放到最后
            plt.imshow(data.numpy())            # 转成numpy展示图片
            plt.show()
            break
        break

2.针对预测数据

from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import torch


def get_transform_for_predict():
    '''
    图片数据转换
    :return:
    '''
    return transforms.Compose([
        transforms.Resize(size=(224, 224)),                                       # 图片拉成 224*224
        # transforms.RandomHorizontalFlip(p=0.3),                                   # 将三成图片水平翻转
        # transforms.RandomVerticalFlip(p=0.3),                                     # 将三成图片垂直翻转
        # transforms.RandomPerspective(distortion_scale=0.3, p=0.3),                # 将三成图片不规则拉伸,拉伸力度0.3
        # transforms.RandomRotation(degrees=(0, 180)),                              # 图片随机旋转,0-180度
        # transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),                               # 图片随机抖动,四个对应的值分别是 亮度、对比度、饱和度、色调
        transforms.ToTensor(),                                                    # 1.归一化处理,所有像素值除以255,归一化到[0,1];2.通道维度提前 HWC->CHW
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),          # 标准化处理,归一化后的三个通道维度上的值分别减去 mean 除以 std,将所有像素值标准化到[-1,1]
    ])


if __name__ == '__main__':

    # 加载图片
    img_path = 'dataset/train/dog/狗砸.png'
    img = Image.open(img_path).convert('RGB')

    # 初始化数据增强方法
    tsf = get_transform_for_train()
    # 图片增强
    X = tsf(img)

    # # show
    # plt.imshow(X.permute(1, 2, 0).numpy())
    # plt.show()

评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

什么都干的派森

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值