torch 网络模型转换onnx格式,并可视化

1. 构建lenet5 网络

import torch.nn as nn
import torch.nn.functional as F
import torch
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
class LeNet(nn.Module):
    def __init__(self,class_num=10,input_shape=(1,32,32)):
        super(LeNet, self).__init__()
        self.conv1 = nn.Sequential(                  #input_size=(1*28*28)
            nn.Conv2d(1, 6, 5, 1, 2),                #padding=2保证输入输出尺寸相同
            nn.ReLU(),                               #input_size=(6*28*28)
            nn.MaxPool2d(kernel_size=2, stride=2),   #output_size=(6*14*14)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(6, 16, 5),                     #padding=1输出尺寸变化
            nn.ReLU(),                               #input_size=(16*10*10)
            nn.MaxPool2d(2, 2)                       #output_size=(16*5*5)
        )

        self.fc1 = nn.Sequential(
            nn.Linear(16 * ((input_shape[1]//2-4)//2)  * ((input_shape[2]//2-4)//2), 120),
            nn.ReLU()
        )
        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.ReLU()
        )
        self.fc3 = nn.Linear(84, class_num)

                                                   # 定义前向传播过程,输入为x
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
                                                    # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)

2. 转为onnx格式

input_shape = (1,100,100)   #输入数据
model = LeNet(input_shape=input_shape)

torch.save(model, './model_para.pth')
torch_model = torch.load("./model_para.pth") # pytorch模型加载
batch_size = 1  #批处理大小


# set the model to inference mode
torch_model.eval()

x = torch.randn(batch_size,*input_shape)		    # 生成张量
print (x.shape)
export_onnx_file = "lenet5.onnx"		        	# 目的ONNX文件名
torch.onnx.export(torch_model,
                   x,
                   export_onnx_file,
                   opset_version=10,
                   do_constant_folding=True,	# 是否执行常量折叠优化
                   input_names=["input"],		# 输入名
                   output_names=["output"],	# 输出名
                   dynamic_axes={"input":{0:"batch_size"},		# 批处理变量
                                   "output":{0:"batch_size"}})

3. 通过netron查看网络结构

3.1 netron安装

pip install netron

3.2 netron可视化

import netron
onnx_path = "lenet5.onnx"
netron.start(file=onnx_path, log=False, browse=True)

在这里插入图片描述

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

佐倉

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值