【Datawhale组队学习记录】through-pytorch-task2

Datawhale组队学习学习记录task2

本文仅记录在学习过程中遇到的部分问题,Datawhale教程:DATAWHALE - 一个热爱学习的社区 (linklearner.com)

3.3数据读入

PyTorch数据读入通过Dataset+DataLoader的方式完成。

Dataset定义好数据的格式和数据变换形式,可以定义自己的Dataset类来灵活读取数据,定义的类需要继承PyTorch自身的Dataset类。

DataLoader用iterative的方式不断读入批次数据。

接下来是Dataset的两种构建方法:

1.引用官方数据集

这里的datasets是torcvision.datasets,torchvision是基于pytorch的工具集,用于处理图像视频,类似的还有NLP的torchtext,音频的torchaudio。torchvision包含一些常用数据集、模型、转换函数等,torchvision.datasets用于调用官方数据集,ImageFolder类用于读取按一定结构存储的图片数据,要求结构类似下图这样:

 

 参数中的path应为./root,例子中即为./data/train,transform为预处理函数。

2.自定义Datasets类

需要根据自己的数据来定义,类中需要包含__init__,__getitem__,__len__三个函数。

构建好Dataset后,再使用DataLoader读入数据:

from torch.utils.data import DataLoader

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, num_workers=4, shuffle=False)

参数中:1.batch_size即一次训练所抓取的数据样本数量;2.num_workers多个进程读取数据,Windows用户设为0;3.shuffle是否将数据打乱;4.drop_last使样本最后未达到批次数的部分不再参与训练。

4基础实战--FashionMNIST

按照教程来即可,需要注意的是无GPU时运行要去掉.cuda()的几句代码,以及数据读入时,如果是直接调用官方数据需要把预处理中的transforms.ToPILImage()这行去掉。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值