通过这份文档的学习,我们会了解到如何往TensorBoard里面送入图片、图表、模型、scalars(损失值、权值、偏置等)、构建embeddings、PR曲线等,其中送入的图片或图表数据主要是多张图片合成的网格图片,利用torchvision.utils.make_grid
函数或fig.add_subplot
构建,细节内容请往下看。
本片文档来源于PyTorch官方教程,我仅其内容进行部分解读,多数解读是注释在代码行中。
声明:没有耐心看几句英文说明的可以试一下Ctrl + W
,我建议大家静下心来学习,不要浮躁。
如果看明白了本文内容,想要更细致地了解Pytorch下TensorBoard的相关用法,可以看官方的Document(https://pytorch.org/docs/stable/tensorboard.html?highlight=tensorboard)
Let’s get started!
在这份文档中,将记录以下几点:
- 读取数据,并作适当的数据转换;
- 设置TensorBoard;
- 写入TensorBoard相关内容;
- 利用TensorBoard查看模型结构;
- 利用TensorBoard创建可视化的交互界面;
特别是在第5点中,我们将看到:
- 查看训练数据的几种方式;
- 在训练时如何追踪模型的性能;
- 训练结束后,如何评估模型的性能。
本文所用数据集为 CIFAR-10。
# imports
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms # 注意transforms是torchvision里面的工具,主要是为图像开发
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# transforms
transform = transforms.Compose(
[transforms.ToTensor(), # 转变成pytorch类型的tensor,针对图像进行转变,把载入的图像转变成Pytorch格式的tensor,结果为NCHW
transforms.Normalize((0.5,), (0.5,))]) # 标准化操作
# datasets
trainset = torchvision.datasets.FashionMNIST('./data',# 下载“训练集”/“测试集”,并转变数据形式(对图片格式进行转变)
download=True,
train=True,
transform=transform)
testset = torchvision.datasets.FashionMNIST('./data',
download=True,
train=False,
transform=transform)
# dataloaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2) # 这里面有一些讲究,尤其是多进程相关的,回过头来可以再看
# constant for classes
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
# helper function to show an image
# (used in the `plot_classes_preds` function below)
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0) # 其实就是一种数据维度的压缩,可以替换为img = img.squeeze(0)
img = img / 2 + 0.5 # unnormalize 反归一化,反向操作
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (</