2020-12-24

深度学习,复习,PyTorch dataset
转 Datawhale

from PIL import Image
import torch
import torchvision
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
import pandas as pd
import os
from skimage import io
import torchvision.datasets.mnist as mnist
import numpy as np
from torch.utils.data import DataLoader

from torchvision.datasets import ImageFolder

'''
dataset_dir = '../../../dataset/'
torchvision.datasets.CIFAR10(dataset_dir, train=True, transform=None, target_transform=None, download=False) 

dataset_dir:存放数据集的路径。
train(bool,可选)–如果为True,则构建训练集,否则构建测试集。
transform:定义数据预处理,数据增强方案都是在这里指定。
target_transform:标注的预处理,分类任务不常用。
download:是否下载,若为True则从互联网下载,如果已经在dataset_dir下存在,就不会再次下载
'''

# 读取示例1(从网上下载)
# 读取训练集
train_data = torchvision.datasets.CIFAR10('../../../dataset',
                                          train=True,
                                          transform=None,
                                          target_transform=None,
                                          download=True)
# 读取测试集
test_data = torchvision.datasets.CIFAR10('../../../dataset',
                                         train=True,
                                         transform=None,
                                         target_transform=None,
                                         download=True)

# 读取示例2(示例1基础上附带数据增强)
'''
图像进行各种变换来增加数据的丰富性称为数据增强.使用torchvision.transforms中的函数来实现数据增强,
并用transforms.Compose将所要进行的变换操作都组合在一起,其变换操作的顺序按照在transforms.Compose中出现的先后顺序排列。
在transforms中有很多实现好的数据增强方法,在这里我们尝试使用缩放,随机颜色变换、随机旋转、图像像素归一化等组合变换。
'''

# 读取数据集
custom_transform = transforms.Compose([transforms.Resize((64, 64)),  # 缩放到指定大小 64*64
                                       transforms.ColorJitter(0.2, 0.2, 0.2),  # 随机颜色变换
                                       transforms.RandomRotation(5),  # 随机旋转
                                       transforms.Normalize([0.485, 0.486, 0.406],  # 对图像像素进行归一化
                                                            [0.229, 0.224, 0.225])])
train_data = torchvision.datasets.CIFAR10('../../../dataset',
                                          train=True,
                                          transform=custom_transform,
                                          target_transform=None,
                                          download=False)

# Pytorch提供DataLoader来完成对于数据集的加载,并且支持多进程并行读取。
# DataLoader使用示例

# 读取数据集
train_data = torchvision.datasets.CIFAR10('../../../dataset',
                                          train=True,
                                          transform=None,
                                          target_transform=None,
                                          download=True)
# 实现数据批量读取
train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=2,
                                           shuffle=True,
                                           num_workers=4)
# batch_size设置了批量大小,shuffle设置为True在装载过程中为随机乱序,
# num_workers>=1表示多进程读取数据,在Win下num_workers只能设置为0,否则会报错

# 自定义数据集读取方法
# 首先,我们要确定是否包含标签文件,如果没有就要自己先创建标签文件
# 对pytorch读取数据一般化pipeline的描述
# 图像数据 ➡ 图像索引文件 ➡ 使用Dataset构建数据集 ➡ 使用DataLoader读取数据
# 图像数据是训练测试模型使用的图片;索引文件指的就是记录数据标注信息的文件,告诉程序哪个图片对应哪些标注信息

# 图像索引文件制作
'''图像索引文件只要能够合理记录标注信息即可,内容可以简单也可以复杂,
但有一条要注意:内容是待读取图像的名称(或路径)及标签,并且读取后能够方便实现索引。该文件可以是txt文件,csv文件等多种形式,
甚至是一个list都可以,只要是能够被Dataset类索引到即可
'''

# 数据文件读取
root = r'./MNIST'  # MNIST解压文件根目录
train_set = (
    mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
    mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))

)
test_set = (
    mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
    mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
)

# 数据量展示
print('train set:', train_set[0].size())
print('test set:', test_set[0].size())


