一文带你玩转 TensorBoard:可视化 FashionMNIST 数据集

简介

TensorBoard 作为 TensorFlow 官方的可视化工具,可以帮助我们直观地理解模型训练过程,快速发现并解决问题。在本篇文章中,将带你使用 TensorBoard 可视化 FashionMNIST 数据集,学习如何将图像数据、标量数据等以更直观的方式呈现,从而更好地理解模型性能。

亮点抢先看

        轻松掌握 TensorBoard 基本功能

        深入了解 FashionMNIST 数据集

        一步步实现图像网格、曲线图等可视化效果

        掌握可视化技巧,提升阅读体验

准备工作

        安装 TensorBoard 和 TensorFlow

pip install tensorflow tensorboard

TensorBoard 基本功能

1. 导入 TensorBoard 库:首先,确保您已经安装了 TensorFlow 和 TensorBoard,并在代码中导入相应的库。在 Python 中,可以使用以下语句导入 TensorBoard 库:

from torch.utils.tensorboard import SummaryWriter

2. 创建 SummaryWriter 对象:在代码中创建一个 SummaryWriter对象,用于将日志写入到 TensorBoard 日志目录中。可以指定一个目录作为日志保存的位置:

writer = SummaryWriter('logs')

3. 记录训练过程中的指标:在训练过程中,使用 add_scalar() 方法记录训练过程中的指标,比如损失函数值、准确率等。示例代码如:

for epoch in range(num_epochs):
    # 训练模型并计算损失
    train_loss = ...

    # 在 TensorBoard 中记录训练损失
    writer.add_scalar('Train/Loss', train_loss, epoch)

4. 可视化模型结构:使用 add_graph() 方法将模型的图结构添加到 TensorBoard 中,方便查看模型的层次结构和数据流向。示例代码如:

model = ...
input_data = ...
writer.add_graph(model, input_data)

5. 记录模型参数和梯度:使用 add_histogram()方法记录模型参数和梯度的直方图信息,以便分析参数的分布情况和变化趋势。示例代码如:

for name, param in model.named_parameters():
    writer.add_histogram(name, param, epoch)
    writer.add_histogram(name + '_grad', param.grad, epoch)

