pytorch 网络结构可视化之netron

pytorch 网络结构可视化之netron



一、netron

netron是一个深度学习模型可视化库,支持以下格式的模型存储文件:

ONNX (.onnx, .pb)
Keras (.h5, .keras)
CoreML (.mlmodel)
TensorFlow Lite (.tflite)
TensorFlow

但netron并不支持pytorch通过torch.save方法导出的模型文件,(及可视化过程中无法捕获模型的执行操作与结构)。
因此在pytorch保存模型的时候,可以用torch.onnx模块将其导出为onnx格式的模型文件,或用torch.jit.trace模块追踪模型在输入数据后的执行路径调用的操作。

整体的流程分为两步,第一步,基于pytorch两种方法导出模型文件。第二步,netron载入模型文件,进行可视化。

二、使用步骤

1.安装可视化工具netron

pip install netron

2.导出可视化模型文件

①导出onnx格式模型文件

import torchvision
import torch
data = torch.rand(1, 3, 224, 224)
model=torchvision.models.resnet50()
output = model(data)
# 导出为onnx格式
onnx_path = "onnx_model.onnx"
torch.onnx.export(model, data, onnx_path)

②torch.jit.trace转换模型文件

torch.jit.trace在跟踪遇到的计算步骤时通过函数或模块运行示例输入,并输出执行Tracing操作的基于图形的函数。Tracing非常适用于不涉及数据相关控制流的简单模块和功能,例如标准卷积神经网络。但是,如果Tracing具有依赖于数据的if语句和循环的函数,则仅记录由示例输入执行的执行路径调用的操作,即尽量避免转换代码中有if条件控制的模型。

import torchvision
import torch
data = torch.rand(1, 3, 224, 224)
model=torchvision.models.resnet50()
output = model(data)
trace_model = torch.jit.trace(model, data)
trace_model.save("mtrace.pt")

如果模型设计多个输入,需要将传入torch.onnx.export和torch.jit.trace中的data参数改为多输入张量元组,即data=(input1,input2 )

3.netron载入模型

如果能成功转换模型,在python代码调用netron库来载入模型进行可视化。

import netron
netron.start("mtrace.pth")

netron还做了一个在线demo网站,可以直接上传模型文件查看可视化结果,与代码调用netron库来载入模型一样。网址https://netron.app/
在这里插入图片描述
整体效果比美观
在这里插入图片描述

三、总结

在实际过程中,由于网络模型中复杂的结构以及调用,导出为onnx格式的模型时会出现各式各样的问题
在这里插入图片描述
但torch.jit.trace相对好用一些,能使我们快速便捷地了解复杂模型。
在这里插入图片描述

  • 10
    点赞
  • 40
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值