def convert_to_img(save_path, train=True):
    '''
    将图片存储在本地,并制作所应文件
    :param save_path: 图像保存路径,将在路径下创建train、test文件夹分别存储训练集和测试机
    :param train: 默认True,本低存储训练集图像,否则本地存储测试集图像
    :return: 
    '''

    if train:
        f = open(save_path + 'train.txt', 'w')
        data_path = save_path + '/train/'
        if (not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i, (img, label) in enumerate(zip(train_set[0], test_set[1])):
            img_path = data_path + str(i) + '.jpg'
            io.imsave(img_path, img.numpy())
            int_label = str(label).replace('tensor(', '')
            int_label = int_label.replace(')', '')
            f.write(str(i) + '.jpg' + ',' + str(int_label) + '\n')
        f.close()
    else:
        f = open(save_path + 'test.txt', 'w')
        data_path = save_path + '/test/'
        if (not os.path.exists(data_path)):
            os.makedirs(data_path)
            for i, (img, label) in enumerate(zip(test_set[0], test_set[1])):
                img_path = data_path + str(i) + '.jpg'
                io.imsave(img_path, img.numpy())
                int_label = str(label).replace('tensor(', '')
                int_label = int_label.replace(')', '')
                f.write(str(i) + '.jpg' + ',' + str(int_label) + '\n')
            f.close()


# 根据需求本地存储训练集或测试集
save_path = r'./MNIST/mnist_data/'
convert_to_img(save_path, True)
convert_to_img(save_path, False)

# 构建自己的Dataset
'''
想要读取我们自己数据集中的数据,就需要写一个Dataset的子类来定义我们的数据集,
并必须对 __init__、__getitem__ 和 __len__ 方法进行重载
'''


class MyDataset(Dataset):  # 继承Dataset
    def __init__(self):
        # 初始化图像文件路径或图像文件名列表等
        pass

    def __getitem__(self, index):
        # 1.根据索引index从文件中读取一个数据 (例如,使用numpy,fromfile,PIL.Image.open, cv2.imread)
        # 2.预处理数据(例如 torchvision.Transform)
        # 3.返回数据对(例如图像标签)
        pass

    def __len__(self):
        count = 0
        return count  # 返回数据量


'''
__init__() : 初始化模块,初始化该类的一些基本参数
__getitem__() : 该函数接收一个index,也就是索引值。只要是具有索引的数据类型都能够被读取,如list,Series,Dataframe等形式。
为了方便,我们一般采用list形式将文件代入函数中,该list中的每一个元素包含了图片的路径或标签等信息,以方便index用来逐一读取单一样本数据。
在__getitem__() 函数内部,我们可以选择性的对图像和标签进行预处理等操作,最后返回图像数据和标签。
__len__() : 返回所有数据的数量
'''


class MnistDataset(Dataset):
    def __init__(self, image_path, image_label, transform=None):
        super(MnistDataset, self).__init__()
        self.image_path = image_path  # 初始化图像路径列表
        self.image_label = image_label  # 初始化图像标签列表
        self.transform = transform  # 初始化数据增强方法

    def __getitem__(self, index):
        # 获取对应 index 的图像,并视情况进行数据增强
        image = Image.open(self.image_label[index])
        image = np.array(image)
        label = float(self.image_label[index])

        if self.transform is not None:
            image = self.transform(image)

        return image, torch.tensor(label)

    def __len__(self):
        return len(self.image_path)


def get_path_label(img_root, label_file_path):
    '''
    获取数字图像的路径和标签并返回对应列表
    :param img_root: 保存图像的根目录
    :param label_file_path: 保存图像标签数据的文件路径 .csv 或 .txt 分割为','
    :return: 图像的路径列表和对应标签列表
    '''
    data = pd.read_csv(label_file_path, names=['img', 'label'])
    data['img'] = data['img'].apply(lambda x: img_root + x)
    return data['img'].tolist(), data['label'].tolist()


# 获取训练集路径列表和标签列表
train_data_root = './dataset/MNIST/mnist_data/train/'
train_label = './dataset/MNIST/mnist_data/train.txt'
train_img_list, train_lable_list = get_path_label(train_data_root, train_label)
# 训练集dataset
train_dataset = MnistDataset(train_img_list,
                             train_lable_list,
                             transform=transforms.Compose([transforms.ToTensor()]))

# 获取测试机路径列表和标签列表
test_data_root = './dataset/MNIST/mnist_data/test/'
test_label = './dataset/MNIST/mnist_data/test.txt'
test_img_list, test_label_list = get_path_label(test_data_root, test_label)
# 测试集dataset
test_dataset = MnistDataset(test_img_list,
                            test_label_list,
                            transform=transforms.Compose([transforms.ToTensor()]))

# 使用DataLoader批量读取数据
'''DataLoader(dataset, 
           batch_size=1, 
           shuffle=False, 
           sampler=None, 
           num_workers=0, 
           collate_fn=default_collate, 
           pin_memory=False, 
           drop_last=False)
           '''
'''
dataset:加载的数据集(Dataset对象)
batch_size:一个批量数目大小
shuffle::是否打乱数据顺序
sampler: 样本抽样方式
num_workers:使用多进程加载的进程数,0代表不使用多进程
collate_fn: 将多个样本数据组成一个batch的方式,一般使用默认的拼接方式,可以通过自定义这个函数来完成一些特殊的读取逻辑。
pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
drop_last:为True时,dataset中的数据个数不是batch_size整数倍时,将多出来不足一个batch的数据丢弃
'''

# 训练数据加载
train_loader = DataLoader(dataset=train_dataset,  # 加载的数据集(Dataset对象)
                          batch_size=3,  # 一个批量大小
                          shuffle=True,  # 是否打乱数据顺序
                          num_workers=4)  # 使用多进程的进程数,0代表不适用多进程(win系统建议改成0)
# 测试数据加载
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=3,
                         shuffle=False,
                         num_workers=4)
