【深度之眼】【Pytorch打卡第3天】:DataLoader、DataSet、Transforms+划分数据集代码、构建Dataset、读取数据

本文介绍了Pytorch中用于数据处理的关键组件,包括DataLoader的作用,Dataset的创建与使用,以及Transforms的各种图像预处理操作。DataLoader用于构建可迭代的数据加载器,设置批大小、多进程读取和数据乱序等选项。Dataset是数据源的抽象类,需要自定义以适应具体数据。Transforms提供了如数据中心化、标准化、缩放等预处理方法,加速模型训练。此外,还讲解了如何划分数据集、构建自定义Dataset以及如何读取数据。
摘要由CSDN通过智能技术生成

概括

在这里插入图片描述


DataLoader与DataSet

torch.utils.data.DataLoader:构建可迭代的数据装载器
  • dataset: Dataset类,决定数据从哪读取 及如何读取
  • batchsize : 批大小
  • num_works: 是否多进程读取数据
  • shuffle: 每个epoch是否乱序
  • drop_last:当样本数不能被batchsize整 除时,是否舍弃最后一批数据
torch.utils.data.Dataset:Dataset抽象类,所有自定义的 Dataset需要继承它,并且复写
  • getitem()
    getitem : 接收一个索引,返回一个样本

在这里插入图片描述

在这里插入图片描述


Transforms

  • torchvision.transforms : 常用的图像预处理方法
  • torchvision.datasets : 常用数据集的dataset实现,MNIST,CIFAR-10,ImageNet等
  • torchvision.model : 常用的模型预训练,AlexNet,VGG, ResNet,GoogLeNet等
transforms

torchvision.transforms : 常用的图像预处理方法
• 数据中心化
• 数据标准化
• 缩放
• 裁剪
• 旋转
• 翻转
• 填充
• 噪声添加
• 灰度变换
• 线性变换
• 仿射变换
• 亮度、饱和度及对比度变换

transforms.Normalize:加速运算
  • 功能:逐channel的对图像进行标准化 output = (input - mean) / std
    • mean:各通道的均值
    • std:各通道的标准差
    • inplace:是否原地操作
train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

在这里插入图片描述

划分数据集

# -*- coding: utf-8 -*-
"""
# @file name  : 1_split_dataset.py
# @author     : xinwenhu
# @date       : 2019-09-07 10:08:00
# @brief      : 将数据集划分为训练集,验证集,测试集
"""

import os
import random
import shutil


def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值