CV:Pytorch数据集制作,数据增强处理

Pytorch数据集制作,数据增强处理

1. Pytorch自带数据集及读取

import torch
import torchvision
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

# 1. 数据集定义
"""
一般的,我们使用torchvision.transforms中的函数来实现数据增强,并用transforms.Compose将所要进行的变换操作都组合在一起,
其变换操作的顺序按照在transforms.Compose中出现的先后顺序排列
"""
custom_transform = transforms.Compose([
    transforms.Resize((64, 64)),  # 缩放到指定大小 64*64
    transforms.ColorJitter(0.2, 0.2, 0.2),  # 随机颜色变换
    transforms.RandomRotation(5),  # 随机旋转
    transforms.Normalize([0.485, 0.456, 0.406],  # 对图像像素进行归一化
                         [0.229, 0.224, 0.225])])

# 读取训练集.transform:定义数据预处理,target_transform:目标检测标注的预处理,分类任务不常用,download是选择是否从网上下载
train_data = torchvision.datasets.CIFAR10('D:/Datasets/CIFAR10', train=True, transform=None,
                                          target_transform=None, download=True)
# 读取测试集
test_data = torchvision.datasets.CIFAR10('D:/Datasets/CIFAR10', train=False, transform=None,
                                         target_transform=None, download=True)

# 2. 数据集加载,shuffle设置为True在装载过程中为随机乱序,num_workers>=1表示多进程读取数据,在Win下num_workers只能设置为0,否则会报错。
train_loader = torch.utils.data.DataLoader(train_data, batch_size=2, shuffle=True, num_workers=0)
print(len(train_loader), type(train_loader), sep='\n')

2. Pytorch自制数据集及读取

import os
import torchvision.datasets.mnist as mnist
from skimage import io

# 1. 数据文件读取
root = r'D:/Datasets/MNIST'
train_set = (
    mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
    mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
)

test_set = (
    mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
    mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
)

# 数据量展示
print(train_set[0].size(), test_set[0].size(), train_set[1].size(), type(test_set), sep='\n')


# 2. 制作训练集,测试集
def convert_to_img(save_path, train=True):
    """
    将图片存储在本地,并制作索引文件
    :param save_path: 图像保存路径,将在路径下创建train、test文件夹分别存储训练集和测试集
    :param train: 默认True,本地存储训练集图像,否则本地存储测试集图像
    """

    if train:
        f = open(save_path + 'train.txt', 'w')
        data_path = save_path + 'train/'
        if not os.path.exists(data_path):
            os.makedirs(data_path)
        for i, (img, label) in enumerate(zip(train_set[0], train_set[1])):
            img_path = data_path + str(i) + '.jpg'
            io.imsave(img_path, img.numpy())  # 第一个参数表示保存的路径和名称,第二个参数表示需要保存的数组变量
            int_label = str(label).replace('tensor(', '')
            int_label = int_label.replace(')', '')
            f.write(str(i) + '.jpg' + ',' + str(int_label) + '\n')
        f.close()
    else:
        f = open(save_path + 'test.txt', 'w')
        data_path = save_path + '/test/'
        if not os.path.exists(data_path):
            os.makedirs(data_path)
        for i, (img, label) in enumerate(zip(test_set[0], test_set[1])):
            img_path = data_path + str(i) + '.jpg'
            io.imsave(img_path, img.numpy())
            int_label = str(label).replace('tensor(', '')
            int_label = int_label.replace(')', '')
            f.write(str(i) + '.jpg' + ',' + str(int_label) + '\n')
        f.close()


save_path = 'D:/DataSets/MNIST/mnist_data/'
convert_to_img(save_path)
convert_to_img(save_path, False)

3. 构建自己的Dataset及加载数据

import pandas as pd
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader
"""
需要注意的是,当 Dataset 创建好后并没有将数据生产出来,我们只是定义了数据及标签生产的流水线,只有在真正使用时,
如手动调用 next(iter(train_dataset)),或被 DataLoader调用,才会触发数据集内部的 __getitem__() 函数来读取数据
Dataset是对本地数据读取逻辑的定义;而DataLoader是对Dataset对象的封装,执行调度,将一个batch size的图像数据组装在一起,实现批量读取数据。
"""

# 1. 创建自己的train_dataset,test_dataset
class MnistDataset(Dataset):

    def __init__(self, image_path, image_label, transform=None):
        super(MnistDataset, self).__init__()
        self.image_path = image_path  # 初始化图像路径列表
        self.image_label = image_label  # 初始化图像标签列表
        self.transform = transform  # 初始化数据增强方法

    def __getitem__(self, index):
        """
        获取对应index的图像,并视情况进行数据增强
        """
        image = Image.open(self.image_path[index])
        image = np.array(image)
        label = float(self.image_label[index])

        if self.transform is not None:
            image = self.transform(image)

        return image, torch.tensor(label)

    def __len__(self):
        return len(self.image_path)


