tensorbord可视化-pytorch

 

TensorBoard是一款优秀的基于浏览器的机器学习可视化工具。之前是tensorflow的御用可视化工具,由于tensorboard并不是直接读取tf张量,而是读取log进行可视化所以,其他框架只需生成tensorboard可读的log,即可完成可视化。

之前,我一直用visdom做pytorch可视化,也是非常易用。不过现在跟tensorboard对比,我还是更推荐tensorboard

visdom相比tensorboard只有一个优点,那就是自动实时刷新。而tensorboard无论从可视化美观性、可视化数据多样性等多个方面,都碾压visdom。甚至,tensorboard更加易用一些。

先给一个官方文档链接:https://pytorch.org/docs/stable/tensorboard.html

tensorboard的安装就不在这篇文章讲述了。安装1.15以上的版本即可。


1. 标量(scalars)数据可视化

标量就是数字,咱们训练过程中的loss值,测试集的accuracy,包括precision和recall等等都可以通过这个方式画出曲线。可以更直观地反映模型的训练情况。还在拿matplotlib可视化loss曲线的童鞋可以换家伙什儿了。

用tensorboard做标量可视化非常简单:


from torch.utils.tensorboard import SummaryWriter


log_writer = SummaryWriter()


def train(xxx):

for epoch in epochs:

loss = xxx

log_writer.add_scalar('Loss/train', float(loss), epoch)

首先咱们需要一个写log的东西——log_writer, 然后直接add_scalar就可以了。

add_scalar('Loss/train', float(loss), epoch),第一个参数是名称,第二个参数是y值,第三个参数是x值。(用x,y画图,不用我解释x,y是啥吧?)

也就比原来的训练代码多了三行,即可收集训练的loss,用以可视化。

咱们可以先把加了三行的训练代码跑起来,中途生成的loss都会保存在当前目录下一个名为'runs/'的文件夹中。当然这个文件夹可以自定义,'runs/'只是默认名。

第二步,咱们就打开tensorboard瞅一瞅。再打开一个terminal,输入:

tensorboard --logdir=runs/

运行以后,会给你一个链接,用浏览器打开即可。一般为 https://127.0.0.1:6006,6006端口被占用的话会是另一个端口。

我实际运行的一个loss如下:

看着还是很酷炫的,可以点击右上角刷新,以查看实时训练情况。

2. 图(GRAPH)数据可视化

这个可以用来可视化网络结构,不太涉及动态变化,所以甚至比标量可视化更加简单。直接用add_graph就可以完成,

需要注意的是要定义输入的shape,类似于tf的placeholder。我们看一个官方栗子:


import torch

import torchvision

from torch.utils.tensorboard import SummaryWriter

from torchvision import datasets, transforms


# Writer will output to ./runs/ directory by default

writer = SummaryWriter()


transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

trainset = datasets.MNIST('mnist_train', train=True, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

model = torchvision.models.resnet50(False)

# Have ResNet model take in grayscale rather than RGB

model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

images, labels = next(iter(trainloader))


grid = torchvision.utils.make_grid(images)

writer.add_image('images', grid, 0)

writer.add_graph(model, images)

writer.close()

注意,上面的add_graph有2个输入参数,一个是模型,另一个就是类似于placeholder的东西,用来描述输入的shape因为可视化网络结构的时候,后台会帮你计算出每一层feature map的尺寸,但这尺寸都与输入的shape有关。

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

DLANDML

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值