PyTorch学习笔记(3)dataset与dataloader

pytorch 官网有大量数据集,可以通过函数调用的方式直接下载并使用,避免了繁琐的数据集搜集与整理工作。

在官方文档中有详细的 API 说明与数据集介绍:https://pytorch.org/docs/stable/index.html


torchvision中数据集的使用

下载与查看

这里下载 CIFAR10 数据集(用于图像识别分类任务)。

root 为数据集要保存的根目录,train=True 表示下载训练集download=True 表示如果本地没有该数据集,才从网上下载。

import torchvision

train_set = torchvision.datasets.CIFAR10(root="./dataset_CIFAR10",train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset_CIFAR10",train=False,download=True)

可以看到,正在下载相应的数据集。如果网速过慢,可以粘贴链接到迅雷进行下载。

image-20220811164506430

下载完毕,我们打个断点 debug 看看得到了个什么东西。

image-20220811170317277

其中,classes 为该数据集的所有类别targets 为所有图片对应的类别的索引

image-20220811170647569

通过 target ,我们就能知道一张图片属于哪个类别

import torchvision
from PIL import Image

train_set = torchvision.datasets.CIFAR10(root="./dataset_CIFAR10",train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset_CIFAR10",train=False,download=True)

img,target = test_set[0]
print(img)
print(target)

print(test_set.classes)
print(test_set.classes[target])
img.show()
image-20220811171830358

在tensorboard中显示

import torchvision
from torch.utils.tensorboard import SummaryWriter

# 定义数据集要进行的变换
dataset_trans = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

train_set = torchvision.datasets.CIFAR10(root="./dataset_CIFAR10",train=True,
                                         transform=dataset_trans,download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset_CIFAR10",train=False,
                                        transform=dataset_trans,download=True)

writer = SummaryWriter("p10")

for i in range(10):
    img,target = train_set[i]
    print(i)
    print(img)
    writer.add_image("前十张图片",img,i)

writer.close()

如果出现图片无法显示的问题,那就删掉日志文件,重新运行并输入 tensorboard --logdir=p10 命令。

如果发现 step 不连续,有缺失,很正常,默认只显示十张。

如果想要显示更多图片,输入以下命令:

tensorboard --logdir=p10 --samples_per_plugin=images=100

dataloader的使用

dataset 决定了数据从哪里读取以及如何读取,dataloader 构建可迭代的数据装载器,进一步确定如何加载 dataset 里的数据。

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

test_set = torchvision.datasets.CIFAR10(root="./dataset_CIFAR10",train=False,transform=dataset_transform)

# batch_size=4 表示随机四张图片打包
# num_workers=0 表示仅一个主线程
# shuffle=True 表示每次打包的四张图片顺序不同
# drop_last=False 表示打包无法整除时,最后的那几张图片留下
test_loader = DataLoader(test_set,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
print(test_loader)

writer = SummaryWriter("dataloader")
step = 0
for data in test_loader:
    imgs,targets = data # imgs将4张图片打包,3表示RGB  targets表示4张图片的所有target
    print(imgs.shape)
    print(targets)
    writer.add_images("test_data",imgs,step)
    step += 1
    pass
writer.close()
  • batch_size=4 表示随机四张图片打包
  • num_workers=0 表示仅一个主线程
  • shuffle=True 表示每次打包的四张图片顺序不同
  • drop_last=False 表示打包无法整除时,最后的那几张图片留下
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch中,数据读取是构建深度学习模型的重要一环。为了高效处理大规模数据集,PyTorch提供了三个主要的工具:DatasetDataLoader和TensorDatasetDataset是一个抽象类,用于自定义数据集。我们可以继承Dataset类,并重写其中的__len__和__getitem__方法来实现自己的数据加载逻辑。__len__方法返回数据集的大小,而__getitem__方法根据给定的索引返回样本和对应的标签。通过自定义Dataset类,我们可以灵活地处理各种类型的数据集。 DataLoader是数据加载器,用于对数据集进行批量加载。它接收一个Dataset对象作为输入,并可以定义一些参数例如批量大小、是否乱序等。DataLoader能够自动将数据集划分为小批次,将数据转换为Tensor形式,然后通过迭代器的方式供模型训练使用。DataLoader在数据准备和模型训练的过程中起到了桥梁作用。 TensorDataset是一个继承自Dataset的类,在构造时将输入数据和目标数据封装成Tensor。通过TensorDataset,我们可以方便地处理Tensor格式的数据集。TensorDataset可以将多个Tensor按行对齐,即将第i个样本从各个Tensor中取出,构成一个新的Tensor作为数据集的一部分。这对于处理多输入或者多标签的情况非常有用。 总结来说,Dataset提供了自定义数据集的接口,DataLoader提供了批量加载数据集的能力,而TensorDataset则使得我们可以方便地处理Tensor格式的数据集。这三个工具的配合使用可以使得数据处理变得更加方便和高效。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值