PyTorch里如何利用TensorBoard--详解

本文详细解析如何在PyTorch中使用TensorBoard,包括设置、记录图像、模型结构可视化、Projector使用、训练过程跟踪以及模型评估。通过实例展示了如何将训练损失、模型预测与实际结果对比等信息写入TensorBoard,便于模型理解和优化。
摘要由CSDN通过智能技术生成

通过这份文档的学习,我们会了解到如何往TensorBoard里面送入图片、图表、模型、scalars(损失值、权值、偏置等)、构建embeddings、PR曲线等,其中送入的图片或图表数据主要是多张图片合成的网格图片,利用torchvision.utils.make_grid函数或fig.add_subplot构建,细节内容请往下看。

本片文档来源于PyTorch官方教程,我仅其内容进行部分解读,多数解读是注释在代码行中。
声明:没有耐心看几句英文说明的可以试一下Ctrl + W,我建议大家静下心来学习,不要浮躁。

如果看明白了本文内容,想要更细致地了解Pytorch下TensorBoard的相关用法,可以看官方的Document(https://pytorch.org/docs/stable/tensorboard.html?highlight=tensorboard)

Let’s get started!

在这份文档中,将记录以下几点:

  1. 读取数据,并作适当的数据转换;
  2. 设置TensorBoard;
  3. 写入TensorBoard相关内容;
  4. 利用TensorBoard查看模型结构;
  5. 利用TensorBoard创建可视化的交互界面;

特别是在第5点中,我们将看到:

  • 查看训练数据的几种方式;
  • 在训练时如何追踪模型的性能;
  • 训练结束后,如何评估模型的性能。

本文所用数据集为 CIFAR-10。

# imports
import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision
import torchvision.transforms as transforms # 注意transforms是torchvision里面的工具,主要是为图像开发

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# transforms
transform = transforms.Compose(
    [transforms.ToTensor(),  # 转变成pytorch类型的tensor,针对图像进行转变,把载入的图像转变成Pytorch格式的tensor,结果为NCHW
    transforms.Normalize((0.5,), (0.5,))]) # 标准化操作

# datasets
trainset = torchvision.datasets.FashionMNIST('./data',# 下载“训练集”/“测试集”,并转变数据形式(对图片格式进行转变)
    download=True,
    train=True,
    transform=transform)
testset = torchvision.datasets.FashionMNIST('./data',  
    download=True,
    train=False,
    transform=transform)

# dataloaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                        shuffle=True, num_workers=2)


testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                        shuffle=False, num_workers=2)  # 这里面有一些讲究,尤其是多进程相关的,回过头来可以再看

# constant for classes
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# helper function to show an image
# (used in the `plot_classes_preds` function below)
def matplotlib_imshow(img, one_channel=False):
    if one_channel:         
        img = img.mean(dim=0) # 其实就是一种数据维度的压缩,可以替换为img = img.squeeze(0)
    img = img / 2 + 0.5     # unnormalize 反归一化,反向操作
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (</
  • 8
    点赞
  • 71
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值