学习pytorch数据读取机制中两个重要模块dataloader与Dataset:通过一个人民币分类实验来学习pytorch是如何从硬盘中读取数据的,并深入学习数据读取中涉及的两个模块DataSet与Dataloader;
熟悉数据预处理处理transforms方法的运行机制:数据在读取到pytorch之后通常都需要对数据进行预处理,包括尺寸缩放、转换张量、数据中心化或标准化等等,这些操作都是通过transforms进行的。
数据 :
数据收集 :img、label
数据划分:train、valid、test
数据读取:dataloader,dataloader又分为两个子模块,分别是Sampler(生成索引)和DataSet(根据索引读取图片和标签)
数据预处理:transforms
dataloader:构建可迭代的数据装载器
torch.utils.data.Dataloader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_works=0,collate_fn=None,pin_memor=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None)
下面介绍常用的几个参数:
- dataset:Dataset类,决定数据从哪读取即如何读取
- batchsize:批大小
- num_works:是否多进程读取数据
- shuffle:每个epoch是否乱序
- drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
Epoch:指所有训练样本都已输入到模型中,称为一个epoch
Iteration:一批样本输入到模型中,称之为一个Iteration
Batchsize:批大小,决定一个epoch有多少个iteration
例:假如样本总数为80,Batchsize为8,则1个epoch == 10 iteration;若样本总数为87,Batchsize为8,则当drop_last=True时,1个epoch为10个iteration;当drop_last=False时,1个epoch为11个iteration。
torch.utils.data.Dataset()
Dataset:定义数据从哪里读取,以及如何读取的工具,pytorch中的dataset为抽象类,所有自定义的dataset需要继承它,并且复写 __getitem__(),getitem:接收一个索引,返回一个样本。
class Dataset(object):
def __getitem(self,index):
raise NotImplementedError
def __add__(self,other):
return ConcatDataset([self,other])
数据读取及划分
import os
import random
import shutil
def makedir(new_dir):
if not os.path.exists(new_dir):
os.makedirs(new_dir)
if __name__ == '__main__':
random.seed(1)
dataset_dir = os.path.join("..", "..", "data", "RMB_data")
split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
test_dir = os.path.join(split_dir, "test")
train_pct = 0.8
valid_pct = 0.1
test_pct = 0.1
for root, dirs, files in os.walk(dataset_dir):
for sub_dir in dirs:
imgs = os.listdir(os.path.join(root, sub_dir))
imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))
random.shuffle(imgs)
img_count = len(imgs)
train_point = int(img_count * train_pct)
valid_point = int(img_count * (train_pct + valid_pct))
for i in range(img_count):
if i < train_point:
out_dir = os.path.join(train_dir, sub_dir)
elif i < valid_point:
out_dir = os.path.join(valid_dir, sub_dir)
else:
out_dir = os.path.join(test_dir, sub_dir)
makedir(out_dir)
target_path = os.path.join(out_dir, imgs[i])
src_path = os.path.join(dataset_dir, sub_dir, imgs[i])
shutil.copy(src_path, target_path)
print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point,
img_count-valid_point))
2 transform
torchvision:计算机视觉工具包,在pytorch中三个主要的模块:
- torchvision.transforms :常用的图像预处理模块(图片标准化、翻转、缩放、裁剪等)
- torchvision.datasets:常用数据集的dataset实现,MNIST、CIFAR-10、ImageNet等
- torchvision.model:常用的模型预训练,AlexNet、VGG、ResNet、GoogleNet等
torchvision.transforms常用的图像预处理方法:数据中心化、数据标准化、缩放、裁剪、旋转、翻转、填充、噪声添加、灰度变换、线性变换、放射变换、亮度、饱和度及对比变换等。
transform.Normalize(mean,std,inplace=False)
- 功能:逐channel的对图像进行标准化 output = (input-mean)/std
- mean:各通道的均值
- std:各通道的标准差
- inplace:是否原地操作
对数据进行标准化可加快模型的收敛