def get_path_label(img_root, label_file_path):
    """
    获取数字图像的路径和标签并返回对应列表
    @para: img_root: 保存图像的根目录
    @para:label_file_path: 保存图像标签数据的文件路径 .csv 或 .txt 分隔符为','
    @return: 图像的路径列表和对应标签列表
    """
    data = pd.read_csv(label_file_path, names=['img', 'label'])
    data['img'] = data['img'].apply(lambda x: img_root + x)
    return data['img'].tolist(), data['label'].tolist()


# 获取训练集路径列表和标签列表
train_data_root = 'D:/Datasets/MNIST/mnist_data/train/'
train_label = 'D:/Datasets/MNIST/mnist_data/train.txt'
train_img_list, train_label_list = get_path_label(train_data_root, train_label)
# 训练集dataset
train_dataset = MnistDataset(train_img_list,
                             train_label_list,
                             transform=transforms.Compose([transforms.ToTensor()]))

# 获取测试集路径列表和标签列表
test_data_root = 'D:/Datasets/MNIST/mnist_data/test/'
test_label = 'D:/Datasets/MNIST/mnist_data/test.txt'
test_img_list, test_label_list = get_path_label(test_data_root, test_label)
# 测试集dataset
test_dataset = MnistDataset(test_img_list,
                            test_label_list,
                            transform=transforms.Compose([transforms.ToTensor()]))

# 观察结果,注意下面两句代码,可以用来查看结果
# train_iter = iter(train_dataset)
# print(next(train_iter))
# img, label = train_dataset[0]
# print(len(train_dataset), img.size(), label, sep='\n')

"""
2. 数据加载,num_works默认为0,这里就不写了,shuffle:True表示打乱数据顺序,再分batch,
train_loader 已经将原来训练集中的60000张图像重新“洗牌”后按照每3张一个batch划分完成(test_loader同理)
"""
train_loader = DataLoader(train_dataset, batch_size=3, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=3, shuffle=True)

"""
观察结果,经过DataLoader的封装,每3(一个batch_size数量)张图像数据及对应的标签被封装为一个二元元组,第一个元素为四维的tensor形式,第二个元素为对应的图像标签数据
"""
# loader = iter(train_loader)
# print(next(loader))
# enumerate中的0表示可迭代对象索引从0开始,如果是1那么从1开始
# for i, img_data in enumerate(train_loader, 0):
#     images, labels = img_data
#     print('batch{}:images shape info-->{} labels-->{}'.format(i, images.size(), labels))

4. 数据增强处理

from PIL import Image
from matplotlib import pyplot as plt
import torchvision.transforms as transforms

# 原始图像
im = Image.open('./fish.jpg')


# 中心裁剪
center_crop = transforms.CenterCrop([200, 200])
# 随机裁剪
random_crop = transforms.RandomCrop([200, 200])
# 随机长宽比裁剪
random_resized_crop = transforms.RandomResizedCrop(200,
                                                   scale=(0.08, 1.0),
                                                   ratio=(0.75, 1.55),
                                                   interpolation=2)(im)
# 依概率p水平翻转
h_flip = transforms.RandomHorizontalFlip(0.7)
# 依概率p垂直翻转
v_flip = transforms.RandomVerticalFlip(0.8)
# 随机旋转
random_rotation = transforms.RandomRotation(30)

# 图像填充
pad = transforms.Pad(10, fill=0, padding_mode='constant')

# 调整亮度、对比度和饱和度
color_jitter = transforms.ColorJitter(brightness=1,
                                      contrast=0.5,
                                      saturation=0.5,
                                      hue=0.4)
# 转成灰度图
gray = transforms.Grayscale(1)

# 仿射变换
random_affine = transforms.RandomAffine(45, (0.5, 0.7), (0.5, 0.8), 3)

# 尺寸缩放
resize = transforms.Resize([100, 200])

# # 转Tensor、标准化和转换为PILImage
mean = [0.45, 0.5, 0.5]
std = [0.3, 0.6, 0.5]
transform = transforms.Compose([transforms.ToTensor(),  # 转Tensor
                                random_affine,
                                transforms.ToPILImage()  # 这里是为了可视化,故将其再转为 PIL
                                ])

img_tansform = transform(im)
plt.imshow(img_tansform)  # plt.imshow()函数负责对图像进行处理,并显示其格式,而plt.show()则是将plt.imshow()处理后的函数显示出来
plt.show()

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值