PyTorch深度学习入门笔记(七)DataLoader的使用

课程学习笔记,课程链接
学习笔记同步发布在我的个人网站上,欢迎来访查看。

一、dataloader简介

dataset在程序中起到的作用是告诉程序数据在哪,每个索引所对应的数据是什么。相当于一系列的存储单元,每个单元都存储了数据。这里可以类比成一幅扑克牌,一张扑克牌就是一个数据,一幅扑克牌就是一个完整的数据集。

再把神经网络的输入获取类比成手,用手去抓扑克牌,每次抓几张,用一只手去抓取,还是用两只手,这就是 dataloader 要做的事,可以通过参数进行一个设置。
在这里插入图片描述

Pytoch 官网也对 dataloader 进行了一个介绍:
在这里插入图片描述
各个参数都有详细的描述,这里就不再赘述。

二、dataloader的使用

2.1 简单测试

测试代码:

import torchvision
from torch.utils.data import DataLoader

test_data = torchvision.datasets.CIFAR10("./dataset", False, torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)

img, target = test_data[0]
print(img.shape)
print(target)

# return of dataloader
for data in test_loader:
    imgs, targets = data
    print(imgs.shape)
    print(targets)

输出:
在这里插入图片描述
可以看到,单个数据读取时,输出是

torch.Size([3, 32, 32])
3

即图片为RGB三通道,像素大小为32*32,tag为3
采用 dataloader(batch_size=4)读取时:

torch.Size([4, 3, 32, 32])
tensor([1, 7, 9, 2])
torch.Size([4, 3, 32, 32])
tensor([2, 7, 4, 7])

即4张图片,每个图片都为RGB三通道,像素大小为32*32
然后tag也打包在一起了,返回为 tensor([1, 7, 9, 2])形式。
注:Dataloader默认采用的是从数据集中进行随机抓取。

2.2 通过tensorboard显示抓取结果

示例代码:

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_data = torchvision.datasets.CIFAR10("./dataset", False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

img, target = test_data[0]
print(img.shape)
print(target)

# return of dataloader
writer = SummaryWriter("dataloader")
step = 0
for data in test_loader:
    imgs, targets = data
    # print(imgs.shape)
    # print(targets)
    writer.add_images("test_data_drop_last", imgs, step)
    step = step + 1
writer.close()

结果:
在这里插入图片描述
这里每次抓取64个数据,用 add_images 函数写入到 SummaryWriter实例化对象中,再进行显示:
这里当 DataLoader 的输入 drop_last设置为True时,最后一次抓取的数据若不满64,则会被丢弃。为Flase时则不会,如上图的上半部分所示,最后一次抓取了16个数据,不满64,没有丢弃。

2.3 shuffle

示例代码:

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_data = torchvision.datasets.CIFAR10("./dataset", False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=True)

img, target = test_data[0]
print(img.shape)
print(target)

# return of dataloader
writer = SummaryWriter("dataloader")
step = 0
for epoch in range(2):
    for data in test_loader:
        imgs, targets = data
        # print(imgs.shape)
        # print(targets)
        writer.add_images("Epoch:{}".format(epoch), imgs, step)
        step = step + 1
writer.close()
  • shuffle为 False时,两次抓取的顺序不会进行打乱,即两次抓取的结果一样
    在这里插入图片描述

  • shuffle为 True时,两次抓取的顺序会进行打乱,即两次抓取的结果不一样
    在这里插入图片描述

dataloader 返回的 imgs 可以作为神经网络的输入,那下一篇博客将介绍如何搭建神经网络。

  • 6
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

雪天鱼

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值