简介
TensorBoard 作为 TensorFlow 官方的可视化工具,可以帮助我们直观地理解模型训练过程,快速发现并解决问题。在本篇文章中,将带你使用 TensorBoard 可视化 FashionMNIST 数据集,学习如何将图像数据、标量数据等以更直观的方式呈现,从而更好地理解模型性能。
亮点抢先看
轻松掌握 TensorBoard 基本功能
深入了解 FashionMNIST 数据集
一步步实现图像网格、曲线图等可视化效果
掌握可视化技巧,提升阅读体验
准备工作
安装 TensorBoard 和 TensorFlow
pip install tensorflow tensorboard
TensorBoard 基本功能
- TensorBoard 官方文档:https://www.tensorflow.org/tensorboard
1. 导入 TensorBoard 库:首先,确保您已经安装了 TensorFlow 和 TensorBoard,并在代码中导入相应的库。在 Python 中,可以使用以下语句导入 TensorBoard 库:
from torch.utils.tensorboard import SummaryWriter
2. 创建 SummaryWriter 对象:在代码中创建一个 SummaryWriter对象,用于将日志写入到 TensorBoard 日志目录中。可以指定一个目录作为日志保存的位置:
writer = SummaryWriter('logs')
3. 记录训练过程中的指标:在训练过程中,使用 add_scalar() 方法记录训练过程中的指标,比如损失函数值、准确率等。示例代码如:
for epoch in range(num_epochs):
# 训练模型并计算损失
train_loss = ...
# 在 TensorBoard 中记录训练损失
writer.add_scalar('Train/Loss', train_loss, epoch)
4. 可视化模型结构:使用 add_graph() 方法将模型的图结构添加到 TensorBoard 中,方便查看模型的层次结构和数据流向。示例代码如:
model = ...
input_data = ...
writer.add_graph(model, input_data)
5. 记录模型参数和梯度:使用 add_histogram()方法记录模型参数和梯度的直方图信息,以便分析参数的分布情况和变化趋势。示例代码如:
for name, param in model.named_parameters():
writer.add_histogram(name, param, epoch)
writer.add_histogram(name + '_grad', param.grad, epoch)
6. 记录嵌入向量:如果您有高维数据需要降维可视化,可以使用 add_embedding()`方法记录嵌入向量,并在 TensorBoard 中展示。示例代码如:
embedding = ...
metadata = ...
writer.add_embedding(embedding, metadata)
7. 记录图像和媒体数据:使用 add_image()、add_audio()、add_video() 等方法记录图像、音频和视频数据,并在 TensorBoard 中展示。示例代码如:
image = ...
writer.add_image('Image', image, epoch)
8. 关闭 SummaryWriter 对象:在所有记录完成后,记得关闭 SummaryWriter对象,以确保所有日志都被写入到日志文件中:
writer.close()
通过完成这些步骤,就可以基本实现 TensorBoard 的各种功能,并在训练过程中实时监控模型的性能和结果。
FashionMNIST 数据集
- FashionMNIST 数据集:Fashion MNIST | Kaggle
FashionMNIST 数据集是一个用于图像分类任务的流行数据集,类似于传统的手写数字 MNIST 数据集,但它包含的是 10 种不同类型的时尚服饰和配件的灰度图像。该数据集由 Zalando Research 创建,旨在成为机器学习领域的基准数据集之一,用于评估和比较图像分类算法的性能。
1. 类别:FashionMNIST 数据集共包含 10 个类别,分别是:T-shirt/top(T 恤/上衣)、Trouser(裤子)、Pullover(套衫)、Dress(连衣裙)、Coat(外套)、Sandal(凉鞋)、Shirt(衬衫)、Sneaker(运动鞋)、Bag(包)和 Ankle Boot(短靴)。
2. 图像内容:每个类别的图像都是 28x28 像素的灰度图像,展示了相应类别的服饰或配件。这些图像经过标准化处理,像素值范围在 0 到 255 之间。
3. 训练集和测试集:FashionMNIST 数据集包含 60,000 张训练图像和 10,000 张测试图像,用于训练和评估机器学习模型的性能。
4. 替代 MNIST 数据集:FashionMNIST 数据集被视为传统 MNIST 数据集的替代品,因为它提供了更具挑战性的图像分类任务,并且在实际应用中更具代表性。
5. 用途:FashionMNIST 数据集通常用于测试和比较图像分类算法的性能,以及进行深度学习模型的实验和验证。它也被用作教学和研究目的,帮助学习者理解和掌握图像分类任务的基本原理和方法。
TensorBoard 可视化 FashionMNIST 数据集
导入必要的库
import torch
from torch.utils.tensorboard import SummaryWriter # 导入TensorBoard的SummaryWriter
import warnings
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader
记录标量数据
warnings.filterwarnings('ignore') # 忽略警告信息
# 创建一个名为'logs_one'的TensorBoard日志写入器
writer = SummaryWriter('logs_one')
for i in range(100):
writer.add_scalar('my test', 3 * i, i)
writer.close()
使用了 add_scalar()
方法将标量数据写入到名为 'my test'
的指标中,并在每个迭代步骤中记录了一个标量数据。
记录图像数据
for i in range(100):
image_path = f'train/{i:01d}.jpg'
# print(image_path)
img_pil = Image.open(image_path)
img_arr = np.array(img_pil)
writer.add_image('train', img_arr, i,
dataformats='HW')
writer.close()
使用 add_image()
方法将图像数据写入到名为 'train'
的图像数据集中,并使用了 Image.open()
和 np.array()
方法来读取和处理图像数据。
加载和预处理 FashionMNIST 数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
torchvision.datasets.FashionMNIST('./data',download = True,train = True,transform = transform)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
torchvision.datasets.FashionMNIST('./data',download = True,train = True,transform = transform)
train_set = torchvision.datasets.FashionMNIST('./data',download = True,train = True,transform = transform)
test_set = torchvision.datasets.FashionMNIST('./data',download = True,train = True,transform = transform)
train_loader = DataLoader(train_set, batch_size=64,shuffle = True)
dataiter = iter(train_loader)
images, labels = next(dataiter)
import matplotlib.pyplot as plt
_= plt.imshow(torch.squeeze(images[0]),cmap = 'gray')
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
使用 SummaryWriter
创建了一个名为 'logs_one'
的 TensorBoard 日志写入器,并从每个类别中随机选择了 90 张图像,将它们合并成一个图像网格,并将该图像网格添加到 TensorBoard 日志中。
tensorboard中显示随机8*8图片
# (used in the `plot_classes_preds` function below)
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, True)
writer.add_image('64 images', img_grid)
writer.close()
定义了一个名为 matplotlib_imshow
的函数,用于在 Matplotlib 中显示图像。然后,通过 torchvision.utils.make_grid()
函数将一批图像 images
合并成一个图像网格,并使用 matplotlib_imshow
函数将该图像网格显示出来。最后,使用 SummaryWriter
的 add_image()
方法将这个图像网格添加到 TensorBoard 日志中,并关闭了日志写入器。
tensorboard中显示有规则30*30图片
有规则30*30图片要求:每类图片随机挑选90张,这90张图片分成三行,一行30张,10类图片,每类3行,共3*10 = 30行,这30行图片在一张大图里。生成的这一张图片保持本地,并在Tensorboard的前端网页显示这张图像
定义数据转换操作和加载 FashionMNIST 数据集
首先,定义一个数据转换操作 transform
,将图像转换为张量,并进行标准化处理。然后使用 torchvision.datasets.FashionMNIST
加载 FashionMNIST 数据集,并传入了定义好的转换操作 transform
:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载FashionMNIST数据集,如果不存在则下载
train_set = torchvision.datasets.FashionMNIST('data',
download=True,
train=True,
transform=transform,
)
选择图像并创建图像网格
接下来,从每个类别中选择了 90 张图像,并将它们合并成一个图像网格。通过设置随机种子,保证随机选择结果的可重复性。遍历每个类别,从中随机选择 90 张图像,并将它们添加到 selected_images
列表中。
selected_images = []
for class_id in range(10):
class_images = [img for img, label in train_set if label == class_id]
selected_class_images = random.sample(class_images, 90)
selected_images.extend(selected_class_images)
然后,使用 torchvision.utils.make_grid()
将这些图像合并成一个图像网格:
img_grid = torchvision.utils.make_grid(selected_images, nrow=30)
显示和保存图像网格
将图像网格转换为 PIL 图像,并保存到本地文件中:
pil_img = transforms.ToPILImage()(img_grid)
pil_img.save("selected_images_grid.png")
然后使用 matplotlib_imshow
函数在 Matplotlib 中显示图像网格,并将其添加到 TensorBoard 日志中:
matplotlib_imshow(img_grid, True)
writer.add_image('image_one', img_grid)
writer.close()
tensorboard可视化展示
总结
通过本篇文章便可了解了使用 TensorBoard 对深度学习模型进行数据记录和可视化分析,提高模型开发的效率和质量。同时,通过对 FashionMNIST 数据集的加载和展示,还可以学习到如何使用 PyTorch 处理和展示图像数据。有问题欢迎评论区交流~~