【Pytorch】4.torchvision.datasets的使用

什么是torchvision.datasets、

pytorch官方给出的关于cv领域的训练数据集,我们可以用官方提供的数据集进行学习与训练

如何查看

我们可以进入Pytorch官网
在这里插入图片描述
在这里插入图片描述
切换一下版本到v0.9.0,就可以看到官方给出的数据集了
在这里插入图片描述
同时也有官方训练好的cv模型可以供我们学习使用
在这里插入图片描述

如何使用

以使用CIFAR10为例,首先导入torchvision

import torchvision

然后通过代码

train_set = torchvision.datasets.CIFAR10()

进入源码查看
在这里插入图片描述
同时在__init__方法中,我们也可以看到root参数是一定要赋值的,而其他的变量都有默认值
在这里插入图片描述
了解到参数的用法后,我们就可以来创建自己的数据集了

train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)

运行后就会自己下载数据集了
在这里插入图片描述
在这里插入图片描述
而生成的train_settest_set相当于是imglabel的键值对数据,这一点我们通过

print(train_set[0])

就可以看到
在这里插入图片描述
因为我们没有指定transforms所以图片的默认格式是PIL.Image,后面的6代表是在标签中的第六个位置,所有的类别存放在train_set对象中的classes
在这里插入图片描述
我们通过

img,label = train_set[0]
print(train_set.classes[label])

就能看到对应的标签了


如果我们想将图片的格式改为tensor格式,我们就可以预先配置好对应的对象,然后在数据集中配置,比如

# 配置tansform的对象
trans = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

#在transform中添加刚刚的配置
train_set = torchvision.datasets.CIFAR10(root="./dataset", transform=trans, train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", transform=trans, train=False, download=True)

# 查看运行结果
img, label = train_set[0]
print(type(img))

在这里插入图片描述
可以发现已经转化为tensor类型了


同时我们也可以引入TensorBoard,来对训练集的前十个数据进行可视化

writer = SummaryWriter("logs")
for i in range(10):
    img,label = train_set[i]
    writer.add_image("train", img, i)   # 将训练数据集中的img添加到writer中
writer.close()

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值