参考:
【详解PyTorch项目使用TensorboardX进行训练可视化】
【pytorch官网】
【B站大佬】
PyTorch integrates with TensorBoard, a tool designed for visualizing the results of neural network training runs.
目录
一、安装tensorboardx
引用官方的例子,竟然出现下面的错误,搜索了解决办法是让安装tensorboardx。
更换引用方式仍然出错
安装tensorboardx
pip install tensorboardX
不能用的问题解决。
二、开始学习
(一)example:fashion_minist数据集,tensorboard展示四张图片,add_image
搞到最后,再来看,add_image和add_figure有什么区别与联系?
参照这篇文章,add_image可以是tensor、numpy和string?
看看源码,应该是Image需要pillow库,也就是PIL.Image打开;figure主要针对的是matplotlib.pyplot打开;
def add_image(
self,
tag: str,
img_tensor: numpy_compatible,
global_step: Optional[int] = None,
walltime: Optional[float] = None,
dataformats: Optional[str] = 'CHW'):
"""Add image data to summary.
Note that this requires the ``pillow`` package.
Args:
tag: Data identifier
img_tensor: An `uint8` or `float` Tensor of shape `
[channel, height, width]` where `channel` is 1, 3, or 4.
The elements in img_tensor can either have values
in [0, 1] (float32) or [0, 255] (uint8).
Users are responsible to scale the data in the correct range/type.
global_step: Global step value to record
walltime: Optional override default walltime (time.time()) of event.
dataformats: This parameter specifies the meaning of each dimension of the input tensor.
Shape:
img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to
convert a batch of tensor into 3xHxW format or use ``add_images()`` and let us do the job.
Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitible as long as
corresponding ``dataformats`` argument is passed. e.g. CHW, HWC, HW.
Examples::
from tensorboardX import SummaryWriter
import numpy as np
img = np.zeros((3, 100, 100))
img[0] = np.arange(0, 10000).reshape(100, 100) / 10000
img[1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000
img_HWC = np.zeros((100, 100, 3))
img_HWC[:, :, 0] = np.arange(0, 10000).reshape(100, 100) / 10000
img_HWC[:, :, 1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000
writer = SummaryWriter()
writer.add_image('my_image', img, 0)
# If you have non-default dimension setting, set the dataformats argument.
writer.add_image('my_image_HWC', img_HWC, 0, dataformats='HWC')
writer.close()
Expected result:
.. image:: _static/img/tensorboard/add_image.png
:scale: 50 %
"""
# imports
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# ************** 基础环境配置 ******************
# transforms
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# datasets,FashionMNIST数据集,将自动下载到./data目录中,下载后的文件是以.pt结尾的
trainset = torchvision.datasets.FashionMNIST('./data',
download=True,
train=True,
transform=transform)
testset = torchvision.datasets.FashionMNIST('./data',
download=True,
train=False,
transform=transform)
# dataloaders,创建dataloader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
# constant for classes
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
# helper function to show an image
# (used in the `plot_classes_preds` function below)
# img是batch个图片,此处batchsize=4,也就是4张图片,且要求图片是torchvision.make_grid()处理后格式
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# 搭建了一个普通的神经网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x)<