pytorch使用tensorboardX面板自动生成模型结构图和各类可视化图像

总结:

在原本代码中额外添加如下几行即可实现查看模型结构:

    from tensorboardX import SummaryWriter  # 用于进行可视化

    # 1. 来用tensorflow进行可视化
    with SummaryWriter("./log", comment="sample_model_visualization") as sw:  
        sw.add_graph(modelviz, sampledata)

操作步骤如下

安装完torch之后,再安装tensorboardX

pip install tensorboardX -i https://pypi.tuna.tsinghua.edu.cn/simple

运行下面代码 

import torch
import torch.nn as nn
import torch.nn.functional as F
from tensorboardX import SummaryWriter  # 用于进行可视化 
 
class modelViz(nn.Module):
    def __init__(self):
        super(modelViz, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 64, 3, 1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 10, 3, 1, padding=1)
        self.bn3 = nn.BatchNorm2d(10)
 
    def forward(self, x):
        x = self.bn1(self.conv1(x))
        x = F.relu(x)
        x = self.bn2(self.conv2(x))
        x = F.relu(x)
        x = self.bn3(self.conv3(x))
        x = F.relu(x)
        return x
 
 
if __name__  == "__main__":
    # 首先来搭建一个模型
    modelviz = modelViz()
    # 创建输入
    sampledata = torch.rand(1, 3, 4, 4)
    # 看看输出结果对不对
    out = modelviz(sampledata)
    print(out)  # 测试有输出,网络没有问题
 
    # 1. 来用tensorflow进行可视化
    with SummaryWriter("./log", comment="sample_model_visualization") as sw:
        sw.add_graph(modelviz, sampledata)
 
    # # 2. 保存成pt文件后进行可视化
    # torch.save(modelviz, "./log/modelviz.pt")

运行代码后会在"./log"路径下生成一个tfevents文件,在终端中进入代码的主目录下执行命令

tensorboard --logdir=./ 

 然后会输出

(base) jie@dell:~/桌面/fno_task$ tensorboard --logdir=./
TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.16.2 at http://localhost:6006/ (Press CTRL+C to quit)

 http://localhost:6006/

然后按照提示打开浏览器,输入上面这个网址就可以看到我们搭建的网络结构了,如下图所示,可以双击打开每一个节点查看其内容。也可以查看详细的结构以及每一层的输入输出shape。通过双击模型的组件实现展示网络细节和收起细节。

结束!!!

官网详细和介绍使用链接:https://www.tensorflow.org/tensorboard/graphs?hl=zh-cn

tips:tensorboard是适用于tensorflow,而tensorboardX可以适用pytorch

tips: 如果你在虚拟环境cd到log的上一级文件夹,那么按照上面的路径就得不到你想要的可视化结果,路径不正确,应该输入

tensorboard --logdir=./log/
                        
参考链接:https://blog.csdn.net/Vertira/article/details/127326470  

  • 8
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
好的,我可以帮你解答这个问题! 首先,为了使用tensorboardx可视化Pytorch模型,你需要先安装tensorboardX库。你可以通过以下命令来安装: ``` pip install tensorboardX ``` 安装完成后,你需要在你的Pytorch代码中导入tensorboardX: ```python from tensorboardX import SummaryWriter ``` 接下来,在你的代码中创建一个`SummaryWriter`对象,它将负责记录你的模型的数据和可视化: ```python writer = SummaryWriter() ``` 现在你可以在代码中使用`writer`对象来记录任何你想要记录的数据。例如,你可以记录损失函数的值: ```python writer.add_scalar('Loss', loss_value, global_step) ``` 其中`loss_value`是损失函数的值,`global_step`是你的训练步数。 你也可以记录模型的权重和梯度: ```python writer.add_histogram('conv1/weights', conv1.weight, global_step) writer.add_histogram('conv1/grads', conv1.weight.grad, global_step) ``` 这将记录名为`conv1/weights`和`conv1/grads`的直方图,它们分别显示了`conv1`层的权重和梯度。 最后,在你的代码结束时,不要忘记关闭`SummaryWriter`对象: ```python writer.close() ``` 现在你可以在终端中输入以下命令来启动tensorboard: ``` tensorboard --logdir=/path/to/logs ``` 其中`/path/to/logs`是你保存日志文件的路径。然后在你的浏览器中访问`http://localhost:6006`,你将能够看到Pytorch模型可视化结果。 希望这能够帮助你!
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

热爱生活的五柒

谢谢你的打赏,人好心善的朋友!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值