PyTorch之Dataset和TensorDataset

Deep Learning系列 @cxx

Dataset v.s. TensorDataset

使用PyTorch搭建过Neural Network的小伙伴们都知道,在数据准备步骤里,我们需要把训练集的x和y分装在dataset里,然后将dataset分装到DataLoader中去,便于之后在搭建好的模型中训练。
简言之,dataset是用来做打包和预处理(比如输入资料路径自动读取);DataLoader则是将整个资料集(dataset)按照batch进行迭代分装或者shuffle(可以得到一个iterator以利于for循环读取)。

Dataset

如果使用继承Dataset的方式,那么在自定义的dataset类中必须给予__len__和__getitem__的定义。
进行图片处理的时候,可以定义一个transforms来随机旋转训练图片,将图片格式变成tensor等
(这里有一个坑)

假设我们读取了一个有如下格式的图片
在这里插入图片描述
将图片分装到dataset里,再放到dataloader里

from torch.utils.data import TensorDataset
batch_size = 128
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),]
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值