6.初识Pytorch之torchvision中的数据集使用

  • 首先从下图——Pytorch官网可以看出,在torchvision提供的数据库很多,用红色框出。
    在这里插入图片描述
  • 选择CIFAR-10数据库进行实验,其对应的官方文档如下:
    其参数有
root:CIFAR10 在文件中,放置CIFAR10数据库的位置.
train:是否是训练集,是就是True,否则就是测试集False.
transform:使用了那种数据增强,要在前面先定义,这里就可以具体使用.
target_transform:这个一般使用默认即可(目标的).
download:是否下载,是True,否则False.

在这里插入图片描述在这里插入图片描述

  • 简单粗暴上代码:
import torchvision
# 训练集的定义
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
# 测试集的定义
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
#拿测试集的第一个数据进行测试,得到img(img为PIL格式)与target(这个target为目标的分类)
img, target = test_set[0]
print(img)
print(target)
# 将test_set的数据集,所有分类都打印出来
print(test_set.classes)
#得到测试集的第一张图像的类别
print(test_set.classes[target])

结果:
在这里插入图片描述 - 从结果中显示,数据已经下载好了并且得到了验证,接着就是测试集第一张图像print(img)是一个PIL格式,之后是测试集第一张图像的print(target)类别打印。


  • 查看图像,可以直接img.show()调用因为这个img是PIL格式
    简单粗暴上代码:
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

tran_tensor = transforms.ToTensor()
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
# train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=tran_tensor, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=tran_tensor, download=True)

img, target = train_set[0]
img.show()

结果:
在这里插入图片描述---------------------------------------------------------------------------------------------
在这里插入图片描述


  • 结合之前使用的transforms.ToTensor与SummaryWriter来使用

简单粗暴上代码:

import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

# 创建transforms.ToTensor()模板
tran_tensor = transforms.ToTensor()

# SummaryWriter模板
writer = SummaryWriter("logs")

# 训练集
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
# 测试集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)

# 选训练集前十张图像来做transforms.ToTensor,再用Tensorboard打印出来
for i in range(10):
    img, target = train_set[i]
    img_tensor = tran_tensor(img)
    writer.add_image("img_tensor", img_tensor, i)
writer.close()

结果:
在这里插入图片描述------------------------------------------------------------------------------------------
在这里插入图片描述 ---------------------------------------------------------------------------------------
还有另一种方式放置transform,就是在数据集定义的时候就放进去

简单粗暴上代码:

import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("logs")
tran_tensor = transforms.ToTensor()
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=tran_tensor, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=tran_tensor, download=True)

for i in range(10):
    img_tensor, _ = train_set[i]
    writer.add_image("img_tensor", img_tensor, i)
    writer.close()

上一章 5.初识Pytorch使用常用的transforms
下一章 7.初识Pytorch使用Dataloader

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值