一、datasets数据集的使用
torchvision中的datasets含有很多数据集,可以用来生成图片,视频数据集,和一些流行的模型类和预训练模型
1、CIFAR10数据集
先下载该数据集
import torchvision
from torch.utils.tensorboard import SummaryWriter
#root:数据所在位置
#train:True为训练集,False为测试集
#transform:对数据集的变动
#download:True为自动下载
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
查看某个数据
import torchvision
#root:数据所在位置
#train:True为训练集,False为测试集
#transform:对数据集的变动
#download:True为自动下载
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
print(test_set[0])
结果:
Files already downloaded and verified
Files already downloaded and verified
(<PIL.Image.Image image mode=RGB size=32x32 at 0x23686A83198>, 3)
查看数据集中包含哪些数据
print(test_set.classes)
结果:
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
数据集中包含以上图片及对应的标签
若想查看这些数据可以使用数字映射,例
img, target = test_set[0]
print(target)
print(test_set.classes[target])
img.show()
结果:
3
cat
图片像素只有 32*32
classes得到的是一个列表,通过对应的target标签,也就是列表的下标去获取对应的类别
二、torchvision和transform联动
#将数据集的每张图片,都转成 tensor 数据类型
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="./dataset", transform=dataset_transform, train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", transform=dataset_transform, train=False, download=True)
print(test_set[0])
查看转换结果:
(tensor([[[0.6196, 0.6235, 0.6471, ..., 0.5373, 0.4941, 0.4549],
[0.5961, 0.5922, 0.6235, ..., 0.5333, 0.4902, 0.4667],
[0.5922, 0.5922, 0.6196, ..., 0.5451, 0.5098, 0.4706],
...,
[0.2667, 0.1647, 0.1216, ..., 0.1490, 0.0510, 0.1569],
[0.2392, 0.1922, 0.1373, ..., 0.1020, 0.1137, 0.0784],
[0.2118, 0.2196, 0.1765, ..., 0.0941, 0.1333, 0.0824]],
[[0.4392, 0.4353, 0.4549, ..., 0.3725, 0.3569, 0.3333],
[0.4392, 0.4314, 0.4471, ..., 0.3725, 0.3569, 0.3451],
[0.4314, 0.4275, 0.4353, ..., 0.3843, 0.3725, 0.3490],
...,
[0.4863, 0.3922, 0.3451, ..., 0.3804, 0.2510, 0.3333],
[0.4549, 0.4000, 0.3333, ..., 0.3216, 0.3216, 0.2510],
[0.4196, 0.4118, 0.3490, ..., 0.3020, 0.3294, 0.2627]],
[[0.1922, 0.1843, 0.2000, ..., 0.1412, 0.1412, 0.1294],
[0.2000, 0.1569, 0.1765, ..., 0.1216, 0.1255, 0.1333],
[0.1843, 0.1294, 0.1412, ..., 0.1333, 0.1333, 0.1294],
...,
[0.6941, 0.5804, 0.5373, ..., 0.5725, 0.4235, 0.4980],
[0.6588, 0.5804, 0.5176, ..., 0.5098, 0.4941, 0.4196],
[0.6275, 0.5843, 0.5176, ..., 0.4863, 0.5059, 0.4314]]]), 3)
上传图片:
writer = SummaryWriter("dataset")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
tensorboard --logdir="logs" --port=6007
结果: