Pytorch与Onnx模型的保存、转换与操作

Open Neural Network Exchange(ONNX,开放神经网络交换)格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移。

一、pytorch模型保存/加载

有两种方式可用于保存/加载pytorch模型 1)文件中保存模型结构和权重参数 2)文件只保留模型权重.

1、文件中保存模型结构和权重参数

模型保存与调用方式一(只保存权重):

保存:

torch.save(model.state_dict(), mymodel.pth)#只保存模型权重参数,不保存模型结构

调用:

model = My_model(*args, **kwargs)  #这里需要重新创建模型,My_model
model.load_state_dict(torch.load(mymodel.pth))#这里根据模型结构,导入存储的模型参数
model.eval()

模型保存与调用方式二(保存完整模型):

保存:

torch.save(model, mymodel.pth)#保存整个model的状态

调用:

model=torch.load(mymodel.pth)#这里已经不需要重构模型结构了,直接load就可以
model.eval()

.pt表示pytorch的模型,.onnx表示onnx的模型,后缀名为.pt, .pth, .pkl的pytorch模型文件之间其实没有任何区别

二、pytorch模型转ONNX模型

1、文件中保存模型结构和权重参数

import torch
torch_model = torch.load("save.pt") # pytorch模型加载
batch_size = 1  #批处理大小
input_shape = (3,244,244)   #输入数据


#set the model to inference mode
torch_model.eval()

x = torch.randn(batch_size,*input_shape)		# 生成张量(模型输入格式)
export_onnx_file = "test.onnx"					# 目的ONNX文件名

// 导出export:pt->onnx
torch.onnx.export(torch_model,					# pytorch模型
                    x,							# 生成张量(模型输入格式)
                    export_onnx_file,			# 目的ONNX文件名
                    do_constant_folding=True,	# 是否执行常量折叠优化
                    input_names=["input"],		# 输入名(可略)
                    output_names=["output"],	# 输出名(可略)
                    dynamic_axes={"input":{0:"batch_size"},		# 批处理变量(可略)
                                    "output":{0:"batch_size"}}) 

注:dynamic_axes字段用于批处理.若不想支持批处理或固定批处理大小,移除dynamic_axes字段即可.

2、文件中只保留模型权重

import torch
torch_model = selfmodel()  					# 由研究员提供python.py文件
batch_size = 1 								# 批处理大小
input_shape = (3, 244, 244) 				# 输入数据

#set the model to inference mode
torch_model.eval()

x = torch.randn(batch_size,*input_shape) 	# 生成张量(模型输入格式)
export_onnx_file = "test.onnx" 				# 目的ONNX文件名

// 导出export:pt->onnx
torch.onnx.export(torch_model,					# pytorch模型
                    x,							# 生成张量(模型输入格式)
                    export_onnx_file,			# 目的ONNX文件名
                    do_constant_folding=True,	# 是否执行常量折叠优化
                    input_names=["input"],		# 输入名(可略)
                    output_names=["output"],	# 输出名(可略)
                    dynamic_axes={"input":{0:"batch_size"},	# 批处理变量(可略)
                                    "output":{0:"batch_size"}})

3、onnx文件操作

3.1 加载onnx文件

# "加载load"
model=onnx.load('net.onnx')

检查模型格式是否完整及正确

onnx.checker.check_model(model)

3.2 打印onnx模型文件信息

session=onnxruntime.InferenceSession('net.onnx')
inp=session.get_inputs()[0]


#conv1=session.get_inputs()['conv1']
#out1=session.get_outputs()[1]
out=session.get_provider_options()
#print(inp,conv1,out1)
print(inp)
#print(out)
"打印图信息:字符串信息"
graph=onnx.helper.printable_graph(model.graph)
print(type(graph))

3.3 获取onnx模型输入输出层

input=model.graph.input
output = model.graph.output
"""输入输出层"""
print(input,output)

3.4 推断结果

"""推断"""
session=onnxruntime.InferenceSession('net.onnx')
input_name = session.get_inputs()
print(input_name)
output_name=session.get_outputs()[0].name
res=session.run([output_name],{input_name[0].name:inputs.numpy()})
print(res)
  • 1
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Yuezero_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值