Day7 DataLoader + Tensorboard

参考了两篇总结:点击跳转1

点击跳转2


前言——回顾Pytorch的数据读取机制

之前学习过Dataset读取数据,其实,DataLoader和DataSet就是数据读取子模块中的核心机制。引用一张图来说明:
在这里插入图片描述

torch.utils.data.Dataset(): Dataset抽象类,
所有自定义的Dataset都需要继承它,并且必须重写getitem、init、len这个类的方法。

getitem方法是Dataset的核心,作用是接收一个索引(将路径对应数据通过os.dirlist合并为一个列表并返回), 返回一个样本,参数里面接收index,然后编写如何通过这个索引去读取数据部分。接下来正文部分介绍DataLoader。


一、DataLaoder介绍

torch.utils.data.DataLoader():构建可迭代的数据装载器,,我们在训练的时候,每一次for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据的。

DataLoader的参数很多,但我们常用的主要有5个:

  • dataset:Dataset类, 决定数据从哪读取以及如何读取
  • bathsize:批大小
  • num_works:是否多进程读取机制
  • shuffle:每个epoch是否乱序
  • drop_last:当样本数不能被batchsize整除时, 是否舍弃最后一批数据

二、具体过程

1.准备数据

代码如下:

#导入包
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

#下载数据集,以CIFAR10为例
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())

#加载数据集
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=True)

注意:

  1. 加载数据时的参数dataset指向下载的数据集
  2. 每次循环取batch_size=4张数据
  3. shuffle为true表示随机打乱
  4. drop_last表示最后一组数据不足4张时候舍弃。

下载数据集完成:

在这里插入图片描述

2.在控制台查看加载好的数据

代码如下:

for data in test_loader:
    imgs,targets=data
    print(imgs.shape)
    print(targets)

也可以直接查看加载好的第一张图片信息:

# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)  # torch.Size([3, 32, 32]) 单张图片的尺寸和通道数
print(target)  # 输出为 3

3.通过tensorboard加载多组图片信息

writer = SummaryWriter("dataloader")
# 测试数据集上所有的图片 imgs 是复数
for epoch in range(2):  # 进行两轮,上面的 shuffle,是对这个位置有影响,而不是 for data 那个循环有影响
    step = 0
    for data in test_loader:  # 这个loader,返回的内容,就已经是包含了 img 和 target 两个值了,这个在 cifar 数据集的 getitem 函数里,写了
        imgs, targets = data
        # print(imgs.shape)   # torch.Size([4, 3, 32, 32]) 这个输出的结果,其中的 4 ,是 batch_size 设定的值, 后面的 3, 32, 32 是单张图片的尺寸和通道数
        #  print(targets)  # tensor([2, 8, 0, 2])  这4个数字,是对 target 的打包,是随机的,数值代表所在的分类;debug一下,可以看到 sampler中的取值方式,是 RandomSampler
        #  随机从 Data 中,抓取 4 个数据
        writer.add_images("Epoch: {}".format(epoch), imgs, step)
        step = step + 1

writer.close()

三、运行结果展示

1.Epoch0

在这里插入图片描述

2.Epoch1

在这里插入图片描述

四、总结

在本文中回顾了Dataset加载数据的方式,以及学习了DataLoader的机制、简单操作过程,并以读取CIFAR10中的数据为例,借助Tensorboard的展示读取到的数据,更好的理解了dataloader的参数含义,能为后续神经网络的训练奠定基础,同时更好的学习和使用pytorch。
发现了一篇关于PyTorch读取数据机制的文章,非常适合学习,链接附上:
点击跳转

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值