今天继续更新torchvision中的数据集使用。
# 引入torchvision模块,该模块里包含了常见的预处理函数、数据集、模型等
import torchvision
# 从torch.utils.tensorboard模块中引入SummaryWriter,用于将数据写入TensorBoard
from torch.utils.tensorboard import SummaryWriter
# 定义数据集的预处理操作,此处只进行了一项操作:将PIL格式的图像转为Tensor格式
dataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
# 加载CIFAR-10数据集的训练集和测试集,将数据集的预处理操作作为参数传入
train_set = torchvision.datasets.CIFAR10("./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10("./dataset", train=False, transform=dataset_transform, download=True)
# 可以通过下面的代码打印出数据集中的某个样本的图像和标签信息
# print(test_set[0])
# print(test_set.classes)
# img, target = train_set[0]
# print(img)
# print(target)
# print(test_set.classes[target])
# img.show()
# 创建SummaryWriter对象,指定输出目录为"logs",将数据写入TensorBoard
writer = SummaryWriter("logs")
# 遍历测试集,将前10张图片以及它们的标签信息通过SummaryWriter写入TensorBoard
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
# 关闭SummaryWriter,将所有数据写入TensorBoard
writer.close()
在Pytorch Terminal中输入:tensorboard --logdir=logs,即可得图像