什么是torchvision.datasets、
是pytorch
官方给出的关于cv领域的训练数据集,我们可以用官方提供的数据集进行学习与训练
如何查看
我们可以进入Pytorch官网
切换一下版本到v0.9.0
,就可以看到官方给出的数据集了
同时也有官方训练好的cv
模型可以供我们学习使用
如何使用
以使用CIFAR10
为例,首先导入torchvision
import torchvision
然后通过代码
train_set = torchvision.datasets.CIFAR10()
进入源码查看
同时在__init__
方法中,我们也可以看到root
参数是一定要赋值的,而其他的变量都有默认值
了解到参数的用法后,我们就可以来创建自己的数据集了
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
运行后就会自己下载数据集了
而生成的train_set
和test_set
相当于是img
和label
的键值对数据,这一点我们通过
print(train_set[0])
就可以看到
因为我们没有指定transforms
所以图片的默认格式是PIL.Image
,后面的6
代表是在标签中的第六个位置,所有的类别存放在train_set
对象中的classes
中
我们通过
img,label = train_set[0]
print(train_set.classes[label])
就能看到对应的标签了
如果我们想将图片的格式改为tensor
格式,我们就可以预先配置好对应的对象,然后在数据集中配置,比如
# 配置tansform的对象
trans = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
#在transform中添加刚刚的配置
train_set = torchvision.datasets.CIFAR10(root="./dataset", transform=trans, train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", transform=trans, train=False, download=True)
# 查看运行结果
img, label = train_set[0]
print(type(img))
可以发现已经转化为tensor
类型了
同时我们也可以引入TensorBoard
,来对训练集的前十个数据进行可视化
writer = SummaryWriter("logs")
for i in range(10):
img,label = train_set[i]
writer.add_image("train", img, i) # 将训练数据集中的img添加到writer中
writer.close()