前言
在深度学习的模型训练中,可视化数据和模型结构是非常关键的一环,尤其是对于新手,理解“模型到底在学什么”、“输入图片是什么样”等问题是入门的第一步。
今天我们就用最简单的 PyTorch 和 TensorBoard,加载 CIFAR-10 数据集并将其图像写入 TensorBoard 可视化界面。
一、导入必要库
import torchvision
from torch.utils.tensorboard import SummaryWriter
二、TensorBoard 可视化 CIFAR-10 数据集
1.Compose[]数据预处理
dataset_transforms=torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
2.加载 CIFAR-10 数据集
train_set = torchvision.datasets.CIFAR10(
root="./database",
train=True,
transform=dataset_transforms,
download=True
)
test_set = torchvision.datasets.CIFAR10(
root="./database",
train=False,
transform=dataset_transforms,
download=True
)
root:指定保存数据集的根目录。(最好用 “./XXXX”,这样数据文件夹会自动下载到当前文件夹)
train:布尔值,True 表示加载训练集,False 表示加载测试集。
transform:对图像进行的转换操作,常见有 ToTensor()、Normalize()、RandomCrop() 等。
download:是否从网上下载数据集(一般为True)。
更多详细信息请前往下面的pytorch帮助文档查看https://docs.pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html?highlight=mnis#torchvision.datasets.MNIST
3.查看数据
# print(test_set[0]) # 打印第一张图片的信息
# print(train_set[0])
# print(test_set.classes) # 查看所有类别标签,如 'airplane', 'automobile', ...
# img, target = test_set[0]
# print(img) # 打印 Tensor 图像信息
# print(target) # 打印对应的标签索引
# print(test_set.classes[target]) # 打印具体类别名
# img.show() # 显示图片(仅适用于 PIL 图像)
4. 使用 TensorBoard 可视化训练集图像
Writer = SummaryWriter("../logs2") # 创建日志文件夹 logs4
for i in range(10):
img, target = train_set[i] # 依次获取前10张图片
Writer.add_image("testset", img, i) # 将图像写入 TensorBoard
Writer.close() # 关闭 writer
依旧是查看帮助文档这一”招:(按住ctrl+CIFAR-10)
你可以在cifar.py文件的124行看见如下图:
![]()
这样你就会理解我为什么写下 :img, target = train_set[i]
因为我知道我从CIFAR-10获取的信息是这两个,所以直接读取。
5.启动 TensorBoard 查看图像
在终端中运行以下命令,打开 TensorBoard:
tensorboard --logdir=../logs2
7.可视化展示
像文档中解释的那样,显示出的图片是32*32像素,所以才画质感人,大家可以自己前往https://www.cs.toronto.edu/~kriz/cifar.html去了解更多关于CIFAR-10数据集。
总结
本代码用于学习 PyTorch 中 torchvision.datasets.CIFAR10 的用法。
使用 ToTensor() 转换图片并加载至训练集。
借助 TensorBoard 可视化图像数据,便于理解模型输入和预处理效果。