pytorch 网络结构可视化之netron
目录
一、netron
netron是一个深度学习模型可视化库,支持以下格式的模型存储文件:
ONNX (.onnx, .pb)
Keras (.h5, .keras)
CoreML (.mlmodel)
TensorFlow Lite (.tflite)
TensorFlow
但netron并不支持pytorch通过torch.save方法导出的模型文件,(及可视化过程中无法捕获模型的执行操作与结构)。
因此在pytorch保存模型的时候,可以用torch.onnx模块将其导出为onnx格式的模型文件,或用torch.jit.trace模块追踪模型在输入数据后的执行路径调用的操作。
整体的流程分为两步,第一步,基于pytorch两种方法导出模型文件。第二步,netron载入模型文件,进行可视化。
二、使用步骤
1.安装可视化工具netron
pip install netron
2.导出可视化模型文件
①导出onnx格式模型文件
import torchvision
import torch
data = torch.rand(1, 3, 224, 224)
model=torchvision.models.resnet50()
output = model(data)
# 导出为onnx格式
onnx_path = "onnx_model.onnx"
torch.onnx.export(model, data, onnx_path)
②torch.jit.trace转换模型文件
torch.jit.trace在跟踪遇到的计算步骤时通过函数或模块运行示例输入,并输出执行Tracing操作的基于图形的函数。Tracing非常适用于不涉及数据相关控制流的简单模块和功能,例如标准卷积神经网络。但是,如果Tracing具有依赖于数据的if语句和循环的函数,则仅记录由示例输入执行的执行路径调用的操作,即尽量避免转换代码中有if条件控制的模型。
import torchvision
import torch
data = torch.rand(1, 3, 224, 224)
model=torchvision.models.resnet50()
output = model(data)
trace_model = torch.jit.trace(model, data)
trace_model.save("mtrace.pt")
如果模型设计多个输入,需要将传入torch.onnx.export和torch.jit.trace中的data参数改为多输入张量元组,即data=(input1,input2 )
3.netron载入模型
如果能成功转换模型,在python代码调用netron库来载入模型进行可视化。
import netron
netron.start("mtrace.pth")
netron还做了一个在线demo网站,可以直接上传模型文件查看可视化结果,与代码调用netron库来载入模型一样。网址https://netron.app/
整体效果比美观
三、总结
在实际过程中,由于网络模型中复杂的结构以及调用,导出为onnx格式的模型时会出现各式各样的问题
但torch.jit.trace相对好用一些,能使我们快速便捷地了解复杂模型。