Datawhale组队学习学习记录task2
本文仅记录在学习过程中遇到的部分问题,Datawhale教程:DATAWHALE - 一个热爱学习的社区 (linklearner.com)。
3.3数据读入
PyTorch数据读入通过Dataset+DataLoader的方式完成。
Dataset定义好数据的格式和数据变换形式,可以定义自己的Dataset类来灵活读取数据,定义的类需要继承PyTorch自身的Dataset类。
DataLoader用iterative的方式不断读入批次数据。
接下来是Dataset的两种构建方法:
1.引用官方数据集
这里的datasets是torcvision.datasets,torchvision是基于pytorch的工具集,用于处理图像视频,类似的还有NLP的torchtext,音频的torchaudio。torchvision包含一些常用数据集、模型、转换函数等,torchvision.datasets用于调用官方数据集,ImageFolder类用于读取按一定结构存储的图片数据,要求结构类似下图这样:
参数中的path应为./root,例子中即为./data/train,transform为预处理函数。
2.自定义Datasets类
需要根据自己的数据来定义,类中需要包含__init__,__getitem__,__len__三个函数。
构建好Dataset后,再使用DataLoader读入数据:
from torch.utils.data import DataLoader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, num_workers=4, shuffle=False)
参数中:1.batch_size即一次训练所抓取的数据样本数量;2.num_workers多个进程读取数据,Windows用户设为0;3.shuffle是否将数据打乱;4.drop_last使样本最后未达到批次数的部分不再参与训练。
4基础实战--FashionMNIST
按照教程来即可,需要注意的是无GPU时运行要去掉.cuda()的几句代码,以及数据读入时,如果是直接调用官方数据需要把预处理中的transforms.ToPILImage()这行去掉。