'''
经过DataLoader的封装,每3(一个batch_size数量)张图像数据及对应的标签被封装为一个二元元组,
第一个元素为四维的tensor形式,第二个元素为对应的图像标签数据。

DataLoader与Dataset分别处理后的数据比较可以发现出两者的不同:Dataset是对本地数据读取逻辑的定义;
而DataLoader是对Dataset对象的封装,执行调度,将一个batch size的图像数据组装在一起,实现批量读取数据。
'''

'''
图像分类问题,torchvision还提供了一种文件目录组织形式可供调用,即ImageFolder,因为利用了分类任务的特性,
此时就不用再另行创建一份标签文件了。这种文件目录组织形式,要求数据集已经自觉按照待分配的类别分成了不同的文件夹,
一种类别的文件夹下面只存放同一种类别的图片。
'''

# train & test root
train_root = r'./sample/train/'
test_root = r'./sample/test'

# transform
train_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

# train dataset
train_dataset = torchvision.datasets.ImageFolder(root=train_root,
                                                 transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)

# test dataset
test_dataset = torchvision.datasets.ImageFolder(root=test_root,
                                                transform=test_transform, )
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

# 数据增强
'''
图像的增广是通过对训练图像进行一系列变换,产生相似但不同于主体图像的训练样本,来扩大数据集的规模的一种常用技巧。
另一方面,随机改变训练样本降低了模型对特定数据进行记忆的可能,有利于增强模型的泛化能⼒,提高模型的预测效果,
因此可以说数据增强已经不算是一种优化技巧,而是CNN训练中默认要使用的标准操作。
在常见的数据增广方法中,一般会从图像颜色、尺寸、形态、亮度/对比度、噪声和像素等角度进行变换。
当然不同的数据增广方法可以自由进行组合,得到更加丰富的数据增广方法。

在torchvision.transforms中,提供了Compose类来快速控制图像增广方式:我们只需将要采用的数据增广方式存放在一个list中,
并传入到Compose中,便可按照数据增广方式出现的先后顺序依次处理图像。如下面的样例所示:
'''

# 数据预处理
transform = transforms.Compose([transforms.CenterCrop(10),
                                transforms.ToTensor()])

# 中心裁剪
center_crop = transforms.CenterCrop([200, 200])
# 随机裁剪
random_crop = transforms.RandomCrop([200, 200])
# 随机长宽比裁剪
random_resized_crop = transforms.RandomResizedCrop(200,
                                                   scale=(0.08, 1.0),
                                                   ratio=(0.75, 1.55),
                                                   interpolation=2)

# 依概率P水平翻转
h_flip = transforms.RandomHorizontalFlip(0.7)
# 依概率p垂直翻转
v_flip = transforms.RandomVerticalFlip(0.8)
# 随机翻转
random_rotation = transforms.RandomRotation(30)

# 图像填充
pad = transforms.Pad(10, fill=0, padding_mode='constant')
# 调整亮度、对比度、饱和度、色调
color_jitter = transforms.ColorJitter(brightness=1,
                                      contrast=0.5,
                                      saturation=0.5,
                                      hue=0.4)
# 转成灰度图
gray = transforms.Grayscale(1)
# 仿射变换
random_affine = transforms.RandomAffine(45, (0.5, 0.7), (0.8, 0.5), 3)
# 尺寸缩放
resize = transforms.Resize([100, 200])
# 转Tensor、标准化和转化为PILImage
mean = [0.45, 0.5, 0.5]
std = [0.3, 0.6, 0.5]
transforms = transforms.Compose([transforms.ToTensor(),              # 转Tensor
                                 transforms.Normalize(mean, std),
                                 transforms.ToPILImage()])           # 这里是为了可视化,故将其再转为 ptl
img_transform = transform



  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值