4、pyorch模型转换为onnx、.pb、.h文件

#https://pypi.tuna.tsinghua.edu.cn/simple
import torch
import onnx
from onnx_tf.backend import prepare
from onnx2keras import onnx_to_keras
import keras
import tensorflow as tf
'''---------------------------------------------------------------------------
功能简述:
    1)将.pth文件转化为.onnx文件
    2)将.onnx文件转化为.pb文件
    3)将.onnx文件转化为.h5文件
-----------------------------------------------------------------------------'''
'''
时间:2021.7.29
各模块使用版本信息:
        python  3.8.5
        torchvision  0.10.0
        torch: 1.9.0
        onnxtime: 1.8.1
        keras: 2.5.0
        onnx: 1.9.0
        onnx2keras:0.0.24
        onnx-tf:1.8.0
'''
def pth_to_onnx(input_path,output_path):
    '''
    1)声明:使用本函数之前,必须保证你手上已经有了.pth模型文件.
    2)功能:本函数功能四将pytorch训练得到的.pth文件转化为onnx文件。
    '''
    torch_model = torch.load(input_path)       # pytorch模型加载,此处加载的模型包含图和参数
    #torch_model = selfmodel()  # 若只需要保存参数,可以换成这一种,其中selfmodel需要自己编写
    torch_model.eval()
    x = torch.randn(1,1,28,28)          # 输入一张28*28的灰度图像并生成张量
    export_onnx_file = output_path         #输出.onnx文件的文件路径及文件名
    torch.onnx.export(torch_model,
                      x,
                      export_onnx_file,
                      opset_version=9,    #操作的版本,稳定操作集为9
                      do_constant_folding=True,          # 是否执行常量折叠优化
                      input_names=["input"],        # 输入名
                      output_names=["output"],       # 输出名
                      dynamic_axes={"input": {0: "batch_size"},         # 批处理变量
                                    "output": {0: "batch_size"}}
                      )

    #onnx_model = onnx.load('model_all.onnx')    #加载.onnx文件
    #onnx.checker.check_model(onnx_model)
    #print(onnx.helper.printable_graph(onnx_model.graph))       #打印.onnx文件信息
def onnx_to_pb(output_path):
    '''
    将.onnx模型保存为.pb文件模型
    '''
    model = onnx.load(output_path) #加载.onnx模型文件
    tf_rep = prepare(model)
    tf_rep.export_graph('model_all.pb')    #保存最终的.pb文件
def onnx_to_h5(output_path ):
    '''
    将.onnx模型保存为.h5文件模型,并打印出模型的大致结构
    '''
    onnx_model = onnx.load(output_path)
    k_model = onnx_to_keras(onnx_model, ['input'])
    keras.models.save_model(k_model, 'kerasModel.h5', overwrite=True, include_optimizer=True)    #第二个参数是新的.h5模型的保存地址及文件名
    #下面内容是加载该模型,然后将该模型的结构打印出来
    model = tf.keras.models.load_model('kerasModel.h5')
    model.summary()
    print(model)
if __name__=='__main__':
    input_path = "model_all.pth"    #输入需要转换的.pth模型路径及文件名
    output_path = "model_all.onnx"  #转换为.onnx后文件的保存位置及文件名
    pth_to_onnx(input_path,output_path)  #执行pth转onnx函数,具体转换参数去该函数里面修改
    #onnx_pre(output_path)   #【可选项】若有需要,可以使用onnxruntime进行部署测试,看所转换模型是否可用,其中,output_path指加载进去的onnx格式模型所在路径及文件名
    #onnx_to_pb(output_path)   #将onnx模型转换为pb模型
    #onnx_to_h5(output_path )   #将onnx模型转换为h5模型


  • 4
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小树苗m

您的打赏,是我的动力。

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值