概括
DataLoader与DataSet
torch.utils.data.DataLoader:构建可迭代的数据装载器
- dataset: Dataset类,决定数据从哪读取 及如何读取
- batchsize : 批大小
- num_works: 是否多进程读取数据
- shuffle: 每个epoch是否乱序
- drop_last:当样本数不能被batchsize整 除时,是否舍弃最后一批数据
torch.utils.data.Dataset:Dataset抽象类,所有自定义的 Dataset需要继承它,并且复写
- getitem()
getitem : 接收一个索引,返回一个样本
Transforms
- torchvision.transforms : 常用的图像预处理方法
- torchvision.datasets : 常用数据集的dataset实现,MNIST,CIFAR-10,ImageNet等
- torchvision.model : 常用的模型预训练,AlexNet,VGG, ResNet,GoogLeNet等
transforms
torchvision.transforms : 常用的图像预处理方法
• 数据中心化
• 数据标准化
• 缩放
• 裁剪
• 旋转
• 翻转
• 填充
• 噪声添加
• 灰度变换
• 线性变换
• 仿射变换
• 亮度、饱和度及对比度变换
transforms.Normalize:加速运算
- 功能:逐channel的对图像进行标准化 output = (input - mean) / std
• mean:各通道的均值
• std:各通道的标准差
• inplace:是否原地操作
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
valid_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
划分数据集
# -*- coding: utf-8 -*-
"""
# @file name : 1_split_dataset.py
# @author : xinwenhu
# @date : 2019-09-07 10:08:00
# @brief : 将数据集划分为训练集,验证集,测试集
"""
import os
import random
import shutil
def makedir(new_dir):
if not os.path.exists(new_dir):
os.makedirs(