使用pytorchviz和Netron可视化pytorch网络结构

一 使用pytorchviz可视化

 

  • 安装依赖和pytorchviz

pip install graphviz
pip install tochviz (或pip install git+https://github.com/szagoruyko/pytorchviz)

 

Graphviz 是 AT&T 开发的一款开源的图形可视化软件,可以根据dot脚本语言中绘制的无向图(显示了对象间最简单的关系)画出直观的树形图。
Graphviz在Windows中的安装需要下载Release包,并配置环境变量,否则会报错:

graphviz.backend.ExecutableNotFound: failed to execute [‘dot’, ‘-Tpng’, ‘-O’, ‘tmp’], make sure the Graphviz executables are on your systems’ PATH

 

Graphviz下载地址 https://graphviz.gitlab.io/_pages/Download/Download_windows.html

下载之后解压出来是一个“release”文件夹,把“release\bin”目录添加到系统环境变量,之后在终端中输入“dot -V”,显示以下信息表示Graphviz配置成功:

 

  • torchviz可视化torch网络结构

# Created by 牧野 CSDN
import torch
from torch import nn
from torchviz import make_dot, make_dot_from_trace

model = nn.Sequential()
model.add_module('W0', nn.Linear(8, 16))
model.add_module('tanh', nn.Tanh())
model.add_module('W1', nn.Linear(16, 1))

x = torch.randn(1,8)

vis_graph = make_dot(model(x), params=dict(model.named_parameters()))
vis_graph.view()  # 会在当前目录下保存一个“Digraph.gv.pdf”文件,并在默认浏览器中打开

with torch.onnx.set_training(model, False):
    trace, _ = torch.jit.get_trace_graph(model, args=(x,))
make_dot_from_trace(trace)

 

调用“make_dot”方法创建一个dot对象,使用“view”方法显示出来。

pytorch1.2和1.3版本中使用“torch.jit.get_trace_graph”可能会报错,1.1版本ok。

AttributeError: 'torch._C.Value' object has no attribute 'uniqueName'

 

可视化结果:

 

二 使用Netron可视化

 

Netron开源地址: https://github.com/lutzroeder/Netron
Netron的开发者是Lutz Roeder,一位来自微软Visual Studio团队的帅哥:

 

Netron是一款支持离线查看“各种”神经网络框架的模型可视化神器,其中的“各种”包括:

  1. ONNX (.onnx, .pb, .pbtxt)
  2. Keras (.h5, .keras)
  3. Core ML (.mlmodel)
  4. Caffe (.caffemodel, .prototxt)
  5. Caffe2 (predict_net.pb, predict_net.pbtxt)
  6. MXNet (.model, -symbol.json)
  7. NCNN (.param)
  8. TensorFlow Lite (.tflite)
  9. TorchScript (.pt, .pth)
  10. PyTorch (.pt, .pth)
  11. Torch (.t7)
  12. Arm NN (.armnn)
  13. BigDL (.bigdl, .model)
  14. Chainer, (.npz, .h5)
  15. CNTK (.model, .cntk)
  16. Deeplearning4j (.zip)
  17. Darknet (.cfg)
  18. ML.NET (.zip)
  19. MNN (.mnn)
  20. OpenVINO (.xml)
  21. PaddlePaddle (.zip, __model__)
  22. scikit-learn (.pkl)
  23. TensorFlow.js (model.json, .pb)
  24. TensorFlow (.pb, .meta, .pbtxt)

嗯,够多了。

Netron使用很简单,作者提供了各个平台的安装包,安装之后打开,把保存的模型文件拖入就可以了。
还以上边的模型为例,先把pytorch模型保存出来:

import torch
from torch import nn
from torchviz import make_dot, make_dot_from_trace

model = nn.Sequential()
model.add_module('W0', nn.Linear(8, 16))
model.add_module('tanh', nn.Tanh())
model.add_module('W1', nn.Linear(16, 1))

torch.save(model, 'model.pth')  # 保存模型

之后用Netron打开保存的“model.pth”:

 

网络结构很清晰,一目了然,右侧还能显示操作的进一步信息。

如果你懒得安装,还可以使用作者提供的在线Netron查看器,地址:https://lutzroeder.github.io/netron/

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值