pyTorch初识(3)—— DataLoader

Dataset告诉程序数据集在什么位置,DataLoader是一个加载器,从Dataset中取数据加载到神经网络中,取多少怎么取都由DataLoader中的参数控制,官网中的定义如下:

一、每次使用需要引用包

二、参数说明

CLASS torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False, pin_memory_device='')

其中大多数参数都有默认值,只有 dataset 没有,这个 dataset 需要自定义,附上链接:(38条消息) pytorch 基础认识——Dataset_堂小仙儿的博客-CSDN博客,主要告诉我们数据在什么位置,包括第一第二数据集有多少数据,具体可以看官网里 Dataset 有什么方法和类。

三、常见参数设置

batch_size(int,optional):例如有一摞牌,batch_size=1,每次抓牌抓一张,等于2,每次抓就抓两张;

shuffle(bool,optional):就是洗牌,每一局牌结束就要洗牌,如果和第一次牌的顺序一致,就返回 False,否则返回 True;

sampler (Sampler or Iterable, optional):随机抓取,官网中的定义是这样的:

sampler: Union[Sampler, Iterable, None] = None

一般情况下默认随即抓取

batch_sampler(Sampler or Iterable, optional): batch_sampler是一个批数据采样器,通过某种限制得到输入数据集的子集,即批数据;

num_workers(int, optional):多进程,默认为0,表示只用主进程(注:有时候不为零会出现报错,可以该成零看看能不能解决);

drop_last(bool, optional):50张牌每次抓三张牌,抓到最后还剩一张牌,这张牌是否舍弃,True 表示舍弃,False 表示保留;

举例:

import torchvision
from torch.utils.data import DataLoader

# 准备的测试集
test_data = torchvision.datasets.CIFAR10(root='F:\\pycharm\\text\\dataset',train=False,transform=torchvision.transforms.ToTensor())
# 最后一句 transform 是自定义,表示对数据做处理,都变成tensor类型的
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0,drop_last=False)
# 要处理的数据集是 test_data,每次处理4张图片,处理结束,顺序和处理之前一致,单线程处理,多出来的图片保留

img, target = test_data[0]
print(img.shape)  # 输出:torch.Size([3, 32, 32])  3通道,大小为32*32
print(target)   # 输出:3  在数据集内标签为3

for data in test_loader:
    imgs, targets = data
    print(imgs.shape)
    print(targets)
# for循环说的是,从 test_loader中取出数据,test_loader 实例化 DataLoader 这个类,引用了 dataset ,这个类里面需要重置 __getitem__方法,给的 test_data 就是重置的目标,返回值包括图片和标签,然后图片和标签被 dataloader 按照 batchsize 的大小分开,具体如下所示:

h

整个运行结果如下所示(只拿部分举例):

torch.Size([3, 32, 32])
3 #一二两句是单张图片的数据,3通道的32*32大小的图片,在数据集内的标签是3
# 当batchsize是4的时候,就从数据据里一次拿四张,都是3通道32*32的图片,每张图片的标签被打包了,如第五句所示,这里的采样时随机的,因为 sample 时
torch.Size([4, 3, 32, 32])
tensor([5, 0, 1, 4])
torch.Size([4, 3, 32, 32])
tensor([4, 9, 3, 6])

附上一篇博客:(38条消息) 系统学习Pytorch笔记三:Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)_pytorch dataloader读取数据_翻滚的小@强的博客-CSDN博客

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值