pytorch学习四:pytorch数据读取机制D

学习pytorch数据读取机制中两个重要模块dataloader与Dataset:通过一个人民币分类实验来学习pytorch是如何从硬盘中读取数据的,并深入学习数据读取中涉及的两个模块DataSet与Dataloader;

熟悉数据预处理处理transforms方法的运行机制:数据在读取到pytorch之后通常都需要对数据进行预处理,包括尺寸缩放、转换张量、数据中心化或标准化等等,这些操作都是通过transforms进行的。

数据 :

数据收集 :img、label

数据划分:train、valid、test

数据读取:dataloader,dataloader又分为两个子模块,分别是Sampler(生成索引)和DataSet(根据索引读取图片和标签)

数据预处理:transforms

dataloader:构建可迭代的数据装载器

torch.utils.data.Dataloader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_works=0,collate_fn=None,pin_memor=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None)

下面介绍常用的几个参数:

  • dataset:Dataset类,决定数据从哪读取即如何读取
  • batchsize:批大小
  • num_works:是否多进程读取数据
  • shuffle:每个epoch是否乱序
  • drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

Epoch:指所有训练样本都已输入到模型中,称为一个epoch

Iteration:一批样本输入到模型中,称之为一个Iteration

Batchsize:批大小,决定一个epoch有多少个iteration

例:假如样本总数为80,Batchsize为8,则1个epoch == 10 iteration;若样本总数为87,Batchsize为8,则当drop_last=True时,1个epoch为10个iteration;当drop_last=False时,1个epoch为11个iteration。

torch.utils.data.Dataset()

Dataset:定义数据从哪里读取,以及如何读取的工具,pytorch中的dataset为抽象类,所有自定义的dataset需要继承它,并且复写 __getitem__(),getitem:接收一个索引,返回一个样本。

class Dataset(object):

    def __getitem(self,index):
        raise NotImplementedError

    def __add__(self,other):
        return ConcatDataset([self,other])

数据读取及划分

import os
import random
import shutil


def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)


if __name__ == '__main__':

    random.seed(1)

    dataset_dir = os.path.join("..", "..", "data", "RMB_data")
    split_dir = os.path.join("..", "..", "data", "rmb_split")
    train_dir = os.path.join(split_dir, "train")
    valid_dir = os.path.join(split_dir, "valid")
    test_dir = os.path.join(split_dir, "test")

    train_pct = 0.8
    valid_pct = 0.1
    test_pct = 0.1

    for root, dirs, files in os.walk(dataset_dir):
        for sub_dir in dirs:

            imgs = os.listdir(os.path.join(root, sub_dir))
            imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))
            random.shuffle(imgs)
            img_count = len(imgs)

            train_point = int(img_count * train_pct)
            valid_point = int(img_count * (train_pct + valid_pct))

            for i in range(img_count):
                if i < train_point:
                    out_dir = os.path.join(train_dir, sub_dir)
                elif i < valid_point:
                    out_dir = os.path.join(valid_dir, sub_dir)
                else:
                    out_dir = os.path.join(test_dir, sub_dir)

                makedir(out_dir)

                target_path = os.path.join(out_dir, imgs[i])
                src_path = os.path.join(dataset_dir, sub_dir, imgs[i])

                shutil.copy(src_path, target_path)

            print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point,
                                                                 img_count-valid_point))

2 transform

torchvision:计算机视觉工具包,在pytorch中三个主要的模块:

  • torchvision.transforms :常用的图像预处理模块(图片标准化、翻转、缩放、裁剪等)
  • torchvision.datasets:常用数据集的dataset实现,MNIST、CIFAR-10、ImageNet等
  • torchvision.model:常用的模型预训练,AlexNet、VGG、ResNet、GoogleNet等

torchvision.transforms常用的图像预处理方法:数据中心化、数据标准化、缩放、裁剪、旋转、翻转、填充、噪声添加、灰度变换、线性变换、放射变换、亮度、饱和度及对比变换等。

transform.Normalize(mean,std,inplace=False)

  • 功能:逐channel的对图像进行标准化 output = (input-mean)/std
  • mean:各通道的均值
  • std:各通道的标准差
  • inplace:是否原地操作

对数据进行标准化可加快模型的收敛

 

 

 

 

 

 

 

 

 

 

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值