Pytorch学习笔记(深度之眼)(2)之DataLoader and Dataset

21 篇文章 2 订阅
17 篇文章 1 订阅

1.DataLoader and Dataset

在这里插入图片描述数据模块又可以细分为 4 个部分:

数据收集:样本和标签。
数据划分:训练集、验证集和测试集
数据读取:对应于PyTorch 的 DataLoader。其中 DataLoader 包括 Sampler 和 DataSet。Sampler 的功能是生成索引, DataSet 是根据生成的索引读取样本以及标签。
数据预处理:对应于 PyTorch 的 transforms
在这里插入图片描述在这里插入图片描述在这里插入图片描述
功能:Dataset 是抽象类,所有自定义的 Dataset 都需要继承该类,并且重写__getitem()方法和__len()方法 。__getitem()方法的作用是接收一个索引,返回索引对应的样本和标签,这是我们自己需要实现的逻辑。len()方法是返回所有样本的数量。
在这里插入图片描述在这里插入图片描述
首先在 for 循环中遍历DataLoader,然后根据是否采用多进程,决定使用单进程或者多进程的DataLoaderIter。在DataLoaderIter里调用Sampler生成Index的 list,再调用DatasetFetcher根据index获取数据。在DatasetFetcher里会调用Dataset的__getitem
()方法获取真正的数据。这里获取的数据是一个 list,其中每个元素是 (img, label) 的元组,再使用 collate_fn()函数整理成一个 list,里面包含两个元素,分别是 img 和 label 的tenser。

2.transforms

在这里插入图片描述在这里插入图片描述在这里插入图片描述

# 设置训练集的数据增强和转化
train_transform = transforms.Compose([
    transforms.Resize((32, 32)),# 缩放
    transforms.RandomCrop(32, padding=4), #裁剪
    transforms.ToTensor(), # 转为张量,同时归一化
    transforms.Normalize(norm_mean, norm_std),# 标准化
])

 设置验证集的数据增强和转化,不需要 RandomCrop
valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

当我们需要多个transforms操作时,需要作为一个list放在transforms.Compose中。需要注意的是transforms.ToTensor()是把图片转换为张量,同时进行归一化操作,把每个通道 0~255 的值归一化为 0~1。在验证集的数据增强中,不再需要transforms.RandomCrop()操作。然后把这两个transform操作作为参数传给Dataset,在Dataset的__getitem__()方法中做图像增强。
在这里插入图片描述对数据进行均值为 0,标准差为 1 的标准化,可以加快模型的收敛。

3.数据增强 在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述 在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述  在这里插入图片描述在这里插入图片描述

4. code

数据集划分

# -*- coding: utf-8 -*-
"""
# @file name  : 1_split_dataset.py
# @author     : tingsongyu
# @date       : 2019-09-07 10:08:00
# @brief      : 将数据集划分为训练集,验证集,测试集
"""
import os
import random
import shutil


def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)             # 用来创建多层目录(单层请用os.mkdir)


if __name__ == '__main__':

    random.seed(1)

    dataset_dir = os.path.join("G:\\", "hello", "data", "Cat_dog_data")   # 路径拼接
    split_dir = os.path.join("G:\\", "hello", "data", "cat_dog_split")
    train_dir = os.path.join(split_dir, "train")
    valid_dir = os.path.join(split_dir, "valid")
    test_dir = os.path.join(split_dir, "test")

    train_pct = 0.8
    valid_pct = 0.1
    test_pct = 0.1

    for root, dirs, files in os.walk(dataset_dir):
        for sub_dir in dirs:

            imgs = os.listdir(os.path.join(root, sub_dir))
            imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))
            random.shuffle(imgs)
            img_count = len(imgs)

            train_point = int(img_count * train_pct)
            valid_point = int(img_count * (train_pct + valid_pct))

            for i in range(img_count):
                if i < train_point:
                    out_dir = os.path.join(train_dir, sub_dir)
                elif i < valid_point:
                    out_dir = os.path.join(valid_dir, sub_dir)
                else:
                    out_dir = os.path.join(test_dir, sub_dir)

                makedir(out_dir)

                target_path = os.path.join(out_dir, imgs[i])
                src_path = os.path.join(dataset_dir, sub_dir, imgs[i])

                shutil.copy(src_path, target_path)

            print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point - train_point,
                                                                 img_count - valid_point))

dataset

# -*- coding: utf-8 -*-
"""
# @file name  : dataset.py
# @author     : yts3221@126.com
# @date       : 2019-08-21 10:08:00
# @brief      : 各数据集的Dataset定义
"""

import os
import random
from PIL import Image
from torch.utils.data import Dataset

random.seed(1)
rmb_label = {"cat": 0, "dog": 1}


class CATDataset(Dataset):
    def __init__(self, data_dir, transform=None):         # 初始化
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.label_name = {"cat": 0, "dog": 1}
        # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.data_info = self.get_img_info(data_dir)
        self.transform = transform

    # 根据索引index返回图像及标签,即获取图像
    def __getitem__(self, index):
        # 通过self.data_info函数得到图像路径和标签
        path_img, label = self.data_info[index]
        # 通过Image.open得到img
        img = Image.open(path_img).convert('RGB')     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    # 查看数据长度,即样本的数量,数据集的数量
    def __len__(self):
        return len(self.data_info)

    # 自定义的函数
    # 用于获取路径和标签
    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_img, int(label)))

        return data_info


  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值