6. 记录嵌入向量:如果您有高维数据需要降维可视化,可以使用 add_embedding()`方法记录嵌入向量,并在 TensorBoard 中展示。示例代码如:

embedding = ...
metadata = ...
writer.add_embedding(embedding, metadata)

7. 记录图像和媒体数据:使用 add_image()、add_audio()、add_video() 等方法记录图像、音频和视频数据,并在 TensorBoard 中展示。示例代码如:

image = ...
writer.add_image('Image', image, epoch)

8. 关闭 SummaryWriter 对象:在所有记录完成后,记得关闭 SummaryWriter对象,以确保所有日志都被写入到日志文件中:

writer.close()

通过完成这些步骤,就可以基本实现 TensorBoard 的各种功能,并在训练过程中实时监控模型的性能和结果。

FashionMNIST 数据集 

 

FashionMNIST 数据集是一个用于图像分类任务的流行数据集,类似于传统的手写数字 MNIST 数据集,但它包含的是 10 种不同类型的时尚服饰和配件的灰度图像。该数据集由 Zalando Research 创建,旨在成为机器学习领域的基准数据集之一,用于评估和比较图像分类算法的性能。

1. 类别:FashionMNIST 数据集共包含 10 个类别,分别是:T-shirt/top(T 恤/上衣)、Trouser(裤子)、Pullover(套衫)、Dress(连衣裙)、Coat(外套)、Sandal(凉鞋)、Shirt(衬衫)、Sneaker(运动鞋)、Bag(包)和 Ankle Boot(短靴)。

2. 图像内容:每个类别的图像都是 28x28 像素的灰度图像,展示了相应类别的服饰或配件。这些图像经过标准化处理,像素值范围在 0 到 255 之间。

3. 训练集和测试集:FashionMNIST 数据集包含 60,000 张训练图像和 10,000 张测试图像,用于训练和评估机器学习模型的性能。

4. 替代 MNIST 数据集:FashionMNIST 数据集被视为传统 MNIST 数据集的替代品,因为它提供了更具挑战性的图像分类任务,并且在实际应用中更具代表性。

5. 用途:FashionMNIST 数据集通常用于测试和比较图像分类算法的性能,以及进行深度学习模型的实验和验证。它也被用作教学和研究目的,帮助学习者理解和掌握图像分类任务的基本原理和方法。

TensorBoard 可视化 FashionMNIST 数据集

导入必要的库

import torch
from torch.utils.tensorboard import SummaryWriter  # 导入TensorBoard的SummaryWriter
import warnings
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader

 记录标量数据

warnings.filterwarnings('ignore')  # 忽略警告信息

# 创建一个名为'logs_one'的TensorBoard日志写入器
writer = SummaryWriter('logs_one')

for i in range(100):
    writer.add_scalar('my test', 3 * i, i)
writer.close()

 使用了 add_scalar() 方法将标量数据写入到名为 'my test' 的指标中,并在每个迭代步骤中记录了一个标量数据。

记录图像数据

for i in range(100):
    image_path = f'train/{i:01d}.jpg'
    # print(image_path)
    img_pil = Image.open(image_path)
    img_arr = np.array(img_pil)
    writer.add_image('train', img_arr, i,
                     dataformats='HW')

writer.close()

使用 add_image() 方法将图像数据写入到名为 'train' 的图像数据集中,并使用了 Image.open()np.array() 方法来读取和处理图像数据。

加载和预处理 FashionMNIST 数据集

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
torchvision.datasets.FashionMNIST('./data',download = True,train = True,transform = transform)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
torchvision.datasets.FashionMNIST('./data',download = True,train = True,transform = transform)
train_set = torchvision.datasets.FashionMNIST('./data',download = True,train = True,transform = transform)
test_set = torchvision.datasets.FashionMNIST('./data',download = True,train = True,transform = transform)
train_loader = DataLoader(train_set, batch_size=64,shuffle = True)
dataiter = iter(train_loader)
images, labels = next(dataiter)
import matplotlib.pyplot as plt
_= plt.imshow(torch.squeeze(images[0]),cmap = 'gray')

classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

使用 SummaryWriter 创建了一个名为 'logs_one' 的 TensorBoard 日志写入器,并从每个类别中随机选择了 90 张图像,将它们合并成一个图像网格,并将该图像网格添加到 TensorBoard 日志中。

tensorboard中显示随机8*8图片

# (used in the `plot_classes_preds` function below)
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, True)
writer.add_image('64 images', img_grid)
writer.close()

定义了一个名为 matplotlib_imshow 的函数,用于在 Matplotlib 中显示图像。然后,通过 torchvision.utils.make_grid() 函数将一批图像 images 合并成一个图像网格,并使用 matplotlib_imshow 函数将该图像网格显示出来。最后,使用 SummaryWriteradd_image() 方法将这个图像网格添加到 TensorBoard 日志中,并关闭了日志写入器。

tensorboard中显示有规则30*30图片

有规则30*30图片要求:每类图片随机挑选90张,这90张图片分成三行,一行30张,10类图片,每类3行,共3*10 = 30行,这30行图片在一张大图里。生成的这一张图片保持本地,并在Tensorboard的前端网页显示这张图像

定义数据转换操作和加载 FashionMNIST 数据集

首先,定义一个数据转换操作 transform,将图像转换为张量,并进行标准化处理。然后使用 torchvision.datasets.FashionMNIST 加载 FashionMNIST 数据集,并传入了定义好的转换操作 transform

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载FashionMNIST数据集,如果不存在则下载
train_set = torchvision.datasets.FashionMNIST('data',
                                              download=True,
                                              train=True,
                                              transform=transform,
                                              )
选择图像并创建图像网格

接下来,从每个类别中选择了 90 张图像,并将它们合并成一个图像网格。通过设置随机种子,保证随机选择结果的可重复性。遍历每个类别,从中随机选择 90 张图像,并将它们添加到 selected_images 列表中。

selected_images = []

for class_id in range(10):
    class_images = [img for img, label in train_set if label == class_id]
    selected_class_images = random.sample(class_images, 90)
    selected_images.extend(selected_class_images)

然后,使用 torchvision.utils.make_grid() 将这些图像合并成一个图像网格:

img_grid = torchvision.utils.make_grid(selected_images, nrow=30)
显示和保存图像网格

将图像网格转换为 PIL 图像,并保存到本地文件中:

pil_img = transforms.ToPILImage()(img_grid)
pil_img.save("selected_images_grid.png")

然后使用 matplotlib_imshow 函数在 Matplotlib 中显示图像网格,并将其添加到 TensorBoard 日志中:

matplotlib_imshow(img_grid, True)
writer.add_image('image_one', img_grid)
writer.close()

tensorboard可视化展示

总结 

通过本篇文章便可了解了使用 TensorBoard 对深度学习模型进行数据记录和可视化分析,提高模型开发的效率和质量。同时,通过对 FashionMNIST 数据集的加载和展示,还可以学习到如何使用 PyTorch 处理和展示图像数据。有问题欢迎评论区交流~~

  • 19
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
Spark是一个快速通用的集群计算框架,它可以处理大规模数据,并且具有高效的内存计算能力。Spark可以用于各种计算任务,包括批处理、流处理、机器学习等。本文将你了解Spark计算框架的基本概念和使用方法。 一、Spark基础概念 1. RDD RDD(Resilient Distributed Datasets)是Spark的基本数据结构,它是一个分布式的、可容错的、不可变的数据集合。RDD可以从Hadoop、本地文件系统等数据源中读取数据,并且可以通过多个转换操作(如map、filter、reduce等)进行处理。RDD也可以被持久化到内存中,以便下次使用。 2. Spark应用程序 Spark应用程序是由一个驱动程序和多个执行程序组成的分布式计算应用程序。驱动程序是应用程序的主要入口点,它通常位于用户的本地计算机上,驱动程序负责将应用程序分发到执行程序上并收集结果。执行程序是运行在集群节点上的计算单元,它们负责执行驱动程序分配给它们的任务。 3. Spark集群管理器 Spark集群管理器负责管理Spark应用程序在集群中的运行。Spark支持多种集群管理器,包括Standalone、YARN、Mesos等。 二、Spark计算框架使用方法 1. 安装Spark 首先需要安装Spark,可以从Spark官网下载并解压缩Spark安装包。 2. 编写Spark应用程序 编写Spark应用程序通常需要使用Java、Scala或Python编程语言。以下是一个简单的Java代码示例,用于统计文本文件中单词的出现次数: ```java import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import java.util.Arrays; import java.util.Map; public class WordCount { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("WordCount").setMaster("local"); JavaSparkContext sc = new JavaSparkContext(conf); JavaRDD<String> lines = sc.textFile("input.txt"); JavaRDD<String> words = lines.flatMap(line -> Arrays.asList(line.split(" ")).iterator()); Map<String, Long> wordCounts = words.countByValue(); for (Map.Entry<String, Long> entry : wordCounts.entrySet()) { System.out.println(entry.getKey() + " : " + entry.getValue()); } sc.stop(); } } ``` 3. 运行Spark应用程序 将编写好的Spark应用程序打包成jar包,并通过以下命令运行: ```bash spark-submit --class WordCount /path/to/wordcount.jar input.txt ``` 其中,--class参数指定应用程序的主类,后面跟上打包好的jar包路径,input.txt是输入文件路径。 4. 查看运行结果 Spark应用程序运行完毕后,可以查看应用程序的输出结果,例如上述示例中的单词出现次数。 以上就是Spark计算框架的基本概念和使用方法。通过学习Spark,我们可以更好地处理大规模数据,并且提高计算效率。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

归栀1102

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值