今天复习torchvision中的数据集使用
在PyTorch中,torchvision
是一个用于计算机视觉任务的库,它提供了许多预定义的数据集和数据转换工具。
现在开始:
import torchvision # 导入torchvision库,用于处理图像数据集
from torch.utils.tensorboard import SummaryWriter # 导入SummaryWriter,用于写入TensorBoard日志
# 定义数据集的转换操作,这里只使用了ToTensor(),将PIL图像或Numpy数组转换为Tensor
dataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
# 加载训练集,指定数据集存放的文件夹,是否为训练集,转换操作,以及是否下载数据集(如果本地没有的话)
train_set = torchvision.datasets.CIFAR10("./dataset", train=True, transform=dataset_transform, download=True)
# 加载测试集,与训练集类似,但是指定train=False,表示加载测试数据
test_set = torchvision.datasets.CIFAR10("./dataset", train=False, transform=dataset_transform, download=True)
# 创建一个SummaryWriter实例(它将信息写入TensorBoard日志文件),指定日志存放的文件夹名称
writer = SummaryWriter("logs")
# 遍历测试集中的前10个样本
for i in range(10):
# 从数据集中获取第i个样本的图像和标签
img, target = test_set[i]
# 使用SummaryWriter的add_image方法将图像添加到TensorBoard日志中,"test_set"是标签,img是图像,i是全局步骤计数器
writer.add_image("test_set", img, i)
# 可以通过下面的代码打印出数据集中的某个样本的图像和标签信息
# print(test_set[0])
# print(test_set.classes)
# img, target = train_set[0]
# print(img)
# print(target)
# print(test_set.classes[target])
# img.show()
# 关闭SummaryWriter,确保所有日志数据都被写入
writer.close()
运行代码,这里我并没有提前下载数据集,所以他在网上下载了,并解压到指定的本地路径"./dataset"
然后我们在python终端中输入
tensorboard --logdir=logs
(注意如果以上口令得到的链接无法生成图像的话,就像我一样改成内容根路径,这里TransBoard无法生成图像的问题我在前面的文章有解决方法)
终端输入口令后得到下面结果
打开给的网页链接即可得到图像