04.DataLoader与Dataset;transforms与normalize


本课程来自深度之眼deepshare.net,部分截图来自课程视频。

人民币二分类

在这里插入图片描述
人民币看做自变量x,类别看做因变量y。回顾上节课的机器学习模型:
在这里插入图片描述
这节课主要是数据这个模块。数据模块分为四个部分:
数据收集:Img、Label
数据划分:train、valid、test
数据读取DataLoader:分为两个子模块:Sampler和DataSet
Sampler的主要功能是对数据生成索引:Index
DatsSet是根据索引读取数据:Img、Label
数据预处理:transforms

DataLoader与Dataset

DataLoader

torch.utils.data.DataLoader
功能:构建可迭代的数据装载器
·dataset:Dataset类,决定数据从哪读取及如何读取
·batchsize:批大小
·num_works:是否多进程读取数据:4.8.16
·shuffle:每个epoch是否乱序
·drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
collate_fn:要对数据进行的统一处理,一般是一个自定义的函数

DataLoader(dataset, batch_size=1, 
shuffle=False, sampler=None, batch_sampler=None, 
num_workers=0, collate_fn=None, pin_memory=False,
 drop_last=False, timeout=0, worker_init_fn=None, 
 multiprocessing_context=None)

Epoch:所有训练样本都已输入到模型中,称为一个Epoch
Iteration:批样本输入到模型中,称之为一个lteration
Batchsize:批大小,决定一个Epoch有多少个lteration
例如:
样本总数:80,Batchsize:8
1Epoch=10 lteration
再例如:
样本总数:87,Batchsize:8
1 Epoch=10 lteration?drop_last=True
1 Epoch=11 lteration?drop last=False

Dataset

torch.utils.data.Dataset
功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写/实现
getitem()
getitem:
接收一个索引,返回一个样本

class Dataset(object): 
	def __getitem__(self,index): 
		raise NotImplementedError 
	def add(self,other): 
		return ConcatDataset([self,other])

PyTorch数据读取机制

1.读哪些数据?Sampler输出的Index
2.从哪读数据?Dataset中的data_dir

split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
train_transform = transforms.Compose([
    transforms.Resize((32, 32)),#改大小
    transforms.RandomCrop(32, padding=4),#随机裁剪
    transforms.ToTensor(),#转为张量
    transforms.Normalize(norm_mean, norm_std),#归一化
])

3.怎么读数据?Dataset中的getitem
在这里插入图片描述

transforms 运行机制

transforms

torchvision是计算机视觉工具包(安装的时候有单独装),有三个主要模块
torchvision.transforms:常用的图像预处理方法(这节课的主要学习对象)
torchvision.datasets:常用数据集的dataset实现,MNIST,CIFAR-10,ImageNet等
torchvision.model:常用的模型预训练,AlexNet,VGG,ResNet,GoogLeNet等

torchvision.transforms:常用的图像预处理方法

·数据中心化
·数据标准化
缩放
裁剪
旋转
翻转
填充
·噪声添加
·灰度变换
·线性变换
仿射变换
·亮度、饱和度及对比度变换
在这里插入图片描述
上面就是一张图片经过数据增强之后得到的一些结果。提高模型的泛化能力
继续看上节内容中的人民币二分类的例子来看。就用到了resize、randomcorp、totensor、normalize处理。该方法在getitem中调用。
在这里插入图片描述
注意:在验证集中不需要做数据增强。
相当于练习的时候会使用各种花招,增加难度,提高训练效果,对付真正敌人的时候就直接上。

数据标准化—transforms.normalize

transforms.Normalize
功能:逐channel的对图像进行标准化,公式为:output=(input-mean)/std
·mean:各通道的均值
·std:各通道的标准差
·inplace:是否原地操作

transforms. Normalize(mean, std, inplace=False)

对图像进行标准化可以加快数据处理的速度。例子:
逻辑回归例子中,如果把原来在零点均匀分布的数据移动(改变bias)后,训练的效果的速度明显变差。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

oldmao_2000

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值